def process_utterance(mel_dir: str, linear_dir: str, wav_dir: str,
                      basename: str, wav_path: str, hp: hparams):
    """
  Preprocesses a single utterance wav/text pair

  this writes the mel scale spectogram to disk and return a tuple to write
  to the train.txt file

  Args:
    - mel_dir: the directory to write the mel spectograms into
    - linear_dir: the directory to write the linear spectrograms into
    - wav_dir: the directory to write the preprocessed wav into
    - basename: the numeric index to use in the spectogram filename
    - wav_path: path to the audio file containing the speech input
    - text: text spoken in the input audio file
    - hp: hyper parameters

  Returns:
    - A tuple: (audio_filename, mel_filename, linear_filename, time_steps, mel_frames, linear_frames, text)
  """
    try:
        # Load the audio as numpy array
        wav = audio.load_wav(wav_path, sr=hp.sample_rate)
    except FileNotFoundError:  #catch missing wav exception
        print(
            'file {} present in csv metadata is not present in wav folder. skipping!'
            .format(wav_path))
        return None

    #Trim lead/trail silences
    if hp.trim_silence:
        wav = audio.trim_silence(wav, hp)

    #Pre-emphasize
    preem_wav = audio.preemphasis(wav, hp.preemphasis, hp.preemphasize)

    #rescale wav
    if hp.rescale:
        wav = wav / np.abs(wav).max() * hp.rescaling_max
        preem_wav = preem_wav / np.abs(preem_wav).max() * hp.rescaling_max

        #Assert all audio is in [-1, 1]
        if (wav > 1.).any() or (wav < -1.).any():
            raise RuntimeError('wav has invalid value: {}'.format(wav_path))
        if (preem_wav > 1.).any() or (preem_wav < -1.).any():
            raise RuntimeError('wav has invalid value: {}'.format(wav_path))

    #Mu-law quantize
    if is_mulaw_quantize(hp.input_type):
        #[0, quantize_channels)
        out = mulaw_quantize(wav, hp.quantize_channels)

        #Trim silences
        start, end = audio.start_and_end_indices(out, hp.silence_threshold)
        wav = wav[start:end]
        preem_wav = preem_wav[start:end]
        out = out[start:end]

        constant_values = mulaw_quantize(0, hp.quantize_channels)
        out_dtype = np.int16

    elif is_mulaw(hp.input_type):
        #[-1, 1]
        out = mulaw(wav, hp.quantize_channels)
        constant_values = mulaw(0., hp.quantize_channels)
        out_dtype = np.float32

    else:
        #[-1, 1]
        out = wav
        constant_values = 0.
        out_dtype = np.float32

    # Compute the mel scale spectrogram from the wav
    mel_spectrogram = audio.melspectrogram(preem_wav, hp).astype(np.float32)
    mel_frames = mel_spectrogram.shape[1]

    if mel_frames > hp.max_mel_frames and hp.clip_mels_length:
        return None

    #Compute the linear scale spectrogram from the wav
    linear_spectrogram = audio.linearspectrogram(preem_wav,
                                                 hp).astype(np.float32)
    linear_frames = linear_spectrogram.shape[1]

    #sanity check
    assert linear_frames == mel_frames

    if hp.use_lws:
        #Ensure time resolution adjustement between audio and mel-spectrogram
        fft_size = hp.n_fft if hp.win_size is None else hp.win_size
        l, r = audio.pad_lr(wav, fft_size, audio.get_hop_size(hp))

        #Zero pad audio signal
        out = np.pad(out, (l, r),
                     mode='constant',
                     constant_values=constant_values)
    else:
        #Ensure time resolution adjustement between audio and mel-spectrogram
        l_pad, r_pad = audio.librosa_pad_lr(wav, hp.n_fft,
                                            audio.get_hop_size(hp),
                                            hp.wavenet_pad_sides)

        #Reflect pad audio signal on the right (Just like it's done in Librosa to avoid frame inconsistency)
        out = np.pad(out, (l_pad, r_pad),
                     mode='constant',
                     constant_values=constant_values)

    assert len(out) >= mel_frames * audio.get_hop_size(hp)

    #time resolution adjustement
    #ensure length of raw audio is multiple of hop size so that we can use
    #transposed convolution to upsample
    out = out[:mel_frames * audio.get_hop_size(hp)]
    assert len(out) % audio.get_hop_size(hp) == 0
    time_steps = len(out)

    # Write the spectrogram and audio to disk
    audio_filename = '{}.npy'.format(basename)
    mel_filename = '{}.npy'.format(basename)
    linear_filename = '{}.npy'.format(basename)
    np.save(os.path.join(wav_dir, audio_filename),
            out.astype(out_dtype),
            allow_pickle=False)
    np.save(os.path.join(mel_dir, mel_filename),
            mel_spectrogram.T,
            allow_pickle=False)
    np.save(os.path.join(linear_dir, linear_filename),
            linear_spectrogram.T,
            allow_pickle=False)

    # Return a tuple describing this training example
    #return (audio_filename, mel_filename, linear_filename, time_steps, mel_frames, text)
    # is simply the name of the file, length of the audio and the specs
    return (basename, time_steps, mel_frames)
def _process_utterance(mel_dir, wav_dir, index, wav_path, hparams):
    """
	Preprocesses a single utterance wav/text pair

	this writes the mel scale spectogram to disk and return a tuple to write
	to the train.txt file

	Args:
		- mel_dir: the directory to write the mel spectograms into
		- linear_dir: the directory to write the linear spectrograms into
		- wav_dir: the directory to write the preprocessed wav into
		- index: the numeric index to use in the spectrogram filename
		- wav_path: path to the audio file containing the speech input
		- text: text spoken in the input audio file
		- hparams: hyper parameters

	Returns:
		- A tuple: (audio_filename, mel_filename, linear_filename, time_steps, mel_frames, linear_frames, text)
	"""
    try:
        # Load the audio as numpy array
        wav = audio.load_wav(wav_path, sr=hparams.sample_rate)
    except FileNotFoundError:  #catch missing wav exception
        print(
            'file {} present in csv metadata is not present in wav folder. skipping!'
            .format(wav_path))
        return None

    #M-AILABS extra silence specific
    if hparams.trim_silence:
        wav = audio.trim_silence(wav, hparams)

    #Pre-emphasize
    preem_wav = audio.preemphasis(wav, hparams.preemphasis,
                                  hparams.preemphasize)

    #rescale wav
    if hparams.rescale:
        wav = wav / np.abs(wav).max() * hparams.rescaling_max
        preem_wav = preem_wav / np.abs(preem_wav).max() * hparams.rescaling_max

        #Assert all audio is in [-1, 1]
        if (wav > 1.).any() or (wav < -1.).any():
            raise RuntimeError('wav has invalid value: {}'.format(wav_path))
        if (preem_wav > 1.).any() or (preem_wav < -1.).any():
            raise RuntimeError('wav has invalid value: {}'.format(wav_path))

    #Mu-law quantize
    if is_mulaw_quantize(hparams.input_type):
        #[0, quantize_channels)
        out = mulaw_quantize(wav, hparams.quantize_channels)

        #Trim silences
        start, end = audio.start_and_end_indices(out,
                                                 hparams.silence_threshold)
        wav = wav[start:end]
        preem_wav = preem_wav[start:end]
        out = out[start:end]

        constant_values = mulaw_quantize(0, hparams.quantize_channels)
        out_dtype = np.int16

    elif is_mulaw(hparams.input_type):
        #[-1, 1]
        out = mulaw(wav, hparams.quantize_channels)
        constant_values = mulaw(0., hparams.quantize_channels)
        out_dtype = np.float32

    else:
        #[-1, 1]
        out = wav
        constant_values = 0.
        out_dtype = np.float32

    # Compute the mel scale spectrogram from the wav
    mel_spectrogram = audio.melspectrogram(preem_wav,
                                           hparams).astype(np.float32)
    mel_frames = mel_spectrogram.shape[1]

    if mel_frames > hparams.max_mel_frames and hparams.clip_mels_length:
        return None

    if hparams.use_lws:
        #Ensure time resolution adjustement between audio and mel-spectrogram
        fft_size = hparams.n_fft if hparams.win_size is None else hparams.win_size
        l, r = audio.pad_lr(wav, fft_size, audio.get_hop_size(hparams))

        #Zero pad audio signal
        out = np.pad(out, (l, r),
                     mode='constant',
                     constant_values=constant_values)
    else:
        #Ensure time resolution adjustement between audio and mel-spectrogram
        l_pad, r_pad = audio.librosa_pad_lr(wav, hparams.n_fft,
                                            audio.get_hop_size(hparams))

        #Reflect pad audio signal (Just like it's done in Librosa to avoid frame inconsistency)
        out = np.pad(out, (l_pad, r_pad),
                     mode='constant',
                     constant_values=constant_values)

    assert len(out) >= mel_frames * audio.get_hop_size(hparams)

    #time resolution adjustement
    #ensure length of raw audio is multiple of hop size so that we can use
    #transposed convolution to upsample
    out = out[:mel_frames * audio.get_hop_size(hparams)]
    assert len(out) % audio.get_hop_size(hparams) == 0
    time_steps = len(out)

    # Write the spectrogram and audio to disk
    audio_filename = os.path.join(wav_dir, 'audio-{}.npy'.format(index))
    mel_filename = os.path.join(mel_dir, 'mel-{}.npy'.format(index))
    np.save(audio_filename, out.astype(out_dtype), allow_pickle=False)
    np.save(mel_filename, mel_spectrogram.T, allow_pickle=False)

    #global condition features
    if hparams.gin_channels > 0:
        raise RuntimeError(
            'When activating global conditions, please set your speaker_id rules in line 129 of datasets/wavenet_preprocessor.py to use them during training'
        )
        speaker_id = '<no_g>'  #put the rule to determine how to assign speaker ids (using file names maybe? file basenames are available in "index" variable)
    else:
        speaker_id = '<no_g>'

    # Return a tuple describing this training example
    return (audio_filename, mel_filename, mel_filename, speaker_id, time_steps,
            mel_frames)
def eval_step(sess, global_step, model, plot_dir, wav_dir, summary_writer,
              hparams, model_name):
    '''Evaluate model during training.
  Supposes that model variables are averaged.
  '''
    start_time = time.time()
    y_hat, y_target, loss, input_mel, upsampled_features = sess.run([
        model.tower_y_hat[0], model.tower_y_target[0], model.eval_loss,
        model.tower_eval_c[0], model.tower_eval_upsampled_local_features[0]
    ])
    duration = time.time() - start_time
    log('Time Evaluation: Generation of {} audio frames took {:.3f} sec ({:.3f} frames/sec)'
        .format(len(y_target), duration,
                len(y_target) / duration))

    #Make audio and plot paths
    pred_wav_path = os.path.join(wav_dir,
                                 'step-{}-pred.wav'.format(global_step))
    target_wav_path = os.path.join(wav_dir,
                                   'step-{}-real.wav'.format(global_step))
    plot_path = os.path.join(plot_dir,
                             'step-{}-waveplot.png'.format(global_step))
    mel_path = os.path.join(
        plot_dir,
        'step-{}-reconstruction-mel-spectrogram.png'.format(global_step))
    upsampled_path = os.path.join(
        plot_dir, 'step-{}-upsampled-features.png'.format(global_step))

    #Save figure
    util.waveplot(plot_path,
                  y_hat,
                  y_target,
                  model._hparams,
                  title='{}, {}, step={}, loss={:.5f}'.format(
                      model_name, time_string(), global_step, loss))
    log('Eval loss for global step {}: {:.3f}'.format(global_step, loss))

    #Compare generated wav mel with original input mel to evaluate wavenet audio reconstruction performance
    #Both mels should match on low frequency information, wavenet mel should contain more high frequency detail when compared to Tacotron mels.
    T2_output_range = (-hparams.max_abs_value,
                       hparams.max_abs_value) if hparams.symmetric_mels else (
                           0, hparams.max_abs_value)
    generated_mel = _interp(melspectrogram(y_hat, hparams).T, T2_output_range)
    util.plot_spectrogram(
        generated_mel,
        mel_path,
        title='Local Condition vs Reconst. Mel-Spectrogram, step={}, loss={:.5f}'
        .format(global_step, loss),
        target_spectrogram=input_mel.T)
    util.plot_spectrogram(
        upsampled_features.T,
        upsampled_path,
        title='Upsampled Local Condition features, step={}, loss={:.5f}'.
        format(global_step, loss),
        auto_aspect=True)

    #Save Audio
    save_wavenet_wav(y_hat,
                     pred_wav_path,
                     sr=hparams.sample_rate,
                     inv_preemphasize=hparams.preemphasize,
                     k=hparams.preemphasis)
    save_wavenet_wav(y_target,
                     target_wav_path,
                     sr=hparams.sample_rate,
                     inv_preemphasize=hparams.preemphasize,
                     k=hparams.preemphasis)

    #Write eval summary to tensorboard
    log('Writing eval summary!')
    add_test_stats(summary_writer, global_step, loss, hparams=hparams)
def save_log(sess, global_step, model, plot_dir, wav_dir, hparams, model_name):
    log('\nSaving intermediate states at step {}'.format(global_step))
    idx = 0
    y_hat, y, loss, length, input_mel, upsampled_features = sess.run([
        model.tower_y_hat_log[0][idx], model.tower_y_log[0][idx], model.loss,
        model.tower_input_lengths[0][idx], model.tower_c[0][idx],
        model.tower_upsampled_local_features[0][idx]
    ])

    #mask by length
    y_hat[length:] = 0
    y[length:] = 0

    #Make audio and plot paths
    pred_wav_path = os.path.join(wav_dir,
                                 'step-{}-pred.wav'.format(global_step))
    target_wav_path = os.path.join(wav_dir,
                                   'step-{}-real.wav'.format(global_step))
    plot_path = os.path.join(plot_dir,
                             'step-{}-waveplot.png'.format(global_step))
    mel_path = os.path.join(
        plot_dir,
        'step-{}-reconstruction-mel-spectrogram.png'.format(global_step))
    upsampled_path = os.path.join(
        plot_dir, 'step-{}-upsampled-features.png'.format(global_step))

    #Save figure
    util.waveplot(plot_path,
                  y_hat,
                  y,
                  hparams,
                  title='{}, {}, step={}, loss={:.5f}'.format(
                      model_name, time_string(), global_step, loss))

    #Compare generated wav mel with original input mel to evaluate wavenet audio reconstruction performance
    #Both mels should match on low frequency information, wavenet mel should contain more high frequency detail when compared to Tacotron mels.
    T2_output_range = (-hparams.max_abs_value,
                       hparams.max_abs_value) if hparams.symmetric_mels else (
                           0, hparams.max_abs_value)
    generated_mel = _interp(melspectrogram(y_hat, hparams).T, T2_output_range)
    util.plot_spectrogram(
        generated_mel,
        mel_path,
        title='Local Condition vs Reconst. Mel-Spectrogram, step={}, loss={:.5f}'
        .format(global_step, loss),
        target_spectrogram=input_mel.T)
    util.plot_spectrogram(
        upsampled_features.T,
        upsampled_path,
        title='Upsampled Local Condition features, step={}, loss={:.5f}'.
        format(global_step, loss),
        auto_aspect=True)

    #Save audio
    save_wavenet_wav(y_hat,
                     pred_wav_path,
                     sr=hparams.sample_rate,
                     inv_preemphasize=hparams.preemphasize,
                     k=hparams.preemphasis)
    save_wavenet_wav(y,
                     target_wav_path,
                     sr=hparams.sample_rate,
                     inv_preemphasize=hparams.preemphasize,
                     k=hparams.preemphasis)
    def synthesize(self, mel_spectrograms, speaker_ids, basenames, out_dir,
                   log_dir):
        hparams = self._hparams
        local_cond, global_cond = self._check_conditions()

        #Switch mels in case of debug
        if self.synth_debug:
            assert len(hparams.wavenet_debug_mels) == len(
                hparams.wavenet_debug_wavs)
            mel_spectrograms = [
                np.load(mel_file) for mel_file in hparams.wavenet_debug_mels
            ]

        #Get True length of audio to be synthesized: audio_len = mel_len * hop_size
        audio_lengths = [
            len(x) * get_hop_size(self._hparams) for x in mel_spectrograms
        ]

        #Prepare local condition batch
        maxlen = max([len(x) for x in mel_spectrograms])
        #[-max, max] or [0,max]
        T2_output_range = (
            -self._hparams.max_abs_value,
            self._hparams.max_abs_value) if self._hparams.symmetric_mels else (
                0, self._hparams.max_abs_value)

        if self._hparams.clip_for_wavenet:
            mel_spectrograms = [
                np.clip(x, T2_output_range[0], T2_output_range[1])
                for x in mel_spectrograms
            ]

        c_batch = np.stack([
            _pad_inputs(x, maxlen, _pad=T2_output_range[0])
            for x in mel_spectrograms
        ]).astype(np.float32)

        if self._hparams.normalize_for_wavenet:
            #rerange to [0, 1]
            c_batch = _interp(c_batch, T2_output_range).astype(np.float32)

        g = None if speaker_ids is None else np.asarray(
            speaker_ids, dtype=np.int32).reshape(len(c_batch), 1)
        feed_dict = {}

        if local_cond:
            feed_dict[self.local_conditions] = c_batch
        else:
            feed_dict[self.synthesis_length] = 100

        if global_cond:
            feed_dict[self.global_conditions] = g

        if self.synth_debug:
            debug_wavs = hparams.wavenet_debug_wavs
            assert len(debug_wavs) % hparams.wavenet_num_gpus == 0
            test_wavs = [
                np.load(debug_wav).reshape(-1, 1) for debug_wav in debug_wavs
            ]

            #pad wavs to same length
            max_test_len = max([len(x) for x in test_wavs])
            test_wavs = np.stack([
                _pad_inputs(x, max_test_len) for x in test_wavs
            ]).astype(np.float32)

            assert len(test_wavs) == len(debug_wavs)
            feed_dict[self.targets] = test_wavs.reshape(
                len(test_wavs), max_test_len, 1)
            feed_dict[self.input_lengths] = np.asarray([test_wavs.shape[1]])

        #Generate wavs and clip extra padding to select Real speech parts
        generated_wavs, upsampled_features = self.session.run(
            [
                self.model.tower_y_hat,
                self.model.tower_synth_upsampled_local_features
            ],
            feed_dict=feed_dict)

        #Linearize outputs (n_gpus -> 1D)
        generated_wavs = [
            wav for gpu_wavs in generated_wavs for wav in gpu_wavs
        ]
        upsampled_features = [
            feat for gpu_feats in upsampled_features for feat in gpu_feats
        ]

        generated_wavs = [
            generated_wav[:length]
            for generated_wav, length in zip(generated_wavs, audio_lengths)
        ]
        upsampled_features = [
            upsampled_feature[:, :length] for upsampled_feature, length in zip(
                upsampled_features, audio_lengths)
        ]

        audio_filenames = []
        for i, (generated_wav, input_mel, upsampled_feature) in enumerate(
                zip(generated_wavs, mel_spectrograms, upsampled_features)):
            #Save wav to disk
            audio_filename = os.path.join(
                out_dir, 'wavenet-audio-{}.wav'.format(basenames[i]))
            save_wavenet_wav(generated_wav,
                             audio_filename,
                             sr=hparams.sample_rate,
                             inv_preemphasize=hparams.preemphasize,
                             k=hparams.preemphasis)
            audio_filenames.append(audio_filename)

            #Compare generated wav mel with original input mel to evaluate wavenet audio reconstruction performance
            #Both mels should match on low frequency information, wavenet mel should contain more high frequency detail when compared to Tacotron mels.
            generated_mel = melspectrogram(generated_wav, hparams).T
            util.plot_spectrogram(
                generated_mel,
                os.path.join(
                    log_dir,
                    'wavenet-mel-spectrogram-{}.png'.format(basenames[i])),
                title=
                'Local Condition vs Reconstructed Audio Mel-Spectrogram analysis',
                target_spectrogram=input_mel)
            #Save upsampled features to visualize checkerboard artifacts.
            util.plot_spectrogram(
                upsampled_feature.T,
                os.path.join(
                    log_dir,
                    'wavenet-upsampled_features-{}.png'.format(basenames[i])),
                title='Upmsampled Local Condition features',
                auto_aspect=True)

            #Save waveplot to disk
            if log_dir is not None:
                plot_filename = os.path.join(
                    log_dir, 'wavenet-waveplot-{}.png'.format(basenames[i]))
                util.waveplot(plot_filename,
                              generated_wav,
                              None,
                              hparams,
                              title='WaveNet generated Waveform.')

        return audio_filenames