def sents_convert_to_ipa(sentences: SentenceList, text_symbols: SymbolIdDict, ignore_tones: bool, ignore_arcs: bool, mode: Optional[EngToIpaMode], consider_ipa_annotations: bool, logger: Logger) -> Tuple[SymbolIdDict, SentenceList]: sents_new_symbols = [] for sentence in sentences.items(True): if sentence.lang == Language.ENG and mode is None: ex = "Please specify the ipa conversion mode." logger.exception(ex) raise Exception(ex) new_symbols, new_accent_ids = symbols_to_ipa( symbols=text_symbols.get_symbols(sentence.serialized_symbols), lang=sentence.lang, accent_ids=deserialize_list(sentence.serialized_accents), ignore_arcs=ignore_arcs, ignore_tones=ignore_tones, mode=mode, replace_unknown_with=DEFAULT_PADDING_SYMBOL, consider_ipa_annotations=consider_ipa_annotations, logger=logger, ) assert len(new_symbols) == len(new_accent_ids) sentence.lang = Language.IPA sentence.serialized_accents = serialize_list(new_accent_ids) sents_new_symbols.append(new_symbols) assert len(sentence.get_accent_ids()) == len(new_symbols) return update_symbols_and_text(sentences, sents_new_symbols)
def _remove_unused_accents(self) -> None: all_accent_ids: Set[int] = set() for entry in self.data.items(): all_accent_ids |= set(deserialize_list( entry.serialized_accent_ids)) unused_accent_ids = self.accent_ids.get_all_ids().difference( all_accent_ids) # unused_symbols = unused_symbols.difference({PADDING_SYMBOL}) self.accent_ids.remove_ids(unused_accent_ids)
def sents_map(sentences: SentenceList, text_symbols: SymbolIdDict, symbols_map: SymbolsMap, ignore_arcs: bool, logger: Logger) -> Tuple[SymbolIdDict, SentenceList]: sents_new_symbols = [] result = SentenceList() new_sent_id = 0 ipa_settings = IPAExtractionSettings( ignore_tones=False, ignore_arcs=ignore_arcs, replace_unknown_ipa_by=DEFAULT_PADDING_SYMBOL, ) for sentence in sentences.items(): symbols = text_symbols.get_symbols(sentence.serialized_symbols) accent_ids = deserialize_list(sentence.serialized_accents) mapped_symbols = symbols_map.apply_to_symbols(symbols) text = SymbolIdDict.symbols_to_text(mapped_symbols) # a resulting empty text would make no problems sents = text_to_sentences( text=text, lang=sentence.lang, logger=logger, ) for new_sent_text in sents: new_symbols = text_to_symbols( new_sent_text, lang=sentence.lang, ipa_settings=ipa_settings, logger=logger, ) if len(accent_ids) > 0: new_accent_ids = [accent_ids[0]] * len(new_symbols) else: new_accent_ids = [] assert len(new_accent_ids) == len(new_symbols) new_sent_id += 1 tmp = Sentence( sent_id=new_sent_id, text=new_sent_text, lang=sentence.lang, orig_lang=sentence.orig_lang, # this is not correct but nearest possible currently original_text=sentence.original_text, serialized_accents=serialize_list(new_accent_ids), serialized_symbols="" ) sents_new_symbols.append(new_symbols) assert len(tmp.get_accent_ids()) == len(new_symbols) result.append(tmp) return update_symbols_and_text(result, sents_new_symbols)
def sents_accent_apply(sentences: SentenceList, accented_symbols: AccentedSymbolList, accent_ids: AccentsDict) -> SentenceList: current_index = 0 for sent in sentences.items(): accent_ids_count = len(deserialize_list(sent.serialized_accents)) assert len(accented_symbols) >= current_index + accent_ids_count accented_symbol_selection: List[AccentedSymbol] = accented_symbols[current_index:current_index + accent_ids_count] current_index += accent_ids_count new_accent_ids = accent_ids.get_ids([x.accent for x in accented_symbol_selection]) sent.serialized_accents = serialize_list(new_accent_ids) assert len(sent.get_accent_ids()) == len(sent.get_symbol_ids()) return sentences
def sents_normalize(sentences: SentenceList, text_symbols: SymbolIdDict, logger: Logger) -> Tuple[SymbolIdDict, SentenceList]: # Maybe add info if something was unknown sents_new_symbols = [] for sentence in sentences.items(): new_symbols, new_accent_ids = symbols_normalize( symbols=text_symbols.get_symbols(sentence.serialized_symbols), lang=sentence.lang, accent_ids=deserialize_list(sentence.serialized_accents), logger=logger, ) # TODO: check if new sentences resulted and then split them. sentence.serialized_accents = serialize_list(new_accent_ids) sents_new_symbols.append(new_symbols) return update_symbols_and_text(sentences, sents_new_symbols)
def __init__(self, data: PreparedDataList, hparams: HParams, logger: Logger): # random.seed(hparams.seed) # random.shuffle(data) self.use_saved_mels = hparams.use_saved_mels if not hparams.use_saved_mels: self.mel_parser = TacotronSTFT(hparams, logger) logger.info("Reading files...") self.data: Dict[int, Tuple[IntTensor, IntTensor, str, int]] = {} for i, values in enumerate(data.items(True)): symbol_ids = deserialize_list(values.serialized_symbol_ids) accent_ids = deserialize_list(values.serialized_accent_ids) model_symbol_ids = get_accent_symbol_ids( symbol_ids, accent_ids, hparams.n_symbols, hparams.accents_use_own_symbols, SHARED_SYMBOLS_COUNT) symbols_tensor = IntTensor(model_symbol_ids) accents_tensor = IntTensor(accent_ids) if hparams.use_saved_mels: self.data[i] = (symbols_tensor, accents_tensor, values.mel_path, values.speaker_id) else: self.data[i] = (symbols_tensor, accents_tensor, values.wav_path, values.speaker_id) if hparams.use_saved_mels and hparams.cache_mels: logger.info("Loading mels into memory...") self.cache: Dict[int, Tensor] = {} vals: tuple for i, vals in tqdm(self.data.items()): mel_tensor = torch.load(vals[1], map_location='cpu') self.cache[i] = mel_tensor self.use_cache: bool = hparams.cache_mels
def filter_symbols(data: MergedDataset, symbols: SymbolIdDict, accent_ids: AccentsDict, speakers: SpeakersDict, allowed_symbol_ids: Set[int], logger: Logger) -> MergedDatasetContainer: # maybe check all symbol ids are valid before allowed_symbols = [symbols.get_symbol(x) for x in allowed_symbol_ids] not_allowed_symbols = [ symbols.get_symbol(x) for x in symbols.get_all_symbol_ids() if x not in allowed_symbol_ids ] logger.info( f"Keep utterances with these symbols: {' '.join(allowed_symbols)}") logger.info( f"Remove utterances with these symbols: {' '.join(not_allowed_symbols)}" ) logger.info("Statistics before filtering:") log_stats(data, symbols, accent_ids, speakers, logger) result = MergedDataset([ x for x in data.items() if contains_only_allowed_symbols( deserialize_list(x.serialized_symbol_ids), allowed_symbol_ids) ]) if len(result) > 0: logger.info( f"Removed {len(data) - len(result)} from {len(data)} total entries and got {len(result)} entries ({len(result)/len(data)*100:.2f}%)." ) else: logger.info("Removed all utterances!") new_symbol_ids = update_symbols(result, symbols) new_accent_ids = update_accents(result, accent_ids) new_speaker_ids = update_speakers(result, speakers) logger.info("Statistics after filtering:") log_stats(result, new_symbol_ids, new_accent_ids, new_speaker_ids, logger) res = MergedDatasetContainer( name=None, data=result, accent_ids=new_accent_ids, speaker_ids=new_speaker_ids, symbol_ids=new_symbol_ids, ) return res
def get_accent_ids(self): return deserialize_list(self.serialized_accents)
def get_symbol_ids(self): return deserialize_list(self.serialized_symbols)
def validate(checkpoint: CheckpointTacotron, data: PreparedDataList, trainset: PreparedDataList, custom_hparams: Optional[Dict[str, str]], entry_ids: Optional[Set[int]], speaker_name: Optional[str], train_name: str, full_run: bool, save_callback: Optional[ Callable[[PreparedData, ValidationEntryOutput], None]], max_decoder_steps: int, fast: bool, mcd_no_of_coeffs_per_frame: int, repetitions: int, seed: Optional[int], select_best_from: Optional[pd.DataFrame], logger: Logger) -> ValidationEntries: model_symbols = checkpoint.get_symbols() model_accents = checkpoint.get_accents() model_speakers = checkpoint.get_speakers() seeds: List[int] validation_data: PreparedDataList if full_run: validation_data = data assert seed is not None seeds = [seed for _ in data] elif select_best_from is not None: assert entry_ids is not None logger.info("Finding best seeds...") validation_data = PreparedDataList( [x for x in data.items() if x.entry_id in entry_ids]) if len(validation_data) != len(entry_ids): logger.error("Not all entry_id's were found!") assert False entry_ids_order_from_valdata = OrderedSet( [x.entry_id for x in validation_data.items()]) seeds = get_best_seeds(select_best_from, entry_ids_order_from_valdata, checkpoint.iteration, logger) # entry_ids = [entry_id for entry_id, _ in entry_ids_w_seed] # have_no_double_entries = len(set(entry_ids)) == len(entry_ids) # assert have_no_double_entries # validation_data = PreparedDataList([x for x in data.items() if x.entry_id in entry_ids]) # seeds = [s for _, s in entry_ids_w_seed] # if len(validation_data) != len(entry_ids): # logger.error("Not all entry_id's were found!") # assert False elif entry_ids is not None: validation_data = PreparedDataList( [x for x in data.items() if x.entry_id in entry_ids]) if len(validation_data) != len(entry_ids): logger.error("Not all entry_id's were found!") assert False assert seed is not None seeds = [seed for _ in validation_data] elif speaker_name is not None: speaker_id = model_speakers.get_id(speaker_name) relevant_entries = [ x for x in data.items() if x.speaker_id == speaker_id ] assert len(relevant_entries) > 0 assert seed is not None random.seed(seed) entry = random.choice(relevant_entries) validation_data = PreparedDataList([entry]) seeds = [seed] else: assert seed is not None random.seed(seed) entry = random.choice(data) validation_data = PreparedDataList([entry]) seeds = [seed] assert len(seeds) == len(validation_data) if len(validation_data) == 0: logger.info("Nothing to synthesize!") return validation_data train_onegram_rarities = get_ngram_rarity(validation_data, trainset, model_symbols, 1) train_twogram_rarities = get_ngram_rarity(validation_data, trainset, model_symbols, 2) train_threegram_rarities = get_ngram_rarity(validation_data, trainset, model_symbols, 3) synth = Synthesizer( checkpoint=checkpoint, custom_hparams=custom_hparams, logger=logger, ) taco_stft = TacotronSTFT(synth.hparams, logger=logger) validation_entries = ValidationEntries() for repetition in range(repetitions): # rep_seed = seed + repetition rep_human_readable = repetition + 1 logger.info(f"Starting repetition: {rep_human_readable}/{repetitions}") for entry, entry_seed in zip(validation_data.items(True), seeds): rep_seed = entry_seed + repetition logger.info( f"Current --> entry_id: {entry.entry_id}; seed: {rep_seed}; iteration: {checkpoint.iteration}; rep: {rep_human_readable}/{repetitions}" ) infer_sent = InferSentence( sent_id=1, symbols=model_symbols.get_symbols(entry.serialized_symbol_ids), accents=model_accents.get_accents(entry.serialized_accent_ids), original_text=entry.text_original, ) speaker_name = model_speakers.get_speaker(entry.speaker_id) inference_result = synth.infer( sentence=infer_sent, speaker=speaker_name, ignore_unknown_symbols=False, max_decoder_steps=max_decoder_steps, seed=rep_seed, ) symbol_count = len(deserialize_list(entry.serialized_symbol_ids)) unique_symbols = set( model_symbols.get_symbols(entry.serialized_symbol_ids)) unique_symbols_str = " ".join(list(sorted(unique_symbols))) unique_symbols_count = len(unique_symbols) timepoint = f"{datetime.datetime.now():%Y/%m/%d %H:%M:%S}" mel_orig: np.ndarray = taco_stft.get_mel_tensor_from_file( entry.wav_path).cpu().numpy() target_frames = mel_orig.shape[1] predicted_frames = inference_result.mel_outputs_postnet.shape[1] diff_frames = predicted_frames - target_frames frame_deviation_percent = (predicted_frames / target_frames) - 1 dtw_mcd, dtw_penalty, dtw_frames = get_mcd_between_mel_spectograms( mel_1=mel_orig, mel_2=inference_result.mel_outputs_postnet, n_mfcc=mcd_no_of_coeffs_per_frame, take_log=False, use_dtw=True, ) padded_mel_orig, padded_mel_postnet = make_same_dim( mel_orig, inference_result.mel_outputs_postnet) aligned_mel_orig, aligned_mel_postnet, mel_dtw_dist, _, path_mel_postnet = align_mels_with_dtw( mel_orig, inference_result.mel_outputs_postnet) mel_aligned_length = len(aligned_mel_postnet) msd = get_msd(mel_dtw_dist, mel_aligned_length) padded_cosine_similarity = cosine_dist_mels( padded_mel_orig, padded_mel_postnet) aligned_cosine_similarity = cosine_dist_mels( aligned_mel_orig, aligned_mel_postnet) padded_mse = mean_squared_error(padded_mel_orig, padded_mel_postnet) aligned_mse = mean_squared_error(aligned_mel_orig, aligned_mel_postnet) padded_structural_similarity = None aligned_structural_similarity = None if not fast: padded_mel_orig_img_raw_1, padded_mel_orig_img = plot_melspec_np( padded_mel_orig, title="padded_mel_orig") padded_mel_outputs_postnet_img_raw_1, padded_mel_outputs_postnet_img = plot_melspec_np( padded_mel_postnet, title="padded_mel_postnet") # imageio.imsave("/tmp/padded_mel_orig_img_raw_1.png", padded_mel_orig_img_raw_1) # imageio.imsave("/tmp/padded_mel_outputs_postnet_img_raw_1.png", padded_mel_outputs_postnet_img_raw_1) padded_structural_similarity, padded_mel_postnet_diff_img_raw = calculate_structual_similarity_np( img_a=padded_mel_orig_img_raw_1, img_b=padded_mel_outputs_postnet_img_raw_1, ) # imageio.imsave("/tmp/padded_mel_diff_img_raw_1.png", padded_mel_postnet_diff_img_raw) aligned_mel_orig_img_raw, aligned_mel_orig_img = plot_melspec_np( aligned_mel_orig, title="aligned_mel_orig") aligned_mel_postnet_img_raw, aligned_mel_postnet_img = plot_melspec_np( aligned_mel_postnet, title="aligned_mel_postnet") # imageio.imsave("/tmp/aligned_mel_orig_img_raw.png", aligned_mel_orig_img_raw) # imageio.imsave("/tmp/aligned_mel_postnet_img_raw.png", aligned_mel_postnet_img_raw) aligned_structural_similarity, aligned_mel_diff_img_raw = calculate_structual_similarity_np( img_a=aligned_mel_orig_img_raw, img_b=aligned_mel_postnet_img_raw, ) # imageio.imsave("/tmp/aligned_mel_diff_img_raw.png", aligned_mel_diff_img_raw) train_combined_rarity = train_onegram_rarities[entry.entry_id] + \ train_twogram_rarities[entry.entry_id] + train_threegram_rarities[entry.entry_id] val_entry = ValidationEntry( entry_id=entry.entry_id, repetition=rep_human_readable, repetitions=repetitions, seed=rep_seed, ds_entry_id=entry.ds_entry_id, text_original=entry.text_original, text=entry.text, wav_path=entry.wav_path, wav_duration_s=entry.duration_s, speaker_id=entry.speaker_id, speaker_name=speaker_name, iteration=checkpoint.iteration, unique_symbols=unique_symbols_str, unique_symbols_count=unique_symbols_count, symbol_count=symbol_count, timepoint=timepoint, train_name=train_name, sampling_rate=synth.get_sampling_rate(), reached_max_decoder_steps=inference_result. reached_max_decoder_steps, inference_duration_s=inference_result.inference_duration_s, predicted_frames=predicted_frames, target_frames=target_frames, diff_frames=diff_frames, frame_deviation_percent=frame_deviation_percent, padded_cosine_similarity=padded_cosine_similarity, mfcc_no_coeffs=mcd_no_of_coeffs_per_frame, mfcc_dtw_mcd=dtw_mcd, mfcc_dtw_penalty=dtw_penalty, mfcc_dtw_frames=dtw_frames, padded_structural_similarity=padded_structural_similarity, padded_mse=padded_mse, msd=msd, aligned_cosine_similarity=aligned_cosine_similarity, aligned_mse=aligned_mse, aligned_structural_similarity=aligned_structural_similarity, global_one_gram_rarity=entry.one_gram_rarity, global_two_gram_rarity=entry.two_gram_rarity, global_three_gram_rarity=entry.three_gram_rarity, global_combined_rarity=entry.combined_rarity, train_one_gram_rarity=train_onegram_rarities[entry.entry_id], train_two_gram_rarity=train_twogram_rarities[entry.entry_id], train_three_gram_rarity=train_threegram_rarities[ entry.entry_id], train_combined_rarity=train_combined_rarity, ) validation_entries.append(val_entry) if not fast: orig_sr, orig_wav = read(entry.wav_path) _, padded_mel_postnet_diff_img = calculate_structual_similarity_np( img_a=padded_mel_orig_img, img_b=padded_mel_outputs_postnet_img, ) _, aligned_mel_postnet_diff_img = calculate_structual_similarity_np( img_a=aligned_mel_orig_img, img_b=aligned_mel_postnet_img, ) _, mel_orig_img = plot_melspec_np(mel_orig, title="mel_orig") # alignments_img = plot_alignment_np(inference_result.alignments) _, post_mel_img = plot_melspec_np( inference_result.mel_outputs_postnet, title="mel_postnet") _, mel_img = plot_melspec_np(inference_result.mel_outputs, title="mel") _, alignments_img = plot_alignment_np_new( inference_result.alignments, title="alignments") aligned_alignments = inference_result.alignments[ path_mel_postnet] _, aligned_alignments_img = plot_alignment_np_new( aligned_alignments, title="aligned_alignments") # imageio.imsave("/tmp/alignments.png", alignments_img) validation_entry_output = ValidationEntryOutput( repetition=rep_human_readable, wav_orig=orig_wav, mel_orig=mel_orig, orig_sr=orig_sr, mel_postnet=inference_result.mel_outputs_postnet, mel_postnet_sr=inference_result.sampling_rate, mel_orig_img=mel_orig_img, mel_postnet_img=post_mel_img, mel_postnet_diff_img=padded_mel_postnet_diff_img, alignments_img=alignments_img, mel_img=mel_img, mel_postnet_aligned_diff_img=aligned_mel_postnet_diff_img, mel_orig_aligned=aligned_mel_orig, mel_orig_aligned_img=aligned_mel_orig_img, mel_postnet_aligned=aligned_mel_postnet, mel_postnet_aligned_img=aligned_mel_postnet_img, alignments_aligned_img=aligned_alignments_img, ) assert save_callback is not None save_callback(entry, validation_entry_output) logger.info(f"MFCC MCD DTW: {val_entry.mfcc_dtw_mcd}") return validation_entries
def validate(checkpoint: CheckpointWaveglow, data: PreparedDataList, custom_hparams: Optional[Dict[str, str]], denoiser_strength: float, sigma: float, entry_ids: Optional[Set[int]], train_name: str, full_run: bool, save_callback: Callable[[PreparedData, ValidationEntryOutput], None], seed: int, logger: Logger): validation_entries = ValidationEntries() if full_run: entries = data else: speaker_id: Optional[int] = None entries = PreparedDataList(data.get_for_validation(entry_ids, speaker_id)) if len(entries) == 0: logger.info("Nothing to synthesize!") return validation_entries synth = Synthesizer( checkpoint=checkpoint, custom_hparams=custom_hparams, logger=logger ) taco_stft = TacotronSTFT(synth.hparams, logger=logger) for entry in entries.items(True): mel = taco_stft.get_mel_tensor_from_file(entry.wav_path) mel_var = torch.autograd.Variable(mel) mel_var = mel_var.cuda() mel_var = mel_var.unsqueeze(0) inference_result = synth.infer(mel_var, sigma, denoiser_strength, seed=seed) wav_inferred_denoised_normalized = normalize_wav(inference_result.wav_denoised) symbol_count = len(deserialize_list(entry.serialized_symbol_ids)) unique_symbols_count = len(set(deserialize_list(entry.serialized_symbol_ids))) timepoint = f"{datetime.datetime.now():%Y/%m/%d %H:%M:%S}" val_entry = ValidationEntry( entry_id=entry.entry_id, ds_entry_id=entry.ds_entry_id, text_original=entry.text_original, text=entry.text, seed=seed, wav_path=entry.wav_path, original_duration_s=entry.duration_s, speaker_id=entry.speaker_id, iteration=checkpoint.iteration, unique_symbols_count=unique_symbols_count, symbol_count=symbol_count, timepoint=timepoint, train_name=train_name, sampling_rate=inference_result.sampling_rate, inference_duration_s=inference_result.inference_duration_s, was_overamplified=inference_result.was_overamplified, denoising_duration_s=inference_result.denoising_duration_s, inferred_duration_s=get_duration_s( inference_result.wav_denoised, inference_result.sampling_rate), denoiser_strength=denoiser_strength, sigma=sigma, ) val_entry.diff_duration_s = val_entry.inferred_duration_s - val_entry.original_duration_s mel_orig = mel.cpu().numpy() mel_inferred_denoised_tensor = torch.FloatTensor(wav_inferred_denoised_normalized) mel_inferred_denoised = taco_stft.get_mel_tensor(mel_inferred_denoised_tensor) mel_inferred_denoised = mel_inferred_denoised.numpy() wav_orig, orig_sr = wav_to_float32(entry.wav_path) validation_entry_output = ValidationEntryOutput( mel_orig=mel_orig, inferred_sr=inference_result.sampling_rate, mel_inferred_denoised=mel_inferred_denoised, wav_inferred_denoised=wav_inferred_denoised_normalized, wav_orig=wav_orig, orig_sr=orig_sr, wav_inferred=normalize_wav(inference_result.wav), mel_denoised_diff_img=None, mel_inferred_denoised_img=None, mel_orig_img=None, ) mcd_dtw, penalty_dtw, final_frame_number_dtw = get_mcd_between_mel_spectograms( mel_1=mel_orig, mel_2=mel_inferred_denoised, n_mfcc=MCD_NO_OF_COEFFS_PER_FRAME, take_log=False, use_dtw=True, ) val_entry.diff_frames = mel_inferred_denoised.shape[1] - mel_orig.shape[1] val_entry.mcd_dtw = mcd_dtw val_entry.mcd_dtw_penalty = penalty_dtw val_entry.mcd_dtw_frames = final_frame_number_dtw mcd, penalty, final_frame_number = get_mcd_between_mel_spectograms( mel_1=mel_orig, mel_2=mel_inferred_denoised, n_mfcc=MCD_NO_OF_COEFFS_PER_FRAME, take_log=False, use_dtw=False, ) val_entry.mcd = mcd val_entry.mcd_penalty = penalty val_entry.mcd_frames = final_frame_number cosine_similarity = cosine_dist_mels(mel_orig, mel_inferred_denoised) val_entry.cosine_similarity = cosine_similarity mel_original_img_raw, mel_original_img = plot_melspec_np(mel_orig) mel_inferred_denoised_img_raw, mel_inferred_denoised_img = plot_melspec_np( mel_inferred_denoised) validation_entry_output.mel_orig_img = mel_original_img validation_entry_output.mel_inferred_denoised_img = mel_inferred_denoised_img mel_original_img_raw_same_dim, mel_inferred_denoised_img_raw_same_dim = make_same_width_by_filling_white( img_a=mel_original_img_raw, img_b=mel_inferred_denoised_img_raw, ) mel_original_img_same_dim, mel_inferred_denoised_img_same_dim = make_same_width_by_filling_white( img_a=mel_original_img, img_b=mel_inferred_denoised_img, ) structural_similarity_raw, mel_difference_denoised_img_raw = calculate_structual_similarity_np( img_a=mel_original_img_raw_same_dim, img_b=mel_inferred_denoised_img_raw_same_dim, ) val_entry.structural_similarity = structural_similarity_raw structural_similarity, mel_denoised_diff_img = calculate_structual_similarity_np( img_a=mel_original_img_same_dim, img_b=mel_inferred_denoised_img_same_dim, ) validation_entry_output.mel_denoised_diff_img = mel_denoised_diff_img imageio.imsave("/tmp/mel_original_img_raw.png", mel_original_img_raw) imageio.imsave("/tmp/mel_inferred_img_raw.png", mel_inferred_denoised_img_raw) imageio.imsave("/tmp/mel_difference_denoised_img_raw.png", mel_difference_denoised_img_raw) # logger.info(val_entry) logger.info(f"Current: {val_entry.entry_id}") logger.info(f"MCD DTW: {val_entry.mcd_dtw}") logger.info(f"MCD DTW penalty: {val_entry.mcd_dtw_penalty}") logger.info(f"MCD DTW frames: {val_entry.mcd_dtw_frames}") logger.info(f"MCD: {val_entry.mcd}") logger.info(f"MCD penalty: {val_entry.mcd_penalty}") logger.info(f"MCD frames: {val_entry.mcd_frames}") # logger.info(f"MCD DTW V2: {val_entry.mcd_dtw_v2}") logger.info(f"Structural Similarity: {val_entry.structural_similarity}") logger.info(f"Cosine Similarity: {val_entry.cosine_similarity}") save_callback(entry, validation_entry_output) validation_entries.append(val_entry) #score, diff_img = compare_mels(a, b) return validation_entries