コード例 #1
0
def save_current_model(args, checkpoint_path, global_step, hparams, loss,
                       model, plot_dir, saver, sess, step, wav_dir):
    # Save model and current global step
    saver.save(sess, checkpoint_path, global_step=global_step)
    log('\nSaving alignment, Mel-Spectrograms and griffin-lim inverted waveform..'
        )
    input_seq, mel_prediction, linear_prediction, attention_mask_sample, targets_mel, target_length, linear_target = sess.run(
        [
            model.inputs[0],
            model.post_net_predictions[0],
            model.mag_pred[0],
            model.alignments[0],
            model.targets_mel[0],
            model.targets_length[0],
            model.targets_mag[0],
        ])
    alignments, alignment_titles = get_alignments(attention_mask_sample)
    # save griffin lim inverted wav for debug (linear -> wav)
    wav = audio.inv_linear_spectrogram(linear_prediction.T, hparams)
    audio.save_wav(wav,
                   os.path.join(wav_dir, '{}-linear.wav'.format(step)),
                   sr=hparams.sample_rate)
    # Save real and predicted linear-spectrogram plot to disk (control purposes)
    plot.plot_spectrogram(
        linear_prediction,
        os.path.join(plot_dir, '{}-linear-spectrogram.png'.format(step)),
        title='{}, {}, step={}, loss={:.5f}'.format(args.model, time_string(),
                                                    step, loss),
        target_spectrogram=linear_target,
        max_len=target_length,
        auto_aspect=True)
    # save griffin lim inverted wav for debug (mel -> wav)
    wav = audio.inv_mel_spectrogram(mel_prediction.T, hparams)
    audio.save_wav(wav,
                   os.path.join(wav_dir, '{}-mel.wav'.format(step)),
                   sr=hparams.sample_rate)
    # save alignment plot to disk (control purposes)
    for i in range(len(alignments)):
        plot.plot_alignment(
            alignments[i],
            os.path.join(plot_dir,
                         '{}_{}-align.png'.format(step, alignment_titles[i])),
            title='{}, {}, step={}, loss={:.5f}'.format(
                args.model, time_string(), step, loss),
            max_len=target_length // hparams.reduction_factor)
    # save real and predicted mel-spectrogram plot to disk (control purposes)
    plot.plot_spectrogram(mel_prediction,
                          os.path.join(plot_dir,
                                       '{}-mel-spectrogram.png'.format(step)),
                          title='{}, {}, step={}, loss={:.5f}'.format(
                              args.model, time_string(), step, loss),
                          target_spectrogram=targets_mel,
                          max_len=target_length)
    log('Input at step {}: {}'.format(step, sequence_to_text(input_seq)))
コード例 #2
0
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
import os
import numpy as np
from utils import audio
from hparams import hparams as hps

linear_path = './data/linear-000001.npy'
linear_name = linear_path.split('/')[-1].split('.')[0]
linear_p = np.load(linear_path)

mel_path = r'./data/mel-000001.npy'
mel_name = mel_path.split('/')[-1].split('.')[0]
mel_p = np.load(mel_path)

# 保存线性频谱
wav = audio.inv_linear_spectrogram(linear_p.T, hps)
audio.save_wav(wav, os.path.join("./data", "{}.wav".format(linear_name)), hps)

# 保存mel频谱
wav = audio.inv_mel_spectrogram(mel_p.T, hps)
audio.save_wav(wav, os.path.join("./data", "{}.wav".format(mel_name)), hps)
コード例 #3
0
ファイル: train.py プロジェクト: samprate1st/Tacotron-3
def train(log_dir, args):
    save_dir = os.path.join(log_dir, 'pretrained/')
    checkpoint_path = os.path.join(save_dir, 'model.ckpt')
    input_path = os.path.join(args.base_dir, args.input)
    plot_dir = os.path.join(log_dir, 'plots')
    wav_dir = os.path.join(log_dir, 'wavs')
    mel_dir = os.path.join(log_dir, 'mel-spectrograms')
    os.makedirs(plot_dir, exist_ok=True)
    os.makedirs(wav_dir, exist_ok=True)
    os.makedirs(mel_dir, exist_ok=True)
    log('Checkpoint path: {}'.format(checkpoint_path))
    log('Loading training data from: {}'.format(input_path))
    log('Using model: {}'.format(args.model))
    log(hparams_debug_string())

    #Set up data feeder
    coord = tf.train.Coordinator()
    with tf.variable_scope('datafeeder') as scope:
        feeder = Feeder(coord, input_path, hparams)

    #Set up model:
    step_count = 0
    try:
        #simple text file to keep count of global step
        with open(os.path.join(log_dir, 'step_counter.txt'), 'r') as file:
            step_count = int(file.read())
    except:
        print(
            'no step_counter file found, assuming there is no saved checkpoint'
        )

    global_step = tf.Variable(step_count, name='global_step', trainable=False)
    with tf.variable_scope('model') as scope:
        model = create_model(args.model, hparams)
        model.initialize(feeder.inputs, feeder.input_lengths,
                         feeder.mel_targets, feeder.token_targets)
        model.add_loss()
        model.add_optimizer(global_step)
        stats = add_stats(model)

    #Book keeping
    step = 0
    save_step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=5)

    #Memory allocation on the GPU as needed
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True

    #Train
    with tf.Session(config=config) as sess:
        try:
            summary_writer = tf.summary.FileWriter(log_dir, sess.graph)
            sess.run(tf.global_variables_initializer())

            #saved model restoring
            if args.restore:
                #Restore saved model if the user requested it, Default = True.
                try:
                    checkpoint_state = tf.train.get_checkpoint_state(save_dir)
                except tf.errors.OutOfRangeError as e:
                    log('Cannot restore checkpoint: {}'.format(e))

            if (checkpoint_state and checkpoint_state.model_checkpoint_path):
                log('Loading checkpoint {}'.format(
                    checkpoint_state.model_checkpoint_path))
                saver.restore(sess, checkpoint_state.model_checkpoint_path)

            else:
                if not args.restore:
                    log('Starting new training!')
                else:
                    log('No model to load at {}'.format(save_dir))

            #initiating feeder
            feeder.start_in_session(sess)

            #Training loop
            while not coord.should_stop():
                start_time = time.time()
                step, loss, opt = sess.run(
                    [global_step, model.loss, model.optimize])
                time_window.append(time.time() - start_time)
                loss_window.append(loss)
                message = 'Step {:7d} [{:.3f} sec/step, loss={:.5f}, avg_loss={:.5f}]'.format(
                    step, time_window.average, loss, loss_window.average)
                log(message, end='\r')

                if loss > 100 or np.isnan(loss):
                    log('Loss exploded to {:.5f} at step {}'.format(
                        loss, step))
                    raise Exception('Loss exploded')

                if step % args.summary_interval == 0:
                    log('\nWriting summary at step: {}'.format(step))
                    summary_writer.add_summary(sess.run(stats), step)

                if step % args.checkpoint_interval == 0:
                    with open(os.path.join(log_dir, 'step_counter.txt'),
                              'w') as file:
                        file.write(str(step))
                    log('Saving checkpoint to: {}-{}'.format(
                        checkpoint_path, step))
                    saver.save(sess, checkpoint_path, global_step=step)
                    save_step = step

                    log('Saving alignment, Mel-Spectrograms and griffin-lim inverted waveform..'
                        )
                    input_seq, prediction, alignment, target = sess.run([
                        model.inputs[0],
                        model.mel_outputs[0],
                        model.alignments[0],
                        model.mel_targets[0],
                    ])
                    #save predicted spectrogram to disk (for plot and manual evaluation purposes)
                    mel_filename = 'ljspeech-mel-prediction-step-{}.npy'.format(
                        step)
                    np.save(os.path.join(mel_dir, mel_filename),
                            prediction.T,
                            allow_pickle=False)

                    #save griffin lim inverted wav for debug.
                    wav = audio.inv_mel_spectrogram(prediction.T)
                    audio.save_wav(
                        wav,
                        os.path.join(wav_dir,
                                     'step-{}-waveform.wav'.format(step)))

                    #save alignment plot to disk (control purposes)
                    plot.plot_alignment(
                        alignment,
                        os.path.join(plot_dir,
                                     'step-{}-align.png'.format(step)),
                        info='{}, {}, step={}, loss={:.5f}'.format(
                            args.model, time_string(), step, loss))
                    #save real mel-spectrogram plot to disk (control purposes)
                    plot.plot_spectrogram(
                        target,
                        os.path.join(
                            plot_dir,
                            'step-{}-real-mel-spectrogram.png'.format(step)),
                        info='{}, {}, step={}, Real'.format(
                            args.model, time_string(), step, loss))
                    #save predicted mel-spectrogram plot to disk (control purposes)
                    plot.plot_spectrogram(
                        prediction,
                        os.path.join(
                            plot_dir,
                            'step-{}-pred-mel-spectrogram.png'.format(step)),
                        info='{}, {}, step={}, loss={:.5}'.format(
                            args.model, time_string(), step, loss))
                    log('Input at step {}: {}'.format(
                        step, sequence_to_text(input_seq)))

        except Exception as e:
            log('Exiting due to exception: {}'.format(e), slack=True)
            traceback.print_exc()
            coord.request_stop(e)
コード例 #4
0
from utils.audio import melspectrogram,inv_mel_spectrogram,load_wav,save_wav


wav_path = "LJ001-0008.wav"
raw_wav = load_wav(wav_path)
mel_spec = melspectrogram(raw_wav)
inv_wav = inv_mel_spectrogram(mel_spec)
save_wav(inv_wav,"inv.wav")

コード例 #5
0
def run_eval(args, eval_dir, eval_model, eval_plot_dir, eval_wav_dir, feeder,
             hparams, sess, step, summary_writer):
    # Run eval and save eval stats
    log('\nRunning evaluation at step {}'.format(step))
    sum_eval_loss = 0.0
    sum_mel_loss = 0.0
    sum_stop_token_loss = 0.0
    sum_linear_loss = 0.0
    count = 0.0
    mel_p = None
    mel_t = None
    t_len = None
    attention_mask_sample = None
    lin_p = None
    lin_t = None
    for _ in tqdm(range(feeder.test_steps)):
        test_eloss, test_mel_loss, test_stop_token_loss, test_linear_loss, mel_p, mel_t, t_len, attention_mask_sample, lin_p, lin_t = sess.run(
            [
                eval_model.loss,
                eval_model.mel_loss,
                eval_model.stop_token_loss,
                eval_model.linear_loss,
                eval_model.post_net_predictions[0],
                eval_model.targets_mel[0],
                eval_model.targets_length[0],
                eval_model.alignments[0],
                eval_model.mag_pred[0],
                eval_model.targets_mag[0],
            ])
        sum_eval_loss += test_eloss
        sum_mel_loss += test_mel_loss
        sum_stop_token_loss += test_stop_token_loss
        sum_linear_loss += test_linear_loss
        count += 1.0
    wav = audio.inv_linear_spectrogram(lin_p.T, hparams)
    audio.save_wav(wav,
                   os.path.join(eval_wav_dir,
                                '{}-eval-linear.wav'.format(step)),
                   sr=hparams.sample_rate)
    if count > 0.0:
        eval_loss = sum_eval_loss / count
        mel_loss = sum_mel_loss / count
        stop_token_loss = sum_stop_token_loss / count
        linear_loss = sum_linear_loss / count
    else:
        eval_loss = sum_eval_loss
        mel_loss = sum_mel_loss
        stop_token_loss = sum_stop_token_loss
        linear_loss = sum_linear_loss
    log('Saving eval log to {}..'.format(eval_dir))
    # Save some log to monitor model improvement on same unseen sequence
    wav = audio.inv_mel_spectrogram(mel_p.T, hparams)
    audio.save_wav(wav,
                   os.path.join(eval_wav_dir, '{}-eval-mel.wav'.format(step)),
                   sr=hparams.sample_rate)
    alignments, alignment_titles = get_alignments(attention_mask_sample)
    for i in range(len(alignments)):
        plot.plot_alignment(alignments[i],
                            os.path.join(
                                eval_plot_dir, '{}_{}-eval-align.png'.format(
                                    step, alignment_titles[i])),
                            title='{}, {}, step={}, loss={:.5f}'.format(
                                args.model, time_string(), step, eval_loss),
                            max_len=t_len // hparams.reduction_factor)
    plot.plot_spectrogram(
        mel_p,
        os.path.join(eval_plot_dir,
                     '{}-eval-mel-spectrogram.png'.format(step)),
        title='{}, {}, step={}, loss={:.5f}'.format(args.model, time_string(),
                                                    step, eval_loss),
        target_spectrogram=mel_t,
        max_len=t_len)
    plot.plot_spectrogram(
        lin_p,
        os.path.join(eval_plot_dir,
                     '{}-eval-linear-spectrogram.png'.format(step)),
        title='{}, {}, step={}, loss={:.5f}'.format(args.model, time_string(),
                                                    step, eval_loss),
        target_spectrogram=lin_t,
        max_len=t_len,
        auto_aspect=True)
    log('Eval loss for global step {}: {:.3f}'.format(step, eval_loss))
    log('Writing eval summary!')
    add_eval_stats(summary_writer, step, linear_loss, mel_loss,
                   stop_token_loss, eval_loss)
コード例 #6
0
    def synthesize(self, texts, basenames, log_dir, mel_filenames):
        hparams = self._hparams

        # Repeat last sample until number of samples is dividable by the number of GPUs (last run scenario)
        while len(texts) % hparams.synthesis_batch_size != 0:
            texts.append(texts[-1])
            basenames.append(basenames[-1])
            if mel_filenames is not None:
                mel_filenames.append(mel_filenames[-1])
        sequences = [np.asarray(text_to_sequence(text)) for text in texts]
        input_lengths = [len(seq) for seq in sequences]
        seqs, max_seq_len = self._prepare_inputs(sequences)

        feed_dict = {
            self.inputs: seqs,
            self.input_lengths: np.asarray(input_lengths, dtype=np.int32)
        }

        linears, mels, alignments, audio_length = self.session.run(
            [self.linear_outputs, self.mel_outputs, self.alignments[0], self.audio_length],
            feed_dict=feed_dict)
        # Natural batch synthesis
        # Get Mel/Linear lengths for the entire batch from stop_tokens predictions
        target_lengths = audio_length

        if basenames is None:
            # Generate wav and read it
            wav = audio.inv_mel_spectrogram(mels[0].T, hparams)
            audio.save_wav(wav, 'temp.wav', sr=hparams.sample_rate)  # Find a better way

            if platform.system() == 'Linux':
                # Linux wav reader
                os.system('aplay temp.wav')

            elif platform.system() == 'Windows':
                # windows wav reader
                os.system('start /min mplay32 /play /close temp.wav')

            else:
                raise RuntimeError(
                    'Your OS type is not supported yet, please add it to "centaur/synthesizer.py, line-165" and feel free to make a Pull Request ;) Thanks!')

            return

        for i, mel in enumerate(mels):

            if log_dir is not None:
                # save wav (mel -> wav)
                wav = audio.inv_mel_spectrogram(mel.T, hparams)
                audio.save_wav(wav, os.path.join(log_dir, 'wavs/wav-{}-mel.wav'.format(basenames[i])),
                               sr=hparams.sample_rate)
                alignments_samples, alignment_titles = self.get_alignments(alignments)
                for idx in range(len(alignments_samples)):
                    # save alignments
                    plot.plot_alignment(alignments_samples[idx],
                                        os.path.join(log_dir, 'plots/{}.png'.format(
                                            alignment_titles[
                                                idx])),
                                        title='{}'.format(texts[i]), split_title=True, max_len=target_lengths[i])

                # save mel spectrogram plot
                plot.plot_spectrogram(mel,
                                      os.path.join(log_dir, 'plots/mel-{}.png'.format(basenames[i])),
                                      title='{}'.format(texts[i]), split_title=True)

                # save wav (linear -> wav)

                wav = audio.inv_linear_spectrogram(linears[i].T, hparams)
                audio.save_wav(wav,
                               os.path.join(log_dir, 'wavs/wav-{}-linear.wav'.format(basenames[i])),
                               sr=hparams.sample_rate)

                # save linear spectrogram plot
                plot.plot_spectrogram(linears[i],
                                      os.path.join(log_dir, 'plots/linear-{}.png'.format(basenames[i])),
                                      title='{}'.format(texts[i]), split_title=True, auto_aspect=True)
コード例 #7
0


max_length = 100
targets = torch.zeros([1, 1, 80]).to(device)
with torch.no_grad():
  for i in range(max_length):
    t_mask = torch.ones([1,len(targets)]).long().to(device)
    predMel,enc_dec_attn_list = model(inputs,targets,t_mask)
    targets = torch.cat([targets[:,:-1,:], predMel[:,-1:,:]], dim=1)
    targets = torch.cat([targets, torch.zeros([1, 1, 80]).to(device)], dim=1)


import matplotlib.pyplot as plt
bi = 0
plt.figure()
for layer_idx,enc_dec_attn in enumerate(enc_dec_attn_list):
  for head_idx,attn in enumerate(enc_dec_attn[bi].detach().cpu().numpy()):
   idx = layer_idx*8+head_idx+1
   plt.subplot(4,8,idx)
   plt.imshow(attn)


plt.show()


from utils.audio import melspectrogram,inv_mel_spectrogram,load_wav,save_wav
syn_melSpec = targets[0,:-1,:].contiguous().transpose(0,1).cpu().numpy()
syn_wav = inv_mel_spectrogram(syn_melSpec)
save_wav(syn_wav,"测试/syn.wav")