def __getitem__(self, index): # Read audio filename = self.audio_files[index] audio_mel, sampling_rate = load_wav_to_torch(self.dir_normal + '/' + filename) audio, sampling_rate = load_wav_to_torch(self.dir_hi + '/' + filename) if sampling_rate != self.sampling_rate: raise ValueError("{} SR doesn't match target {} SR".format( sampling_rate, self.sampling_rate)) # Take segment if audio.size(0) >= self.segment_length: max_audio_start = audio.size(0) - self.segment_length audio_start = random.randint(0, max_audio_start) audio = audio[audio_start:audio_start + self.segment_length] audio_mel = audio_mel[audio_start:audio_start + self.segment_length] else: audio = torch.nn.functional.pad( audio, (0, self.segment_length - audio.size(0)), 'constant').data audio_mel = torch.nn.functional.pad( audio_mel, (0, self.segment_length - audio.size(0)), 'constant').data mel = self.get_mel(audio_mel) audio = audio / MAX_WAV_VALUE return (mel, audio)
def load_buffer(self): num_files = len(self.audio_files) for i in tqdm(range(num_files)): filename = self.audio_files[i] audio, sampling_rate = ms.load_wav_to_torch(filename) if sampling_rate != self.sampling_rate: raise ValueError("{} SR doesn't match target {} SR".format( sampling_rate, self.sampling_rate)) self.all_length += audio.size(0) if i == 0: self.all_audio = audio else: self.all_audio = torch.cat((self.all_audio, audio), 0) print("All audio has been loaded, totally length: {}".format( self.all_length))
def DNN_test(checkpoint_path, filename): model = DNNnet(**DNN_net_config) checkpoint_dict = torch.load(checkpoint_path, map_location='cpu') model_for_loading = checkpoint_dict['model'] model.load_state_dict(model_for_loading.state_dict()) print("Loaded checkpoint '{}' ".format(checkpoint_path)) model.eval() Data_gen = spec2load(**DNN_data_config) audio, sr = ms.load_wav_to_torch(filename) # Take segment if audio.size(0) >= 2240: max_audio_start = audio.size(0) - 2240 audio_start = random.randint(0, max_audio_start) audio = audio[audio_start:audio_start + 2240] else: audio = torch.nn.functional.pad(audio, (2240 - audio.size(0), 0), 'constant').data spec = Data_gen.get_spec(audio) feed_spec = spec[:, :-2] targ_spec = spec[:, -2:] feed_spec = feed_spec.unsqueeze(0) gener_spec = model.forward(feed_spec) # gener_linear = gener_linear.squeeze().view(-1, 2).T gener_spec = gener_spec.detach().numpy() targ_spec = torch.cat((targ_spec[:, 0], targ_spec[:, 1]), 0) targ_spec = targ_spec.numpy() # targ_mel_db = librosa.power_to_db(targ_mel[0], ref=np.max) # gener_mel_db = librosa.power_to_db(gener_mel[0], ref=np.max) plt.figure() plt.plot(targ_spec) # librosa.display.specshow(targ_mel, x_axis='time', y_axis='mel') plt.plot(gener_spec[0], 'r') # librosa.display.specshow(gener_mel, x_axis='time', y_axis='mel') plt.show()
def validate(model, loader_STFT, STFTs, logger, iteration, validation_files, speaker_lookup, sigma, output_directory, data_config, save_audio=True, max_length_s= 3, files_to_process=9): print("Validating... ", end="") from mel2samp import DTW val_sigma = sigma * 1.00 model.eval() val_start_time = time.time() STFT_elapsed = samples_processed = 0 with torch.no_grad(): with torch.random.fork_rng(devices=[0,]): torch.random.manual_seed(0)# use same Z / random seed during validation so results are more consistent and comparable. with open(validation_files, encoding='utf-8') as f: audiopaths_and_melpaths = [line.strip().split('|') for line in f] if next(model.parameters()).type() == "torch.cuda.HalfTensor": model_type = "half" else: model_type = "float" timestr = time.strftime("%Y_%m_%d-%H_%M_%S") total_MAE = total_MSE = total = files_processed = 0 input_mels = [] gt_mels = [] pred_mels = [] MAE_specs = [] for i, (audiopath, melpath, *remaining) in enumerate(audiopaths_and_melpaths): if files_processed >= files_to_process: # number of validation files to run. break audio = load_wav_to_torch(audiopath)[0]/32768.0 # load audio from wav file to tensor if audio.shape[0] > (data_config['sampling_rate']*max_length_s): continue # ignore audio over max_length_seconds gt_mel = loader_STFT.mel_spectrogram(audio.unsqueeze(0)).cuda()# [T] -> [1, T] -> [1, n_mel, T] if 'load_hidden_from_disk' in data_config.keys() and data_config['load_hidden_from_disk']: mel = None hidden_path = remaining[1] model_input = np.load(hidden_path) # load tacotron hidden from file into numpy arr model_input = torch.from_numpy(model_input).unsqueeze(0).cuda() # from numpy arr to tensor on GPU else: if loader_STFT and data_config['load_mel_from_disk'] < 0.2: mel = None model_input = gt_mel.clone() else: mel = np.load(melpath) # load mel from file into numpy arr mel = torch.from_numpy(mel).unsqueeze(0).cuda() # from numpy arr to tensor on GPU assert mel[:, :gt_mel.shape[1], :].shape == gt_mel.shape, f'shapes of {mel[:, :gt_mel.shape[1], :].shape} and {gt_mel.shape} do not match.' #if torch.nn.functional.mse_loss(mel[:, :gt_mel.shape[1], :], gt_mel) > 0.3: # continue # skip validation files that significantly vary from the target. if model.has_logvar_channels: if mel.shape[1] == model.n_mel_channels*2: mel, logvar = mel.chunk(2, dim=1)# [1, n_mel*2, T] -> [1, n_mel, T], [1, n_mel, T] mel = DTW(mel, gt_mel, scale_factor=8, range_=5) model_input = torch.cat((mel, logvar), dim=1) else: raise Exception("Loaded mel from disk has wrong shape.") else: if mel.shape[1] == model.n_mel_channels*2: mel = mel.chunk(2, dim=1)[0]# [1, n_mel*2, T] -> [1, n_mel, T] #mel = DTW(mel, gt_mel, scale_factor=8, range_=5) model_input = mel if hasattr(model, 'multispeaker') and model.multispeaker == True: assert len(remaining), f"Speaker ID missing while multispeaker == True.\nLine: {i}\n'{'|'.join([autiopath, melpath])}'" speaker_id = remaining[0] speaker_id = torch.IntTensor([speaker_lookup[int(speaker_id)],]) speaker_id = speaker_id.cuda(non_blocking=True).long() else: speaker_id = None if model_type == "half": model_input = model_input.half() # for fp16 training audio_waveglow = model.infer(model_input, speaker_id, sigma=val_sigma).cpu().float() audio = audio.squeeze().unsqueeze(0) # crush extra dimensions and shape for STFT audio_waveglow = audio_waveglow.squeeze().unsqueeze(0).clamp(min=-0.999, max=0.999) # [1, T] crush extra dimensions and shape for STFT audio_waveglow[torch.isnan(audio_waveglow) | torch.isinf(audio_waveglow)] = 0.0 # and clamp any values over/under |1.0| (which should only exist very early in training) STFT_start_time = time.time() for j, STFT in enumerate(STFTs): # check Spectrogram Error with multiple window sizes input_mels.append(mel) mel_GT = STFT.mel_spectrogram(audio)# [1, T] -> [1, n_mel, T//hop_len] gt_mels.append(mel_GT) mel_waveglow = STFT.mel_spectrogram(audio_waveglow)[:,:,:mel_GT.shape[-1]]# [1, T] -> [1, n_mel, T//hop_len] pred_mels.append(mel_waveglow) MSE = (torch.nn.MSELoss()(mel_waveglow, mel_GT)).item() # get MSE (Mean Squared Error) between Ground Truth and WaveGlow inferred spectrograms. MAE_spec = torch.nn.L1Loss(reduction='none')(mel_waveglow, mel_GT) MAE = (MAE_spec.mean()).item() # get MAE (Mean Absolute Error) between Ground Truth and WaveGlow inferred spectrograms. MAE_specs.append(MAE_spec) total_MAE+=MAE total_MSE+=MSE total+=1 STFT_elapsed += time.time()-STFT_start_time if save_audio: audio_path = os.path.join(output_directory, "samples", str(iteration)+"-"+timestr, os.path.basename(audiopath)) # Write audio to checkpoint_directory/iteration/audiofilename.wav os.makedirs(os.path.join(output_directory, "samples", str(iteration)+"-"+timestr), exist_ok=True) sf.write(audio_path, audio_waveglow.squeeze().cpu().numpy(), data_config['sampling_rate'], "PCM_16") # save waveglow sample audio_path = os.path.join(output_directory, "samples", "Ground Truth", os.path.basename(audiopath)) # Write audio to checkpoint_directory/iteration/audiofilename.wav if not os.path.exists(audio_path): os.makedirs(os.path.join(output_directory, "samples", "Ground Truth"), exist_ok=True) sf.write(audio_path, audio.squeeze().cpu().numpy(), data_config['sampling_rate'], "PCM_16") # save ground truth files_processed+=1 samples_processed+=audio_waveglow.shape[-1] if total: average_MSE = total_MSE/total average_MAE = total_MAE/total logger.add_scalar('val_MSE', average_MSE, iteration) logger.add_scalar('val_MAE', average_MAE, iteration) for idx, (gt_mel, pred_mel, input_mel, mae_mel) in enumerate(zip(gt_mels[-6:], pred_mels[-6:], input_mels[-6:], MAE_specs[-6:])): logger.add_image(f'mel_{idx}/pred', plot_spectrogram_to_numpy(pred_mel[0].data.cpu().numpy(), range=[-11.5, 2.0]), iteration, dataformats='HWC') if mae_mel is not None: logger.add_image(f'mel_{idx}/zmae', plot_spectrogram_to_numpy(mae_mel[0].data.cpu().numpy(), range=[0.0, 2.5]), iteration, dataformats='HWC') if iteration % 10000 == 0: # target doesn't change unless batch size or dataset changes so only needs to be plotted once in a while. logger.add_image(f'mel_{idx}/target', plot_spectrogram_to_numpy(gt_mel[0].data.cpu().numpy(), range=[-11.5, 2.0]), iteration, dataformats='HWC') if input_mel is not None: logger.add_image(f'mel_{idx}/input', plot_spectrogram_to_numpy(input_mel[0].data.cpu().numpy(), range=[-11.5, 2.0]), iteration, dataformats='HWC') time_elapsed = time.time()-val_start_time time_elapsed_without_stft = time_elapsed-STFT_elapsed samples_per_second = samples_processed/time_elapsed_without_stft print(f"[Avg MSE: {average_MSE:.6f} MAE: {average_MAE:.6f}]", f"[{time_elapsed_without_stft:.3f}s]", f"[{time_elapsed:.3f}s_stft]", f"[{time_elapsed_without_stft/files_processed:.3f}s/file]", f"[{samples_per_second/data_config['sampling_rate']:.3f}rtf]", f"[{samples_per_second:,.0f}samples/s]" ) logger.add_scalar('val_rtf', samples_per_second/data_config['sampling_rate'], iteration) else: average_MSE = 1e3 average_MAE = 1e3 print("Average MSE: N/A", "Average MAE: N/A") for convinv in model.convinv: if hasattr(convinv, 'W_inverse'): delattr(convinv, "W_inverse") # clear Inverse Weights. if hasattr(model, 'iso226'): delattr(model, 'iso226') mel = speaker_id = None # del GPU based tensors. torch.cuda.empty_cache() # clear cache for next training model.train() return average_MSE, average_MAE
def main(files, waveglow_path, sigma, output_dir, sampling_rate, is_fp16, denoiser_strength, args): #mel_files = files_to_list(mel_files) #print(mel_files) files = ['/local-scratch/fuyang/cmpt726/final_project/cremad/1091_WSI_SAD_XX.wav'] #files = ['/local-scratch/fuyang/cmpt726/waveglow/data/LJSpeech-1.1/LJ001-0001.wav'] with open('config.json') as f: data = f.read() config = json.loads(data) waveglow_config = config["waveglow_config"] model = WaveGlow(**waveglow_config) checkpoint_dict = torch.load('waveglow_256channels_universal_v5.pt', map_location='cpu') model_for_loading = checkpoint_dict['model'] model.load_state_dict(model_for_loading.state_dict()) model.cuda() #waveglow = torch.load(waveglow_path)['model'] #waveglow = waveglow.remove_weightnorm(waveglow) #waveglow.cuda() waveglow = model if is_fp16: from apex import amp waveglow, _ = amp.initialize(waveglow, [], opt_level="O1") if denoiser_strength > 0: denoiser = Denoiser(waveglow).cuda() mel_extractor = Get_mel(1024, 256, 1024, args.sampling_rate, 0.0, 8000.0) for i, file_path in enumerate(files): audio, rate = load_wav_to_torch(file_path) if rate != sampling_rate: audio = resampy.resample(audio.numpy(), rate, sampling_rate) audio = torch.from_numpy(audio).float() #if audio.size(0) >= args.segment_length: # max_audio_start = audio.size(0) - args.segment_length # audio_start = random.randint(0, max_audio_start) # audio = audio[audio_start:audio_start+args.segment_length] #else: # audio = torch.nn.functional.pad(audio, (0, args.segment_length-audio.size(0)), 'constant').data mel = mel_extractor.get_mel(audio) audio = audio / MAX_WAV_VALUE mel = torch.autograd.Variable(mel.cuda().unsqueeze(0)) audio = torch.autograd.Variable(audio.cuda().unsqueeze(0)) audio = audio.half() if is_fp16 else audio mel = mel.half() if is_fp16 else mel outputs = waveglow((mel, audio)) z = outputs[0][:,4:] print(outputs) mel_up = waveglow.upsample(mel) time_cutoff = waveglow.upsample.kernel_size[0]-waveglow.upsample.stride[0] mel_up = mel_up[:,:,:-time_cutoff] #mel_up = mel_up[:,:,:-(time_cutoff+128)] mel_up = mel_up.unfold(2, waveglow.n_group, waveglow.n_group).permute(0,2,1,3) mel_up = mel_up.contiguous().view(mel_up.size(0), mel_up.size(1), -1).permute(0, 2, 1) audio = z mel_up = mel_up[:,:,:audio.size(2)] sigma = 0.7 z_i = 0 for k in reversed(range(waveglow.n_flows)): n_half = int(audio.size(1)/2) audio_0 = audio[:,:n_half, :] audio_1 = audio[:, n_half:, :] output = waveglow.WN[k]((audio_0, mel_up)) s = output[:,n_half:, :] b = output[:, :n_half, :] audio_1 = (audio_1-b)/torch.exp(s) audio = torch.cat([audio_0, audio_1],1) audio = waveglow.convinv[k](audio, reverse=True) if k % waveglow.n_early_every == 0 and k > 0: z = outputs[0][:, 2-z_i:4-z_i] #if mel_up.type() == 'torch.cuda.HalfTensor': # z = torch.cuda.HalfTensor(mel_up.size(0), waveglow.n_early_size, mel_up.size(2)).normal_() #else: # z = torch.cuda.FloatTensor(mel_up.size(0), waveglow.n_early_size, mel_up.size(2)).normal_() audio = torch.cat((sigma*z, audio),1) audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data audio = audio * MAX_WAV_VALUE audio = audio.squeeze() audio = audio.cpu().numpy() audio = audio.astype('int16') audio_path = os.path.join( output_dir, "{}_synthesis.wav".format('fuyangz')) write(audio_path, sampling_rate, audio) print(audio_path)
def main(style, waveglow_path, sigma, output_dir, sampling_rate, is_fp16, denoiser_strength, args): #mel_files = files_to_list(mel_files) #print(mel_files) dataset = voice_dataset(dataBase={ 'ravdess': './our_data/ravdess', 'cremad': './our_data/cremad' }, style=('happy', 'sad', 'angry')) #print(len(dataset.final_data['happy'])) #sample = dataset.pick_one_random_sample('happy') files = dataset.final_data[style] #files = ['/local-scratch/fuyang/cmpt726/waveglow/data/LJSpeech-1.1/LJ001-0001.wav'] with open('config.json') as f: data = f.read() config = json.loads(data) waveglow_config = config["waveglow_config"] model = WaveGlow(**waveglow_config) checkpoint_dict = torch.load('waveglow_256channels_universal_v5.pt', map_location='cpu') model_for_loading = checkpoint_dict['model'] model.load_state_dict(model_for_loading.state_dict()) model.cuda() waveglow = model if is_fp16: from apex import amp waveglow, _ = amp.initialize(waveglow, [], opt_level="O1") if denoiser_strength > 0: denoiser = Denoiser(waveglow).cuda() mel_extractor = Get_mel(1024, 256, 1024, args.sampling_rate, 0.0, 8000.0) avg_z = np.zeros(8) _count = 0 for i, (_, file_path) in enumerate(files): if i > 50: break try: audio, rate = load_wav_to_torch(file_path) if rate != sampling_rate: audio = resampy.resample(audio.numpy(), rate, sampling_rate) audio = torch.from_numpy(audio).float() #if audio.size(0) >= args.segment_length: # max_audio_start = audio.size(0) - args.segment_length # audio_start = random.randint(0, max_audio_start) # audio = audio[audio_start:audio_start+args.segment_length] #else: # audio = torch.nn.functional.pad(audio, (0, args.segment_length-audio.size(0)), 'constant').data mel = mel_extractor.get_mel(audio) audio = audio / MAX_WAV_VALUE mel = torch.autograd.Variable(mel.cuda().unsqueeze(0)) audio = torch.autograd.Variable(audio.cuda().unsqueeze(0)) audio = audio.half() if is_fp16 else audio mel = mel.half() if is_fp16 else mel outputs = waveglow((mel, audio)) avg_z += outputs[0].squeeze(0).mean(1).detach().cpu().numpy() _count += 1 z = outputs[0][:, 4:] #print(outputs) mel_up = waveglow.upsample(mel) time_cutoff = waveglow.upsample.kernel_size[ 0] - waveglow.upsample.stride[0] mel_up = mel_up[:, :, :-time_cutoff] #mel_up = mel_up[:,:,:-(time_cutoff+128)] mel_up = mel_up.unfold(2, waveglow.n_group, waveglow.n_group).permute(0, 2, 1, 3) mel_up = mel_up.contiguous().view(mel_up.size(0), mel_up.size(1), -1).permute(0, 2, 1) audio = z mel_up = mel_up[:, :, :audio.size(2)] sigma = 0.7 z_i = 0 for k in reversed(range(waveglow.n_flows)): n_half = int(audio.size(1) / 2) audio_0 = audio[:, :n_half, :] audio_1 = audio[:, n_half:, :] output = waveglow.WN[k]((audio_0, mel_up)) s = output[:, n_half:, :] b = output[:, :n_half, :] audio_1 = (audio_1 - b) / torch.exp(s) audio = torch.cat([audio_0, audio_1], 1) audio = waveglow.convinv[k](audio, reverse=True) if k % waveglow.n_early_every == 0 and k > 0: z = outputs[0][:, 2 - z_i:4 - z_i] #if mel_up.type() == 'torch.cuda.HalfTensor': # z = torch.cuda.HalfTensor(mel_up.size(0), waveglow.n_early_size, mel_up.size(2)).normal_() #else: # z = torch.cuda.FloatTensor(mel_up.size(0), waveglow.n_early_size, mel_up.size(2)).normal_() audio = torch.cat((sigma * z, audio), 1) audio = audio.permute(0, 2, 1).contiguous().view(audio.size(0), -1).data audio = audio * MAX_WAV_VALUE audio = audio.squeeze() audio = audio.cpu().numpy() audio = audio.astype('int16') audio_path = os.path.join( output_dir, "{}_synthesis.wav".format(file_path[:-4])) if os.path.exists( os.path.join(*audio_path.split('/')[:-1])) is False: os.makedirs(os.path.join(*audio_path.split('/')[:-1]), exist_ok=True) write(audio_path, sampling_rate, audio) print(audio_path) except: continue avg_z = avg_z / _count np.save(style, avg_z)
def test_mel2samp(): """Test mel2samp modules on example data.""" from mel2samp import Mel2Samp hparams = hparams_class() passed = 0 # test filelist loader try: from mel2samp import load_filepaths_and_text audio_files = load_filepaths_and_text("code_tests/test_materials/filelists/validation_utf8.txt") passed+=1 print("--PASSED--\n") except Exception as ex: print("--EXCEPTION-- @ Load Filepaths and Text (UTF-8)") traceback.print_exc(file=sys.stdout) print("\n") # test filelist checker try: assert audio_files from mel2samp import check_files audio_files = check_files(audio_files, hparams) assert len(audio_files) == 1 passed+=1 print("--PASSED--\n") del audio_files except Exception as ex: print("--EXCEPTION-- @ Load Filepaths and Text (UTF-8)") traceback.print_exc(file=sys.stdout) print("\n") # test initalization try: trainset = Mel2Samp(hparams) passed+=1 print("--PASSED--\n") except Exception as ex: print("--EXCEPTION-- @ Mel2Samp Initialization") traceback.print_exc(file=sys.stdout) print("\n") # test 16-BIT .wav to torch try: from mel2samp import load_wav_to_torch x, sr = load_wav_to_torch("code_tests/test_materials/audio_0/example_16bits.wav") assert len(x) assert x.max() <= 2**15 assert x.min() >= -(2**15) assert sr == 48000 passed+=1 print("--PASSED--\n") except Exception as ex: print("--EXCEPTION-- @ Load 16-BIT .wav to Pytorch") traceback.print_exc(file=sys.stdout) print("\n") # test 24-BIT .wav to torch try: from mel2samp import load_wav_to_torch x, sr = load_wav_to_torch("code_tests/test_materials/audio_0/example_24bits.wav") assert len(x) assert x.max() <= 2**23 assert x.min() >= -(2**23) assert sr == 48000 passed+=1 print("--PASSED--\n") except Exception as ex: print("--EXCEPTION-- @ Load 24-BIT .wav to Pytorch") traceback.print_exc(file=sys.stdout) print("\n") # test 32-BIT .wav to torch try: from mel2samp import load_wav_to_torch x, sr = load_wav_to_torch("code_tests/test_materials/audio_0/example_32bits.wav") assert len(x) assert x.max() <= 2**31 assert x.min() >= -(2**31) assert sr == 48000 passed+=1 print("--PASSED--\n") except Exception as ex: print("--EXCEPTION-- @ Load 32-BIT .wav to Pytorch") traceback.print_exc(file=sys.stdout) print("\n") # test 32-BIT .mp3 to torch try: from mel2samp import load_wav_to_torch x, sr = load_wav_to_torch("code_tests/test_materials/audio_0/example_32bits.mp3") assert len(x) assert x.max() <= 2**31 assert x.min() >= -(2**31) assert sr == 48000 passed+=1 print("--PASSED--\n") except Exception as ex: print("--EXCEPTION-- @ Load 32-BIT .mp3 to Pytorch") traceback.print_exc(file=sys.stdout) print("\n") # test 16-BIT .wav to mel try: x, sr = load_wav_to_torch("code_tests/test_materials/audio_0/example_16bits.wav") x = trainset.get_mel(x) passed+=1 print("--PASSED--\n") except Exception as ex: print("--EXCEPTION-- @ 16-BIT .wav to Mel-spec") traceback.print_exc(file=sys.stdout) print("\n") # test 24-BIT .wav to mel try: x, sr = load_wav_to_torch("code_tests/test_materials/audio_0/example_24bits.wav") x = trainset.get_mel(x) passed+=1 print("--PASSED--\n") except Exception as ex: print("--EXCEPTION-- @ 24-BIT .wav to Mel-spec") traceback.print_exc(file=sys.stdout) print("\n") # test 32-BIT .wav to mel try: x, sr = load_wav_to_torch("code_tests/test_materials/audio_0/example_32bits.wav") x = trainset.get_mel(x) passed+=1 print("--PASSED--\n") except Exception as ex: print("--EXCEPTION-- @ 32-BIT .wav to Mel-spec") traceback.print_exc(file=sys.stdout) print("\n") # test 32-BIT .mp3 to mel try: x, sr = load_wav_to_torch("code_tests/test_materials/audio_0/example_32bits.mp3") x = trainset.get_mel(x) passed+=1 print("--PASSED--\n") except Exception as ex: print("--EXCEPTION-- @ 32-BIT .mp3 to Mel-spec") traceback.print_exc(file=sys.stdout) print("\n") # test __getitem__ with load_mel_from_disk = False try: assert trainset # This test will fail if Mel2Samp cannot initalize trainset.load_mel_from_disk = False trainset.__getitem__(0) passed+=1 print("--PASSED--\n") except Exception as ex: print("--EXCEPTION-- @ __getitem__ with load_mel_from_disk = False") traceback.print_exc(file=sys.stdout) print("\n") # test __getitem__ with load_mel_from_disk = True try: assert trainset # This test will fail if Mel2Samp cannot initalize trainset.load_mel_from_disk = True trainset.__getitem__(0) passed+=1 print("--PASSED--\n") except Exception as ex: print("--EXCEPTION-- @ __getitem__ with load_mel_from_disk = True") traceback.print_exc(file=sys.stdout) print("\n") # test initalization with Pre-empthasis try: trainset = None hparams.preempthasis = 0.98 trainset = Mel2Samp(hparams) passed+=1 print("--PASSED--\n") except Exception as ex: print("--EXCEPTION-- @ Mel2Samp with Pre-empthasis Initialization") traceback.print_exc(file=sys.stdout) print("\n") # test __getitem__ with Pre-empthasis try: assert trainset # This test will fail if Mel2Samp cannot initalize trainset.load_mel_from_disk = False trainset.__getitem__(0) passed+=1 print("--PASSED--\n") except Exception as ex: print("--EXCEPTION-- @ __getitem__ with Pre-empthasis") traceback.print_exc(file=sys.stdout) print("\n")
def validate(model, loader_STFT, STFTs, logger, iteration, validation_files, speaker_lookup, sigma, output_directory, data_config, save_audio=True, max_length_s=5): from mel2samp import load_wav_to_torch from scipy import signal val_sigma = sigma * 0.9 model.eval() with torch.no_grad(): with open(validation_files, encoding='utf-8') as f: audiopaths_and_melpaths = [line.strip().split('|') for line in f] if list(model.parameters())[0].type() == "torch.cuda.HalfTensor": model_type = "half" else: model_type = "float" timestr = time.strftime("%Y_%m_%d-%H_%M_%S") total_MAE = total_MSE = total = 0 for i, (audiopath, melpath, *remaining) in enumerate(audiopaths_and_melpaths): if i > 60: break # debug audio = load_wav_to_torch( audiopath)[0] / 32768.0 # load audio from wav file to tensor if audio.shape[0] > (data_config['sampling_rate'] * max_length_s): continue # ignore audio over max_length_seconds if loader_STFT: mel = loader_STFT.mel_spectrogram(audio.unsqueeze(0)).cuda() else: mel = np.load(melpath) # load mel from file into numpy arr mel = torch.from_numpy(mel).unsqueeze( 0).cuda() # from numpy arr to tensor on GPU if hasattr(model, 'multispeaker') and model.multispeaker == True: assert len( remaining ), f"Speaker ID missing while multispeaker == True.\nLine: {i}\n'{'|'.join([autiopath, melpath])}'" speaker_id = remaining[0] speaker_id = torch.IntTensor([speaker_lookup[int(speaker_id)]]) speaker_id = speaker_id.cuda(non_blocking=True).long() else: speaker_id = None if model_type == "half": mel = mel.half() # for fp16 training audio_waveglow = model.infer(mel, speaker_id, sigma=val_sigma) audio_waveglow = audio_waveglow.cpu().float() if data_config['preempthasis']: # inverse-preempthasis audio_waveglow = audio_waveglow.squeeze() audio_waveglow = torch.from_numpy( signal.lfilter([1], [1, -float(data_config['preempthasis'])], audio_waveglow.numpy()) ).float( ) # de-preempthasis (scipy signal is faster than pytorch implementation for some reason /shrug ) audio = audio.squeeze().unsqueeze( 0) # crush extra dimensions and shape for STFT audio_waveglow = audio_waveglow.squeeze().unsqueeze( 0) # crush extra dimensions and shape for STFT audio_waveglow = audio_waveglow.clamp( -1, 1 ) # clamp any values over/under |1.0| (which should only exist very early in training) for STFT in STFTs: # check Spectrogram Error with multiple window sizes mel_GT = STFT.mel_spectrogram(audio) try: mel_waveglow = STFT.mel_spectrogram( audio_waveglow)[:, :, :mel_GT.shape[-1]] except AssertionError as ex: print(ex) continue MSE = (torch.nn.MSELoss()(mel_waveglow, mel_GT)).item( ) # get MSE (Mean Squared Error) between Ground Truth and WaveGlow inferred spectrograms. MAE = (torch.nn.L1Loss()(mel_waveglow, mel_GT)).item( ) # get MAE (Mean Absolute Error) between Ground Truth and WaveGlow inferred spectrograms. total_MAE += MAE total_MSE += MSE total += 1 if save_audio: audio_path = os.path.join( output_directory, "samples", str(iteration) + "-" + timestr, os.path.basename(audiopath) ) # Write audio to checkpoint_directory/iteration/audiofilename.wav os.makedirs(os.path.join(output_directory, "samples", str(iteration) + "-" + timestr), exist_ok=True) sf.write(audio_path, audio_waveglow.squeeze().cpu().numpy(), data_config['sampling_rate'], "PCM_16") # save waveglow sample audio_path = os.path.join( output_directory, "samples", "Ground Truth", os.path.basename(audiopath) ) # Write audio to checkpoint_directory/iteration/audiofilename.wav if not os.path.exists(audio_path): os.makedirs(os.path.join(output_directory, "samples", "Ground Truth"), exist_ok=True) sf.write(audio_path, audio.squeeze().cpu().numpy(), data_config['sampling_rate'], "PCM_16") # save ground truth for convinv in model.convinv: if hasattr(convinv, 'W_inverse'): delattr(convinv, "W_inverse") # clear Inverse Weights. if total: average_MSE = total_MSE / total average_MAE = total_MAE / total logger.add_scalar('val_MSE', average_MSE, iteration) logger.add_scalar('val_MAE', average_MAE, iteration) print("Average MSE:", average_MSE, "Average MAE:", average_MAE) else: average_MSE = 1e3 average_MAE = 1e3 print("Average MSE: N/A", "Average MAE: N/A") model.train() return average_MSE, average_MAE
def main(style, waveglow_path, sigma, output_dir, sampling_rate, is_fp16, denoiser_strength, args): #mel_files = files_to_list(mel_files) #print(mel_files) dataset = voice_dataset(dataBase={ 'ravdess': './our_data/ravdess', 'cremad': './our_data/cremad' }, style=('happy', 'sad', 'angry')) #print(len(dataset.final_data['happy'])) #sample = dataset.pick_one_random_sample('happy') styles = ['happy', 'sad', 'angry'] with open('config.json') as f: data = f.read() config = json.loads(data) waveglow_config = config["waveglow_config"] model = WaveGlow(**waveglow_config) checkpoint_dict = torch.load('waveglow_256channels_universal_v5.pt', map_location='cpu') model_for_loading = checkpoint_dict['model'] model.load_state_dict(model_for_loading.state_dict()) model.cuda() waveglow = model if is_fp16: from apex import amp waveglow, _ = amp.initialize(waveglow, [], opt_level="O1") if denoiser_strength > 0: denoiser = Denoiser(waveglow).cuda() mel_extractor = Get_mel(1024, 256, 1024, args.sampling_rate, 0.0, 8000.0) vector_all = {} for style in styles: files = dataset.final_data[style].copy() random.shuffle(files) vectors = [] for i, (_, file_path) in enumerate(files): if i > 200: break try: audio, rate = load_wav_to_torch(file_path) if rate != sampling_rate: audio = resampy.resample(audio.numpy(), rate, sampling_rate) audio = torch.from_numpy(audio).float() #if audio.size(0) >= args.segment_length: # max_audio_start = audio.size(0) - args.segment_length # audio_start = random.randint(0, max_audio_start) # audio = audio[audio_start:audio_start+args.segment_length] #else: # audio = torch.nn.functional.pad(audio, (0, args.segment_length-audio.size(0)), 'constant').data mel = mel_extractor.get_mel(audio) audio = audio / MAX_WAV_VALUE mel = torch.autograd.Variable(mel.cuda().unsqueeze(0)) audio = torch.autograd.Variable(audio.cuda().unsqueeze(0)) audio = audio.half() if is_fp16 else audio mel = mel.half() if is_fp16 else mel outputs = waveglow((mel, audio)) vectors.append( outputs[0].squeeze(0).mean(1).detach().cpu().numpy()) print(style, i) except: continue vector_all[style] = vectors np.save('all_style_vector', vector_all)
def validate(model, STFTs, logger, iteration, speaker_lookup, hparams, output_directory, save_audio=True, max_length_s=5): from mel2samp import load_wav_to_torch val_sigma = hparams.sigma * 0.9 model.eval() with torch.no_grad(): with open(hparams.validation_files, encoding='utf-8') as f: audiopaths_and_melpaths = [line.strip().split('|') for line in f] if list(model.parameters())[0].type() == "torch.cuda.HalfTensor": model_type = "half" else: model_type = "float" timestr = time.strftime("%Y_%m_%d-%H_%M_%S") total_MAE = total_MSE = total = 0 for i, (audiopath, melpath, *remaining) in enumerate(audiopaths_and_melpaths): if i > 30: break # debug audio = load_wav_to_torch( audiopath)[0] / 32768.0 # load audio from wav file to tensor if audio.shape[0] > (hparams.sampling_rate * max_length_s): continue # ignore audio over max_length_seconds mel = np.load(melpath) # load mel from file into numpy arr mel = torch.from_numpy(mel).unsqueeze( 0).cuda() # from numpy arr to tensor on GPU #mel = (mel+5.2)*0.5 # shift values between approx -4 and 4 if hasattr(model, 'multispeaker') and model.multispeaker == True: assert len( remaining ), f"Speaker ID missing while multispeaker == True.\nLine: {i}\n'{'|'.join([autiopath, melpath])}'" speaker_id = remaining[0] assert int(speaker_id) in speaker_lookup.keys( ), f"Validation speaker ID:{speaker_id} not found in training filelist.\n(This speaker does not have an embedding, either use single-speaker models or provide an example of this speaker in the training data)." speaker_id = torch.IntTensor([speaker_lookup[int(speaker_id)]]) speaker_id = speaker_id.cuda(non_blocking=True).long() else: speaker_id = None if model_type == "half": mel = mel.half() # for fp16 training audio_waveglow = model.infer(mel, speaker_id, sigma=val_sigma) audio_waveglow = audio_waveglow.cpu().float() audio = audio.squeeze().unsqueeze( 0) # crush extra dimensions and shape for STFT audio_waveglow = audio_waveglow.squeeze().unsqueeze( 0) # crush extra dimensions and shape for STFT audio_waveglow = audio_waveglow.clamp( -1, 1 ) # clamp any values over/under |1.0| (which should only exist very early in training) for STFT in STFTs: # check Spectrogram Error with multiple window sizes mel_GT = STFT.mel_spectrogram(audio) try: mel_waveglow = STFT.mel_spectrogram( audio_waveglow)[:, :, :mel_GT.shape[-1]] except AssertionError as ex: cprint(ex, b_tqdm=hparams.tqdm) continue MSE = (torch.nn.MSELoss()(mel_waveglow, mel_GT)).item( ) # get MSE (Mean Squared Error) between Ground Truth and WaveGlow inferred spectrograms. MAE = (torch.nn.L1Loss()(mel_waveglow, mel_GT)).item( ) # get MAE (Mean Absolute Error) between Ground Truth and WaveGlow inferred spectrograms. total_MAE += MAE total_MSE += MSE total += 1 if save_audio: audio_path = os.path.join( output_directory, "samples", str(iteration) + "-" + timestr, os.path.basename(audiopath) ) # Write audio to checkpoint_directory/iteration/audiofilename.wav os.makedirs(os.path.join(output_directory, "samples", str(iteration) + "-" + timestr), exist_ok=True) sf.write(audio_path, audio_waveglow.squeeze().cpu().numpy(), hparams.sampling_rate, "PCM_16") # save waveglow sample audio_path = os.path.join( output_directory, "samples", "Ground Truth", os.path.basename(audiopath) ) # Write audio to checkpoint_directory/iteration/audiofilename.wav if not os.path.exists(audio_path): os.makedirs(os.path.join(output_directory, "samples", "Ground Truth"), exist_ok=True) sf.write(audio_path, audio.squeeze().cpu().numpy(), hparams.sampling_rate, "PCM_16") # save ground truth for convinv in model.convinv: if hasattr(convinv, 'W_inverse'): delattr(convinv, "W_inverse") # clear Inverse Weights. if total: average_MSE = total_MSE / total average_MAE = total_MAE / total logger.add_scalar('val_MSE', average_MSE, iteration) logger.add_scalar('val_MAE', average_MAE, iteration) cprint("Average MSE:", average_MSE, "Average MAE:", average_MAE, b_tqdm=hparams.tqdm) else: average_MSE = 1e3 average_MAE = 1e3 cprint("Average MSE: N/A", "Average MAE: N/A", b_tqdm=hparams.tqdm) model.train() return average_MSE, average_MAE