def process(data: WavDataList, ds: DsDataList, wav_dir: Path, custom_hparams: Optional[Dict[str, str]], save_callback: Callable[[WavData, DsData], Path]) -> List[Path]: hparams = TSTFTHParams() hparams = overwrite_custom_hparams(hparams, custom_hparams) mel_parser = TacotronSTFT(hparams, logger=getLogger()) all_paths: List[Path] = [] for wav_entry, ds_entry in zip(data.items(True), ds.items(True)): absolute_wav_path = wav_dir / wav_entry.wav_relative_path mel_tensor = mel_parser.get_mel_tensor_from_file(absolute_wav_path) absolute_path = save_callback(wav_entry=wav_entry, ds_entry=ds_entry, mel_tensor=mel_tensor) all_paths.append(absolute_path) return all_paths
class MelLoader(Dataset): """ This is the main class that calculates the spectrogram and returns the spectrogram, audio pair. """ def __init__(self, prepare_ds_data: PreparedDataList, hparams: HParams, logger: Logger): self.taco_stft = TacotronSTFT(hparams, logger=logger) self.hparams = hparams self._logger = logger data = prepare_ds_data random.seed(hparams.seed) random.shuffle(data) wav_paths = {} for i, values in enumerate(data.items()): wav_paths[i] = values.wav_path self.wav_paths = wav_paths if hparams.cache_wavs: self._logger.info("Loading wavs into memory...") cache = {} for i, wav_path in tqdm(wav_paths.items()): cache[i] = self.taco_stft.get_wav_tensor_from_file(wav_path) self._logger.info("Done") self.cache = cache def __getitem__(self, index): if self.hparams.cache_wavs: wav_tensor = self.cache[index].clone().detach() else: wav_tensor = self.taco_stft.get_wav_tensor_from_file( self.wav_paths[index]) wav_tensor = get_wav_tensor_segment(wav_tensor, self.hparams.segment_length) mel_tensor = self.taco_stft.get_mel_tensor(wav_tensor) return (mel_tensor, wav_tensor) def __len__(self): return len(self.wav_paths)
def remove_silence_plot(wav_path: Path, out_path: Path, chunk_size: int, threshold_start: float, threshold_end: float, buffer_start_ms: float, buffer_end_ms: float): remove_silence_file( in_path=wav_path, out_path=out_path, chunk_size=chunk_size, threshold_start=threshold_start, threshold_end=threshold_end, buffer_start_ms=buffer_start_ms, buffer_end_ms=buffer_end_ms ) sampling_rate, _ = read(wav_path) hparams = TSTFTHParams() hparams.sampling_rate = sampling_rate plotter = TacotronSTFT(hparams, logger=getLogger()) mel_orig = plotter.get_mel_tensor_from_file(wav_path) mel_trimmed = plotter.get_mel_tensor_from_file(out_path) return mel_orig, mel_trimmed
def __init__(self, prepare_ds_data: PreparedDataList, hparams: HParams, logger: Logger): self.taco_stft = TacotronSTFT(hparams, logger=logger) self.hparams = hparams self._logger = logger data = prepare_ds_data random.seed(hparams.seed) random.shuffle(data) wav_paths = {} for i, values in enumerate(data.items()): wav_paths[i] = values.wav_path self.wav_paths = wav_paths if hparams.cache_wavs: self._logger.info("Loading wavs into memory...") cache = {} for i, wav_path in tqdm(wav_paths.items()): cache[i] = self.taco_stft.get_wav_tensor_from_file(wav_path) self._logger.info("Done") self.cache = cache
def process(data: WavDataList, wav_dir: Path, custom_hparams: Optional[Dict[str, str]], save_callback: Callable[[WavData, Tensor], str], n_jobs: int) -> MelDataList: hparams = TSTFTHParams() hparams = overwrite_custom_hparams(hparams, custom_hparams) mel_parser = TacotronSTFT(hparams, logger=getLogger()) mt_method = partial( process_entry, wav_dir=wav_dir, mel_parser=mel_parser, save_callback=save_callback, ) with ThreadPoolExecutor(max_workers=n_jobs) as ex: result = MelDataList(tqdm(ex.map(mt_method, data.items()), total=len(data))) return result
def __init__(self, data: PreparedDataList, hparams: HParams, logger: Logger): # random.seed(hparams.seed) # random.shuffle(data) self.use_saved_mels = hparams.use_saved_mels if not hparams.use_saved_mels: self.mel_parser = TacotronSTFT(hparams, logger) logger.info("Reading files...") self.data: Dict[int, Tuple[IntTensor, IntTensor, str, int]] = {} for i, values in enumerate(data.items(True)): symbol_ids = deserialize_list(values.serialized_symbol_ids) accent_ids = deserialize_list(values.serialized_accent_ids) model_symbol_ids = get_accent_symbol_ids( symbol_ids, accent_ids, hparams.n_symbols, hparams.accents_use_own_symbols, SHARED_SYMBOLS_COUNT) symbols_tensor = IntTensor(model_symbol_ids) accents_tensor = IntTensor(accent_ids) if hparams.use_saved_mels: self.data[i] = (symbols_tensor, accents_tensor, values.mel_path, values.speaker_id) else: self.data[i] = (symbols_tensor, accents_tensor, values.wav_path, values.speaker_id) if hparams.use_saved_mels and hparams.cache_mels: logger.info("Loading mels into memory...") self.cache: Dict[int, Tensor] = {} vals: tuple for i, vals in tqdm(self.data.items()): mel_tensor = torch.load(vals[1], map_location='cpu') self.cache[i] = mel_tensor self.use_cache: bool = hparams.cache_mels
def infer( mel_entries: List[InferMelEntry], checkpoint: CheckpointWaveglow, custom_hparams: Optional[Dict[str, str]], denoiser_strength: float, sigma: float, sentence_pause_s: float, save_callback: Callable[[InferenceEntryOutput], None], concatenate: bool, seed: int, logger: Logger ) -> Tuple[InferenceEntries, Tuple[Optional[np.ndarray], int]]: inference_entries = InferenceEntries() if len(mel_entries) == 0: logger.info("Nothing to synthesize!") return inference_entries synth = Synthesizer(checkpoint=checkpoint, custom_hparams=custom_hparams, logger=logger) # Check mels have the same sampling rate as trained waveglow model for mel_entry in mel_entries: assert mel_entry.sr == synth.hparams.sampling_rate taco_stft = TacotronSTFT(synth.hparams, logger=logger) mels_torch = [] mels_torch_prepared = [] for mel_entry in mel_entries: mel_torch = mel_to_torch(mel_entry.mel) mels_torch.append(mel_torch) mel_var = torch.autograd.Variable(mel_torch) mel_var = mel_var.cuda() mel_var = mel_var.unsqueeze(0) mels_torch_prepared.append(mel_var) inference_results = synth.infer_all(mels_torch_prepared, sigma, denoiser_strength, seed=seed) complete_wav_denoised: Optional[np.ndarray] = None if concatenate: if len(inference_results) >= 1: logger.info("Concatening audios...") complete_wav_denoised = concatenate_audios( [x.wav_denoised for x in inference_results], sentence_pause_s, synth.hparams.sampling_rate) complete_wav_denoised = normalize_wav(complete_wav_denoised) if len(inference_results) >= 1: logger.info("Done.") inference_result: InferenceResult mel_entry: InferMelEntry for mel_entry, inference_result in tqdm(zip(mel_entries, inference_results)): wav_inferred_denoised_normalized = normalize_wav( inference_result.wav_denoised) timepoint = f"{datetime.datetime.now():%Y/%m/%d %H:%M:%S}" val_entry = InferenceEntry( identifier=mel_entry.identifier, iteration=checkpoint.iteration, timepoint=timepoint, sampling_rate=inference_result.sampling_rate, inference_duration_s=inference_result.inference_duration_s, was_overamplified=inference_result.was_overamplified, denoising_duration_s=inference_result.denoising_duration_s, inferred_duration_s=get_duration_s(inference_result.wav_denoised, inference_result.sampling_rate), denoiser_strength=denoiser_strength, sigma=sigma, ) mel_orig = mel_entry.mel wav_inferred_denoised_normalized_tensor = torch.FloatTensor( wav_inferred_denoised_normalized) mel_inferred_denoised = taco_stft.get_mel_tensor( wav_inferred_denoised_normalized_tensor) mel_inferred_denoised = mel_inferred_denoised.numpy() validation_entry_output = InferenceEntryOutput( identifier=mel_entry.identifier, mel_orig=mel_orig, inferred_sr=inference_result.sampling_rate, mel_inferred_denoised=mel_inferred_denoised, wav_inferred_denoised=wav_inferred_denoised_normalized, orig_sr=mel_entry.sr, wav_inferred=normalize_wav(inference_result.wav), mel_denoised_diff_img=None, mel_inferred_denoised_img=None, mel_orig_img=None, ) mcd_dtw, penalty_dtw, final_frame_number_dtw = get_mcd_between_mel_spectograms( mel_1=mel_orig, mel_2=mel_inferred_denoised, n_mfcc=MCD_NO_OF_COEFFS_PER_FRAME, take_log=False, use_dtw=True, ) val_entry.diff_frames = mel_inferred_denoised.shape[ 1] - mel_orig.shape[1] val_entry.mcd_dtw = mcd_dtw val_entry.mcd_dtw_penalty = penalty_dtw val_entry.mcd_dtw_frames = final_frame_number_dtw mcd, penalty, final_frame_number = get_mcd_between_mel_spectograms( mel_1=mel_orig, mel_2=mel_inferred_denoised, n_mfcc=MCD_NO_OF_COEFFS_PER_FRAME, take_log=False, use_dtw=False, ) val_entry.mcd = mcd val_entry.mcd_penalty = penalty val_entry.mcd_frames = final_frame_number cosine_similarity = cosine_dist_mels(mel_orig, mel_inferred_denoised) val_entry.cosine_similarity = cosine_similarity mel_original_img_raw, mel_original_img = plot_melspec_np(mel_orig) mel_inferred_denoised_img_raw, mel_inferred_denoised_img = plot_melspec_np( mel_inferred_denoised) validation_entry_output.mel_orig_img = mel_original_img validation_entry_output.mel_inferred_denoised_img = mel_inferred_denoised_img mel_original_img_raw_same_dim, mel_inferred_denoised_img_raw_same_dim = make_same_width_by_filling_white( img_a=mel_original_img_raw, img_b=mel_inferred_denoised_img_raw, ) mel_original_img_same_dim, mel_inferred_denoised_img_same_dim = make_same_width_by_filling_white( img_a=mel_original_img, img_b=mel_inferred_denoised_img, ) structural_similarity_raw, mel_difference_denoised_img_raw = calculate_structual_similarity_np( img_a=mel_original_img_raw_same_dim, img_b=mel_inferred_denoised_img_raw_same_dim, ) val_entry.structural_similarity = structural_similarity_raw structural_similarity, mel_denoised_diff_img = calculate_structual_similarity_np( img_a=mel_original_img_same_dim, img_b=mel_inferred_denoised_img_same_dim, ) validation_entry_output.mel_denoised_diff_img = mel_denoised_diff_img imageio.imsave("/tmp/mel_original_img_raw.png", mel_original_img_raw) imageio.imsave("/tmp/mel_inferred_img_raw.png", mel_inferred_denoised_img_raw) imageio.imsave("/tmp/mel_difference_denoised_img_raw.png", mel_difference_denoised_img_raw) # logger.info(val_entry) logger.info(f"Current: {val_entry.identifier}") logger.info(f"MCD DTW: {val_entry.mcd_dtw}") logger.info(f"MCD DTW penalty: {val_entry.mcd_dtw_penalty}") logger.info(f"MCD DTW frames: {val_entry.mcd_dtw_frames}") logger.info(f"MCD: {val_entry.mcd}") logger.info(f"MCD penalty: {val_entry.mcd_penalty}") logger.info(f"MCD frames: {val_entry.mcd_frames}") # logger.info(f"MCD DTW V2: {val_entry.mcd_dtw_v2}") logger.info( f"Structural Similarity: {val_entry.structural_similarity}") logger.info(f"Cosine Similarity: {val_entry.cosine_similarity}") save_callback(validation_entry_output) inference_entries.append(val_entry) return inference_entries, (complete_wav_denoised, synth.hparams.sampling_rate)
def process_entry(entry: WavData, wav_dir: Path, mel_parser: TacotronSTFT, save_callback: Callable[[WavData, Tensor], str]) -> MelData: absolute_wav_path = wav_dir / entry.wav_relative_path mel_tensor = mel_parser.get_mel_tensor_from_file(absolute_wav_path) path = save_callback(wav_entry=entry, mel_tensor=mel_tensor) mel_data = MelData(entry.entry_id, path, mel_parser.n_mel_channels) return mel_data
def validate(checkpoint: CheckpointTacotron, data: PreparedDataList, trainset: PreparedDataList, custom_hparams: Optional[Dict[str, str]], entry_ids: Optional[Set[int]], speaker_name: Optional[str], train_name: str, full_run: bool, save_callback: Optional[ Callable[[PreparedData, ValidationEntryOutput], None]], max_decoder_steps: int, fast: bool, mcd_no_of_coeffs_per_frame: int, repetitions: int, seed: Optional[int], select_best_from: Optional[pd.DataFrame], logger: Logger) -> ValidationEntries: model_symbols = checkpoint.get_symbols() model_accents = checkpoint.get_accents() model_speakers = checkpoint.get_speakers() seeds: List[int] validation_data: PreparedDataList if full_run: validation_data = data assert seed is not None seeds = [seed for _ in data] elif select_best_from is not None: assert entry_ids is not None logger.info("Finding best seeds...") validation_data = PreparedDataList( [x for x in data.items() if x.entry_id in entry_ids]) if len(validation_data) != len(entry_ids): logger.error("Not all entry_id's were found!") assert False entry_ids_order_from_valdata = OrderedSet( [x.entry_id for x in validation_data.items()]) seeds = get_best_seeds(select_best_from, entry_ids_order_from_valdata, checkpoint.iteration, logger) # entry_ids = [entry_id for entry_id, _ in entry_ids_w_seed] # have_no_double_entries = len(set(entry_ids)) == len(entry_ids) # assert have_no_double_entries # validation_data = PreparedDataList([x for x in data.items() if x.entry_id in entry_ids]) # seeds = [s for _, s in entry_ids_w_seed] # if len(validation_data) != len(entry_ids): # logger.error("Not all entry_id's were found!") # assert False elif entry_ids is not None: validation_data = PreparedDataList( [x for x in data.items() if x.entry_id in entry_ids]) if len(validation_data) != len(entry_ids): logger.error("Not all entry_id's were found!") assert False assert seed is not None seeds = [seed for _ in validation_data] elif speaker_name is not None: speaker_id = model_speakers.get_id(speaker_name) relevant_entries = [ x for x in data.items() if x.speaker_id == speaker_id ] assert len(relevant_entries) > 0 assert seed is not None random.seed(seed) entry = random.choice(relevant_entries) validation_data = PreparedDataList([entry]) seeds = [seed] else: assert seed is not None random.seed(seed) entry = random.choice(data) validation_data = PreparedDataList([entry]) seeds = [seed] assert len(seeds) == len(validation_data) if len(validation_data) == 0: logger.info("Nothing to synthesize!") return validation_data train_onegram_rarities = get_ngram_rarity(validation_data, trainset, model_symbols, 1) train_twogram_rarities = get_ngram_rarity(validation_data, trainset, model_symbols, 2) train_threegram_rarities = get_ngram_rarity(validation_data, trainset, model_symbols, 3) synth = Synthesizer( checkpoint=checkpoint, custom_hparams=custom_hparams, logger=logger, ) taco_stft = TacotronSTFT(synth.hparams, logger=logger) validation_entries = ValidationEntries() for repetition in range(repetitions): # rep_seed = seed + repetition rep_human_readable = repetition + 1 logger.info(f"Starting repetition: {rep_human_readable}/{repetitions}") for entry, entry_seed in zip(validation_data.items(True), seeds): rep_seed = entry_seed + repetition logger.info( f"Current --> entry_id: {entry.entry_id}; seed: {rep_seed}; iteration: {checkpoint.iteration}; rep: {rep_human_readable}/{repetitions}" ) infer_sent = InferSentence( sent_id=1, symbols=model_symbols.get_symbols(entry.serialized_symbol_ids), accents=model_accents.get_accents(entry.serialized_accent_ids), original_text=entry.text_original, ) speaker_name = model_speakers.get_speaker(entry.speaker_id) inference_result = synth.infer( sentence=infer_sent, speaker=speaker_name, ignore_unknown_symbols=False, max_decoder_steps=max_decoder_steps, seed=rep_seed, ) symbol_count = len(deserialize_list(entry.serialized_symbol_ids)) unique_symbols = set( model_symbols.get_symbols(entry.serialized_symbol_ids)) unique_symbols_str = " ".join(list(sorted(unique_symbols))) unique_symbols_count = len(unique_symbols) timepoint = f"{datetime.datetime.now():%Y/%m/%d %H:%M:%S}" mel_orig: np.ndarray = taco_stft.get_mel_tensor_from_file( entry.wav_path).cpu().numpy() target_frames = mel_orig.shape[1] predicted_frames = inference_result.mel_outputs_postnet.shape[1] diff_frames = predicted_frames - target_frames frame_deviation_percent = (predicted_frames / target_frames) - 1 dtw_mcd, dtw_penalty, dtw_frames = get_mcd_between_mel_spectograms( mel_1=mel_orig, mel_2=inference_result.mel_outputs_postnet, n_mfcc=mcd_no_of_coeffs_per_frame, take_log=False, use_dtw=True, ) padded_mel_orig, padded_mel_postnet = make_same_dim( mel_orig, inference_result.mel_outputs_postnet) aligned_mel_orig, aligned_mel_postnet, mel_dtw_dist, _, path_mel_postnet = align_mels_with_dtw( mel_orig, inference_result.mel_outputs_postnet) mel_aligned_length = len(aligned_mel_postnet) msd = get_msd(mel_dtw_dist, mel_aligned_length) padded_cosine_similarity = cosine_dist_mels( padded_mel_orig, padded_mel_postnet) aligned_cosine_similarity = cosine_dist_mels( aligned_mel_orig, aligned_mel_postnet) padded_mse = mean_squared_error(padded_mel_orig, padded_mel_postnet) aligned_mse = mean_squared_error(aligned_mel_orig, aligned_mel_postnet) padded_structural_similarity = None aligned_structural_similarity = None if not fast: padded_mel_orig_img_raw_1, padded_mel_orig_img = plot_melspec_np( padded_mel_orig, title="padded_mel_orig") padded_mel_outputs_postnet_img_raw_1, padded_mel_outputs_postnet_img = plot_melspec_np( padded_mel_postnet, title="padded_mel_postnet") # imageio.imsave("/tmp/padded_mel_orig_img_raw_1.png", padded_mel_orig_img_raw_1) # imageio.imsave("/tmp/padded_mel_outputs_postnet_img_raw_1.png", padded_mel_outputs_postnet_img_raw_1) padded_structural_similarity, padded_mel_postnet_diff_img_raw = calculate_structual_similarity_np( img_a=padded_mel_orig_img_raw_1, img_b=padded_mel_outputs_postnet_img_raw_1, ) # imageio.imsave("/tmp/padded_mel_diff_img_raw_1.png", padded_mel_postnet_diff_img_raw) aligned_mel_orig_img_raw, aligned_mel_orig_img = plot_melspec_np( aligned_mel_orig, title="aligned_mel_orig") aligned_mel_postnet_img_raw, aligned_mel_postnet_img = plot_melspec_np( aligned_mel_postnet, title="aligned_mel_postnet") # imageio.imsave("/tmp/aligned_mel_orig_img_raw.png", aligned_mel_orig_img_raw) # imageio.imsave("/tmp/aligned_mel_postnet_img_raw.png", aligned_mel_postnet_img_raw) aligned_structural_similarity, aligned_mel_diff_img_raw = calculate_structual_similarity_np( img_a=aligned_mel_orig_img_raw, img_b=aligned_mel_postnet_img_raw, ) # imageio.imsave("/tmp/aligned_mel_diff_img_raw.png", aligned_mel_diff_img_raw) train_combined_rarity = train_onegram_rarities[entry.entry_id] + \ train_twogram_rarities[entry.entry_id] + train_threegram_rarities[entry.entry_id] val_entry = ValidationEntry( entry_id=entry.entry_id, repetition=rep_human_readable, repetitions=repetitions, seed=rep_seed, ds_entry_id=entry.ds_entry_id, text_original=entry.text_original, text=entry.text, wav_path=entry.wav_path, wav_duration_s=entry.duration_s, speaker_id=entry.speaker_id, speaker_name=speaker_name, iteration=checkpoint.iteration, unique_symbols=unique_symbols_str, unique_symbols_count=unique_symbols_count, symbol_count=symbol_count, timepoint=timepoint, train_name=train_name, sampling_rate=synth.get_sampling_rate(), reached_max_decoder_steps=inference_result. reached_max_decoder_steps, inference_duration_s=inference_result.inference_duration_s, predicted_frames=predicted_frames, target_frames=target_frames, diff_frames=diff_frames, frame_deviation_percent=frame_deviation_percent, padded_cosine_similarity=padded_cosine_similarity, mfcc_no_coeffs=mcd_no_of_coeffs_per_frame, mfcc_dtw_mcd=dtw_mcd, mfcc_dtw_penalty=dtw_penalty, mfcc_dtw_frames=dtw_frames, padded_structural_similarity=padded_structural_similarity, padded_mse=padded_mse, msd=msd, aligned_cosine_similarity=aligned_cosine_similarity, aligned_mse=aligned_mse, aligned_structural_similarity=aligned_structural_similarity, global_one_gram_rarity=entry.one_gram_rarity, global_two_gram_rarity=entry.two_gram_rarity, global_three_gram_rarity=entry.three_gram_rarity, global_combined_rarity=entry.combined_rarity, train_one_gram_rarity=train_onegram_rarities[entry.entry_id], train_two_gram_rarity=train_twogram_rarities[entry.entry_id], train_three_gram_rarity=train_threegram_rarities[ entry.entry_id], train_combined_rarity=train_combined_rarity, ) validation_entries.append(val_entry) if not fast: orig_sr, orig_wav = read(entry.wav_path) _, padded_mel_postnet_diff_img = calculate_structual_similarity_np( img_a=padded_mel_orig_img, img_b=padded_mel_outputs_postnet_img, ) _, aligned_mel_postnet_diff_img = calculate_structual_similarity_np( img_a=aligned_mel_orig_img, img_b=aligned_mel_postnet_img, ) _, mel_orig_img = plot_melspec_np(mel_orig, title="mel_orig") # alignments_img = plot_alignment_np(inference_result.alignments) _, post_mel_img = plot_melspec_np( inference_result.mel_outputs_postnet, title="mel_postnet") _, mel_img = plot_melspec_np(inference_result.mel_outputs, title="mel") _, alignments_img = plot_alignment_np_new( inference_result.alignments, title="alignments") aligned_alignments = inference_result.alignments[ path_mel_postnet] _, aligned_alignments_img = plot_alignment_np_new( aligned_alignments, title="aligned_alignments") # imageio.imsave("/tmp/alignments.png", alignments_img) validation_entry_output = ValidationEntryOutput( repetition=rep_human_readable, wav_orig=orig_wav, mel_orig=mel_orig, orig_sr=orig_sr, mel_postnet=inference_result.mel_outputs_postnet, mel_postnet_sr=inference_result.sampling_rate, mel_orig_img=mel_orig_img, mel_postnet_img=post_mel_img, mel_postnet_diff_img=padded_mel_postnet_diff_img, alignments_img=alignments_img, mel_img=mel_img, mel_postnet_aligned_diff_img=aligned_mel_postnet_diff_img, mel_orig_aligned=aligned_mel_orig, mel_orig_aligned_img=aligned_mel_orig_img, mel_postnet_aligned=aligned_mel_postnet, mel_postnet_aligned_img=aligned_mel_postnet_img, alignments_aligned_img=aligned_alignments_img, ) assert save_callback is not None save_callback(entry, validation_entry_output) logger.info(f"MFCC MCD DTW: {val_entry.mfcc_dtw_mcd}") return validation_entries
def validate(checkpoint: CheckpointWaveglow, data: PreparedDataList, custom_hparams: Optional[Dict[str, str]], denoiser_strength: float, sigma: float, entry_ids: Optional[Set[int]], train_name: str, full_run: bool, save_callback: Callable[[PreparedData, ValidationEntryOutput], None], seed: int, logger: Logger): validation_entries = ValidationEntries() if full_run: entries = data else: speaker_id: Optional[int] = None entries = PreparedDataList(data.get_for_validation(entry_ids, speaker_id)) if len(entries) == 0: logger.info("Nothing to synthesize!") return validation_entries synth = Synthesizer( checkpoint=checkpoint, custom_hparams=custom_hparams, logger=logger ) taco_stft = TacotronSTFT(synth.hparams, logger=logger) for entry in entries.items(True): mel = taco_stft.get_mel_tensor_from_file(entry.wav_path) mel_var = torch.autograd.Variable(mel) mel_var = mel_var.cuda() mel_var = mel_var.unsqueeze(0) inference_result = synth.infer(mel_var, sigma, denoiser_strength, seed=seed) wav_inferred_denoised_normalized = normalize_wav(inference_result.wav_denoised) symbol_count = len(deserialize_list(entry.serialized_symbol_ids)) unique_symbols_count = len(set(deserialize_list(entry.serialized_symbol_ids))) timepoint = f"{datetime.datetime.now():%Y/%m/%d %H:%M:%S}" val_entry = ValidationEntry( entry_id=entry.entry_id, ds_entry_id=entry.ds_entry_id, text_original=entry.text_original, text=entry.text, seed=seed, wav_path=entry.wav_path, original_duration_s=entry.duration_s, speaker_id=entry.speaker_id, iteration=checkpoint.iteration, unique_symbols_count=unique_symbols_count, symbol_count=symbol_count, timepoint=timepoint, train_name=train_name, sampling_rate=inference_result.sampling_rate, inference_duration_s=inference_result.inference_duration_s, was_overamplified=inference_result.was_overamplified, denoising_duration_s=inference_result.denoising_duration_s, inferred_duration_s=get_duration_s( inference_result.wav_denoised, inference_result.sampling_rate), denoiser_strength=denoiser_strength, sigma=sigma, ) val_entry.diff_duration_s = val_entry.inferred_duration_s - val_entry.original_duration_s mel_orig = mel.cpu().numpy() mel_inferred_denoised_tensor = torch.FloatTensor(wav_inferred_denoised_normalized) mel_inferred_denoised = taco_stft.get_mel_tensor(mel_inferred_denoised_tensor) mel_inferred_denoised = mel_inferred_denoised.numpy() wav_orig, orig_sr = wav_to_float32(entry.wav_path) validation_entry_output = ValidationEntryOutput( mel_orig=mel_orig, inferred_sr=inference_result.sampling_rate, mel_inferred_denoised=mel_inferred_denoised, wav_inferred_denoised=wav_inferred_denoised_normalized, wav_orig=wav_orig, orig_sr=orig_sr, wav_inferred=normalize_wav(inference_result.wav), mel_denoised_diff_img=None, mel_inferred_denoised_img=None, mel_orig_img=None, ) mcd_dtw, penalty_dtw, final_frame_number_dtw = get_mcd_between_mel_spectograms( mel_1=mel_orig, mel_2=mel_inferred_denoised, n_mfcc=MCD_NO_OF_COEFFS_PER_FRAME, take_log=False, use_dtw=True, ) val_entry.diff_frames = mel_inferred_denoised.shape[1] - mel_orig.shape[1] val_entry.mcd_dtw = mcd_dtw val_entry.mcd_dtw_penalty = penalty_dtw val_entry.mcd_dtw_frames = final_frame_number_dtw mcd, penalty, final_frame_number = get_mcd_between_mel_spectograms( mel_1=mel_orig, mel_2=mel_inferred_denoised, n_mfcc=MCD_NO_OF_COEFFS_PER_FRAME, take_log=False, use_dtw=False, ) val_entry.mcd = mcd val_entry.mcd_penalty = penalty val_entry.mcd_frames = final_frame_number cosine_similarity = cosine_dist_mels(mel_orig, mel_inferred_denoised) val_entry.cosine_similarity = cosine_similarity mel_original_img_raw, mel_original_img = plot_melspec_np(mel_orig) mel_inferred_denoised_img_raw, mel_inferred_denoised_img = plot_melspec_np( mel_inferred_denoised) validation_entry_output.mel_orig_img = mel_original_img validation_entry_output.mel_inferred_denoised_img = mel_inferred_denoised_img mel_original_img_raw_same_dim, mel_inferred_denoised_img_raw_same_dim = make_same_width_by_filling_white( img_a=mel_original_img_raw, img_b=mel_inferred_denoised_img_raw, ) mel_original_img_same_dim, mel_inferred_denoised_img_same_dim = make_same_width_by_filling_white( img_a=mel_original_img, img_b=mel_inferred_denoised_img, ) structural_similarity_raw, mel_difference_denoised_img_raw = calculate_structual_similarity_np( img_a=mel_original_img_raw_same_dim, img_b=mel_inferred_denoised_img_raw_same_dim, ) val_entry.structural_similarity = structural_similarity_raw structural_similarity, mel_denoised_diff_img = calculate_structual_similarity_np( img_a=mel_original_img_same_dim, img_b=mel_inferred_denoised_img_same_dim, ) validation_entry_output.mel_denoised_diff_img = mel_denoised_diff_img imageio.imsave("/tmp/mel_original_img_raw.png", mel_original_img_raw) imageio.imsave("/tmp/mel_inferred_img_raw.png", mel_inferred_denoised_img_raw) imageio.imsave("/tmp/mel_difference_denoised_img_raw.png", mel_difference_denoised_img_raw) # logger.info(val_entry) logger.info(f"Current: {val_entry.entry_id}") logger.info(f"MCD DTW: {val_entry.mcd_dtw}") logger.info(f"MCD DTW penalty: {val_entry.mcd_dtw_penalty}") logger.info(f"MCD DTW frames: {val_entry.mcd_dtw_frames}") logger.info(f"MCD: {val_entry.mcd}") logger.info(f"MCD penalty: {val_entry.mcd_penalty}") logger.info(f"MCD frames: {val_entry.mcd_frames}") # logger.info(f"MCD DTW V2: {val_entry.mcd_dtw_v2}") logger.info(f"Structural Similarity: {val_entry.structural_similarity}") logger.info(f"Cosine Similarity: {val_entry.cosine_similarity}") save_callback(entry, validation_entry_output) validation_entries.append(val_entry) #score, diff_img = compare_mels(a, b) return validation_entries
class SymbolsMelLoader(Dataset): """ 1) loads audio,text pairs 2) normalizes text and converts them to sequences of one-hot vectors 3) computes mel-spectrograms from audio files. """ def __init__(self, data: PreparedDataList, hparams: HParams, logger: Logger): # random.seed(hparams.seed) # random.shuffle(data) self.use_saved_mels = hparams.use_saved_mels if not hparams.use_saved_mels: self.mel_parser = TacotronSTFT(hparams, logger) logger.info("Reading files...") self.data: Dict[int, Tuple[IntTensor, IntTensor, str, int]] = {} for i, values in enumerate(data.items(True)): symbol_ids = deserialize_list(values.serialized_symbol_ids) accent_ids = deserialize_list(values.serialized_accent_ids) model_symbol_ids = get_accent_symbol_ids( symbol_ids, accent_ids, hparams.n_symbols, hparams.accents_use_own_symbols, SHARED_SYMBOLS_COUNT) symbols_tensor = IntTensor(model_symbol_ids) accents_tensor = IntTensor(accent_ids) if hparams.use_saved_mels: self.data[i] = (symbols_tensor, accents_tensor, values.mel_path, values.speaker_id) else: self.data[i] = (symbols_tensor, accents_tensor, values.wav_path, values.speaker_id) if hparams.use_saved_mels and hparams.cache_mels: logger.info("Loading mels into memory...") self.cache: Dict[int, Tensor] = {} vals: tuple for i, vals in tqdm(self.data.items()): mel_tensor = torch.load(vals[1], map_location='cpu') self.cache[i] = mel_tensor self.use_cache: bool = hparams.cache_mels def __getitem__(self, index: int) -> Tuple[IntTensor, IntTensor, Tensor, int]: # return self.cache[index] # debug_logger.debug(f"getitem called {index}") symbols_tensor, accents_tensor, path, speaker_id = self.data[index] if self.use_saved_mels: if self.use_cache: mel_tensor = self.cache[index].clone().detach() else: mel_tensor: Tensor = torch.load(path, map_location='cpu') else: mel_tensor = self.mel_parser.get_mel_tensor_from_file(path) symbols_tensor_cloned = symbols_tensor.clone().detach() accents_tensor_cloned = accents_tensor.clone().detach() # debug_logger.debug(f"getitem finished {index}") return symbols_tensor_cloned, accents_tensor_cloned, mel_tensor, speaker_id def __len__(self): return len(self.data)