def process(data: WavDataList, ds: DsDataList, wav_dir: Path,
            custom_hparams: Optional[Dict[str, str]],
            save_callback: Callable[[WavData, DsData], Path]) -> List[Path]:
    hparams = TSTFTHParams()
    hparams = overwrite_custom_hparams(hparams, custom_hparams)
    mel_parser = TacotronSTFT(hparams, logger=getLogger())

    all_paths: List[Path] = []
    for wav_entry, ds_entry in zip(data.items(True), ds.items(True)):
        absolute_wav_path = wav_dir / wav_entry.wav_relative_path
        mel_tensor = mel_parser.get_mel_tensor_from_file(absolute_wav_path)
        absolute_path = save_callback(wav_entry=wav_entry,
                                      ds_entry=ds_entry,
                                      mel_tensor=mel_tensor)
        all_paths.append(absolute_path)
    return all_paths
示例#2
0
class MelLoader(Dataset):
    """
  This is the main class that calculates the spectrogram and returns the
  spectrogram, audio pair.
  """
    def __init__(self, prepare_ds_data: PreparedDataList, hparams: HParams,
                 logger: Logger):
        self.taco_stft = TacotronSTFT(hparams, logger=logger)
        self.hparams = hparams
        self._logger = logger

        data = prepare_ds_data
        random.seed(hparams.seed)
        random.shuffle(data)

        wav_paths = {}
        for i, values in enumerate(data.items()):
            wav_paths[i] = values.wav_path
        self.wav_paths = wav_paths

        if hparams.cache_wavs:
            self._logger.info("Loading wavs into memory...")
            cache = {}
            for i, wav_path in tqdm(wav_paths.items()):
                cache[i] = self.taco_stft.get_wav_tensor_from_file(wav_path)
            self._logger.info("Done")
            self.cache = cache

    def __getitem__(self, index):
        if self.hparams.cache_wavs:
            wav_tensor = self.cache[index].clone().detach()
        else:
            wav_tensor = self.taco_stft.get_wav_tensor_from_file(
                self.wav_paths[index])
        wav_tensor = get_wav_tensor_segment(wav_tensor,
                                            self.hparams.segment_length)
        mel_tensor = self.taco_stft.get_mel_tensor(wav_tensor)
        return (mel_tensor, wav_tensor)

    def __len__(self):
        return len(self.wav_paths)
def remove_silence_plot(wav_path: Path, out_path: Path, chunk_size: int, threshold_start: float, threshold_end: float, buffer_start_ms: float, buffer_end_ms: float):
  remove_silence_file(
    in_path=wav_path,
    out_path=out_path,
    chunk_size=chunk_size,
    threshold_start=threshold_start,
    threshold_end=threshold_end,
    buffer_start_ms=buffer_start_ms,
    buffer_end_ms=buffer_end_ms
  )

  sampling_rate, _ = read(wav_path)

  hparams = TSTFTHParams()
  hparams.sampling_rate = sampling_rate
  plotter = TacotronSTFT(hparams, logger=getLogger())

  mel_orig = plotter.get_mel_tensor_from_file(wav_path)
  mel_trimmed = plotter.get_mel_tensor_from_file(out_path)

  return mel_orig, mel_trimmed
示例#4
0
    def __init__(self, prepare_ds_data: PreparedDataList, hparams: HParams,
                 logger: Logger):
        self.taco_stft = TacotronSTFT(hparams, logger=logger)
        self.hparams = hparams
        self._logger = logger

        data = prepare_ds_data
        random.seed(hparams.seed)
        random.shuffle(data)

        wav_paths = {}
        for i, values in enumerate(data.items()):
            wav_paths[i] = values.wav_path
        self.wav_paths = wav_paths

        if hparams.cache_wavs:
            self._logger.info("Loading wavs into memory...")
            cache = {}
            for i, wav_path in tqdm(wav_paths.items()):
                cache[i] = self.taco_stft.get_wav_tensor_from_file(wav_path)
            self._logger.info("Done")
            self.cache = cache
def process(data: WavDataList, wav_dir: Path, custom_hparams: Optional[Dict[str, str]], save_callback: Callable[[WavData, Tensor], str], n_jobs: int) -> MelDataList:
  hparams = TSTFTHParams()
  hparams = overwrite_custom_hparams(hparams, custom_hparams)
  mel_parser = TacotronSTFT(hparams, logger=getLogger())
  mt_method = partial(
    process_entry,
    wav_dir=wav_dir,
    mel_parser=mel_parser,
    save_callback=save_callback,
  )

  with ThreadPoolExecutor(max_workers=n_jobs) as ex:
    result = MelDataList(tqdm(ex.map(mt_method, data.items()), total=len(data)))

  return result
示例#6
0
    def __init__(self, data: PreparedDataList, hparams: HParams,
                 logger: Logger):
        # random.seed(hparams.seed)
        # random.shuffle(data)
        self.use_saved_mels = hparams.use_saved_mels
        if not hparams.use_saved_mels:
            self.mel_parser = TacotronSTFT(hparams, logger)

        logger.info("Reading files...")
        self.data: Dict[int, Tuple[IntTensor, IntTensor, str, int]] = {}
        for i, values in enumerate(data.items(True)):
            symbol_ids = deserialize_list(values.serialized_symbol_ids)
            accent_ids = deserialize_list(values.serialized_accent_ids)

            model_symbol_ids = get_accent_symbol_ids(
                symbol_ids, accent_ids, hparams.n_symbols,
                hparams.accents_use_own_symbols, SHARED_SYMBOLS_COUNT)

            symbols_tensor = IntTensor(model_symbol_ids)
            accents_tensor = IntTensor(accent_ids)

            if hparams.use_saved_mels:
                self.data[i] = (symbols_tensor, accents_tensor,
                                values.mel_path, values.speaker_id)
            else:
                self.data[i] = (symbols_tensor, accents_tensor,
                                values.wav_path, values.speaker_id)

        if hparams.use_saved_mels and hparams.cache_mels:
            logger.info("Loading mels into memory...")
            self.cache: Dict[int, Tensor] = {}
            vals: tuple
            for i, vals in tqdm(self.data.items()):
                mel_tensor = torch.load(vals[1], map_location='cpu')
                self.cache[i] = mel_tensor
        self.use_cache: bool = hparams.cache_mels
示例#7
0
def infer(
    mel_entries: List[InferMelEntry], checkpoint: CheckpointWaveglow,
    custom_hparams: Optional[Dict[str, str]], denoiser_strength: float,
    sigma: float, sentence_pause_s: float,
    save_callback: Callable[[InferenceEntryOutput],
                            None], concatenate: bool, seed: int, logger: Logger
) -> Tuple[InferenceEntries, Tuple[Optional[np.ndarray], int]]:
    inference_entries = InferenceEntries()

    if len(mel_entries) == 0:
        logger.info("Nothing to synthesize!")
        return inference_entries

    synth = Synthesizer(checkpoint=checkpoint,
                        custom_hparams=custom_hparams,
                        logger=logger)

    # Check mels have the same sampling rate as trained waveglow model
    for mel_entry in mel_entries:
        assert mel_entry.sr == synth.hparams.sampling_rate

    taco_stft = TacotronSTFT(synth.hparams, logger=logger)
    mels_torch = []
    mels_torch_prepared = []
    for mel_entry in mel_entries:
        mel_torch = mel_to_torch(mel_entry.mel)
        mels_torch.append(mel_torch)
        mel_var = torch.autograd.Variable(mel_torch)
        mel_var = mel_var.cuda()
        mel_var = mel_var.unsqueeze(0)
        mels_torch_prepared.append(mel_var)

    inference_results = synth.infer_all(mels_torch_prepared,
                                        sigma,
                                        denoiser_strength,
                                        seed=seed)

    complete_wav_denoised: Optional[np.ndarray] = None
    if concatenate:
        if len(inference_results) >= 1:
            logger.info("Concatening audios...")
        complete_wav_denoised = concatenate_audios(
            [x.wav_denoised for x in inference_results], sentence_pause_s,
            synth.hparams.sampling_rate)
        complete_wav_denoised = normalize_wav(complete_wav_denoised)
        if len(inference_results) >= 1:
            logger.info("Done.")

    inference_result: InferenceResult
    mel_entry: InferMelEntry
    for mel_entry, inference_result in tqdm(zip(mel_entries,
                                                inference_results)):
        wav_inferred_denoised_normalized = normalize_wav(
            inference_result.wav_denoised)
        timepoint = f"{datetime.datetime.now():%Y/%m/%d %H:%M:%S}"

        val_entry = InferenceEntry(
            identifier=mel_entry.identifier,
            iteration=checkpoint.iteration,
            timepoint=timepoint,
            sampling_rate=inference_result.sampling_rate,
            inference_duration_s=inference_result.inference_duration_s,
            was_overamplified=inference_result.was_overamplified,
            denoising_duration_s=inference_result.denoising_duration_s,
            inferred_duration_s=get_duration_s(inference_result.wav_denoised,
                                               inference_result.sampling_rate),
            denoiser_strength=denoiser_strength,
            sigma=sigma,
        )

        mel_orig = mel_entry.mel

        wav_inferred_denoised_normalized_tensor = torch.FloatTensor(
            wav_inferred_denoised_normalized)
        mel_inferred_denoised = taco_stft.get_mel_tensor(
            wav_inferred_denoised_normalized_tensor)
        mel_inferred_denoised = mel_inferred_denoised.numpy()

        validation_entry_output = InferenceEntryOutput(
            identifier=mel_entry.identifier,
            mel_orig=mel_orig,
            inferred_sr=inference_result.sampling_rate,
            mel_inferred_denoised=mel_inferred_denoised,
            wav_inferred_denoised=wav_inferred_denoised_normalized,
            orig_sr=mel_entry.sr,
            wav_inferred=normalize_wav(inference_result.wav),
            mel_denoised_diff_img=None,
            mel_inferred_denoised_img=None,
            mel_orig_img=None,
        )

        mcd_dtw, penalty_dtw, final_frame_number_dtw = get_mcd_between_mel_spectograms(
            mel_1=mel_orig,
            mel_2=mel_inferred_denoised,
            n_mfcc=MCD_NO_OF_COEFFS_PER_FRAME,
            take_log=False,
            use_dtw=True,
        )

        val_entry.diff_frames = mel_inferred_denoised.shape[
            1] - mel_orig.shape[1]
        val_entry.mcd_dtw = mcd_dtw
        val_entry.mcd_dtw_penalty = penalty_dtw
        val_entry.mcd_dtw_frames = final_frame_number_dtw

        mcd, penalty, final_frame_number = get_mcd_between_mel_spectograms(
            mel_1=mel_orig,
            mel_2=mel_inferred_denoised,
            n_mfcc=MCD_NO_OF_COEFFS_PER_FRAME,
            take_log=False,
            use_dtw=False,
        )

        val_entry.mcd = mcd
        val_entry.mcd_penalty = penalty
        val_entry.mcd_frames = final_frame_number

        cosine_similarity = cosine_dist_mels(mel_orig, mel_inferred_denoised)
        val_entry.cosine_similarity = cosine_similarity

        mel_original_img_raw, mel_original_img = plot_melspec_np(mel_orig)
        mel_inferred_denoised_img_raw, mel_inferred_denoised_img = plot_melspec_np(
            mel_inferred_denoised)

        validation_entry_output.mel_orig_img = mel_original_img
        validation_entry_output.mel_inferred_denoised_img = mel_inferred_denoised_img

        mel_original_img_raw_same_dim, mel_inferred_denoised_img_raw_same_dim = make_same_width_by_filling_white(
            img_a=mel_original_img_raw,
            img_b=mel_inferred_denoised_img_raw,
        )

        mel_original_img_same_dim, mel_inferred_denoised_img_same_dim = make_same_width_by_filling_white(
            img_a=mel_original_img,
            img_b=mel_inferred_denoised_img,
        )

        structural_similarity_raw, mel_difference_denoised_img_raw = calculate_structual_similarity_np(
            img_a=mel_original_img_raw_same_dim,
            img_b=mel_inferred_denoised_img_raw_same_dim,
        )
        val_entry.structural_similarity = structural_similarity_raw

        structural_similarity, mel_denoised_diff_img = calculate_structual_similarity_np(
            img_a=mel_original_img_same_dim,
            img_b=mel_inferred_denoised_img_same_dim,
        )
        validation_entry_output.mel_denoised_diff_img = mel_denoised_diff_img

        imageio.imsave("/tmp/mel_original_img_raw.png", mel_original_img_raw)
        imageio.imsave("/tmp/mel_inferred_img_raw.png",
                       mel_inferred_denoised_img_raw)
        imageio.imsave("/tmp/mel_difference_denoised_img_raw.png",
                       mel_difference_denoised_img_raw)

        # logger.info(val_entry)
        logger.info(f"Current: {val_entry.identifier}")
        logger.info(f"MCD DTW: {val_entry.mcd_dtw}")
        logger.info(f"MCD DTW penalty: {val_entry.mcd_dtw_penalty}")
        logger.info(f"MCD DTW frames: {val_entry.mcd_dtw_frames}")

        logger.info(f"MCD: {val_entry.mcd}")
        logger.info(f"MCD penalty: {val_entry.mcd_penalty}")
        logger.info(f"MCD frames: {val_entry.mcd_frames}")

        # logger.info(f"MCD DTW V2: {val_entry.mcd_dtw_v2}")
        logger.info(
            f"Structural Similarity: {val_entry.structural_similarity}")
        logger.info(f"Cosine Similarity: {val_entry.cosine_similarity}")
        save_callback(validation_entry_output)
        inference_entries.append(val_entry)

    return inference_entries, (complete_wav_denoised,
                               synth.hparams.sampling_rate)
def process_entry(entry: WavData, wav_dir: Path, mel_parser: TacotronSTFT, save_callback: Callable[[WavData, Tensor], str]) -> MelData:
  absolute_wav_path = wav_dir / entry.wav_relative_path
  mel_tensor = mel_parser.get_mel_tensor_from_file(absolute_wav_path)
  path = save_callback(wav_entry=entry, mel_tensor=mel_tensor)
  mel_data = MelData(entry.entry_id, path, mel_parser.n_mel_channels)
  return mel_data
示例#9
0
def validate(checkpoint: CheckpointTacotron, data: PreparedDataList,
             trainset: PreparedDataList, custom_hparams: Optional[Dict[str,
                                                                       str]],
             entry_ids: Optional[Set[int]], speaker_name: Optional[str],
             train_name: str, full_run: bool, save_callback: Optional[
                 Callable[[PreparedData, ValidationEntryOutput],
                          None]], max_decoder_steps: int, fast: bool,
             mcd_no_of_coeffs_per_frame: int, repetitions: int,
             seed: Optional[int], select_best_from: Optional[pd.DataFrame],
             logger: Logger) -> ValidationEntries:
    model_symbols = checkpoint.get_symbols()
    model_accents = checkpoint.get_accents()
    model_speakers = checkpoint.get_speakers()

    seeds: List[int]
    validation_data: PreparedDataList

    if full_run:
        validation_data = data
        assert seed is not None
        seeds = [seed for _ in data]
    elif select_best_from is not None:
        assert entry_ids is not None
        logger.info("Finding best seeds...")
        validation_data = PreparedDataList(
            [x for x in data.items() if x.entry_id in entry_ids])
        if len(validation_data) != len(entry_ids):
            logger.error("Not all entry_id's were found!")
            assert False
        entry_ids_order_from_valdata = OrderedSet(
            [x.entry_id for x in validation_data.items()])
        seeds = get_best_seeds(select_best_from, entry_ids_order_from_valdata,
                               checkpoint.iteration, logger)
    #   entry_ids = [entry_id for entry_id, _ in entry_ids_w_seed]
    #   have_no_double_entries = len(set(entry_ids)) == len(entry_ids)
    #   assert have_no_double_entries
    #   validation_data = PreparedDataList([x for x in data.items() if x.entry_id in entry_ids])
    #   seeds = [s for _, s in entry_ids_w_seed]
    #   if len(validation_data) != len(entry_ids):
    #     logger.error("Not all entry_id's were found!")
    #     assert False
    elif entry_ids is not None:
        validation_data = PreparedDataList(
            [x for x in data.items() if x.entry_id in entry_ids])
        if len(validation_data) != len(entry_ids):
            logger.error("Not all entry_id's were found!")
            assert False
        assert seed is not None
        seeds = [seed for _ in validation_data]
    elif speaker_name is not None:
        speaker_id = model_speakers.get_id(speaker_name)
        relevant_entries = [
            x for x in data.items() if x.speaker_id == speaker_id
        ]
        assert len(relevant_entries) > 0
        assert seed is not None
        random.seed(seed)
        entry = random.choice(relevant_entries)
        validation_data = PreparedDataList([entry])
        seeds = [seed]
    else:
        assert seed is not None
        random.seed(seed)
        entry = random.choice(data)
        validation_data = PreparedDataList([entry])
        seeds = [seed]

    assert len(seeds) == len(validation_data)

    if len(validation_data) == 0:
        logger.info("Nothing to synthesize!")
        return validation_data

    train_onegram_rarities = get_ngram_rarity(validation_data, trainset,
                                              model_symbols, 1)
    train_twogram_rarities = get_ngram_rarity(validation_data, trainset,
                                              model_symbols, 2)
    train_threegram_rarities = get_ngram_rarity(validation_data, trainset,
                                                model_symbols, 3)

    synth = Synthesizer(
        checkpoint=checkpoint,
        custom_hparams=custom_hparams,
        logger=logger,
    )

    taco_stft = TacotronSTFT(synth.hparams, logger=logger)
    validation_entries = ValidationEntries()

    for repetition in range(repetitions):
        # rep_seed = seed + repetition
        rep_human_readable = repetition + 1
        logger.info(f"Starting repetition: {rep_human_readable}/{repetitions}")
        for entry, entry_seed in zip(validation_data.items(True), seeds):
            rep_seed = entry_seed + repetition
            logger.info(
                f"Current --> entry_id: {entry.entry_id}; seed: {rep_seed}; iteration: {checkpoint.iteration}; rep: {rep_human_readable}/{repetitions}"
            )

            infer_sent = InferSentence(
                sent_id=1,
                symbols=model_symbols.get_symbols(entry.serialized_symbol_ids),
                accents=model_accents.get_accents(entry.serialized_accent_ids),
                original_text=entry.text_original,
            )

            speaker_name = model_speakers.get_speaker(entry.speaker_id)
            inference_result = synth.infer(
                sentence=infer_sent,
                speaker=speaker_name,
                ignore_unknown_symbols=False,
                max_decoder_steps=max_decoder_steps,
                seed=rep_seed,
            )

            symbol_count = len(deserialize_list(entry.serialized_symbol_ids))
            unique_symbols = set(
                model_symbols.get_symbols(entry.serialized_symbol_ids))
            unique_symbols_str = " ".join(list(sorted(unique_symbols)))
            unique_symbols_count = len(unique_symbols)
            timepoint = f"{datetime.datetime.now():%Y/%m/%d %H:%M:%S}"

            mel_orig: np.ndarray = taco_stft.get_mel_tensor_from_file(
                entry.wav_path).cpu().numpy()

            target_frames = mel_orig.shape[1]
            predicted_frames = inference_result.mel_outputs_postnet.shape[1]
            diff_frames = predicted_frames - target_frames
            frame_deviation_percent = (predicted_frames / target_frames) - 1

            dtw_mcd, dtw_penalty, dtw_frames = get_mcd_between_mel_spectograms(
                mel_1=mel_orig,
                mel_2=inference_result.mel_outputs_postnet,
                n_mfcc=mcd_no_of_coeffs_per_frame,
                take_log=False,
                use_dtw=True,
            )

            padded_mel_orig, padded_mel_postnet = make_same_dim(
                mel_orig, inference_result.mel_outputs_postnet)

            aligned_mel_orig, aligned_mel_postnet, mel_dtw_dist, _, path_mel_postnet = align_mels_with_dtw(
                mel_orig, inference_result.mel_outputs_postnet)

            mel_aligned_length = len(aligned_mel_postnet)
            msd = get_msd(mel_dtw_dist, mel_aligned_length)

            padded_cosine_similarity = cosine_dist_mels(
                padded_mel_orig, padded_mel_postnet)
            aligned_cosine_similarity = cosine_dist_mels(
                aligned_mel_orig, aligned_mel_postnet)

            padded_mse = mean_squared_error(padded_mel_orig,
                                            padded_mel_postnet)
            aligned_mse = mean_squared_error(aligned_mel_orig,
                                             aligned_mel_postnet)

            padded_structural_similarity = None
            aligned_structural_similarity = None

            if not fast:
                padded_mel_orig_img_raw_1, padded_mel_orig_img = plot_melspec_np(
                    padded_mel_orig, title="padded_mel_orig")
                padded_mel_outputs_postnet_img_raw_1, padded_mel_outputs_postnet_img = plot_melspec_np(
                    padded_mel_postnet, title="padded_mel_postnet")

                # imageio.imsave("/tmp/padded_mel_orig_img_raw_1.png", padded_mel_orig_img_raw_1)
                # imageio.imsave("/tmp/padded_mel_outputs_postnet_img_raw_1.png", padded_mel_outputs_postnet_img_raw_1)

                padded_structural_similarity, padded_mel_postnet_diff_img_raw = calculate_structual_similarity_np(
                    img_a=padded_mel_orig_img_raw_1,
                    img_b=padded_mel_outputs_postnet_img_raw_1,
                )

                # imageio.imsave("/tmp/padded_mel_diff_img_raw_1.png", padded_mel_postnet_diff_img_raw)

                aligned_mel_orig_img_raw, aligned_mel_orig_img = plot_melspec_np(
                    aligned_mel_orig, title="aligned_mel_orig")
                aligned_mel_postnet_img_raw, aligned_mel_postnet_img = plot_melspec_np(
                    aligned_mel_postnet, title="aligned_mel_postnet")

                # imageio.imsave("/tmp/aligned_mel_orig_img_raw.png", aligned_mel_orig_img_raw)
                # imageio.imsave("/tmp/aligned_mel_postnet_img_raw.png", aligned_mel_postnet_img_raw)

                aligned_structural_similarity, aligned_mel_diff_img_raw = calculate_structual_similarity_np(
                    img_a=aligned_mel_orig_img_raw,
                    img_b=aligned_mel_postnet_img_raw,
                )

                # imageio.imsave("/tmp/aligned_mel_diff_img_raw.png", aligned_mel_diff_img_raw)

            train_combined_rarity = train_onegram_rarities[entry.entry_id] + \
                train_twogram_rarities[entry.entry_id] + train_threegram_rarities[entry.entry_id]

            val_entry = ValidationEntry(
                entry_id=entry.entry_id,
                repetition=rep_human_readable,
                repetitions=repetitions,
                seed=rep_seed,
                ds_entry_id=entry.ds_entry_id,
                text_original=entry.text_original,
                text=entry.text,
                wav_path=entry.wav_path,
                wav_duration_s=entry.duration_s,
                speaker_id=entry.speaker_id,
                speaker_name=speaker_name,
                iteration=checkpoint.iteration,
                unique_symbols=unique_symbols_str,
                unique_symbols_count=unique_symbols_count,
                symbol_count=symbol_count,
                timepoint=timepoint,
                train_name=train_name,
                sampling_rate=synth.get_sampling_rate(),
                reached_max_decoder_steps=inference_result.
                reached_max_decoder_steps,
                inference_duration_s=inference_result.inference_duration_s,
                predicted_frames=predicted_frames,
                target_frames=target_frames,
                diff_frames=diff_frames,
                frame_deviation_percent=frame_deviation_percent,
                padded_cosine_similarity=padded_cosine_similarity,
                mfcc_no_coeffs=mcd_no_of_coeffs_per_frame,
                mfcc_dtw_mcd=dtw_mcd,
                mfcc_dtw_penalty=dtw_penalty,
                mfcc_dtw_frames=dtw_frames,
                padded_structural_similarity=padded_structural_similarity,
                padded_mse=padded_mse,
                msd=msd,
                aligned_cosine_similarity=aligned_cosine_similarity,
                aligned_mse=aligned_mse,
                aligned_structural_similarity=aligned_structural_similarity,
                global_one_gram_rarity=entry.one_gram_rarity,
                global_two_gram_rarity=entry.two_gram_rarity,
                global_three_gram_rarity=entry.three_gram_rarity,
                global_combined_rarity=entry.combined_rarity,
                train_one_gram_rarity=train_onegram_rarities[entry.entry_id],
                train_two_gram_rarity=train_twogram_rarities[entry.entry_id],
                train_three_gram_rarity=train_threegram_rarities[
                    entry.entry_id],
                train_combined_rarity=train_combined_rarity,
            )

            validation_entries.append(val_entry)

            if not fast:
                orig_sr, orig_wav = read(entry.wav_path)

                _, padded_mel_postnet_diff_img = calculate_structual_similarity_np(
                    img_a=padded_mel_orig_img,
                    img_b=padded_mel_outputs_postnet_img,
                )

                _, aligned_mel_postnet_diff_img = calculate_structual_similarity_np(
                    img_a=aligned_mel_orig_img,
                    img_b=aligned_mel_postnet_img,
                )

                _, mel_orig_img = plot_melspec_np(mel_orig, title="mel_orig")
                # alignments_img = plot_alignment_np(inference_result.alignments)
                _, post_mel_img = plot_melspec_np(
                    inference_result.mel_outputs_postnet, title="mel_postnet")
                _, mel_img = plot_melspec_np(inference_result.mel_outputs,
                                             title="mel")
                _, alignments_img = plot_alignment_np_new(
                    inference_result.alignments, title="alignments")

                aligned_alignments = inference_result.alignments[
                    path_mel_postnet]
                _, aligned_alignments_img = plot_alignment_np_new(
                    aligned_alignments, title="aligned_alignments")
                # imageio.imsave("/tmp/alignments.png", alignments_img)

                validation_entry_output = ValidationEntryOutput(
                    repetition=rep_human_readable,
                    wav_orig=orig_wav,
                    mel_orig=mel_orig,
                    orig_sr=orig_sr,
                    mel_postnet=inference_result.mel_outputs_postnet,
                    mel_postnet_sr=inference_result.sampling_rate,
                    mel_orig_img=mel_orig_img,
                    mel_postnet_img=post_mel_img,
                    mel_postnet_diff_img=padded_mel_postnet_diff_img,
                    alignments_img=alignments_img,
                    mel_img=mel_img,
                    mel_postnet_aligned_diff_img=aligned_mel_postnet_diff_img,
                    mel_orig_aligned=aligned_mel_orig,
                    mel_orig_aligned_img=aligned_mel_orig_img,
                    mel_postnet_aligned=aligned_mel_postnet,
                    mel_postnet_aligned_img=aligned_mel_postnet_img,
                    alignments_aligned_img=aligned_alignments_img,
                )

                assert save_callback is not None

                save_callback(entry, validation_entry_output)

            logger.info(f"MFCC MCD DTW: {val_entry.mfcc_dtw_mcd}")

    return validation_entries
示例#10
0
def validate(checkpoint: CheckpointWaveglow, data: PreparedDataList, custom_hparams: Optional[Dict[str, str]], denoiser_strength: float, sigma: float, entry_ids: Optional[Set[int]], train_name: str, full_run: bool, save_callback: Callable[[PreparedData, ValidationEntryOutput], None], seed: int, logger: Logger):
  validation_entries = ValidationEntries()

  if full_run:
    entries = data
  else:
    speaker_id: Optional[int] = None
    entries = PreparedDataList(data.get_for_validation(entry_ids, speaker_id))

  if len(entries) == 0:
    logger.info("Nothing to synthesize!")
    return validation_entries

  synth = Synthesizer(
    checkpoint=checkpoint,
    custom_hparams=custom_hparams,
    logger=logger
  )

  taco_stft = TacotronSTFT(synth.hparams, logger=logger)

  for entry in entries.items(True):
    mel = taco_stft.get_mel_tensor_from_file(entry.wav_path)
    mel_var = torch.autograd.Variable(mel)
    mel_var = mel_var.cuda()
    mel_var = mel_var.unsqueeze(0)

    inference_result = synth.infer(mel_var, sigma, denoiser_strength, seed=seed)
    wav_inferred_denoised_normalized = normalize_wav(inference_result.wav_denoised)

    symbol_count = len(deserialize_list(entry.serialized_symbol_ids))
    unique_symbols_count = len(set(deserialize_list(entry.serialized_symbol_ids)))
    timepoint = f"{datetime.datetime.now():%Y/%m/%d %H:%M:%S}"

    val_entry = ValidationEntry(
      entry_id=entry.entry_id,
      ds_entry_id=entry.ds_entry_id,
      text_original=entry.text_original,
      text=entry.text,
      seed=seed,
      wav_path=entry.wav_path,
      original_duration_s=entry.duration_s,
      speaker_id=entry.speaker_id,
      iteration=checkpoint.iteration,
      unique_symbols_count=unique_symbols_count,
      symbol_count=symbol_count,
      timepoint=timepoint,
      train_name=train_name,
      sampling_rate=inference_result.sampling_rate,
      inference_duration_s=inference_result.inference_duration_s,
      was_overamplified=inference_result.was_overamplified,
      denoising_duration_s=inference_result.denoising_duration_s,
      inferred_duration_s=get_duration_s(
        inference_result.wav_denoised, inference_result.sampling_rate),
      denoiser_strength=denoiser_strength,
      sigma=sigma,
    )

    val_entry.diff_duration_s = val_entry.inferred_duration_s - val_entry.original_duration_s

    mel_orig = mel.cpu().numpy()

    mel_inferred_denoised_tensor = torch.FloatTensor(wav_inferred_denoised_normalized)
    mel_inferred_denoised = taco_stft.get_mel_tensor(mel_inferred_denoised_tensor)
    mel_inferred_denoised = mel_inferred_denoised.numpy()

    wav_orig, orig_sr = wav_to_float32(entry.wav_path)

    validation_entry_output = ValidationEntryOutput(
      mel_orig=mel_orig,
      inferred_sr=inference_result.sampling_rate,
      mel_inferred_denoised=mel_inferred_denoised,
      wav_inferred_denoised=wav_inferred_denoised_normalized,
      wav_orig=wav_orig,
      orig_sr=orig_sr,
      wav_inferred=normalize_wav(inference_result.wav),
      mel_denoised_diff_img=None,
      mel_inferred_denoised_img=None,
      mel_orig_img=None,
    )

    mcd_dtw, penalty_dtw, final_frame_number_dtw = get_mcd_between_mel_spectograms(
      mel_1=mel_orig,
      mel_2=mel_inferred_denoised,
      n_mfcc=MCD_NO_OF_COEFFS_PER_FRAME,
      take_log=False,
      use_dtw=True,
    )

    val_entry.diff_frames = mel_inferred_denoised.shape[1] - mel_orig.shape[1]
    val_entry.mcd_dtw = mcd_dtw
    val_entry.mcd_dtw_penalty = penalty_dtw
    val_entry.mcd_dtw_frames = final_frame_number_dtw

    mcd, penalty, final_frame_number = get_mcd_between_mel_spectograms(
      mel_1=mel_orig,
      mel_2=mel_inferred_denoised,
      n_mfcc=MCD_NO_OF_COEFFS_PER_FRAME,
      take_log=False,
      use_dtw=False,
    )

    val_entry.mcd = mcd
    val_entry.mcd_penalty = penalty
    val_entry.mcd_frames = final_frame_number

    cosine_similarity = cosine_dist_mels(mel_orig, mel_inferred_denoised)
    val_entry.cosine_similarity = cosine_similarity

    mel_original_img_raw, mel_original_img = plot_melspec_np(mel_orig)
    mel_inferred_denoised_img_raw, mel_inferred_denoised_img = plot_melspec_np(
      mel_inferred_denoised)

    validation_entry_output.mel_orig_img = mel_original_img
    validation_entry_output.mel_inferred_denoised_img = mel_inferred_denoised_img

    mel_original_img_raw_same_dim, mel_inferred_denoised_img_raw_same_dim = make_same_width_by_filling_white(
      img_a=mel_original_img_raw,
      img_b=mel_inferred_denoised_img_raw,
    )

    mel_original_img_same_dim, mel_inferred_denoised_img_same_dim = make_same_width_by_filling_white(
      img_a=mel_original_img,
      img_b=mel_inferred_denoised_img,
    )

    structural_similarity_raw, mel_difference_denoised_img_raw = calculate_structual_similarity_np(
        img_a=mel_original_img_raw_same_dim,
        img_b=mel_inferred_denoised_img_raw_same_dim,
    )
    val_entry.structural_similarity = structural_similarity_raw

    structural_similarity, mel_denoised_diff_img = calculate_structual_similarity_np(
        img_a=mel_original_img_same_dim,
        img_b=mel_inferred_denoised_img_same_dim,
    )
    validation_entry_output.mel_denoised_diff_img = mel_denoised_diff_img

    imageio.imsave("/tmp/mel_original_img_raw.png", mel_original_img_raw)
    imageio.imsave("/tmp/mel_inferred_img_raw.png", mel_inferred_denoised_img_raw)
    imageio.imsave("/tmp/mel_difference_denoised_img_raw.png", mel_difference_denoised_img_raw)

    # logger.info(val_entry)
    logger.info(f"Current: {val_entry.entry_id}")
    logger.info(f"MCD DTW: {val_entry.mcd_dtw}")
    logger.info(f"MCD DTW penalty: {val_entry.mcd_dtw_penalty}")
    logger.info(f"MCD DTW frames: {val_entry.mcd_dtw_frames}")

    logger.info(f"MCD: {val_entry.mcd}")
    logger.info(f"MCD penalty: {val_entry.mcd_penalty}")
    logger.info(f"MCD frames: {val_entry.mcd_frames}")

    # logger.info(f"MCD DTW V2: {val_entry.mcd_dtw_v2}")
    logger.info(f"Structural Similarity: {val_entry.structural_similarity}")
    logger.info(f"Cosine Similarity: {val_entry.cosine_similarity}")
    save_callback(entry, validation_entry_output)
    validation_entries.append(val_entry)
    #score, diff_img = compare_mels(a, b)

  return validation_entries
示例#11
0
class SymbolsMelLoader(Dataset):
    """
    1) loads audio,text pairs
    2) normalizes text and converts them to sequences of one-hot vectors
    3) computes mel-spectrograms from audio files.
  """
    def __init__(self, data: PreparedDataList, hparams: HParams,
                 logger: Logger):
        # random.seed(hparams.seed)
        # random.shuffle(data)
        self.use_saved_mels = hparams.use_saved_mels
        if not hparams.use_saved_mels:
            self.mel_parser = TacotronSTFT(hparams, logger)

        logger.info("Reading files...")
        self.data: Dict[int, Tuple[IntTensor, IntTensor, str, int]] = {}
        for i, values in enumerate(data.items(True)):
            symbol_ids = deserialize_list(values.serialized_symbol_ids)
            accent_ids = deserialize_list(values.serialized_accent_ids)

            model_symbol_ids = get_accent_symbol_ids(
                symbol_ids, accent_ids, hparams.n_symbols,
                hparams.accents_use_own_symbols, SHARED_SYMBOLS_COUNT)

            symbols_tensor = IntTensor(model_symbol_ids)
            accents_tensor = IntTensor(accent_ids)

            if hparams.use_saved_mels:
                self.data[i] = (symbols_tensor, accents_tensor,
                                values.mel_path, values.speaker_id)
            else:
                self.data[i] = (symbols_tensor, accents_tensor,
                                values.wav_path, values.speaker_id)

        if hparams.use_saved_mels and hparams.cache_mels:
            logger.info("Loading mels into memory...")
            self.cache: Dict[int, Tensor] = {}
            vals: tuple
            for i, vals in tqdm(self.data.items()):
                mel_tensor = torch.load(vals[1], map_location='cpu')
                self.cache[i] = mel_tensor
        self.use_cache: bool = hparams.cache_mels

    def __getitem__(self,
                    index: int) -> Tuple[IntTensor, IntTensor, Tensor, int]:
        # return self.cache[index]
        # debug_logger.debug(f"getitem called {index}")
        symbols_tensor, accents_tensor, path, speaker_id = self.data[index]
        if self.use_saved_mels:
            if self.use_cache:
                mel_tensor = self.cache[index].clone().detach()
            else:
                mel_tensor: Tensor = torch.load(path, map_location='cpu')
        else:
            mel_tensor = self.mel_parser.get_mel_tensor_from_file(path)

        symbols_tensor_cloned = symbols_tensor.clone().detach()
        accents_tensor_cloned = accents_tensor.clone().detach()
        # debug_logger.debug(f"getitem finished {index}")
        return symbols_tensor_cloned, accents_tensor_cloned, mel_tensor, speaker_id

    def __len__(self):
        return len(self.data)