def sents_convert_to_ipa(sentences: SentenceList, text_symbols: SymbolIdDict, ignore_tones: bool, ignore_arcs: bool, mode: Optional[EngToIpaMode], consider_ipa_annotations: bool, logger: Logger) -> Tuple[SymbolIdDict, SentenceList]:

  sents_new_symbols = []
  for sentence in sentences.items(True):
    if sentence.lang == Language.ENG and mode is None:
      ex = "Please specify the ipa conversion mode."
      logger.exception(ex)
      raise Exception(ex)
    new_symbols, new_accent_ids = symbols_to_ipa(
      symbols=text_symbols.get_symbols(sentence.serialized_symbols),
      lang=sentence.lang,
      accent_ids=deserialize_list(sentence.serialized_accents),
      ignore_arcs=ignore_arcs,
      ignore_tones=ignore_tones,
      mode=mode,
      replace_unknown_with=DEFAULT_PADDING_SYMBOL,
      consider_ipa_annotations=consider_ipa_annotations,
      logger=logger,
    )
    assert len(new_symbols) == len(new_accent_ids)
    sentence.lang = Language.IPA
    sentence.serialized_accents = serialize_list(new_accent_ids)
    sents_new_symbols.append(new_symbols)
    assert len(sentence.get_accent_ids()) == len(new_symbols)

  return update_symbols_and_text(sentences, sents_new_symbols)
Example #2
0
 def _remove_unused_accents(self) -> None:
     all_accent_ids: Set[int] = set()
     for entry in self.data.items():
         all_accent_ids |= set(deserialize_list(
             entry.serialized_accent_ids))
     unused_accent_ids = self.accent_ids.get_all_ids().difference(
         all_accent_ids)
     # unused_symbols = unused_symbols.difference({PADDING_SYMBOL})
     self.accent_ids.remove_ids(unused_accent_ids)
def sents_map(sentences: SentenceList, text_symbols: SymbolIdDict, symbols_map: SymbolsMap, ignore_arcs: bool, logger: Logger) -> Tuple[SymbolIdDict, SentenceList]:
  sents_new_symbols = []
  result = SentenceList()
  new_sent_id = 0

  ipa_settings = IPAExtractionSettings(
    ignore_tones=False,
    ignore_arcs=ignore_arcs,
    replace_unknown_ipa_by=DEFAULT_PADDING_SYMBOL,
  )

  for sentence in sentences.items():
    symbols = text_symbols.get_symbols(sentence.serialized_symbols)
    accent_ids = deserialize_list(sentence.serialized_accents)

    mapped_symbols = symbols_map.apply_to_symbols(symbols)

    text = SymbolIdDict.symbols_to_text(mapped_symbols)
    # a resulting empty text would make no problems
    sents = text_to_sentences(
      text=text,
      lang=sentence.lang,
      logger=logger,
    )

    for new_sent_text in sents:
      new_symbols = text_to_symbols(
        new_sent_text,
        lang=sentence.lang,
        ipa_settings=ipa_settings,
        logger=logger,
      )

      if len(accent_ids) > 0:
        new_accent_ids = [accent_ids[0]] * len(new_symbols)
      else:
        new_accent_ids = []

      assert len(new_accent_ids) == len(new_symbols)

      new_sent_id += 1
      tmp = Sentence(
        sent_id=new_sent_id,
        text=new_sent_text,
        lang=sentence.lang,
        orig_lang=sentence.orig_lang,
        # this is not correct but nearest possible currently
        original_text=sentence.original_text,
        serialized_accents=serialize_list(new_accent_ids),
        serialized_symbols=""
      )
      sents_new_symbols.append(new_symbols)

      assert len(tmp.get_accent_ids()) == len(new_symbols)
      result.append(tmp)

  return update_symbols_and_text(result, sents_new_symbols)
def sents_accent_apply(sentences: SentenceList, accented_symbols: AccentedSymbolList, accent_ids: AccentsDict) -> SentenceList:
  current_index = 0
  for sent in sentences.items():
    accent_ids_count = len(deserialize_list(sent.serialized_accents))
    assert len(accented_symbols) >= current_index + accent_ids_count
    accented_symbol_selection: List[AccentedSymbol] = accented_symbols[current_index:current_index + accent_ids_count]
    current_index += accent_ids_count
    new_accent_ids = accent_ids.get_ids([x.accent for x in accented_symbol_selection])
    sent.serialized_accents = serialize_list(new_accent_ids)
    assert len(sent.get_accent_ids()) == len(sent.get_symbol_ids())
  return sentences
def sents_normalize(sentences: SentenceList, text_symbols: SymbolIdDict, logger: Logger) -> Tuple[SymbolIdDict, SentenceList]:
  # Maybe add info if something was unknown
  sents_new_symbols = []
  for sentence in sentences.items():
    new_symbols, new_accent_ids = symbols_normalize(
      symbols=text_symbols.get_symbols(sentence.serialized_symbols),
      lang=sentence.lang,
      accent_ids=deserialize_list(sentence.serialized_accents),
      logger=logger,
    )
    # TODO: check if new sentences resulted and then split them.
    sentence.serialized_accents = serialize_list(new_accent_ids)
    sents_new_symbols.append(new_symbols)

  return update_symbols_and_text(sentences, sents_new_symbols)
Example #6
0
    def __init__(self, data: PreparedDataList, hparams: HParams,
                 logger: Logger):
        # random.seed(hparams.seed)
        # random.shuffle(data)
        self.use_saved_mels = hparams.use_saved_mels
        if not hparams.use_saved_mels:
            self.mel_parser = TacotronSTFT(hparams, logger)

        logger.info("Reading files...")
        self.data: Dict[int, Tuple[IntTensor, IntTensor, str, int]] = {}
        for i, values in enumerate(data.items(True)):
            symbol_ids = deserialize_list(values.serialized_symbol_ids)
            accent_ids = deserialize_list(values.serialized_accent_ids)

            model_symbol_ids = get_accent_symbol_ids(
                symbol_ids, accent_ids, hparams.n_symbols,
                hparams.accents_use_own_symbols, SHARED_SYMBOLS_COUNT)

            symbols_tensor = IntTensor(model_symbol_ids)
            accents_tensor = IntTensor(accent_ids)

            if hparams.use_saved_mels:
                self.data[i] = (symbols_tensor, accents_tensor,
                                values.mel_path, values.speaker_id)
            else:
                self.data[i] = (symbols_tensor, accents_tensor,
                                values.wav_path, values.speaker_id)

        if hparams.use_saved_mels and hparams.cache_mels:
            logger.info("Loading mels into memory...")
            self.cache: Dict[int, Tensor] = {}
            vals: tuple
            for i, vals in tqdm(self.data.items()):
                mel_tensor = torch.load(vals[1], map_location='cpu')
                self.cache[i] = mel_tensor
        self.use_cache: bool = hparams.cache_mels
Example #7
0
def filter_symbols(data: MergedDataset, symbols: SymbolIdDict,
                   accent_ids: AccentsDict, speakers: SpeakersDict,
                   allowed_symbol_ids: Set[int],
                   logger: Logger) -> MergedDatasetContainer:
    # maybe check all symbol ids are valid before
    allowed_symbols = [symbols.get_symbol(x) for x in allowed_symbol_ids]
    not_allowed_symbols = [
        symbols.get_symbol(x) for x in symbols.get_all_symbol_ids()
        if x not in allowed_symbol_ids
    ]
    logger.info(
        f"Keep utterances with these symbols: {' '.join(allowed_symbols)}")
    logger.info(
        f"Remove utterances with these symbols: {' '.join(not_allowed_symbols)}"
    )
    logger.info("Statistics before filtering:")
    log_stats(data, symbols, accent_ids, speakers, logger)
    result = MergedDataset([
        x for x in data.items() if contains_only_allowed_symbols(
            deserialize_list(x.serialized_symbol_ids), allowed_symbol_ids)
    ])
    if len(result) > 0:
        logger.info(
            f"Removed {len(data) - len(result)} from {len(data)} total entries and got {len(result)} entries ({len(result)/len(data)*100:.2f}%)."
        )
    else:
        logger.info("Removed all utterances!")
    new_symbol_ids = update_symbols(result, symbols)
    new_accent_ids = update_accents(result, accent_ids)
    new_speaker_ids = update_speakers(result, speakers)
    logger.info("Statistics after filtering:")
    log_stats(result, new_symbol_ids, new_accent_ids, new_speaker_ids, logger)

    res = MergedDatasetContainer(
        name=None,
        data=result,
        accent_ids=new_accent_ids,
        speaker_ids=new_speaker_ids,
        symbol_ids=new_symbol_ids,
    )
    return res
 def get_accent_ids(self):
   return deserialize_list(self.serialized_accents)
 def get_symbol_ids(self):
   return deserialize_list(self.serialized_symbols)
Example #10
0
def validate(checkpoint: CheckpointTacotron, data: PreparedDataList,
             trainset: PreparedDataList, custom_hparams: Optional[Dict[str,
                                                                       str]],
             entry_ids: Optional[Set[int]], speaker_name: Optional[str],
             train_name: str, full_run: bool, save_callback: Optional[
                 Callable[[PreparedData, ValidationEntryOutput],
                          None]], max_decoder_steps: int, fast: bool,
             mcd_no_of_coeffs_per_frame: int, repetitions: int,
             seed: Optional[int], select_best_from: Optional[pd.DataFrame],
             logger: Logger) -> ValidationEntries:
    model_symbols = checkpoint.get_symbols()
    model_accents = checkpoint.get_accents()
    model_speakers = checkpoint.get_speakers()

    seeds: List[int]
    validation_data: PreparedDataList

    if full_run:
        validation_data = data
        assert seed is not None
        seeds = [seed for _ in data]
    elif select_best_from is not None:
        assert entry_ids is not None
        logger.info("Finding best seeds...")
        validation_data = PreparedDataList(
            [x for x in data.items() if x.entry_id in entry_ids])
        if len(validation_data) != len(entry_ids):
            logger.error("Not all entry_id's were found!")
            assert False
        entry_ids_order_from_valdata = OrderedSet(
            [x.entry_id for x in validation_data.items()])
        seeds = get_best_seeds(select_best_from, entry_ids_order_from_valdata,
                               checkpoint.iteration, logger)
    #   entry_ids = [entry_id for entry_id, _ in entry_ids_w_seed]
    #   have_no_double_entries = len(set(entry_ids)) == len(entry_ids)
    #   assert have_no_double_entries
    #   validation_data = PreparedDataList([x for x in data.items() if x.entry_id in entry_ids])
    #   seeds = [s for _, s in entry_ids_w_seed]
    #   if len(validation_data) != len(entry_ids):
    #     logger.error("Not all entry_id's were found!")
    #     assert False
    elif entry_ids is not None:
        validation_data = PreparedDataList(
            [x for x in data.items() if x.entry_id in entry_ids])
        if len(validation_data) != len(entry_ids):
            logger.error("Not all entry_id's were found!")
            assert False
        assert seed is not None
        seeds = [seed for _ in validation_data]
    elif speaker_name is not None:
        speaker_id = model_speakers.get_id(speaker_name)
        relevant_entries = [
            x for x in data.items() if x.speaker_id == speaker_id
        ]
        assert len(relevant_entries) > 0
        assert seed is not None
        random.seed(seed)
        entry = random.choice(relevant_entries)
        validation_data = PreparedDataList([entry])
        seeds = [seed]
    else:
        assert seed is not None
        random.seed(seed)
        entry = random.choice(data)
        validation_data = PreparedDataList([entry])
        seeds = [seed]

    assert len(seeds) == len(validation_data)

    if len(validation_data) == 0:
        logger.info("Nothing to synthesize!")
        return validation_data

    train_onegram_rarities = get_ngram_rarity(validation_data, trainset,
                                              model_symbols, 1)
    train_twogram_rarities = get_ngram_rarity(validation_data, trainset,
                                              model_symbols, 2)
    train_threegram_rarities = get_ngram_rarity(validation_data, trainset,
                                                model_symbols, 3)

    synth = Synthesizer(
        checkpoint=checkpoint,
        custom_hparams=custom_hparams,
        logger=logger,
    )

    taco_stft = TacotronSTFT(synth.hparams, logger=logger)
    validation_entries = ValidationEntries()

    for repetition in range(repetitions):
        # rep_seed = seed + repetition
        rep_human_readable = repetition + 1
        logger.info(f"Starting repetition: {rep_human_readable}/{repetitions}")
        for entry, entry_seed in zip(validation_data.items(True), seeds):
            rep_seed = entry_seed + repetition
            logger.info(
                f"Current --> entry_id: {entry.entry_id}; seed: {rep_seed}; iteration: {checkpoint.iteration}; rep: {rep_human_readable}/{repetitions}"
            )

            infer_sent = InferSentence(
                sent_id=1,
                symbols=model_symbols.get_symbols(entry.serialized_symbol_ids),
                accents=model_accents.get_accents(entry.serialized_accent_ids),
                original_text=entry.text_original,
            )

            speaker_name = model_speakers.get_speaker(entry.speaker_id)
            inference_result = synth.infer(
                sentence=infer_sent,
                speaker=speaker_name,
                ignore_unknown_symbols=False,
                max_decoder_steps=max_decoder_steps,
                seed=rep_seed,
            )

            symbol_count = len(deserialize_list(entry.serialized_symbol_ids))
            unique_symbols = set(
                model_symbols.get_symbols(entry.serialized_symbol_ids))
            unique_symbols_str = " ".join(list(sorted(unique_symbols)))
            unique_symbols_count = len(unique_symbols)
            timepoint = f"{datetime.datetime.now():%Y/%m/%d %H:%M:%S}"

            mel_orig: np.ndarray = taco_stft.get_mel_tensor_from_file(
                entry.wav_path).cpu().numpy()

            target_frames = mel_orig.shape[1]
            predicted_frames = inference_result.mel_outputs_postnet.shape[1]
            diff_frames = predicted_frames - target_frames
            frame_deviation_percent = (predicted_frames / target_frames) - 1

            dtw_mcd, dtw_penalty, dtw_frames = get_mcd_between_mel_spectograms(
                mel_1=mel_orig,
                mel_2=inference_result.mel_outputs_postnet,
                n_mfcc=mcd_no_of_coeffs_per_frame,
                take_log=False,
                use_dtw=True,
            )

            padded_mel_orig, padded_mel_postnet = make_same_dim(
                mel_orig, inference_result.mel_outputs_postnet)

            aligned_mel_orig, aligned_mel_postnet, mel_dtw_dist, _, path_mel_postnet = align_mels_with_dtw(
                mel_orig, inference_result.mel_outputs_postnet)

            mel_aligned_length = len(aligned_mel_postnet)
            msd = get_msd(mel_dtw_dist, mel_aligned_length)

            padded_cosine_similarity = cosine_dist_mels(
                padded_mel_orig, padded_mel_postnet)
            aligned_cosine_similarity = cosine_dist_mels(
                aligned_mel_orig, aligned_mel_postnet)

            padded_mse = mean_squared_error(padded_mel_orig,
                                            padded_mel_postnet)
            aligned_mse = mean_squared_error(aligned_mel_orig,
                                             aligned_mel_postnet)

            padded_structural_similarity = None
            aligned_structural_similarity = None

            if not fast:
                padded_mel_orig_img_raw_1, padded_mel_orig_img = plot_melspec_np(
                    padded_mel_orig, title="padded_mel_orig")
                padded_mel_outputs_postnet_img_raw_1, padded_mel_outputs_postnet_img = plot_melspec_np(
                    padded_mel_postnet, title="padded_mel_postnet")

                # imageio.imsave("/tmp/padded_mel_orig_img_raw_1.png", padded_mel_orig_img_raw_1)
                # imageio.imsave("/tmp/padded_mel_outputs_postnet_img_raw_1.png", padded_mel_outputs_postnet_img_raw_1)

                padded_structural_similarity, padded_mel_postnet_diff_img_raw = calculate_structual_similarity_np(
                    img_a=padded_mel_orig_img_raw_1,
                    img_b=padded_mel_outputs_postnet_img_raw_1,
                )

                # imageio.imsave("/tmp/padded_mel_diff_img_raw_1.png", padded_mel_postnet_diff_img_raw)

                aligned_mel_orig_img_raw, aligned_mel_orig_img = plot_melspec_np(
                    aligned_mel_orig, title="aligned_mel_orig")
                aligned_mel_postnet_img_raw, aligned_mel_postnet_img = plot_melspec_np(
                    aligned_mel_postnet, title="aligned_mel_postnet")

                # imageio.imsave("/tmp/aligned_mel_orig_img_raw.png", aligned_mel_orig_img_raw)
                # imageio.imsave("/tmp/aligned_mel_postnet_img_raw.png", aligned_mel_postnet_img_raw)

                aligned_structural_similarity, aligned_mel_diff_img_raw = calculate_structual_similarity_np(
                    img_a=aligned_mel_orig_img_raw,
                    img_b=aligned_mel_postnet_img_raw,
                )

                # imageio.imsave("/tmp/aligned_mel_diff_img_raw.png", aligned_mel_diff_img_raw)

            train_combined_rarity = train_onegram_rarities[entry.entry_id] + \
                train_twogram_rarities[entry.entry_id] + train_threegram_rarities[entry.entry_id]

            val_entry = ValidationEntry(
                entry_id=entry.entry_id,
                repetition=rep_human_readable,
                repetitions=repetitions,
                seed=rep_seed,
                ds_entry_id=entry.ds_entry_id,
                text_original=entry.text_original,
                text=entry.text,
                wav_path=entry.wav_path,
                wav_duration_s=entry.duration_s,
                speaker_id=entry.speaker_id,
                speaker_name=speaker_name,
                iteration=checkpoint.iteration,
                unique_symbols=unique_symbols_str,
                unique_symbols_count=unique_symbols_count,
                symbol_count=symbol_count,
                timepoint=timepoint,
                train_name=train_name,
                sampling_rate=synth.get_sampling_rate(),
                reached_max_decoder_steps=inference_result.
                reached_max_decoder_steps,
                inference_duration_s=inference_result.inference_duration_s,
                predicted_frames=predicted_frames,
                target_frames=target_frames,
                diff_frames=diff_frames,
                frame_deviation_percent=frame_deviation_percent,
                padded_cosine_similarity=padded_cosine_similarity,
                mfcc_no_coeffs=mcd_no_of_coeffs_per_frame,
                mfcc_dtw_mcd=dtw_mcd,
                mfcc_dtw_penalty=dtw_penalty,
                mfcc_dtw_frames=dtw_frames,
                padded_structural_similarity=padded_structural_similarity,
                padded_mse=padded_mse,
                msd=msd,
                aligned_cosine_similarity=aligned_cosine_similarity,
                aligned_mse=aligned_mse,
                aligned_structural_similarity=aligned_structural_similarity,
                global_one_gram_rarity=entry.one_gram_rarity,
                global_two_gram_rarity=entry.two_gram_rarity,
                global_three_gram_rarity=entry.three_gram_rarity,
                global_combined_rarity=entry.combined_rarity,
                train_one_gram_rarity=train_onegram_rarities[entry.entry_id],
                train_two_gram_rarity=train_twogram_rarities[entry.entry_id],
                train_three_gram_rarity=train_threegram_rarities[
                    entry.entry_id],
                train_combined_rarity=train_combined_rarity,
            )

            validation_entries.append(val_entry)

            if not fast:
                orig_sr, orig_wav = read(entry.wav_path)

                _, padded_mel_postnet_diff_img = calculate_structual_similarity_np(
                    img_a=padded_mel_orig_img,
                    img_b=padded_mel_outputs_postnet_img,
                )

                _, aligned_mel_postnet_diff_img = calculate_structual_similarity_np(
                    img_a=aligned_mel_orig_img,
                    img_b=aligned_mel_postnet_img,
                )

                _, mel_orig_img = plot_melspec_np(mel_orig, title="mel_orig")
                # alignments_img = plot_alignment_np(inference_result.alignments)
                _, post_mel_img = plot_melspec_np(
                    inference_result.mel_outputs_postnet, title="mel_postnet")
                _, mel_img = plot_melspec_np(inference_result.mel_outputs,
                                             title="mel")
                _, alignments_img = plot_alignment_np_new(
                    inference_result.alignments, title="alignments")

                aligned_alignments = inference_result.alignments[
                    path_mel_postnet]
                _, aligned_alignments_img = plot_alignment_np_new(
                    aligned_alignments, title="aligned_alignments")
                # imageio.imsave("/tmp/alignments.png", alignments_img)

                validation_entry_output = ValidationEntryOutput(
                    repetition=rep_human_readable,
                    wav_orig=orig_wav,
                    mel_orig=mel_orig,
                    orig_sr=orig_sr,
                    mel_postnet=inference_result.mel_outputs_postnet,
                    mel_postnet_sr=inference_result.sampling_rate,
                    mel_orig_img=mel_orig_img,
                    mel_postnet_img=post_mel_img,
                    mel_postnet_diff_img=padded_mel_postnet_diff_img,
                    alignments_img=alignments_img,
                    mel_img=mel_img,
                    mel_postnet_aligned_diff_img=aligned_mel_postnet_diff_img,
                    mel_orig_aligned=aligned_mel_orig,
                    mel_orig_aligned_img=aligned_mel_orig_img,
                    mel_postnet_aligned=aligned_mel_postnet,
                    mel_postnet_aligned_img=aligned_mel_postnet_img,
                    alignments_aligned_img=aligned_alignments_img,
                )

                assert save_callback is not None

                save_callback(entry, validation_entry_output)

            logger.info(f"MFCC MCD DTW: {val_entry.mfcc_dtw_mcd}")

    return validation_entries
Example #11
0
def validate(checkpoint: CheckpointWaveglow, data: PreparedDataList, custom_hparams: Optional[Dict[str, str]], denoiser_strength: float, sigma: float, entry_ids: Optional[Set[int]], train_name: str, full_run: bool, save_callback: Callable[[PreparedData, ValidationEntryOutput], None], seed: int, logger: Logger):
  validation_entries = ValidationEntries()

  if full_run:
    entries = data
  else:
    speaker_id: Optional[int] = None
    entries = PreparedDataList(data.get_for_validation(entry_ids, speaker_id))

  if len(entries) == 0:
    logger.info("Nothing to synthesize!")
    return validation_entries

  synth = Synthesizer(
    checkpoint=checkpoint,
    custom_hparams=custom_hparams,
    logger=logger
  )

  taco_stft = TacotronSTFT(synth.hparams, logger=logger)

  for entry in entries.items(True):
    mel = taco_stft.get_mel_tensor_from_file(entry.wav_path)
    mel_var = torch.autograd.Variable(mel)
    mel_var = mel_var.cuda()
    mel_var = mel_var.unsqueeze(0)

    inference_result = synth.infer(mel_var, sigma, denoiser_strength, seed=seed)
    wav_inferred_denoised_normalized = normalize_wav(inference_result.wav_denoised)

    symbol_count = len(deserialize_list(entry.serialized_symbol_ids))
    unique_symbols_count = len(set(deserialize_list(entry.serialized_symbol_ids)))
    timepoint = f"{datetime.datetime.now():%Y/%m/%d %H:%M:%S}"

    val_entry = ValidationEntry(
      entry_id=entry.entry_id,
      ds_entry_id=entry.ds_entry_id,
      text_original=entry.text_original,
      text=entry.text,
      seed=seed,
      wav_path=entry.wav_path,
      original_duration_s=entry.duration_s,
      speaker_id=entry.speaker_id,
      iteration=checkpoint.iteration,
      unique_symbols_count=unique_symbols_count,
      symbol_count=symbol_count,
      timepoint=timepoint,
      train_name=train_name,
      sampling_rate=inference_result.sampling_rate,
      inference_duration_s=inference_result.inference_duration_s,
      was_overamplified=inference_result.was_overamplified,
      denoising_duration_s=inference_result.denoising_duration_s,
      inferred_duration_s=get_duration_s(
        inference_result.wav_denoised, inference_result.sampling_rate),
      denoiser_strength=denoiser_strength,
      sigma=sigma,
    )

    val_entry.diff_duration_s = val_entry.inferred_duration_s - val_entry.original_duration_s

    mel_orig = mel.cpu().numpy()

    mel_inferred_denoised_tensor = torch.FloatTensor(wav_inferred_denoised_normalized)
    mel_inferred_denoised = taco_stft.get_mel_tensor(mel_inferred_denoised_tensor)
    mel_inferred_denoised = mel_inferred_denoised.numpy()

    wav_orig, orig_sr = wav_to_float32(entry.wav_path)

    validation_entry_output = ValidationEntryOutput(
      mel_orig=mel_orig,
      inferred_sr=inference_result.sampling_rate,
      mel_inferred_denoised=mel_inferred_denoised,
      wav_inferred_denoised=wav_inferred_denoised_normalized,
      wav_orig=wav_orig,
      orig_sr=orig_sr,
      wav_inferred=normalize_wav(inference_result.wav),
      mel_denoised_diff_img=None,
      mel_inferred_denoised_img=None,
      mel_orig_img=None,
    )

    mcd_dtw, penalty_dtw, final_frame_number_dtw = get_mcd_between_mel_spectograms(
      mel_1=mel_orig,
      mel_2=mel_inferred_denoised,
      n_mfcc=MCD_NO_OF_COEFFS_PER_FRAME,
      take_log=False,
      use_dtw=True,
    )

    val_entry.diff_frames = mel_inferred_denoised.shape[1] - mel_orig.shape[1]
    val_entry.mcd_dtw = mcd_dtw
    val_entry.mcd_dtw_penalty = penalty_dtw
    val_entry.mcd_dtw_frames = final_frame_number_dtw

    mcd, penalty, final_frame_number = get_mcd_between_mel_spectograms(
      mel_1=mel_orig,
      mel_2=mel_inferred_denoised,
      n_mfcc=MCD_NO_OF_COEFFS_PER_FRAME,
      take_log=False,
      use_dtw=False,
    )

    val_entry.mcd = mcd
    val_entry.mcd_penalty = penalty
    val_entry.mcd_frames = final_frame_number

    cosine_similarity = cosine_dist_mels(mel_orig, mel_inferred_denoised)
    val_entry.cosine_similarity = cosine_similarity

    mel_original_img_raw, mel_original_img = plot_melspec_np(mel_orig)
    mel_inferred_denoised_img_raw, mel_inferred_denoised_img = plot_melspec_np(
      mel_inferred_denoised)

    validation_entry_output.mel_orig_img = mel_original_img
    validation_entry_output.mel_inferred_denoised_img = mel_inferred_denoised_img

    mel_original_img_raw_same_dim, mel_inferred_denoised_img_raw_same_dim = make_same_width_by_filling_white(
      img_a=mel_original_img_raw,
      img_b=mel_inferred_denoised_img_raw,
    )

    mel_original_img_same_dim, mel_inferred_denoised_img_same_dim = make_same_width_by_filling_white(
      img_a=mel_original_img,
      img_b=mel_inferred_denoised_img,
    )

    structural_similarity_raw, mel_difference_denoised_img_raw = calculate_structual_similarity_np(
        img_a=mel_original_img_raw_same_dim,
        img_b=mel_inferred_denoised_img_raw_same_dim,
    )
    val_entry.structural_similarity = structural_similarity_raw

    structural_similarity, mel_denoised_diff_img = calculate_structual_similarity_np(
        img_a=mel_original_img_same_dim,
        img_b=mel_inferred_denoised_img_same_dim,
    )
    validation_entry_output.mel_denoised_diff_img = mel_denoised_diff_img

    imageio.imsave("/tmp/mel_original_img_raw.png", mel_original_img_raw)
    imageio.imsave("/tmp/mel_inferred_img_raw.png", mel_inferred_denoised_img_raw)
    imageio.imsave("/tmp/mel_difference_denoised_img_raw.png", mel_difference_denoised_img_raw)

    # logger.info(val_entry)
    logger.info(f"Current: {val_entry.entry_id}")
    logger.info(f"MCD DTW: {val_entry.mcd_dtw}")
    logger.info(f"MCD DTW penalty: {val_entry.mcd_dtw_penalty}")
    logger.info(f"MCD DTW frames: {val_entry.mcd_dtw_frames}")

    logger.info(f"MCD: {val_entry.mcd}")
    logger.info(f"MCD penalty: {val_entry.mcd_penalty}")
    logger.info(f"MCD frames: {val_entry.mcd_frames}")

    # logger.info(f"MCD DTW V2: {val_entry.mcd_dtw_v2}")
    logger.info(f"Structural Similarity: {val_entry.structural_similarity}")
    logger.info(f"Cosine Similarity: {val_entry.cosine_similarity}")
    save_callback(entry, validation_entry_output)
    validation_entries.append(val_entry)
    #score, diff_img = compare_mels(a, b)

  return validation_entries