def evaluate(self, model: WaveRNN, val_set: Dataset) -> float: model.eval() val_loss = 0 device = next(model.parameters()).device for i, (x, y, m) in enumerate(val_set, 1): x, m, y = x.to(device), m.to(device), y.to(device) with torch.no_grad(): y_hat = model(x, m) if model.mode == 'RAW': y_hat = y_hat.transpose(1, 2).unsqueeze(-1) elif model.mode == 'MOL': y = y.float() y = y.unsqueeze(-1) loss = self.loss_func(y_hat, y) val_loss += loss.item() return val_loss / len(val_set)
def evaluate(self, model: WaveRNN, val_set: Dataset) -> float: model.eval() val_loss = 0 device = next(model.parameters()).device for i, batch in enumerate(val_set, 1): batch = to_device(batch, device=device) x, y, m = batch['x'], batch['y'], batch['mel'] with torch.no_grad(): y_hat = model(x, m) if model.mode == 'RAW': y_hat = y_hat.transpose(1, 2).unsqueeze(-1) elif model.mode == 'MOL': y = y.float() y = y.unsqueeze(-1) loss = self.loss_func(y_hat, y) val_loss += loss.item() return val_loss / len(val_set)
def generate_samples(self, model: WaveRNN, session: VocSession) -> Tuple[float, list]: """ Generates audio samples to cherry-pick models. To evaluate audio quality we calculate the l1 distance between mels of predictions and targets. """ model.eval() mel_losses = [] gen_wavs = [] device = next(model.parameters()).device for i, sample in enumerate(session.val_set_samples, 1): m, x = sample['mel'], sample['x'] if i > self.train_cfg['num_gen_samples']: break x = x[0].numpy() bits = 16 if self.dsp.voc_mode == 'MOL' else self.dsp.bits if self.dsp.mu_law and self.dsp.voc_mode != 'MOL': x = DSP.decode_mu_law(x, 2**bits, from_labels=True) else: x = DSP.label_2_float(x, bits) gen_wav = model.generate(mels=m, batched=self.train_cfg['gen_batched'], target=self.train_cfg['target'], overlap=self.train_cfg['overlap'], mu_law=self.dsp.mu_law, silent=True) gen_wavs.append(gen_wav) y_mel = self.dsp.wav_to_mel(x.squeeze(), normalize=False) y_mel = torch.tensor(y_mel).to(device) y_hat_mel = self.dsp.wav_to_mel(gen_wav, normalize=False) y_hat_mel = torch.tensor(y_hat_mel).to(device) loss = F.l1_loss(y_hat_mel, y_mel) mel_losses.append(loss.item()) self.writer.add_audio(tag=f'Validation_Samples/target_{i}', snd_tensor=x, global_step=model.step, sample_rate=self.dsp.sample_rate) self.writer.add_audio(tag=f'Validation_Samples/generated_{i}', snd_tensor=gen_wav, global_step=model.step, sample_rate=self.dsp.sample_rate) return sum(mel_losses) / len(mel_losses), gen_wavs[0]
def generate_samples(self, model: WaveRNN, session: VocSession) -> Tuple[float, list]: """ Generates audio samples to cherry-pick models. To evaluate audio quality we calculate the l1 distance between mels of predictions and targets. """ model.eval() mel_losses = [] gen_wavs = [] device = next(model.parameters()).device for i, (m, x) in enumerate(session.val_set_samples, 1): if i > hp.voc_gen_num_samples: break x = x[0].numpy() bits = 16 if hp.voc_mode == 'MOL' else hp.bits if hp.mu_law and hp.voc_mode != 'MOL': x = decode_mu_law(x, 2**bits, from_labels=True) else: x = label_2_float(x, bits) gen_wav = model.generate(mels=m, save_path=None, batched=hp.voc_gen_batched, target=hp.voc_target, overlap=hp.voc_overlap, mu_law=hp.mu_law, silent=True) gen_wavs.append(gen_wav) y_mel = raw_melspec(x.squeeze()) y_mel = torch.tensor(y_mel).to(device) y_hat_mel = raw_melspec(gen_wav) y_hat_mel = torch.tensor(y_hat_mel).to(device) loss = F.l1_loss(y_hat_mel, y_mel) mel_losses.append(loss.item()) self.writer.add_audio(tag=f'Validation_Samples/target_{i}', snd_tensor=x, global_step=model.step, sample_rate=hp.sample_rate) self.writer.add_audio(tag=f'Validation_Samples/generated_{i}', snd_tensor=gen_wav, global_step=model.step, sample_rate=hp.sample_rate) return sum(mel_losses) / len(mel_losses), gen_wavs[0]
upsample_factors=hp.voc_upsample_factors, feat_dims=hp.num_mels, compute_dims=hp.voc_compute_dims, res_out_dims=hp.voc_res_out_dims, res_blocks=hp.voc_res_blocks, hop_length=hp.hop_length, sample_rate=hp.sample_rate, pad_val=hp.voc_pad_val, mode=hp.voc_mode).cuda() paths = Paths(hp.data_path, hp.voc_model_id, hp.tts_model_id) restore_path = args.weights if args.weights else paths.voc_latest_weights model.restore(restore_path) model.eval() if hp.amp: model, _ = amp.initialize(model, [], opt_level='O3') simple_table([('Generation Mode', 'Batched' if batched else 'Unbatched'), ('Target Samples', target if batched else 'N/A'), ('Overlap Samples', overlap if batched else 'N/A')]) k = model.get_step() // 1000 for file_name in os.listdir(args.dir): if file_name.endswith('.npy'): mel = np.load(os.path.join(args.dir, file_name)) mel = torch.tensor(mel).unsqueeze(0) batch_str = f'gen_batched_target{target}_overlap{overlap}' if batched else 'gen_NOT_BATCHED'
def main(): # Parse Arguments parser = argparse.ArgumentParser(description='TTS Generator') parser.add_argument('--mel', type=str, help='[string/path] path to test mel file') parser.add_argument('--hp_file', metavar='FILE', default='hparams.py', help='The file to use for the hyperparameters') parser.add_argument('--batched', '-b', dest='batched', action='store_true', help='Fast Batched Generation') parser.add_argument( '--voc_weights', type=str, help='[string/path] Load in different FastSpeech weights', default="pretrained/wave_800K.pyt") args = parser.parse_args() if not os.path.exists('onnx'): os.mkdir('onnx') hp.configure(args.hp_file) device = torch.device('cpu') print('Using device:', device) ##### print('\nInitialising WaveRNN Model...\n') # Instantiate WaveRNN Model voc_model = WaveRNN(rnn_dims=hp.voc_rnn_dims, fc_dims=hp.voc_fc_dims, bits=hp.bits, pad=hp.voc_pad, upsample_factors=hp.voc_upsample_factors, feat_dims=hp.num_mels, compute_dims=hp.voc_compute_dims, res_out_dims=hp.voc_res_out_dims, res_blocks=hp.voc_res_blocks, hop_length=hp.hop_length, sample_rate=hp.sample_rate, mode=hp.voc_mode).to(device) voc_load_path = args.voc_weights voc_model.load(voc_load_path) voc_upsampler = WaveRNNUpsamplerONNX(voc_model, args.batched, hp.voc_target, hp.voc_overlap) voc_infer = WaveRNNONNX(voc_model) voc_model.eval() voc_upsampler.eval() voc_infer.eval() opset_version = 11 with torch.no_grad(): mels = np.load(args.mel) mels = torch.from_numpy(mels) mels = mels.unsqueeze(0) mels = voc_upsampler.pad_tensor(mels) mels_onnx = mels.clone() torch.onnx.export(voc_upsampler, mels_onnx, "./onnx/wavernn_upsampler.onnx", opset_version=opset_version, do_constant_folding=True, input_names=["mels"], output_names=["upsample_mels", "aux"]) mels, aux = voc_upsampler(mels) mels = mels[:, 550:-550, :] mels, aux = voc_upsampler.fold(mels, aux) h1, h2, x = voc_infer.get_initial_parameters(mels) aux_split = voc_infer.split_aux(aux) b_size, seq_len, _ = mels.size() if seq_len: m_t = mels[:, 0, :] a1_t, a2_t, a3_t, a4_t = \ (a[:, 0, :] for a in aux_split) rnn_input = (m_t, a1_t, a2_t, a3_t, a4_t, h1, h2, x) torch.onnx.export(voc_infer, rnn_input, "./onnx/wavernn_rnn.onnx", opset_version=opset_version, do_constant_folding=True, input_names=[ "m_t", "a1_t", "a2_t", "a3_t", "a4_t", "h1", "h2", "x" ], output_names=["logits", "h1", "h2"]) print('Done!')