def get_chunks(path, hparams): # mel basis mel_basis = build_mel_basis(hparams, hparams.sr, hparams.sr) # load file x_s, _ = prepare_wav(wav_loc=path, hparams=hparams, debug=False) # Segmentation params min_level_db_floor = -30 db_delta = 5 silence_threshold = 0.01 min_silence_for_spec = 0.05 max_vocal_for_spec = 1.0, min_syllable_length_s = 0.05 # segment results = dynamic_threshold_segmentation( x_s, hparams.sr, n_fft=hparams.n_fft, hop_length=hparams.hop_length_samples, win_length=hparams.win_length_samples, min_level_db_floor=min_level_db_floor, db_delta=db_delta, ref_level_db=hparams.ref_level_db, pre=hparams.preemphasis, min_silence_for_spec=min_silence_for_spec, max_vocal_for_spec=max_vocal_for_spec, silence_threshold=silence_threshold, verbose=True, min_syllable_length_s=min_syllable_length_s, spectral_range=[ hparams.mel_lower_edge_hertz, hparams.mel_upper_edge_hertz ], ) if results is None: print('Cannot segment the input file') return # chunks start_times = results["onsets"] end_times = results["offsets"] chunks_mS = [] start_samples = [] end_samples = [] for start_time, end_time in zip(start_times, end_times): start_sample = int(start_time * hparams.sr) end_sample = int(end_time * hparams.sr) syl = x_s[start_sample:end_sample] # To avoid mistakes, reproduce the whole preprocessing pipeline, even (here useless) int casting _, mS, _ = process_syllable(syl, hparams, mel_basis=mel_basis, debug=False) if mS is None: continue mS_int = (mS * 255).astype('uint8') sample = SpectroDataset.process_mSp(mS_int) chunks_mS.append(sample) start_samples.append(start_sample) end_samples.append(end_sample) return x_s, chunks_mS, start_samples, end_samples
def plot_generation(model, hparams, num_examples, savepath): # forward pass model.eval() gen = model.generate(batch_dim=num_examples).cpu().detach().numpy() # plot dims = gen.shape[2:] plt.clf() fig, axes = plt.subplots(ncols=num_examples) for i in range(num_examples): # show the image axes[i].matshow(gen[i].reshape(dims), origin="lower") for ax in fig.get_axes(): ax.set_xticks([]) ax.set_yticks([]) plt.savefig(f'{savepath}/spectro.pdf') plt.close('all') # audio audios = [] if hparams is not None: mel_basis = build_mel_basis(hparams, hparams.sr, hparams.sr) mel_inversion_basis = build_mel_inversion_basis(mel_basis) for i in range(num_examples): gen_audio = inv_spectrogram_sp( gen[i, 0], n_fft=hparams.n_fft, win_length=hparams.win_length_samples, hop_length=hparams.hop_length_samples, ref_level_db=hparams.ref_level_db, power=hparams.power, mel_inversion_basis=mel_inversion_basis) audios.append(gen_audio) sf.write(f'{savepath}/{i}.wav', gen_audio, samplerate=hparams.sr) return {'audios': audios, 'spectros': gen}
def calibrate_db_file(sr, num_mel_bins, n_fft, mel_lower_edge_hertz, mel_upper_edge_hertz, hop_length_ms, win_length_ms, power, ref_level_db, wav_loc): hparams = HParams( sr=sr, num_mel_bins=num_mel_bins, n_fft=n_fft, mel_lower_edge_hertz=mel_lower_edge_hertz, mel_upper_edge_hertz=mel_upper_edge_hertz, power=power, # for spectral inversion butter_lowcut=mel_lower_edge_hertz, butter_highcut=mel_upper_edge_hertz, ref_level_db=ref_level_db, # 20 mask_spec=False, win_length_ms=win_length_ms, hop_length_ms=hop_length_ms, mask_spec_kwargs={ "spec_thresh": 0.9, "offset": 1e-10 }, n_jobs=1, verbosity=1, reduce_noise=True, noise_reduce_kwargs={ "n_std_thresh": 2.0, "prop_decrease": 0.8 }, griffin_lim_iters=50) suffix = hparams.__repr__() dump_folder = 'dump' / f'{suffix}' if os.path.isdir(dump_folder): shutil.rmtree(dump_folder) os.makedirs(dump_folder) # Read wave data, _ = prepare_wav(wav_loc, hparams, debug=True) # create spec if num_mel_bins is not None: mel_basis = build_mel_basis(hparams, hparams.sr, hparams.sr) else: mel_basis = None # melspec, debug_info = spectrogram_librosa(data, hparams, _mel_basis=mel_basis, debug=True) _, debug_info = spectrogram_sp(y=data, sr=hparams.sr, n_fft=hparams.n_fft, win_length_ms=hparams.win_length_ms, hop_length_ms=hparams.hop_length_ms, ref_level_db=hparams.ref_level_db, _mel_basis=mel_basis, pre_emphasis=hparams.preemphasis, power=hparams.power, debug=True) return { 'max_db': debug_info['max_db'], 'min_db': debug_info['min_db'] } if debug_info is not None else None
def plot_interpolations(model, hparams, dataloader, savepath, num_interpolated_points, method, custom_data): # Forward pass model.eval() if custom_data is None: for _, data in enumerate(dataloader): x_cuda = cuda_variable(data['input']) # Get z mu, logvar = model.encode(x_cuda) z = model.reparameterize(mu, logvar) # Arbitrarily choose start and end points as batch_ind and batch_ind + 1 start_z = z[:-1] end_z = z[1:] batch_dim, rgb_dim, h_dim, w_dim = x_cuda.shape num_examples = batch_dim - 1 break else: x_cuda = cuda_variable(torch.tensor(custom_data['start_data'])) # Get z mu, logvar = model.encode(x_cuda) start_z = model.reparameterize(mu, logvar) x_cuda = cuda_variable(torch.tensor(custom_data['end_data'])) # Get z mu, logvar = model.encode(x_cuda) end_z = model.reparameterize(mu, logvar) batch_dim, rgb_dim, h_dim, w_dim = x_cuda.shape num_examples = batch_dim x_interpolation = np.zeros( (num_examples, rgb_dim, h_dim, w_dim, num_interpolated_points)) ind_interp = 0 for t in np.linspace(start=0, stop=1, num=num_interpolated_points): # Perform interp if method == 'linear': this_z = start_z * (1 - t) + end_z * t elif method == 'constant_radius': this_z = constant_radius_interpolation(start_z, end_z, t) else: raise NotImplementedError # Decode z x_recon = model.decode(this_z).cpu().detach().numpy() x_interpolation[:, :, :, :, ind_interp] = x_recon ind_interp = ind_interp + 1 # Plot dims = h_dim, w_dim plt.clf() fig, axes = plt.subplots(nrows=num_examples, ncols=num_interpolated_points) for ind_example in range(num_examples): for ind_interp in range(num_interpolated_points): # show the image axes[ind_example, ind_interp].matshow(x_interpolation[ind_example, :, :, :, ind_interp].reshape(dims), origin="lower") for ax in fig.get_axes(): ax.set_xticks([]) ax.set_yticks([]) plt.savefig(f'{savepath}/spectro.pdf') plt.close('all') # audio audios = None if hparams is not None: mel_basis = build_mel_basis(hparams, hparams.sr, hparams.sr) mel_inversion_basis = build_mel_inversion_basis(mel_basis) for ind_example in range(num_examples): for ind_interp in range(num_interpolated_points): audio = inv_spectrogram_sp( x_interpolation[ind_example, 0, :, :, ind_interp], n_fft=hparams.n_fft, win_length=hparams.win_length_samples, hop_length=hparams.hop_length_samples, ref_level_db=hparams.ref_level_db, power=hparams.power, mel_inversion_basis=mel_inversion_basis) if audios is None: audios = np.zeros( (num_examples, num_interpolated_points, len(audio))) audios[ind_example, ind_interp] = audio sf.write(f'{savepath}/{ind_example}_{ind_interp}.wav', audio, samplerate=hparams.sr) return {'audios': audios, 'spectros': x_interpolation, 'dims': dims}
def main(config, load, train): # Init model and dataset model, dataset_train, dataset_val, optimizer, hparams, config, model_path, config_path = get_model_and_dataset( config=config, loading_epoch=load) # Training best_val_loss = float('inf') num_examples_plot = 10 if train: print('##### Train') # Copy config file in the save directory before training if not load: if not os.path.exists(model_path): os.makedirs(model_path) os.mkdir(f'{model.model_dir}/training_plots/') os.mkdir(f'{model.model_dir}/plots/') shutil.copy(config_path, f'{model_path}/config.py') writer = SummaryWriter(f'{model.model_dir}') # Epochs for ind_epoch in range(config['num_epochs']): train_dataloader = get_dataloader(dataset_type=config['dataset'], dataset=dataset_train, batch_size=config['batch_size'], shuffle=True) val_dataloader = get_dataloader(dataset_type=config['dataset'], dataset=dataset_val, batch_size=config['batch_size'], shuffle=True) train_loss = epoch(model, optimizer, train_dataloader, num_batches=config['num_batches'], training=True) val_loss = epoch(model, optimizer, val_dataloader, num_batches=config['num_batches'], training=False) writer.add_scalar('train_loss', train_loss, ind_epoch) writer.add_scalar('val_loss', val_loss, ind_epoch) print(f'Epoch {ind_epoch}:') print(f'Train loss {train_loss}:') print(f'Val loss {val_loss}:') del train_dataloader, val_dataloader # if (val_loss < best_val_loss) and (ind_epoch % 200 == 0): if ind_epoch % 10 == 0: # Save model model.save(name=ind_epoch) # Plots if not os.path.isdir(f'{model.model_dir}/training_plots/{ind_epoch}'): os.mkdir(f'{model.model_dir}/training_plots/{ind_epoch}') test_dataloader = get_dataloader(dataset_type=config['dataset'], dataset=dataset_val, batch_size=num_examples_plot, shuffle=True) savepath = f'{model.model_dir}/training_plots/{ind_epoch}/reconstructions' os.mkdir(savepath) plot_reconstruction(model, hparams, test_dataloader, savepath, custom_data=None) savepath = f'{model.model_dir}/training_plots/{ind_epoch}/generations' os.mkdir(savepath) plot_generation(model, hparams, num_examples_plot, savepath) del test_dataloader # Generations print('##### Generate') USE_CUSTOM_SAMPLE = False # Use own samples for generation if USE_CUSTOM_SAMPLE: data_source = { 0: { 'start': {'path': 'data/raw/source_generation/0.wav'}, 'end': {'path': 'data/raw/source_generation/1.wav'}, }, 1: { 'start': {'path': 'data/raw/source_generation/0.wav'}, 'end': {'path': 'data/raw/source_generation/2.wav'}, }, 2: { 'start': {'path': 'data/raw/source_generation/1.wav'}, 'end': {'path': 'data/raw/source_generation/2.wav'}, } } start_data = [] end_data = [] for example_ind, examples_dict in data_source.items(): for name in ['start', 'end']: # read file syl, _ = prepare_wav( wav_loc=examples_dict[name]['path'], hparams=hparams, debug=False) mel_basis = build_mel_basis(hparams, hparams.sr, hparams.sr) # process syllable sn, mSp, _ = process_syllable( syl=syl, hparams=hparams, mel_basis=mel_basis, debug=False) if name == 'start': start_data.append(SpectroDataset.process_mSp(mSp)) elif name == 'end': end_data.append(SpectroDataset.process_mSp(mSp)) all_data = start_data + end_data # Batchify start_data = np.stack(start_data) end_data = np.stack(end_data) all_data = np.stack(all_data) custom_data = { 'start_data': start_data, 'end_data': end_data, 'all_data': all_data, } else: custom_data = None # Dataloader from the dataset, used for latent space visualisation and to feed data for reconstruction # and interpolation if not custom examples are provided gen_dataloader = get_dataloader(dataset_type=config['dataset'], dataset=dataset_val, batch_size=num_examples_plot, shuffle=True) if not os.path.isdir(f'{model.model_dir}/plots'): os.mkdir(f'{model.model_dir}/plots') # Reconstructions savepath = f'{model.model_dir}/plots/reconstructions' if not os.path.isdir(savepath): os.mkdir(savepath) plot_reconstruction(model, hparams, gen_dataloader, savepath, custom_data=custom_data) # Sampling savepath = f'{model.model_dir}/plots/generations' if not os.path.isdir(savepath): os.mkdir(savepath) plot_generation(model, hparams, num_examples_plot, savepath) # Linear interpolations savepath = f'{model.model_dir}/plots/linear_interpolations' if not os.path.isdir(savepath): os.mkdir(savepath) plot_interpolations(model, hparams, gen_dataloader, savepath, num_interpolated_points=10, method='linear', custom_data=custom_data) # Constant r interpolations savepath = f'{model.model_dir}/plots/constant_r_interpolations' if not os.path.isdir(savepath): os.mkdir(savepath) plot_interpolations(model, hparams, gen_dataloader, savepath, num_interpolated_points=10, method='constant_radius', custom_data=custom_data) # TODO # Translations # Check geometric organistation of the latent space per species savepath = f'{model.model_dir}/plots/stats' if not os.path.isdir(savepath): os.mkdir(savepath) # latent_space_stats_per_species(model, gen_dataloader, savepath) plot_tsne_latent(model, gen_dataloader, savepath)
def main(config_path, loading_epoch, source_path, contamination_path, contamination_parameters, method): # load model model, _, _, _, hparams, _, _, config_path = get_model_and_dataset( config=config_path, loading_epoch=loading_epoch) # set savepath savepath = f'{model.model_dir}/plots/contaminations' if not os.path.isdir(savepath): os.makedirs(savepath) # load files source = {} source['path'] = source_path waveform, chunks, start_samples, end_samples = get_chunks( path=source['path'], hparams=hparams) source['waveform'] = waveform source['chunks'] = chunks source['start_samples'] = start_samples source['end_samples'] = end_samples contamination = {} contamination['path'] = contamination_path waveform, chunks, start_samples, end_samples = get_chunks( path=contamination['path'], hparams=hparams) contamination['waveform'] = waveform contamination['chunks'] = chunks contamination['start_samples'] = start_samples contamination['end_samples'] = end_samples # Choose which samples to contaminate and by which degree contamination_indices = [] contamination_degrees = [] xs = [] ys = [] p_contamination = contamination_parameters['p_contamination'] for index, chunk in enumerate(source['chunks']): if random.random() < p_contamination: contamination_indices.append(index) contamination_degrees.append(random.random()) xs.append(chunk) # choose (randomly?) a contaminating syllable ys.append(random.choice(contamination['chunks'])) xs_cuda = cuda_variable(torch.tensor(np.stack(xs))) ys_cuda = cuda_variable(torch.tensor(np.stack(ys))) # Encode mu, logvar = model.encode(xs_cuda) x_z = model.reparameterize(mu, logvar) mu, logvar = model.encode(ys_cuda) y_z = model.reparameterize(mu, logvar) z_out = torch.zeros_like(x_z) # Contaminate for batch_ind, t in enumerate(contamination_degrees): if method == 'linear': z_out[batch_ind] = x_z[batch_ind] * (1 - t) + y_z[batch_ind] * t elif method == 'constant_radius': z_out[batch_ind] = constant_radius_interpolation( x_z[batch_ind], y_z[batch_ind], t) # Decode z x_recon = model.decode(z_out).cpu().detach().numpy() # Replace contamined samples in original wave out_wave = source['waveform'] mel_basis = build_mel_basis(hparams, hparams.sr, hparams.sr) mel_inversion_basis = build_mel_inversion_basis(mel_basis) for batch_index, contamination_index in enumerate(contamination_indices): new_chunk = x_recon[batch_index, 0] new_audio = inv_spectrogram_sp(new_chunk, n_fft=hparams.n_fft, win_length=hparams.win_length_samples, hop_length=hparams.hop_length_samples, ref_level_db=hparams.ref_level_db, power=hparams.power, mel_inversion_basis=mel_inversion_basis) start_sample = source['start_samples'][contamination_index] end_sample = source['end_samples'][contamination_index] length_sample = end_sample - start_sample # FAIRE UN FADE ICI out_wave[start_sample:end_sample] = new_audio[:length_sample] sf.write(f'{savepath}/contamination.wav', out_wave, samplerate=hparams.sr)
def main(debug, sr, num_mel_bins, n_fft, chunk_len, mel_lower_edge_hertz, mel_upper_edge_hertz, hop_length_ms, win_length_ms, ref_level_db, power): # DATASET_ID = 'BIRD_DB_CATH' # DATASET_ID = 'Bird_all' # DATASET_ID = 'Test' # DATASET_ID = 'voizo_all' DATASET_ID = 'voizo_all_segmented' # DATASET_ID = 'voizo_chunks_test_segmented' ind_examples = [200, 400, 500, 600, 800, 1000, 1200, 1400, 1600, 1800] # STFT time parameters if win_length_ms is None: win_length = n_fft else: win_length = ms_to_sample(win_length_ms, sr) if hop_length_ms is None: hop_length = win_length // 4 else: hop_length = ms_to_sample(hop_length_ms, sr) ################################################################################ if chunk_len['type'] == 'ms': chunk_len_ms = chunk_len['value'] chunk_len_samples_not_rounded = ms_to_sample(chunk_len_ms, sr) chunk_len_win = round( (chunk_len_samples_not_rounded - win_length) / hop_length) + 1 chunk_len_samples = (chunk_len_win - 1) * hop_length + win_length elif chunk_len['type'] == 'samples': chunk_len_samples_not_rounded = chunk_len['value'] chunk_len_win = round( (chunk_len_samples_not_rounded - win_length) / hop_length) + 1 chunk_len_samples = (chunk_len_win - 1) * hop_length + win_length elif chunk_len['type'] == 'stft_win': chunk_len_win = chunk_len['value'] chunk_len_samples = (chunk_len_win - 1) * hop_length + win_length ################################################################################ print('Chunk length is automatically set to match STFT windows/hop sizes') print( f'STFT win length: {win_length} samples, {1000 * win_length / sr} ms') print(f'STFT hop length: {hop_length} samples, {1000* hop_length / sr} ms') print( f'Chunk length: {chunk_len_samples} samples, {chunk_len_win} win, {chunk_len_samples * 1000 / sr} ms' ) ################################################################################ print('Create dataset') hparams = HParams( sr=sr, num_mel_bins=num_mel_bins, n_fft=n_fft, chunk_len_samples=chunk_len_samples, mel_lower_edge_hertz=mel_lower_edge_hertz, mel_upper_edge_hertz=mel_upper_edge_hertz, power=power, # for spectral inversion butter_lowcut=mel_lower_edge_hertz, butter_highcut=mel_upper_edge_hertz, ref_level_db=ref_level_db, preemphasis=0.97, mask_spec=False, win_length_samples=win_length, hop_length_samples=hop_length, mask_spec_kwargs={ "spec_thresh": 0.9, "offset": 1e-10 }, reduce_noise=True, noise_reduce_kwargs={ "n_std_thresh": 2.0, "prop_decrease": 0.8 }, n_jobs=1, verbosity=1, ) suffix = hparams.__repr__() if debug: dump_folder = f'dump/{suffix}' if os.path.isdir(dump_folder): shutil.rmtree(dump_folder) os.makedirs(dump_folder) else: dump_folder = None dataset = DataSet(DATASET_ID, hparams=hparams) print(f'Number files: {len(dataset.data_files)}') ################################################################################ print('Create a dataset based upon JSON') verbosity = 10 with Parallel(n_jobs=1, verbose=verbosity) as parallel: syllable_dfs = parallel( delayed(create_label_df)( dataset.data_files[key].data, hparams=dataset.hparams, labels_to_retain=[], unit="syllables", dict_features_to_retain=['species'], key=key, ) for key in tqdm(dataset.data_files.keys())) syllable_df = pd.concat(syllable_dfs) ################################################################################ print('Get audio for dataset') mel_basis = build_mel_basis(hparams, hparams.sr, hparams.sr) mel_inversion_basis = build_mel_inversion_basis(mel_basis) counter = 0 save_loc = DATA_DIR / 'syllables' / f'{DATASET_ID}_{suffix}' if os.path.isdir(save_loc): raise Exception('already exists') os.makedirs(save_loc) skipped_counter = 0 for key in syllable_df.key.unique(): # load audio (key.unique is for loading large wavfiles only once) this_syllable_df = syllable_df[syllable_df.key == key] wav_loc = dataset.data_files[key].data['wav_loc'] print(f'{wav_loc}') data, _ = prepare_wav(wav_loc, hparams, debug=debug) data = data.astype('float32') # process each syllable for syll_ind, (st, et) in enumerate( zip(this_syllable_df.start_time.values, this_syllable_df.end_time.values)): s = data[int(st * hparams.sr):int(et * hparams.sr)] sn, mS, debug_info = process_syllable(syl=s, hparams=hparams, mel_basis=mel_basis, debug=debug) if sn is None: skipped_counter += 1 continue # Save as uint to save space mS_int = (mS * 255).astype('uint8') save_dict = { 'mS_int': mS_int, 'sn': sn, 'indv': this_syllable_df.indv[syll_ind], 'label': this_syllable_df.species[syll_ind] } fname = save_loc / str(counter) with open(fname, 'wb') as ff: pkl.dump(save_dict, ff) counter += 1 if debug and (counter in ind_examples): # normalised audio sf.write(f'{dump_folder}/{counter}_sn.wav', sn, samplerate=hparams.sr) # Padded mel db norm spectro plt.clf() plt.matshow(mS, origin="lower") plt.savefig(f'{dump_folder}/{counter}_mS.pdf') plt.close() audio_reconstruct = inv_spectrogram_sp( mS, n_fft=hparams.n_fft, win_length=hparams.win_length_samples, hop_length=hparams.hop_length_samples, ref_level_db=hparams.ref_level_db, power=hparams.power, mel_inversion_basis=mel_inversion_basis) sf.write(f'{dump_folder}/{counter}_mS.wav', audio_reconstruct, samplerate=hparams.sr) print(f'Skipped counter: {skipped_counter}') # Save hparams print("Save hparams") hparams_loc = f'{save_loc}_hparams.pkl' with open(hparams_loc, 'wb') as ff: pkl.dump(hparams, ff) # Print hparams print("Print hparams") hparams_loc = f'{save_loc}_hparams.txt' with open(hparams_loc, 'w') as ff: for k, v in hparams.__dict__.items(): ff.write(f'{k}: {v}\n')
def single_file_test(sr, num_mel_bins, n_fft, mel_lower_edge_hertz, mel_upper_edge_hertz, hop_length_ms, win_length_ms, power, ref_level_db, wav_loc, index=0): duration_sec = 10 hparams = HParams( sr=sr, num_mel_bins=num_mel_bins, n_fft=n_fft, mel_lower_edge_hertz=mel_lower_edge_hertz, mel_upper_edge_hertz=mel_upper_edge_hertz, power=power, # for spectral inversion butter_lowcut=mel_lower_edge_hertz, butter_highcut=mel_upper_edge_hertz, ref_level_db=ref_level_db, mask_spec=False, win_length_ms=win_length_ms, hop_length_ms=hop_length_ms, mask_spec_kwargs={ "spec_thresh": 0.9, "offset": 1e-10 }, n_jobs=1, verbosity=1, reduce_noise=True, noise_reduce_kwargs={ "n_std_thresh": 2.0, "prop_decrease": 0.8 }, ) suffix = hparams.__repr__() dump_folder = 'dump' / f'{suffix}' if not os.path.isdir(dump_folder): os.makedirs(dump_folder) # Read wave data, _ = prepare_wav(wav_loc, hparams, debug=True) data = data[:hparams.sr * duration_sec] # create spec mel_basis = build_mel_basis(hparams, hparams.sr, hparams.sr) mel_inversion_basis = build_mel_inversion_basis(mel_basis) # mel_basis = None # mel_inversion_basis = None plt.clf() plt.matshow(mel_basis) plt.savefig(f'{dump_folder}/mel_basis.pdf') plt.close() plt.clf() plt.matshow(mel_inversion_basis) plt.savefig(f'{dump_folder}/mel_inversion_basis.pdf') plt.close() _, debug_info = spectrogram_sp(y=data, sr=hparams.sr, n_fft=hparams.n_fft, win_length_ms=hparams.win_length_ms, hop_length_ms=hparams.hop_length_ms, ref_level_db=hparams.ref_level_db, _mel_basis=mel_basis, pre_emphasis=hparams.preemphasis, power=hparams.power, debug=True) if debug_info is None: print('chunk too short') return from avgn.signalprocessing.spectrogramming_scipy import griffinlim_sp, _mel_to_linear sf.write(f'{dump_folder}/{index}_syllable.wav', data, samplerate=hparams.sr) # preemphasis y sf.write(f'{dump_folder}/{index}_preemphasis_y.wav', debug_info['preemphasis_y'], samplerate=hparams.sr) # S_abs plt.clf() plt.matshow(debug_info['S_abs'][:, :200], origin="lower") plt.savefig(f'{dump_folder}/{index}_S_abs.pdf') plt.close() S_abs_inv = griffinlim_sp(debug_info['S'], hparams.sr, hparams) sf.write(f'{dump_folder}/{index}_S_abs.wav', S_abs_inv, samplerate=hparams.sr) # mel plt.clf() plt.matshow(debug_info['mel'][:, :200], origin="lower") plt.savefig(f'{dump_folder}/{index}_mel.pdf') plt.close() if mel_basis is not None: mel_inv = griffinlim_sp( _mel_to_linear(debug_info['mel'], _mel_inverse_basis=mel_inversion_basis), hparams.sr, hparams) sf.write(f'{dump_folder}/{index}_mel.wav', mel_inv, samplerate=hparams.sr) # mel_db plt.clf() plt.matshow(debug_info['mel_db'][:, :200], origin="lower") plt.savefig(f'{dump_folder}/{index}_mel_db.pdf') plt.close() aa_ = debug_info['mel_db'] + hparams.ref_level_db bb_ = _db_to_amplitude(aa_) if mel_basis is not None: cc_ = _mel_to_linear(bb_, _mel_inverse_basis=mel_inversion_basis) else: cc_ = bb_ mel_db_inv = griffinlim_sp(cc_, hparams.sr, hparams) sf.write(f'{dump_folder}/{index}_mel_db.wav', mel_db_inv, samplerate=hparams.sr) ################################## ################################## f, t, S = signal.stft(x=mel_db_inv, fs=hparams.sr, window='hann', nperseg=1024, noverlap=256, nfft=None, detrend=False, return_onesided=True, boundary='zeros', padded=True, axis=-1) S_abs = np.abs(S) S_view = _amplitude_to_db(S_abs) plt.clf() plt.matshow(S_view) plt.savefig(f'{dump_folder}/{index}_mel_db_RECON.pdf') plt.close() ################################## ################################## # mel_db_norm plt.clf() plt.matshow(debug_info['mel_db_norm'][:, :200], origin="lower") plt.savefig(f'{dump_folder}/{index}_mel_db_norm.pdf') plt.close() mel_db_norm_inv = inv_spectrogram_sp( debug_info['mel_db_norm'], hparams.sr, hparams, mel_inversion_basis=mel_inversion_basis) sf.write(f'{dump_folder}/{index}_mel_db_norm.wav', mel_db_norm_inv, samplerate=hparams.sr)
def plot_reconstruction(model, hparams, dataloader, savepath, custom_data): # Forward pass model.eval() if custom_data is None: for _, data in enumerate(dataloader): x_orig = data['input'].numpy() x_cuda = cuda_variable(data['input']) x_recon = model.reconstruct(x_cuda).cpu().detach().numpy() break else: x_orig = custom_data['all_data'] x_cuda = cuda_variable(torch.tensor(custom_data['all_data'])) x_recon = model.reconstruct(x_cuda).cpu().detach().numpy() # Plot dims = x_recon.shape[2:] num_examples = x_recon.shape[0] plt.clf() fig, axes = plt.subplots(nrows=2, ncols=num_examples) for i in range(num_examples): # show the image axes[0, i].matshow(x_orig[i].reshape(dims), origin="lower") axes[1, i].matshow(x_recon[i].reshape(dims), origin="lower") for ax in fig.get_axes(): ax.set_xticks([]) ax.set_yticks([]) plt.savefig(f'{savepath}/spectro.pdf') plt.close('all') # audio original_audios = [] reconstruction_audios = [] if hparams is not None: mel_basis = build_mel_basis(hparams, hparams.sr, hparams.sr) mel_inversion_basis = build_mel_inversion_basis(mel_basis) for i in range(num_examples): original_audio = inv_spectrogram_sp( x_orig[i, 0], n_fft=hparams.n_fft, win_length=hparams.win_length_samples, hop_length=hparams.hop_length_samples, ref_level_db=hparams.ref_level_db, power=hparams.power, mel_inversion_basis=mel_inversion_basis) recon_audio = inv_spectrogram_sp( x_recon[i, 0], n_fft=hparams.n_fft, win_length=hparams.win_length_samples, hop_length=hparams.hop_length_samples, ref_level_db=hparams.ref_level_db, power=hparams.power, mel_inversion_basis=mel_inversion_basis) sf.write(f'{savepath}/{i}_original.wav', original_audio, samplerate=hparams.sr) sf.write(f'{savepath}/{i}_recon.wav', recon_audio, samplerate=hparams.sr) original_audios.append(original_audio) reconstruction_audios.append(recon_audio) return { 'original_audios': original_audios, 'reconstruction_audios': reconstruction_audios, 'original_spectros': x_orig, 'reconstruction_spectros': x_recon }