Exemple #1
0
    def __getitem__(self, index):
        # Read audio
        filename = self.audio_files[index]
        audio_mel, sampling_rate = load_wav_to_torch(self.dir_normal + '/' +
                                                     filename)
        audio, sampling_rate = load_wav_to_torch(self.dir_hi + '/' + filename)
        if sampling_rate != self.sampling_rate:
            raise ValueError("{} SR doesn't match target {} SR".format(
                sampling_rate, self.sampling_rate))

        # Take segment
        if audio.size(0) >= self.segment_length:
            max_audio_start = audio.size(0) - self.segment_length
            audio_start = random.randint(0, max_audio_start)
            audio = audio[audio_start:audio_start + self.segment_length]
            audio_mel = audio_mel[audio_start:audio_start +
                                  self.segment_length]
        else:
            audio = torch.nn.functional.pad(
                audio, (0, self.segment_length - audio.size(0)),
                'constant').data
            audio_mel = torch.nn.functional.pad(
                audio_mel, (0, self.segment_length - audio.size(0)),
                'constant').data

        mel = self.get_mel(audio_mel)
        audio = audio / MAX_WAV_VALUE

        return (mel, audio)
Exemple #2
0
 def load_buffer(self):
     num_files = len(self.audio_files)
     for i in tqdm(range(num_files)):
         filename = self.audio_files[i]
         audio, sampling_rate = ms.load_wav_to_torch(filename)
         if sampling_rate != self.sampling_rate:
             raise ValueError("{} SR doesn't match target {} SR".format(
                 sampling_rate, self.sampling_rate))
         self.all_length += audio.size(0)
         if i == 0:
             self.all_audio = audio
         else:
             self.all_audio = torch.cat((self.all_audio, audio), 0)
     print("All audio has been loaded, totally length: {}".format(
         self.all_length))
Exemple #3
0
def DNN_test(checkpoint_path, filename):
    model = DNNnet(**DNN_net_config)
    checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
    model_for_loading = checkpoint_dict['model']
    model.load_state_dict(model_for_loading.state_dict())
    print("Loaded checkpoint '{}' ".format(checkpoint_path))

    model.eval()

    Data_gen = spec2load(**DNN_data_config)

    audio, sr = ms.load_wav_to_torch(filename)
    # Take segment
    if audio.size(0) >= 2240:
        max_audio_start = audio.size(0) - 2240
        audio_start = random.randint(0, max_audio_start)
        audio = audio[audio_start:audio_start + 2240]
    else:
        audio = torch.nn.functional.pad(audio, (2240 - audio.size(0), 0),
                                        'constant').data

    spec = Data_gen.get_spec(audio)
    feed_spec = spec[:, :-2]
    targ_spec = spec[:, -2:]
    feed_spec = feed_spec.unsqueeze(0)
    gener_spec = model.forward(feed_spec)

    # gener_linear = gener_linear.squeeze().view(-1, 2).T
    gener_spec = gener_spec.detach().numpy()
    targ_spec = torch.cat((targ_spec[:, 0], targ_spec[:, 1]), 0)
    targ_spec = targ_spec.numpy()
    # targ_mel_db = librosa.power_to_db(targ_mel[0], ref=np.max)
    # gener_mel_db = librosa.power_to_db(gener_mel[0], ref=np.max)
    plt.figure()

    plt.plot(targ_spec)
    # librosa.display.specshow(targ_mel, x_axis='time', y_axis='mel')

    plt.plot(gener_spec[0], 'r')
    # librosa.display.specshow(gener_mel, x_axis='time', y_axis='mel')

    plt.show()
Exemple #4
0
def validate(model, loader_STFT, STFTs, logger, iteration, validation_files, speaker_lookup, sigma, output_directory, data_config, save_audio=True, max_length_s= 3, files_to_process=9):
    print("Validating... ", end="")
    from mel2samp import DTW
    val_sigma = sigma * 1.00
    model.eval()
    val_start_time = time.time()
    STFT_elapsed = samples_processed = 0
    with torch.no_grad():
        with torch.random.fork_rng(devices=[0,]):
            torch.random.manual_seed(0)# use same Z / random seed during validation so results are more consistent and comparable.
            
            with open(validation_files, encoding='utf-8') as f:
                audiopaths_and_melpaths = [line.strip().split('|') for line in f]
            
            if next(model.parameters()).type() == "torch.cuda.HalfTensor":
                model_type = "half"
            else:
                model_type = "float"
            
            timestr = time.strftime("%Y_%m_%d-%H_%M_%S")
            total_MAE = total_MSE = total = files_processed = 0
            input_mels = []
            gt_mels = []
            pred_mels = []
            MAE_specs = []
            for i, (audiopath, melpath, *remaining) in enumerate(audiopaths_and_melpaths):
                if files_processed >= files_to_process: # number of validation files to run.
                    break
                audio = load_wav_to_torch(audiopath)[0]/32768.0 # load audio from wav file to tensor
                if audio.shape[0] > (data_config['sampling_rate']*max_length_s):
                    continue # ignore audio over max_length_seconds
                
                gt_mel = loader_STFT.mel_spectrogram(audio.unsqueeze(0)).cuda()# [T] -> [1, T] -> [1, n_mel, T]
                if 'load_hidden_from_disk' in data_config.keys() and data_config['load_hidden_from_disk']:
                    mel = None
                    hidden_path = remaining[1]
                    model_input = np.load(hidden_path) # load tacotron hidden from file into numpy arr
                    model_input = torch.from_numpy(model_input).unsqueeze(0).cuda() # from numpy arr to tensor on GPU
                else:
                    if loader_STFT and data_config['load_mel_from_disk'] < 0.2:
                        mel = None
                        model_input = gt_mel.clone()
                    else:
                        mel = np.load(melpath) # load mel from file into numpy arr
                        mel = torch.from_numpy(mel).unsqueeze(0).cuda() # from numpy arr to tensor on GPU
                        assert mel[:, :gt_mel.shape[1], :].shape == gt_mel.shape, f'shapes of {mel[:, :gt_mel.shape[1], :].shape} and {gt_mel.shape} do not match.'
                        #if torch.nn.functional.mse_loss(mel[:, :gt_mel.shape[1], :], gt_mel) > 0.3:
                        #    continue # skip validation files that significantly vary from the target.
                        if model.has_logvar_channels:
                            if mel.shape[1] == model.n_mel_channels*2:
                                mel, logvar = mel.chunk(2, dim=1)# [1, n_mel*2, T] -> [1, n_mel, T], [1, n_mel, T]
                                mel = DTW(mel, gt_mel, scale_factor=8, range_=5)
                                model_input = torch.cat((mel, logvar), dim=1)
                            else:
                                raise Exception("Loaded mel from disk has wrong shape.")
                        else:
                            if mel.shape[1] == model.n_mel_channels*2:
                                mel = mel.chunk(2, dim=1)[0]# [1, n_mel*2, T] -> [1, n_mel, T]
                            #mel = DTW(mel, gt_mel, scale_factor=8, range_=5)
                            model_input = mel
                
                if hasattr(model, 'multispeaker') and model.multispeaker == True:
                    assert len(remaining), f"Speaker ID missing while multispeaker == True.\nLine: {i}\n'{'|'.join([autiopath, melpath])}'"
                    speaker_id = remaining[0]
                    speaker_id = torch.IntTensor([speaker_lookup[int(speaker_id)],])
                    speaker_id = speaker_id.cuda(non_blocking=True).long()
                else:
                    speaker_id = None
                
                if model_type == "half":
                    model_input = model_input.half() # for fp16 training
                
                audio_waveglow = model.infer(model_input, speaker_id, sigma=val_sigma).cpu().float()
                
                audio = audio.squeeze().unsqueeze(0) # crush extra dimensions and shape for STFT
                audio_waveglow = audio_waveglow.squeeze().unsqueeze(0).clamp(min=-0.999, max=0.999) # [1, T] crush extra dimensions and shape for STFT
                audio_waveglow[torch.isnan(audio_waveglow) | torch.isinf(audio_waveglow)] = 0.0 # and clamp any values over/under |1.0| (which should only exist very early in training)
                
                STFT_start_time = time.time()
                for j, STFT in enumerate(STFTs): # check Spectrogram Error with multiple window sizes
                    input_mels.append(mel)
                    mel_GT = STFT.mel_spectrogram(audio)# [1, T] -> [1, n_mel, T//hop_len]
                    gt_mels.append(mel_GT)
                    
                    mel_waveglow = STFT.mel_spectrogram(audio_waveglow)[:,:,:mel_GT.shape[-1]]# [1, T] -> [1, n_mel, T//hop_len]
                    pred_mels.append(mel_waveglow)
                    
                    MSE = (torch.nn.MSELoss()(mel_waveglow, mel_GT)).item() # get MSE (Mean Squared Error) between Ground Truth and WaveGlow inferred spectrograms.
                    MAE_spec = torch.nn.L1Loss(reduction='none')(mel_waveglow, mel_GT)
                    MAE = (MAE_spec.mean()).item() # get MAE (Mean Absolute Error) between Ground Truth and WaveGlow inferred spectrograms.
                    MAE_specs.append(MAE_spec)
                    
                    total_MAE+=MAE
                    total_MSE+=MSE
                    total+=1
                STFT_elapsed += time.time()-STFT_start_time
                
                if save_audio:
                    audio_path = os.path.join(output_directory, "samples", str(iteration)+"-"+timestr, os.path.basename(audiopath)) # Write audio to checkpoint_directory/iteration/audiofilename.wav
                    os.makedirs(os.path.join(output_directory, "samples", str(iteration)+"-"+timestr), exist_ok=True)
                    sf.write(audio_path, audio_waveglow.squeeze().cpu().numpy(), data_config['sampling_rate'], "PCM_16") # save waveglow sample
                    
                    audio_path = os.path.join(output_directory, "samples", "Ground Truth", os.path.basename(audiopath)) # Write audio to checkpoint_directory/iteration/audiofilename.wav
                    if not os.path.exists(audio_path):
                        os.makedirs(os.path.join(output_directory, "samples", "Ground Truth"), exist_ok=True)
                        sf.write(audio_path, audio.squeeze().cpu().numpy(), data_config['sampling_rate'], "PCM_16") # save ground truth
                files_processed+=1
                samples_processed+=audio_waveglow.shape[-1]
    
    if total:
        average_MSE = total_MSE/total
        average_MAE = total_MAE/total
        logger.add_scalar('val_MSE', average_MSE, iteration)
        logger.add_scalar('val_MAE', average_MAE, iteration)
        
        for idx, (gt_mel, pred_mel, input_mel, mae_mel) in enumerate(zip(gt_mels[-6:], pred_mels[-6:], input_mels[-6:], MAE_specs[-6:])):
            logger.add_image(f'mel_{idx}/pred',
                        plot_spectrogram_to_numpy(pred_mel[0].data.cpu().numpy(), range=[-11.5, 2.0]),
                        iteration, dataformats='HWC')
            if mae_mel is not None:
                logger.add_image(f'mel_{idx}/zmae',
                        plot_spectrogram_to_numpy(mae_mel[0].data.cpu().numpy(), range=[0.0, 2.5]),
                        iteration, dataformats='HWC')
            if iteration % 10000 == 0: # target doesn't change unless batch size or dataset changes so only needs to be plotted once in a while.
                logger.add_image(f'mel_{idx}/target',
                            plot_spectrogram_to_numpy(gt_mel[0].data.cpu().numpy(), range=[-11.5, 2.0]),
                            iteration, dataformats='HWC')
                if input_mel is not None:
                    logger.add_image(f'mel_{idx}/input',
                            plot_spectrogram_to_numpy(input_mel[0].data.cpu().numpy(), range=[-11.5, 2.0]),
                            iteration, dataformats='HWC')
        
        time_elapsed = time.time()-val_start_time
        time_elapsed_without_stft = time_elapsed-STFT_elapsed
        samples_per_second = samples_processed/time_elapsed_without_stft
        print(f"[Avg MSE: {average_MSE:.6f} MAE: {average_MAE:.6f}]",
            f"[{time_elapsed_without_stft:.3f}s]",
            f"[{time_elapsed:.3f}s_stft]",
            f"[{time_elapsed_without_stft/files_processed:.3f}s/file]",
            f"[{samples_per_second/data_config['sampling_rate']:.3f}rtf]",
            f"[{samples_per_second:,.0f}samples/s]"
          )
        logger.add_scalar('val_rtf', samples_per_second/data_config['sampling_rate'], iteration)
    else:
        average_MSE = 1e3
        average_MAE = 1e3
        print("Average MSE: N/A", "Average MAE: N/A")
    
    for convinv in model.convinv:
        if hasattr(convinv, 'W_inverse'):
            delattr(convinv, "W_inverse") # clear Inverse Weights.
    if hasattr(model, 'iso226'):
        delattr(model, 'iso226')
    mel = speaker_id = None # del GPU based tensors.
    torch.cuda.empty_cache() # clear cache for next training
    model.train()
    
    return average_MSE, average_MAE
Exemple #5
0
def main(files, waveglow_path, sigma, output_dir, sampling_rate, is_fp16,
         denoiser_strength, args):
    #mel_files = files_to_list(mel_files)
    #print(mel_files)
    files = ['/local-scratch/fuyang/cmpt726/final_project/cremad/1091_WSI_SAD_XX.wav']
    #files = ['/local-scratch/fuyang/cmpt726/waveglow/data/LJSpeech-1.1/LJ001-0001.wav']
    with open('config.json') as f:
        data = f.read()
    config = json.loads(data)
    waveglow_config = config["waveglow_config"]
    model = WaveGlow(**waveglow_config)
    checkpoint_dict = torch.load('waveglow_256channels_universal_v5.pt', map_location='cpu')
    model_for_loading = checkpoint_dict['model']
    model.load_state_dict(model_for_loading.state_dict())
    model.cuda()
    #waveglow = torch.load(waveglow_path)['model']
    #waveglow = waveglow.remove_weightnorm(waveglow)
    #waveglow.cuda()
    waveglow = model
    if is_fp16:
        from apex import amp
        waveglow, _ = amp.initialize(waveglow, [], opt_level="O1")

    if denoiser_strength > 0:
        denoiser = Denoiser(waveglow).cuda()

    mel_extractor = Get_mel(1024, 256, 1024, args.sampling_rate, 0.0, 8000.0)

    for i, file_path in enumerate(files):
        audio, rate = load_wav_to_torch(file_path)
        if rate != sampling_rate:
            audio = resampy.resample(audio.numpy(), rate, sampling_rate)
            audio = torch.from_numpy(audio).float()
        #if audio.size(0) >= args.segment_length:
        #    max_audio_start = audio.size(0) - args.segment_length
        #    audio_start = random.randint(0, max_audio_start)
        #    audio = audio[audio_start:audio_start+args.segment_length]
        #else:
        #    audio = torch.nn.functional.pad(audio, (0, args.segment_length-audio.size(0)), 'constant').data
        mel = mel_extractor.get_mel(audio)
        audio = audio / MAX_WAV_VALUE

        mel = torch.autograd.Variable(mel.cuda().unsqueeze(0))
        audio = torch.autograd.Variable(audio.cuda().unsqueeze(0))
        audio = audio.half() if is_fp16 else audio
        mel = mel.half() if is_fp16 else mel
        outputs = waveglow((mel, audio))
        z = outputs[0][:,4:]
        print(outputs)
        mel_up = waveglow.upsample(mel)
        time_cutoff = waveglow.upsample.kernel_size[0]-waveglow.upsample.stride[0]
        mel_up = mel_up[:,:,:-time_cutoff]
        #mel_up = mel_up[:,:,:-(time_cutoff+128)]

        mel_up = mel_up.unfold(2, waveglow.n_group, waveglow.n_group).permute(0,2,1,3)
        mel_up = mel_up.contiguous().view(mel_up.size(0), mel_up.size(1), -1).permute(0, 2, 1)
        audio = z
        mel_up = mel_up[:,:,:audio.size(2)]

        sigma = 0.7
        z_i = 0
        for k in reversed(range(waveglow.n_flows)):
            n_half = int(audio.size(1)/2)
            audio_0 = audio[:,:n_half, :]
            audio_1 = audio[:, n_half:, :]

            output = waveglow.WN[k]((audio_0, mel_up))

            s = output[:,n_half:, :]
            b = output[:, :n_half, :]
            audio_1 = (audio_1-b)/torch.exp(s)
            audio = torch.cat([audio_0, audio_1],1)

            audio = waveglow.convinv[k](audio, reverse=True)

            if k % waveglow.n_early_every == 0 and k > 0:
                z = outputs[0][:, 2-z_i:4-z_i]
                #if mel_up.type() == 'torch.cuda.HalfTensor':
                #    z = torch.cuda.HalfTensor(mel_up.size(0), waveglow.n_early_size, mel_up.size(2)).normal_()
                #else:
                #    z = torch.cuda.FloatTensor(mel_up.size(0), waveglow.n_early_size, mel_up.size(2)).normal_()
                audio = torch.cat((sigma*z, audio),1)
        audio = audio.permute(0,2,1).contiguous().view(audio.size(0), -1).data
        audio = audio * MAX_WAV_VALUE
        audio = audio.squeeze()
        audio = audio.cpu().numpy()
        audio = audio.astype('int16')
        audio_path = os.path.join(
            output_dir, "{}_synthesis.wav".format('fuyangz'))
        write(audio_path, sampling_rate, audio)
        print(audio_path)
def main(style, waveglow_path, sigma, output_dir, sampling_rate, is_fp16,
         denoiser_strength, args):
    #mel_files = files_to_list(mel_files)
    #print(mel_files)
    dataset = voice_dataset(dataBase={
        'ravdess': './our_data/ravdess',
        'cremad': './our_data/cremad'
    },
                            style=('happy', 'sad', 'angry'))
    #print(len(dataset.final_data['happy']))

    #sample = dataset.pick_one_random_sample('happy')
    files = dataset.final_data[style]
    #files = ['/local-scratch/fuyang/cmpt726/waveglow/data/LJSpeech-1.1/LJ001-0001.wav']
    with open('config.json') as f:
        data = f.read()
    config = json.loads(data)
    waveglow_config = config["waveglow_config"]
    model = WaveGlow(**waveglow_config)
    checkpoint_dict = torch.load('waveglow_256channels_universal_v5.pt',
                                 map_location='cpu')
    model_for_loading = checkpoint_dict['model']
    model.load_state_dict(model_for_loading.state_dict())
    model.cuda()
    waveglow = model
    if is_fp16:
        from apex import amp
        waveglow, _ = amp.initialize(waveglow, [], opt_level="O1")

    if denoiser_strength > 0:
        denoiser = Denoiser(waveglow).cuda()

    mel_extractor = Get_mel(1024, 256, 1024, args.sampling_rate, 0.0, 8000.0)
    avg_z = np.zeros(8)
    _count = 0
    for i, (_, file_path) in enumerate(files):
        if i > 50:
            break
        try:
            audio, rate = load_wav_to_torch(file_path)
            if rate != sampling_rate:
                audio = resampy.resample(audio.numpy(), rate, sampling_rate)
                audio = torch.from_numpy(audio).float()
            #if audio.size(0) >= args.segment_length:
            #    max_audio_start = audio.size(0) - args.segment_length
            #    audio_start = random.randint(0, max_audio_start)
            #    audio = audio[audio_start:audio_start+args.segment_length]
            #else:
            #    audio = torch.nn.functional.pad(audio, (0, args.segment_length-audio.size(0)), 'constant').data
            mel = mel_extractor.get_mel(audio)
            audio = audio / MAX_WAV_VALUE

            mel = torch.autograd.Variable(mel.cuda().unsqueeze(0))
            audio = torch.autograd.Variable(audio.cuda().unsqueeze(0))
            audio = audio.half() if is_fp16 else audio
            mel = mel.half() if is_fp16 else mel
            outputs = waveglow((mel, audio))
            avg_z += outputs[0].squeeze(0).mean(1).detach().cpu().numpy()
            _count += 1
            z = outputs[0][:, 4:]

            #print(outputs)
            mel_up = waveglow.upsample(mel)
            time_cutoff = waveglow.upsample.kernel_size[
                0] - waveglow.upsample.stride[0]
            mel_up = mel_up[:, :, :-time_cutoff]
            #mel_up = mel_up[:,:,:-(time_cutoff+128)]

            mel_up = mel_up.unfold(2, waveglow.n_group,
                                   waveglow.n_group).permute(0, 2, 1, 3)
            mel_up = mel_up.contiguous().view(mel_up.size(0), mel_up.size(1),
                                              -1).permute(0, 2, 1)
            audio = z
            mel_up = mel_up[:, :, :audio.size(2)]

            sigma = 0.7
            z_i = 0
            for k in reversed(range(waveglow.n_flows)):
                n_half = int(audio.size(1) / 2)
                audio_0 = audio[:, :n_half, :]
                audio_1 = audio[:, n_half:, :]

                output = waveglow.WN[k]((audio_0, mel_up))

                s = output[:, n_half:, :]
                b = output[:, :n_half, :]
                audio_1 = (audio_1 - b) / torch.exp(s)
                audio = torch.cat([audio_0, audio_1], 1)

                audio = waveglow.convinv[k](audio, reverse=True)

                if k % waveglow.n_early_every == 0 and k > 0:
                    z = outputs[0][:, 2 - z_i:4 - z_i]
                    #if mel_up.type() == 'torch.cuda.HalfTensor':
                    #    z = torch.cuda.HalfTensor(mel_up.size(0), waveglow.n_early_size, mel_up.size(2)).normal_()
                    #else:
                    #    z = torch.cuda.FloatTensor(mel_up.size(0), waveglow.n_early_size, mel_up.size(2)).normal_()
                    audio = torch.cat((sigma * z, audio), 1)
            audio = audio.permute(0, 2,
                                  1).contiguous().view(audio.size(0), -1).data
            audio = audio * MAX_WAV_VALUE
            audio = audio.squeeze()
            audio = audio.cpu().numpy()
            audio = audio.astype('int16')
            audio_path = os.path.join(
                output_dir, "{}_synthesis.wav".format(file_path[:-4]))
            if os.path.exists(
                    os.path.join(*audio_path.split('/')[:-1])) is False:
                os.makedirs(os.path.join(*audio_path.split('/')[:-1]),
                            exist_ok=True)
            write(audio_path, sampling_rate, audio)
            print(audio_path)
        except:
            continue

    avg_z = avg_z / _count
    np.save(style, avg_z)
Exemple #7
0
def test_mel2samp():
    """Test mel2samp modules on example data."""
    from mel2samp import Mel2Samp
    
    hparams = hparams_class()
    
    passed = 0
    
    
    # test filelist loader
    try:
        from mel2samp import load_filepaths_and_text
        audio_files = load_filepaths_and_text("code_tests/test_materials/filelists/validation_utf8.txt")
        passed+=1
        print("--PASSED--\n")
    except Exception as ex:
        print("--EXCEPTION-- @ Load Filepaths and Text (UTF-8)")
        traceback.print_exc(file=sys.stdout)
        print("\n")
    
    
    # test filelist checker
    try:
        assert audio_files
        from mel2samp import check_files
        audio_files = check_files(audio_files, hparams)
        assert len(audio_files) == 1
        passed+=1
        print("--PASSED--\n")
        del audio_files
    except Exception as ex:
        print("--EXCEPTION-- @ Load Filepaths and Text (UTF-8)")
        traceback.print_exc(file=sys.stdout)
        print("\n")
    
    
    # test initalization
    try:
        trainset = Mel2Samp(hparams)
        passed+=1
        print("--PASSED--\n")
    except Exception as ex:
        print("--EXCEPTION-- @ Mel2Samp Initialization")
        traceback.print_exc(file=sys.stdout)
        print("\n")
    
    
    # test 16-BIT .wav to torch
    try:
        from mel2samp import load_wav_to_torch
        x, sr = load_wav_to_torch("code_tests/test_materials/audio_0/example_16bits.wav")
        assert len(x)
        assert x.max() <= 2**15
        assert x.min() >= -(2**15)
        assert sr == 48000
        passed+=1
        print("--PASSED--\n")
    except Exception as ex:
        print("--EXCEPTION-- @ Load 16-BIT .wav to Pytorch")
        traceback.print_exc(file=sys.stdout)
        print("\n")
    
    
    # test 24-BIT .wav to torch
    try:
        from mel2samp import load_wav_to_torch
        x, sr = load_wav_to_torch("code_tests/test_materials/audio_0/example_24bits.wav")
        assert len(x)
        assert x.max() <= 2**23
        assert x.min() >= -(2**23)
        assert sr == 48000
        passed+=1
        print("--PASSED--\n")
    except Exception as ex:
        print("--EXCEPTION-- @ Load 24-BIT .wav to Pytorch")
        traceback.print_exc(file=sys.stdout)
        print("\n")
    
    
    # test 32-BIT .wav to torch
    try:
        from mel2samp import load_wav_to_torch
        x, sr = load_wav_to_torch("code_tests/test_materials/audio_0/example_32bits.wav")
        assert len(x)
        assert x.max() <= 2**31
        assert x.min() >= -(2**31)
        assert sr == 48000
        passed+=1
        print("--PASSED--\n")
    except Exception as ex:
        print("--EXCEPTION-- @ Load 32-BIT .wav to Pytorch")
        traceback.print_exc(file=sys.stdout)
        print("\n")
    
    
    # test 32-BIT .mp3 to torch
    try:
        from mel2samp import load_wav_to_torch
        x, sr = load_wav_to_torch("code_tests/test_materials/audio_0/example_32bits.mp3")
        assert len(x)
        assert x.max() <= 2**31
        assert x.min() >= -(2**31)
        assert sr == 48000
        passed+=1
        print("--PASSED--\n")
    except Exception as ex:
        print("--EXCEPTION-- @ Load 32-BIT .mp3 to Pytorch")
        traceback.print_exc(file=sys.stdout)
        print("\n")
    
    
    # test 16-BIT .wav to mel
    try:
        x, sr = load_wav_to_torch("code_tests/test_materials/audio_0/example_16bits.wav")
        x = trainset.get_mel(x)
        passed+=1
        print("--PASSED--\n")
    except Exception as ex:
        print("--EXCEPTION-- @ 16-BIT .wav to Mel-spec")
        traceback.print_exc(file=sys.stdout)
        print("\n")
    
    
    # test 24-BIT .wav to mel
    try:
        x, sr = load_wav_to_torch("code_tests/test_materials/audio_0/example_24bits.wav")
        x = trainset.get_mel(x)
        passed+=1
        print("--PASSED--\n")
    except Exception as ex:
        print("--EXCEPTION-- @ 24-BIT .wav to Mel-spec")
        traceback.print_exc(file=sys.stdout)
        print("\n")
    
    
    # test 32-BIT .wav to mel
    try:
        x, sr = load_wav_to_torch("code_tests/test_materials/audio_0/example_32bits.wav")
        x = trainset.get_mel(x)
        passed+=1
        print("--PASSED--\n")
    except Exception as ex:
        print("--EXCEPTION-- @ 32-BIT .wav to Mel-spec")
        traceback.print_exc(file=sys.stdout)
        print("\n")
    
    
    # test 32-BIT .mp3 to mel
    try:
        x, sr = load_wav_to_torch("code_tests/test_materials/audio_0/example_32bits.mp3")
        x = trainset.get_mel(x)
        passed+=1
        print("--PASSED--\n")
    except Exception as ex:
        print("--EXCEPTION-- @ 32-BIT .mp3 to Mel-spec")
        traceback.print_exc(file=sys.stdout)
        print("\n")
    
    
    # test __getitem__ with load_mel_from_disk = False
    try:
        assert trainset # This test will fail if Mel2Samp cannot initalize
        trainset.load_mel_from_disk = False
        trainset.__getitem__(0)
        passed+=1
        print("--PASSED--\n")
    except Exception as ex:
        print("--EXCEPTION-- @  __getitem__ with load_mel_from_disk = False")
        traceback.print_exc(file=sys.stdout)
        print("\n")
    
    
    # test __getitem__ with load_mel_from_disk = True
    try:
        assert trainset # This test will fail if Mel2Samp cannot initalize
        trainset.load_mel_from_disk = True
        trainset.__getitem__(0)
        passed+=1
        print("--PASSED--\n")
    except Exception as ex:
        print("--EXCEPTION-- @  __getitem__ with load_mel_from_disk = True")
        traceback.print_exc(file=sys.stdout)
        print("\n")
    
    
    # test initalization with Pre-empthasis
    try:
        trainset = None
        hparams.preempthasis = 0.98
        trainset = Mel2Samp(hparams)
        passed+=1
        print("--PASSED--\n")
    except Exception as ex:
        print("--EXCEPTION-- @ Mel2Samp with Pre-empthasis Initialization")
        traceback.print_exc(file=sys.stdout)
        print("\n")
    
    
    # test __getitem__ with Pre-empthasis
    try:
        assert trainset # This test will fail if Mel2Samp cannot initalize
        trainset.load_mel_from_disk = False
        trainset.__getitem__(0)
        passed+=1
        print("--PASSED--\n")
    except Exception as ex:
        print("--EXCEPTION-- @  __getitem__ with Pre-empthasis")
        traceback.print_exc(file=sys.stdout)
        print("\n")
Exemple #8
0
def validate(model,
             loader_STFT,
             STFTs,
             logger,
             iteration,
             validation_files,
             speaker_lookup,
             sigma,
             output_directory,
             data_config,
             save_audio=True,
             max_length_s=5):
    from mel2samp import load_wav_to_torch
    from scipy import signal
    val_sigma = sigma * 0.9
    model.eval()
    with torch.no_grad():
        with open(validation_files, encoding='utf-8') as f:
            audiopaths_and_melpaths = [line.strip().split('|') for line in f]

        if list(model.parameters())[0].type() == "torch.cuda.HalfTensor":
            model_type = "half"
        else:
            model_type = "float"

        timestr = time.strftime("%Y_%m_%d-%H_%M_%S")
        total_MAE = total_MSE = total = 0
        for i, (audiopath, melpath,
                *remaining) in enumerate(audiopaths_and_melpaths):
            if i > 60: break  # debug
            audio = load_wav_to_torch(
                audiopath)[0] / 32768.0  # load audio from wav file to tensor
            if audio.shape[0] > (data_config['sampling_rate'] * max_length_s):
                continue  # ignore audio over max_length_seconds

            if loader_STFT:
                mel = loader_STFT.mel_spectrogram(audio.unsqueeze(0)).cuda()
            else:
                mel = np.load(melpath)  # load mel from file into numpy arr
                mel = torch.from_numpy(mel).unsqueeze(
                    0).cuda()  # from numpy arr to tensor on GPU

            if hasattr(model, 'multispeaker') and model.multispeaker == True:
                assert len(
                    remaining
                ), f"Speaker ID missing while multispeaker == True.\nLine: {i}\n'{'|'.join([autiopath, melpath])}'"
                speaker_id = remaining[0]
                speaker_id = torch.IntTensor([speaker_lookup[int(speaker_id)]])
                speaker_id = speaker_id.cuda(non_blocking=True).long()
            else:
                speaker_id = None

            if model_type == "half":
                mel = mel.half()  # for fp16 training

            audio_waveglow = model.infer(mel, speaker_id, sigma=val_sigma)
            audio_waveglow = audio_waveglow.cpu().float()

            if data_config['preempthasis']:  # inverse-preempthasis
                audio_waveglow = audio_waveglow.squeeze()
                audio_waveglow = torch.from_numpy(
                    signal.lfilter([1],
                                   [1, -float(data_config['preempthasis'])],
                                   audio_waveglow.numpy())
                ).float(
                )  # de-preempthasis (scipy signal is faster than pytorch implementation for some reason /shrug )

            audio = audio.squeeze().unsqueeze(
                0)  # crush extra dimensions and shape for STFT
            audio_waveglow = audio_waveglow.squeeze().unsqueeze(
                0)  # crush extra dimensions and shape for STFT
            audio_waveglow = audio_waveglow.clamp(
                -1, 1
            )  # clamp any values over/under |1.0| (which should only exist very early in training)

            for STFT in STFTs:  # check Spectrogram Error with multiple window sizes
                mel_GT = STFT.mel_spectrogram(audio)
                try:
                    mel_waveglow = STFT.mel_spectrogram(
                        audio_waveglow)[:, :, :mel_GT.shape[-1]]
                except AssertionError as ex:
                    print(ex)
                    continue

                MSE = (torch.nn.MSELoss()(mel_waveglow, mel_GT)).item(
                )  # get MSE (Mean Squared Error) between Ground Truth and WaveGlow inferred spectrograms.
                MAE = (torch.nn.L1Loss()(mel_waveglow, mel_GT)).item(
                )  # get MAE (Mean Absolute Error) between Ground Truth and WaveGlow inferred spectrograms.

                total_MAE += MAE
                total_MSE += MSE
                total += 1

            if save_audio:
                audio_path = os.path.join(
                    output_directory, "samples",
                    str(iteration) + "-" + timestr, os.path.basename(audiopath)
                )  # Write audio to checkpoint_directory/iteration/audiofilename.wav
                os.makedirs(os.path.join(output_directory, "samples",
                                         str(iteration) + "-" + timestr),
                            exist_ok=True)
                sf.write(audio_path,
                         audio_waveglow.squeeze().cpu().numpy(),
                         data_config['sampling_rate'],
                         "PCM_16")  # save waveglow sample

                audio_path = os.path.join(
                    output_directory, "samples", "Ground Truth",
                    os.path.basename(audiopath)
                )  # Write audio to checkpoint_directory/iteration/audiofilename.wav
                if not os.path.exists(audio_path):
                    os.makedirs(os.path.join(output_directory, "samples",
                                             "Ground Truth"),
                                exist_ok=True)
                    sf.write(audio_path,
                             audio.squeeze().cpu().numpy(),
                             data_config['sampling_rate'],
                             "PCM_16")  # save ground truth

    for convinv in model.convinv:
        if hasattr(convinv, 'W_inverse'):
            delattr(convinv, "W_inverse")  # clear Inverse Weights.

    if total:
        average_MSE = total_MSE / total
        average_MAE = total_MAE / total
        logger.add_scalar('val_MSE', average_MSE, iteration)
        logger.add_scalar('val_MAE', average_MAE, iteration)
        print("Average MSE:", average_MSE, "Average MAE:", average_MAE)
    else:
        average_MSE = 1e3
        average_MAE = 1e3
        print("Average MSE: N/A", "Average MAE: N/A")

    model.train()
    return average_MSE, average_MAE
Exemple #9
0
def main(style, waveglow_path, sigma, output_dir, sampling_rate, is_fp16,
         denoiser_strength, args):
    #mel_files = files_to_list(mel_files)
    #print(mel_files)
    dataset = voice_dataset(dataBase={
        'ravdess': './our_data/ravdess',
        'cremad': './our_data/cremad'
    },
                            style=('happy', 'sad', 'angry'))
    #print(len(dataset.final_data['happy']))

    #sample = dataset.pick_one_random_sample('happy')
    styles = ['happy', 'sad', 'angry']
    with open('config.json') as f:
        data = f.read()
    config = json.loads(data)
    waveglow_config = config["waveglow_config"]
    model = WaveGlow(**waveglow_config)
    checkpoint_dict = torch.load('waveglow_256channels_universal_v5.pt',
                                 map_location='cpu')
    model_for_loading = checkpoint_dict['model']
    model.load_state_dict(model_for_loading.state_dict())
    model.cuda()
    waveglow = model
    if is_fp16:
        from apex import amp
        waveglow, _ = amp.initialize(waveglow, [], opt_level="O1")

    if denoiser_strength > 0:
        denoiser = Denoiser(waveglow).cuda()

    mel_extractor = Get_mel(1024, 256, 1024, args.sampling_rate, 0.0, 8000.0)

    vector_all = {}
    for style in styles:
        files = dataset.final_data[style].copy()
        random.shuffle(files)

        vectors = []
        for i, (_, file_path) in enumerate(files):
            if i > 200:
                break
            try:
                audio, rate = load_wav_to_torch(file_path)
                if rate != sampling_rate:
                    audio = resampy.resample(audio.numpy(), rate,
                                             sampling_rate)
                    audio = torch.from_numpy(audio).float()
                #if audio.size(0) >= args.segment_length:
                #    max_audio_start = audio.size(0) - args.segment_length
                #    audio_start = random.randint(0, max_audio_start)
                #    audio = audio[audio_start:audio_start+args.segment_length]
                #else:
                #    audio = torch.nn.functional.pad(audio, (0, args.segment_length-audio.size(0)), 'constant').data
                mel = mel_extractor.get_mel(audio)
                audio = audio / MAX_WAV_VALUE

                mel = torch.autograd.Variable(mel.cuda().unsqueeze(0))
                audio = torch.autograd.Variable(audio.cuda().unsqueeze(0))
                audio = audio.half() if is_fp16 else audio
                mel = mel.half() if is_fp16 else mel
                outputs = waveglow((mel, audio))
                vectors.append(
                    outputs[0].squeeze(0).mean(1).detach().cpu().numpy())
                print(style, i)
            except:
                continue

        vector_all[style] = vectors

    np.save('all_style_vector', vector_all)
Exemple #10
0
def validate(model,
             STFTs,
             logger,
             iteration,
             speaker_lookup,
             hparams,
             output_directory,
             save_audio=True,
             max_length_s=5):
    from mel2samp import load_wav_to_torch
    val_sigma = hparams.sigma * 0.9
    model.eval()
    with torch.no_grad():
        with open(hparams.validation_files, encoding='utf-8') as f:
            audiopaths_and_melpaths = [line.strip().split('|') for line in f]

        if list(model.parameters())[0].type() == "torch.cuda.HalfTensor":
            model_type = "half"
        else:
            model_type = "float"

        timestr = time.strftime("%Y_%m_%d-%H_%M_%S")
        total_MAE = total_MSE = total = 0
        for i, (audiopath, melpath,
                *remaining) in enumerate(audiopaths_and_melpaths):
            if i > 30: break  # debug
            audio = load_wav_to_torch(
                audiopath)[0] / 32768.0  # load audio from wav file to tensor
            if audio.shape[0] > (hparams.sampling_rate * max_length_s):
                continue  # ignore audio over max_length_seconds
            mel = np.load(melpath)  # load mel from file into numpy arr
            mel = torch.from_numpy(mel).unsqueeze(
                0).cuda()  # from numpy arr to tensor on GPU
            #mel = (mel+5.2)*0.5 # shift values between approx -4 and 4
            if hasattr(model, 'multispeaker') and model.multispeaker == True:
                assert len(
                    remaining
                ), f"Speaker ID missing while multispeaker == True.\nLine: {i}\n'{'|'.join([autiopath, melpath])}'"
                speaker_id = remaining[0]
                assert int(speaker_id) in speaker_lookup.keys(
                ), f"Validation speaker ID:{speaker_id} not found in training filelist.\n(This speaker does not have an embedding, either use single-speaker models or provide an example of this speaker in the training data)."
                speaker_id = torch.IntTensor([speaker_lookup[int(speaker_id)]])
                speaker_id = speaker_id.cuda(non_blocking=True).long()
            else:
                speaker_id = None

            if model_type == "half":
                mel = mel.half()  # for fp16 training

            audio_waveglow = model.infer(mel, speaker_id, sigma=val_sigma)
            audio_waveglow = audio_waveglow.cpu().float()

            audio = audio.squeeze().unsqueeze(
                0)  # crush extra dimensions and shape for STFT
            audio_waveglow = audio_waveglow.squeeze().unsqueeze(
                0)  # crush extra dimensions and shape for STFT
            audio_waveglow = audio_waveglow.clamp(
                -1, 1
            )  # clamp any values over/under |1.0| (which should only exist very early in training)

            for STFT in STFTs:  # check Spectrogram Error with multiple window sizes
                mel_GT = STFT.mel_spectrogram(audio)
                try:
                    mel_waveglow = STFT.mel_spectrogram(
                        audio_waveglow)[:, :, :mel_GT.shape[-1]]
                except AssertionError as ex:
                    cprint(ex, b_tqdm=hparams.tqdm)
                    continue

                MSE = (torch.nn.MSELoss()(mel_waveglow, mel_GT)).item(
                )  # get MSE (Mean Squared Error) between Ground Truth and WaveGlow inferred spectrograms.
                MAE = (torch.nn.L1Loss()(mel_waveglow, mel_GT)).item(
                )  # get MAE (Mean Absolute Error) between Ground Truth and WaveGlow inferred spectrograms.

                total_MAE += MAE
                total_MSE += MSE
                total += 1

            if save_audio:
                audio_path = os.path.join(
                    output_directory, "samples",
                    str(iteration) + "-" + timestr, os.path.basename(audiopath)
                )  # Write audio to checkpoint_directory/iteration/audiofilename.wav
                os.makedirs(os.path.join(output_directory, "samples",
                                         str(iteration) + "-" + timestr),
                            exist_ok=True)
                sf.write(audio_path,
                         audio_waveglow.squeeze().cpu().numpy(),
                         hparams.sampling_rate,
                         "PCM_16")  # save waveglow sample

                audio_path = os.path.join(
                    output_directory, "samples", "Ground Truth",
                    os.path.basename(audiopath)
                )  # Write audio to checkpoint_directory/iteration/audiofilename.wav
                if not os.path.exists(audio_path):
                    os.makedirs(os.path.join(output_directory, "samples",
                                             "Ground Truth"),
                                exist_ok=True)
                    sf.write(audio_path,
                             audio.squeeze().cpu().numpy(),
                             hparams.sampling_rate,
                             "PCM_16")  # save ground truth

    for convinv in model.convinv:
        if hasattr(convinv, 'W_inverse'):
            delattr(convinv, "W_inverse")  # clear Inverse Weights.

    if total:
        average_MSE = total_MSE / total
        average_MAE = total_MAE / total
        logger.add_scalar('val_MSE', average_MSE, iteration)
        logger.add_scalar('val_MAE', average_MAE, iteration)
        cprint("Average MSE:",
               average_MSE,
               "Average MAE:",
               average_MAE,
               b_tqdm=hparams.tqdm)
    else:
        average_MSE = 1e3
        average_MAE = 1e3
        cprint("Average MSE: N/A", "Average MAE: N/A", b_tqdm=hparams.tqdm)

    model.train()
    return average_MSE, average_MAE