class MelLoader(Dataset): """ This is the main class that calculates the spectrogram and returns the spectrogram, audio pair. """ def __init__(self, prepare_ds_data: PreparedDataList, hparams: HParams, logger: Logger): self.taco_stft = TacotronSTFT(hparams, logger=logger) self.hparams = hparams self._logger = logger data = prepare_ds_data random.seed(hparams.seed) random.shuffle(data) wav_paths = {} for i, values in enumerate(data.items()): wav_paths[i] = values.wav_path self.wav_paths = wav_paths if hparams.cache_wavs: self._logger.info("Loading wavs into memory...") cache = {} for i, wav_path in tqdm(wav_paths.items()): cache[i] = self.taco_stft.get_wav_tensor_from_file(wav_path) self._logger.info("Done") self.cache = cache def __getitem__(self, index): if self.hparams.cache_wavs: wav_tensor = self.cache[index].clone().detach() else: wav_tensor = self.taco_stft.get_wav_tensor_from_file( self.wav_paths[index]) wav_tensor = get_wav_tensor_segment(wav_tensor, self.hparams.segment_length) mel_tensor = self.taco_stft.get_mel_tensor(wav_tensor) return (mel_tensor, wav_tensor) def __len__(self): return len(self.wav_paths)
def infer( mel_entries: List[InferMelEntry], checkpoint: CheckpointWaveglow, custom_hparams: Optional[Dict[str, str]], denoiser_strength: float, sigma: float, sentence_pause_s: float, save_callback: Callable[[InferenceEntryOutput], None], concatenate: bool, seed: int, logger: Logger ) -> Tuple[InferenceEntries, Tuple[Optional[np.ndarray], int]]: inference_entries = InferenceEntries() if len(mel_entries) == 0: logger.info("Nothing to synthesize!") return inference_entries synth = Synthesizer(checkpoint=checkpoint, custom_hparams=custom_hparams, logger=logger) # Check mels have the same sampling rate as trained waveglow model for mel_entry in mel_entries: assert mel_entry.sr == synth.hparams.sampling_rate taco_stft = TacotronSTFT(synth.hparams, logger=logger) mels_torch = [] mels_torch_prepared = [] for mel_entry in mel_entries: mel_torch = mel_to_torch(mel_entry.mel) mels_torch.append(mel_torch) mel_var = torch.autograd.Variable(mel_torch) mel_var = mel_var.cuda() mel_var = mel_var.unsqueeze(0) mels_torch_prepared.append(mel_var) inference_results = synth.infer_all(mels_torch_prepared, sigma, denoiser_strength, seed=seed) complete_wav_denoised: Optional[np.ndarray] = None if concatenate: if len(inference_results) >= 1: logger.info("Concatening audios...") complete_wav_denoised = concatenate_audios( [x.wav_denoised for x in inference_results], sentence_pause_s, synth.hparams.sampling_rate) complete_wav_denoised = normalize_wav(complete_wav_denoised) if len(inference_results) >= 1: logger.info("Done.") inference_result: InferenceResult mel_entry: InferMelEntry for mel_entry, inference_result in tqdm(zip(mel_entries, inference_results)): wav_inferred_denoised_normalized = normalize_wav( inference_result.wav_denoised) timepoint = f"{datetime.datetime.now():%Y/%m/%d %H:%M:%S}" val_entry = InferenceEntry( identifier=mel_entry.identifier, iteration=checkpoint.iteration, timepoint=timepoint, sampling_rate=inference_result.sampling_rate, inference_duration_s=inference_result.inference_duration_s, was_overamplified=inference_result.was_overamplified, denoising_duration_s=inference_result.denoising_duration_s, inferred_duration_s=get_duration_s(inference_result.wav_denoised, inference_result.sampling_rate), denoiser_strength=denoiser_strength, sigma=sigma, ) mel_orig = mel_entry.mel wav_inferred_denoised_normalized_tensor = torch.FloatTensor( wav_inferred_denoised_normalized) mel_inferred_denoised = taco_stft.get_mel_tensor( wav_inferred_denoised_normalized_tensor) mel_inferred_denoised = mel_inferred_denoised.numpy() validation_entry_output = InferenceEntryOutput( identifier=mel_entry.identifier, mel_orig=mel_orig, inferred_sr=inference_result.sampling_rate, mel_inferred_denoised=mel_inferred_denoised, wav_inferred_denoised=wav_inferred_denoised_normalized, orig_sr=mel_entry.sr, wav_inferred=normalize_wav(inference_result.wav), mel_denoised_diff_img=None, mel_inferred_denoised_img=None, mel_orig_img=None, ) mcd_dtw, penalty_dtw, final_frame_number_dtw = get_mcd_between_mel_spectograms( mel_1=mel_orig, mel_2=mel_inferred_denoised, n_mfcc=MCD_NO_OF_COEFFS_PER_FRAME, take_log=False, use_dtw=True, ) val_entry.diff_frames = mel_inferred_denoised.shape[ 1] - mel_orig.shape[1] val_entry.mcd_dtw = mcd_dtw val_entry.mcd_dtw_penalty = penalty_dtw val_entry.mcd_dtw_frames = final_frame_number_dtw mcd, penalty, final_frame_number = get_mcd_between_mel_spectograms( mel_1=mel_orig, mel_2=mel_inferred_denoised, n_mfcc=MCD_NO_OF_COEFFS_PER_FRAME, take_log=False, use_dtw=False, ) val_entry.mcd = mcd val_entry.mcd_penalty = penalty val_entry.mcd_frames = final_frame_number cosine_similarity = cosine_dist_mels(mel_orig, mel_inferred_denoised) val_entry.cosine_similarity = cosine_similarity mel_original_img_raw, mel_original_img = plot_melspec_np(mel_orig) mel_inferred_denoised_img_raw, mel_inferred_denoised_img = plot_melspec_np( mel_inferred_denoised) validation_entry_output.mel_orig_img = mel_original_img validation_entry_output.mel_inferred_denoised_img = mel_inferred_denoised_img mel_original_img_raw_same_dim, mel_inferred_denoised_img_raw_same_dim = make_same_width_by_filling_white( img_a=mel_original_img_raw, img_b=mel_inferred_denoised_img_raw, ) mel_original_img_same_dim, mel_inferred_denoised_img_same_dim = make_same_width_by_filling_white( img_a=mel_original_img, img_b=mel_inferred_denoised_img, ) structural_similarity_raw, mel_difference_denoised_img_raw = calculate_structual_similarity_np( img_a=mel_original_img_raw_same_dim, img_b=mel_inferred_denoised_img_raw_same_dim, ) val_entry.structural_similarity = structural_similarity_raw structural_similarity, mel_denoised_diff_img = calculate_structual_similarity_np( img_a=mel_original_img_same_dim, img_b=mel_inferred_denoised_img_same_dim, ) validation_entry_output.mel_denoised_diff_img = mel_denoised_diff_img imageio.imsave("/tmp/mel_original_img_raw.png", mel_original_img_raw) imageio.imsave("/tmp/mel_inferred_img_raw.png", mel_inferred_denoised_img_raw) imageio.imsave("/tmp/mel_difference_denoised_img_raw.png", mel_difference_denoised_img_raw) # logger.info(val_entry) logger.info(f"Current: {val_entry.identifier}") logger.info(f"MCD DTW: {val_entry.mcd_dtw}") logger.info(f"MCD DTW penalty: {val_entry.mcd_dtw_penalty}") logger.info(f"MCD DTW frames: {val_entry.mcd_dtw_frames}") logger.info(f"MCD: {val_entry.mcd}") logger.info(f"MCD penalty: {val_entry.mcd_penalty}") logger.info(f"MCD frames: {val_entry.mcd_frames}") # logger.info(f"MCD DTW V2: {val_entry.mcd_dtw_v2}") logger.info( f"Structural Similarity: {val_entry.structural_similarity}") logger.info(f"Cosine Similarity: {val_entry.cosine_similarity}") save_callback(validation_entry_output) inference_entries.append(val_entry) return inference_entries, (complete_wav_denoised, synth.hparams.sampling_rate)
def validate(checkpoint: CheckpointWaveglow, data: PreparedDataList, custom_hparams: Optional[Dict[str, str]], denoiser_strength: float, sigma: float, entry_ids: Optional[Set[int]], train_name: str, full_run: bool, save_callback: Callable[[PreparedData, ValidationEntryOutput], None], seed: int, logger: Logger): validation_entries = ValidationEntries() if full_run: entries = data else: speaker_id: Optional[int] = None entries = PreparedDataList(data.get_for_validation(entry_ids, speaker_id)) if len(entries) == 0: logger.info("Nothing to synthesize!") return validation_entries synth = Synthesizer( checkpoint=checkpoint, custom_hparams=custom_hparams, logger=logger ) taco_stft = TacotronSTFT(synth.hparams, logger=logger) for entry in entries.items(True): mel = taco_stft.get_mel_tensor_from_file(entry.wav_path) mel_var = torch.autograd.Variable(mel) mel_var = mel_var.cuda() mel_var = mel_var.unsqueeze(0) inference_result = synth.infer(mel_var, sigma, denoiser_strength, seed=seed) wav_inferred_denoised_normalized = normalize_wav(inference_result.wav_denoised) symbol_count = len(deserialize_list(entry.serialized_symbol_ids)) unique_symbols_count = len(set(deserialize_list(entry.serialized_symbol_ids))) timepoint = f"{datetime.datetime.now():%Y/%m/%d %H:%M:%S}" val_entry = ValidationEntry( entry_id=entry.entry_id, ds_entry_id=entry.ds_entry_id, text_original=entry.text_original, text=entry.text, seed=seed, wav_path=entry.wav_path, original_duration_s=entry.duration_s, speaker_id=entry.speaker_id, iteration=checkpoint.iteration, unique_symbols_count=unique_symbols_count, symbol_count=symbol_count, timepoint=timepoint, train_name=train_name, sampling_rate=inference_result.sampling_rate, inference_duration_s=inference_result.inference_duration_s, was_overamplified=inference_result.was_overamplified, denoising_duration_s=inference_result.denoising_duration_s, inferred_duration_s=get_duration_s( inference_result.wav_denoised, inference_result.sampling_rate), denoiser_strength=denoiser_strength, sigma=sigma, ) val_entry.diff_duration_s = val_entry.inferred_duration_s - val_entry.original_duration_s mel_orig = mel.cpu().numpy() mel_inferred_denoised_tensor = torch.FloatTensor(wav_inferred_denoised_normalized) mel_inferred_denoised = taco_stft.get_mel_tensor(mel_inferred_denoised_tensor) mel_inferred_denoised = mel_inferred_denoised.numpy() wav_orig, orig_sr = wav_to_float32(entry.wav_path) validation_entry_output = ValidationEntryOutput( mel_orig=mel_orig, inferred_sr=inference_result.sampling_rate, mel_inferred_denoised=mel_inferred_denoised, wav_inferred_denoised=wav_inferred_denoised_normalized, wav_orig=wav_orig, orig_sr=orig_sr, wav_inferred=normalize_wav(inference_result.wav), mel_denoised_diff_img=None, mel_inferred_denoised_img=None, mel_orig_img=None, ) mcd_dtw, penalty_dtw, final_frame_number_dtw = get_mcd_between_mel_spectograms( mel_1=mel_orig, mel_2=mel_inferred_denoised, n_mfcc=MCD_NO_OF_COEFFS_PER_FRAME, take_log=False, use_dtw=True, ) val_entry.diff_frames = mel_inferred_denoised.shape[1] - mel_orig.shape[1] val_entry.mcd_dtw = mcd_dtw val_entry.mcd_dtw_penalty = penalty_dtw val_entry.mcd_dtw_frames = final_frame_number_dtw mcd, penalty, final_frame_number = get_mcd_between_mel_spectograms( mel_1=mel_orig, mel_2=mel_inferred_denoised, n_mfcc=MCD_NO_OF_COEFFS_PER_FRAME, take_log=False, use_dtw=False, ) val_entry.mcd = mcd val_entry.mcd_penalty = penalty val_entry.mcd_frames = final_frame_number cosine_similarity = cosine_dist_mels(mel_orig, mel_inferred_denoised) val_entry.cosine_similarity = cosine_similarity mel_original_img_raw, mel_original_img = plot_melspec_np(mel_orig) mel_inferred_denoised_img_raw, mel_inferred_denoised_img = plot_melspec_np( mel_inferred_denoised) validation_entry_output.mel_orig_img = mel_original_img validation_entry_output.mel_inferred_denoised_img = mel_inferred_denoised_img mel_original_img_raw_same_dim, mel_inferred_denoised_img_raw_same_dim = make_same_width_by_filling_white( img_a=mel_original_img_raw, img_b=mel_inferred_denoised_img_raw, ) mel_original_img_same_dim, mel_inferred_denoised_img_same_dim = make_same_width_by_filling_white( img_a=mel_original_img, img_b=mel_inferred_denoised_img, ) structural_similarity_raw, mel_difference_denoised_img_raw = calculate_structual_similarity_np( img_a=mel_original_img_raw_same_dim, img_b=mel_inferred_denoised_img_raw_same_dim, ) val_entry.structural_similarity = structural_similarity_raw structural_similarity, mel_denoised_diff_img = calculate_structual_similarity_np( img_a=mel_original_img_same_dim, img_b=mel_inferred_denoised_img_same_dim, ) validation_entry_output.mel_denoised_diff_img = mel_denoised_diff_img imageio.imsave("/tmp/mel_original_img_raw.png", mel_original_img_raw) imageio.imsave("/tmp/mel_inferred_img_raw.png", mel_inferred_denoised_img_raw) imageio.imsave("/tmp/mel_difference_denoised_img_raw.png", mel_difference_denoised_img_raw) # logger.info(val_entry) logger.info(f"Current: {val_entry.entry_id}") logger.info(f"MCD DTW: {val_entry.mcd_dtw}") logger.info(f"MCD DTW penalty: {val_entry.mcd_dtw_penalty}") logger.info(f"MCD DTW frames: {val_entry.mcd_dtw_frames}") logger.info(f"MCD: {val_entry.mcd}") logger.info(f"MCD penalty: {val_entry.mcd_penalty}") logger.info(f"MCD frames: {val_entry.mcd_frames}") # logger.info(f"MCD DTW V2: {val_entry.mcd_dtw_v2}") logger.info(f"Structural Similarity: {val_entry.structural_similarity}") logger.info(f"Cosine Similarity: {val_entry.cosine_similarity}") save_callback(entry, validation_entry_output) validation_entries.append(val_entry) #score, diff_img = compare_mels(a, b) return validation_entries