Пример #1
0
class MelLoader(Dataset):
    """
  This is the main class that calculates the spectrogram and returns the
  spectrogram, audio pair.
  """
    def __init__(self, prepare_ds_data: PreparedDataList, hparams: HParams,
                 logger: Logger):
        self.taco_stft = TacotronSTFT(hparams, logger=logger)
        self.hparams = hparams
        self._logger = logger

        data = prepare_ds_data
        random.seed(hparams.seed)
        random.shuffle(data)

        wav_paths = {}
        for i, values in enumerate(data.items()):
            wav_paths[i] = values.wav_path
        self.wav_paths = wav_paths

        if hparams.cache_wavs:
            self._logger.info("Loading wavs into memory...")
            cache = {}
            for i, wav_path in tqdm(wav_paths.items()):
                cache[i] = self.taco_stft.get_wav_tensor_from_file(wav_path)
            self._logger.info("Done")
            self.cache = cache

    def __getitem__(self, index):
        if self.hparams.cache_wavs:
            wav_tensor = self.cache[index].clone().detach()
        else:
            wav_tensor = self.taco_stft.get_wav_tensor_from_file(
                self.wav_paths[index])
        wav_tensor = get_wav_tensor_segment(wav_tensor,
                                            self.hparams.segment_length)
        mel_tensor = self.taco_stft.get_mel_tensor(wav_tensor)
        return (mel_tensor, wav_tensor)

    def __len__(self):
        return len(self.wav_paths)
Пример #2
0
def infer(
    mel_entries: List[InferMelEntry], checkpoint: CheckpointWaveglow,
    custom_hparams: Optional[Dict[str, str]], denoiser_strength: float,
    sigma: float, sentence_pause_s: float,
    save_callback: Callable[[InferenceEntryOutput],
                            None], concatenate: bool, seed: int, logger: Logger
) -> Tuple[InferenceEntries, Tuple[Optional[np.ndarray], int]]:
    inference_entries = InferenceEntries()

    if len(mel_entries) == 0:
        logger.info("Nothing to synthesize!")
        return inference_entries

    synth = Synthesizer(checkpoint=checkpoint,
                        custom_hparams=custom_hparams,
                        logger=logger)

    # Check mels have the same sampling rate as trained waveglow model
    for mel_entry in mel_entries:
        assert mel_entry.sr == synth.hparams.sampling_rate

    taco_stft = TacotronSTFT(synth.hparams, logger=logger)
    mels_torch = []
    mels_torch_prepared = []
    for mel_entry in mel_entries:
        mel_torch = mel_to_torch(mel_entry.mel)
        mels_torch.append(mel_torch)
        mel_var = torch.autograd.Variable(mel_torch)
        mel_var = mel_var.cuda()
        mel_var = mel_var.unsqueeze(0)
        mels_torch_prepared.append(mel_var)

    inference_results = synth.infer_all(mels_torch_prepared,
                                        sigma,
                                        denoiser_strength,
                                        seed=seed)

    complete_wav_denoised: Optional[np.ndarray] = None
    if concatenate:
        if len(inference_results) >= 1:
            logger.info("Concatening audios...")
        complete_wav_denoised = concatenate_audios(
            [x.wav_denoised for x in inference_results], sentence_pause_s,
            synth.hparams.sampling_rate)
        complete_wav_denoised = normalize_wav(complete_wav_denoised)
        if len(inference_results) >= 1:
            logger.info("Done.")

    inference_result: InferenceResult
    mel_entry: InferMelEntry
    for mel_entry, inference_result in tqdm(zip(mel_entries,
                                                inference_results)):
        wav_inferred_denoised_normalized = normalize_wav(
            inference_result.wav_denoised)
        timepoint = f"{datetime.datetime.now():%Y/%m/%d %H:%M:%S}"

        val_entry = InferenceEntry(
            identifier=mel_entry.identifier,
            iteration=checkpoint.iteration,
            timepoint=timepoint,
            sampling_rate=inference_result.sampling_rate,
            inference_duration_s=inference_result.inference_duration_s,
            was_overamplified=inference_result.was_overamplified,
            denoising_duration_s=inference_result.denoising_duration_s,
            inferred_duration_s=get_duration_s(inference_result.wav_denoised,
                                               inference_result.sampling_rate),
            denoiser_strength=denoiser_strength,
            sigma=sigma,
        )

        mel_orig = mel_entry.mel

        wav_inferred_denoised_normalized_tensor = torch.FloatTensor(
            wav_inferred_denoised_normalized)
        mel_inferred_denoised = taco_stft.get_mel_tensor(
            wav_inferred_denoised_normalized_tensor)
        mel_inferred_denoised = mel_inferred_denoised.numpy()

        validation_entry_output = InferenceEntryOutput(
            identifier=mel_entry.identifier,
            mel_orig=mel_orig,
            inferred_sr=inference_result.sampling_rate,
            mel_inferred_denoised=mel_inferred_denoised,
            wav_inferred_denoised=wav_inferred_denoised_normalized,
            orig_sr=mel_entry.sr,
            wav_inferred=normalize_wav(inference_result.wav),
            mel_denoised_diff_img=None,
            mel_inferred_denoised_img=None,
            mel_orig_img=None,
        )

        mcd_dtw, penalty_dtw, final_frame_number_dtw = get_mcd_between_mel_spectograms(
            mel_1=mel_orig,
            mel_2=mel_inferred_denoised,
            n_mfcc=MCD_NO_OF_COEFFS_PER_FRAME,
            take_log=False,
            use_dtw=True,
        )

        val_entry.diff_frames = mel_inferred_denoised.shape[
            1] - mel_orig.shape[1]
        val_entry.mcd_dtw = mcd_dtw
        val_entry.mcd_dtw_penalty = penalty_dtw
        val_entry.mcd_dtw_frames = final_frame_number_dtw

        mcd, penalty, final_frame_number = get_mcd_between_mel_spectograms(
            mel_1=mel_orig,
            mel_2=mel_inferred_denoised,
            n_mfcc=MCD_NO_OF_COEFFS_PER_FRAME,
            take_log=False,
            use_dtw=False,
        )

        val_entry.mcd = mcd
        val_entry.mcd_penalty = penalty
        val_entry.mcd_frames = final_frame_number

        cosine_similarity = cosine_dist_mels(mel_orig, mel_inferred_denoised)
        val_entry.cosine_similarity = cosine_similarity

        mel_original_img_raw, mel_original_img = plot_melspec_np(mel_orig)
        mel_inferred_denoised_img_raw, mel_inferred_denoised_img = plot_melspec_np(
            mel_inferred_denoised)

        validation_entry_output.mel_orig_img = mel_original_img
        validation_entry_output.mel_inferred_denoised_img = mel_inferred_denoised_img

        mel_original_img_raw_same_dim, mel_inferred_denoised_img_raw_same_dim = make_same_width_by_filling_white(
            img_a=mel_original_img_raw,
            img_b=mel_inferred_denoised_img_raw,
        )

        mel_original_img_same_dim, mel_inferred_denoised_img_same_dim = make_same_width_by_filling_white(
            img_a=mel_original_img,
            img_b=mel_inferred_denoised_img,
        )

        structural_similarity_raw, mel_difference_denoised_img_raw = calculate_structual_similarity_np(
            img_a=mel_original_img_raw_same_dim,
            img_b=mel_inferred_denoised_img_raw_same_dim,
        )
        val_entry.structural_similarity = structural_similarity_raw

        structural_similarity, mel_denoised_diff_img = calculate_structual_similarity_np(
            img_a=mel_original_img_same_dim,
            img_b=mel_inferred_denoised_img_same_dim,
        )
        validation_entry_output.mel_denoised_diff_img = mel_denoised_diff_img

        imageio.imsave("/tmp/mel_original_img_raw.png", mel_original_img_raw)
        imageio.imsave("/tmp/mel_inferred_img_raw.png",
                       mel_inferred_denoised_img_raw)
        imageio.imsave("/tmp/mel_difference_denoised_img_raw.png",
                       mel_difference_denoised_img_raw)

        # logger.info(val_entry)
        logger.info(f"Current: {val_entry.identifier}")
        logger.info(f"MCD DTW: {val_entry.mcd_dtw}")
        logger.info(f"MCD DTW penalty: {val_entry.mcd_dtw_penalty}")
        logger.info(f"MCD DTW frames: {val_entry.mcd_dtw_frames}")

        logger.info(f"MCD: {val_entry.mcd}")
        logger.info(f"MCD penalty: {val_entry.mcd_penalty}")
        logger.info(f"MCD frames: {val_entry.mcd_frames}")

        # logger.info(f"MCD DTW V2: {val_entry.mcd_dtw_v2}")
        logger.info(
            f"Structural Similarity: {val_entry.structural_similarity}")
        logger.info(f"Cosine Similarity: {val_entry.cosine_similarity}")
        save_callback(validation_entry_output)
        inference_entries.append(val_entry)

    return inference_entries, (complete_wav_denoised,
                               synth.hparams.sampling_rate)
Пример #3
0
def validate(checkpoint: CheckpointWaveglow, data: PreparedDataList, custom_hparams: Optional[Dict[str, str]], denoiser_strength: float, sigma: float, entry_ids: Optional[Set[int]], train_name: str, full_run: bool, save_callback: Callable[[PreparedData, ValidationEntryOutput], None], seed: int, logger: Logger):
  validation_entries = ValidationEntries()

  if full_run:
    entries = data
  else:
    speaker_id: Optional[int] = None
    entries = PreparedDataList(data.get_for_validation(entry_ids, speaker_id))

  if len(entries) == 0:
    logger.info("Nothing to synthesize!")
    return validation_entries

  synth = Synthesizer(
    checkpoint=checkpoint,
    custom_hparams=custom_hparams,
    logger=logger
  )

  taco_stft = TacotronSTFT(synth.hparams, logger=logger)

  for entry in entries.items(True):
    mel = taco_stft.get_mel_tensor_from_file(entry.wav_path)
    mel_var = torch.autograd.Variable(mel)
    mel_var = mel_var.cuda()
    mel_var = mel_var.unsqueeze(0)

    inference_result = synth.infer(mel_var, sigma, denoiser_strength, seed=seed)
    wav_inferred_denoised_normalized = normalize_wav(inference_result.wav_denoised)

    symbol_count = len(deserialize_list(entry.serialized_symbol_ids))
    unique_symbols_count = len(set(deserialize_list(entry.serialized_symbol_ids)))
    timepoint = f"{datetime.datetime.now():%Y/%m/%d %H:%M:%S}"

    val_entry = ValidationEntry(
      entry_id=entry.entry_id,
      ds_entry_id=entry.ds_entry_id,
      text_original=entry.text_original,
      text=entry.text,
      seed=seed,
      wav_path=entry.wav_path,
      original_duration_s=entry.duration_s,
      speaker_id=entry.speaker_id,
      iteration=checkpoint.iteration,
      unique_symbols_count=unique_symbols_count,
      symbol_count=symbol_count,
      timepoint=timepoint,
      train_name=train_name,
      sampling_rate=inference_result.sampling_rate,
      inference_duration_s=inference_result.inference_duration_s,
      was_overamplified=inference_result.was_overamplified,
      denoising_duration_s=inference_result.denoising_duration_s,
      inferred_duration_s=get_duration_s(
        inference_result.wav_denoised, inference_result.sampling_rate),
      denoiser_strength=denoiser_strength,
      sigma=sigma,
    )

    val_entry.diff_duration_s = val_entry.inferred_duration_s - val_entry.original_duration_s

    mel_orig = mel.cpu().numpy()

    mel_inferred_denoised_tensor = torch.FloatTensor(wav_inferred_denoised_normalized)
    mel_inferred_denoised = taco_stft.get_mel_tensor(mel_inferred_denoised_tensor)
    mel_inferred_denoised = mel_inferred_denoised.numpy()

    wav_orig, orig_sr = wav_to_float32(entry.wav_path)

    validation_entry_output = ValidationEntryOutput(
      mel_orig=mel_orig,
      inferred_sr=inference_result.sampling_rate,
      mel_inferred_denoised=mel_inferred_denoised,
      wav_inferred_denoised=wav_inferred_denoised_normalized,
      wav_orig=wav_orig,
      orig_sr=orig_sr,
      wav_inferred=normalize_wav(inference_result.wav),
      mel_denoised_diff_img=None,
      mel_inferred_denoised_img=None,
      mel_orig_img=None,
    )

    mcd_dtw, penalty_dtw, final_frame_number_dtw = get_mcd_between_mel_spectograms(
      mel_1=mel_orig,
      mel_2=mel_inferred_denoised,
      n_mfcc=MCD_NO_OF_COEFFS_PER_FRAME,
      take_log=False,
      use_dtw=True,
    )

    val_entry.diff_frames = mel_inferred_denoised.shape[1] - mel_orig.shape[1]
    val_entry.mcd_dtw = mcd_dtw
    val_entry.mcd_dtw_penalty = penalty_dtw
    val_entry.mcd_dtw_frames = final_frame_number_dtw

    mcd, penalty, final_frame_number = get_mcd_between_mel_spectograms(
      mel_1=mel_orig,
      mel_2=mel_inferred_denoised,
      n_mfcc=MCD_NO_OF_COEFFS_PER_FRAME,
      take_log=False,
      use_dtw=False,
    )

    val_entry.mcd = mcd
    val_entry.mcd_penalty = penalty
    val_entry.mcd_frames = final_frame_number

    cosine_similarity = cosine_dist_mels(mel_orig, mel_inferred_denoised)
    val_entry.cosine_similarity = cosine_similarity

    mel_original_img_raw, mel_original_img = plot_melspec_np(mel_orig)
    mel_inferred_denoised_img_raw, mel_inferred_denoised_img = plot_melspec_np(
      mel_inferred_denoised)

    validation_entry_output.mel_orig_img = mel_original_img
    validation_entry_output.mel_inferred_denoised_img = mel_inferred_denoised_img

    mel_original_img_raw_same_dim, mel_inferred_denoised_img_raw_same_dim = make_same_width_by_filling_white(
      img_a=mel_original_img_raw,
      img_b=mel_inferred_denoised_img_raw,
    )

    mel_original_img_same_dim, mel_inferred_denoised_img_same_dim = make_same_width_by_filling_white(
      img_a=mel_original_img,
      img_b=mel_inferred_denoised_img,
    )

    structural_similarity_raw, mel_difference_denoised_img_raw = calculate_structual_similarity_np(
        img_a=mel_original_img_raw_same_dim,
        img_b=mel_inferred_denoised_img_raw_same_dim,
    )
    val_entry.structural_similarity = structural_similarity_raw

    structural_similarity, mel_denoised_diff_img = calculate_structual_similarity_np(
        img_a=mel_original_img_same_dim,
        img_b=mel_inferred_denoised_img_same_dim,
    )
    validation_entry_output.mel_denoised_diff_img = mel_denoised_diff_img

    imageio.imsave("/tmp/mel_original_img_raw.png", mel_original_img_raw)
    imageio.imsave("/tmp/mel_inferred_img_raw.png", mel_inferred_denoised_img_raw)
    imageio.imsave("/tmp/mel_difference_denoised_img_raw.png", mel_difference_denoised_img_raw)

    # logger.info(val_entry)
    logger.info(f"Current: {val_entry.entry_id}")
    logger.info(f"MCD DTW: {val_entry.mcd_dtw}")
    logger.info(f"MCD DTW penalty: {val_entry.mcd_dtw_penalty}")
    logger.info(f"MCD DTW frames: {val_entry.mcd_dtw_frames}")

    logger.info(f"MCD: {val_entry.mcd}")
    logger.info(f"MCD penalty: {val_entry.mcd_penalty}")
    logger.info(f"MCD frames: {val_entry.mcd_frames}")

    # logger.info(f"MCD DTW V2: {val_entry.mcd_dtw_v2}")
    logger.info(f"Structural Similarity: {val_entry.structural_similarity}")
    logger.info(f"Cosine Similarity: {val_entry.cosine_similarity}")
    save_callback(entry, validation_entry_output)
    validation_entries.append(val_entry)
    #score, diff_img = compare_mels(a, b)

  return validation_entries