def train(run_id: str, metadata_fpath: str, models_dir: str, save_every: int,
         backup_every: int, force_restart:bool, hparams):

    models_dir = Path(models_dir)
    models_dir.mkdir(exist_ok=True)

    model_dir = models_dir.joinpath(run_id)
    plot_dir = model_dir.joinpath("plots")
    wav_dir = model_dir.joinpath("wavs")
    mel_output_dir = model_dir.joinpath("mel-spectrograms")
    meta_folder = model_dir.joinpath("metas")
    model_dir.mkdir(exist_ok=True)
    plot_dir.mkdir(exist_ok=True)
    wav_dir.mkdir(exist_ok=True)
    mel_output_dir.mkdir(exist_ok=True)
    meta_folder.mkdir(exist_ok=True)
    
    weights_fpath = model_dir.joinpath(run_id).with_suffix(".pt")
    
    print("Checkpoint path: {}".format(weights_fpath))
    print("Loading training data from: {}".format(metadata_fpath))
    print("Using model: Tacotron")
    # return
    
    # Book keeping
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    
    
    # From WaveRNN/train_tacotron.py
    if torch.cuda.is_available():
        device = torch.device("cuda")

        for session in hparams.tts_schedule:
            _, _, _, batch_size = session
            if batch_size % torch.cuda.device_count() != 0:
                raise ValueError("`batch_size` must be evenly divisible by n_gpus!")
    else:
        device = torch.device("cpu")
    print("Using device:", device)

    # Instantiate Tacotron Model
    print("\nInitialising Tacotron Model...\n")
    model = Tacotron(embed_dims=hparams.tts_embed_dims,
                     num_chars=len(symbols),
                     encoder_dims=hparams.tts_encoder_dims,
                     decoder_dims=hparams.tts_decoder_dims,
                     n_mels=hparams.num_mels,
                     fft_bins=hparams.num_mels,
                     postnet_dims=hparams.tts_postnet_dims,
                     encoder_K=hparams.tts_encoder_K,
                     lstm_dims=hparams.tts_lstm_dims,
                     postnet_K=hparams.tts_postnet_K,
                     num_highways=hparams.tts_num_highways,
                     dropout=hparams.tts_dropout,
                     stop_threshold=hparams.tts_stop_threshold,
                     speaker_embedding_size=hparams.speaker_embedding_size).to(device)

    # Initialize the optimizer
    optimizer = optim.Adam(model.parameters())

    # Load the weights
    if force_restart or not weights_fpath.exists():
        print("\nStarting the training of Tacotron from scratch\n")
        model.save(weights_fpath)

        # Embeddings metadata
        char_embedding_fpath = meta_folder.joinpath("CharacterEmbeddings.tsv")
        with open(char_embedding_fpath, "w", encoding="utf-8") as f:
            for symbol in symbols:
                if symbol == " ":
                    symbol = "\\s"  # For visual purposes, swap space with \s

                f.write("{}\n".format(symbol))

    else:
        print("\nLoading weights at %s" % weights_fpath)
        model.load(weights_fpath, optimizer)
        print("Tacotron weights loaded from step %d" % model.step)
    
    # Initialize the dataset
    dataset = SynthesizerDataset(metadata_fpath, hparams)
    # test_loader = DataLoader(dataset,
    #                          batch_size=1,
    #                          shuffle=True,
    #                          pin_memory=True)

    for i, session in enumerate(hparams.tts_schedule):
        current_step = model.get_step()

        r, lr, max_step, batch_size = session

        training_steps = max_step - current_step

        # Do we need to change to the next session?
        if current_step >= max_step:
            # Are there no further sessions than the current one?
            if i == len(hparams.tts_schedule) - 1:
                # We have completed training. Save the model and exit
                model.save(weights_fpath, optimizer)
                break
            else:
                # There is a following session, go to it
                continue

        model.r = r

        # Begin the training
        simple_table([(f"Steps with r={r}", str(training_steps // 1000) + "k Steps"),
                      ("Batch Size", batch_size),
                      ("Learning Rate", lr),
                      ("Outputs/Step (r)", model.r)])

        for p in optimizer.param_groups:
            p["lr"] = lr

        data_loader = DataLoader(dataset,
                                 collate_fn=lambda batch: collate_synthesizer(batch, r, hparams),
                                 batch_size=batch_size,
                                 num_workers=2,
                                 shuffle=True,
                                 pin_memory=True)

        total_iters = len(dataset) 
        steps_per_epoch = np.ceil(total_iters / batch_size).astype(np.int32)
        epochs = np.ceil(training_steps / steps_per_epoch).astype(np.int32)

        for epoch in range(1, epochs+1):
            for i, (texts, mels, embeds, idx) in enumerate(data_loader, 1):
                start_time = time.time()
                start = time.perf_counter()

                # Generate stop tokens for training
                stop = torch.ones(mels.shape[0], mels.shape[2])
                for j, k in enumerate(idx):
                    stop[j, :int(dataset.metadata[k][3])-1] = 0

                texts = texts.to(device)
                mels = mels.to(device)
                embeds = embeds.to(device)
                stop = stop.to(device)

                # print('texts', texts.shape)
                # print(mels.shape)
                # print(embeds.shape)
                # print(stop.shape)

                # Forward pass
                # Parallelize model onto GPUS using workaround due to python bug
                if device.type == "cuda" and torch.cuda.device_count() > 1:
                    m1_hat, m2_hat, attention, stop_pred = data_parallel_workaround(model, texts,
                                                                                    mels, embeds)
                else:
                    m1_hat, m2_hat, attention, stop_pred = model(texts, mels, embeds)

                # Backward pass
                m1_loss = F.mse_loss(m1_hat, mels) + F.l1_loss(m1_hat, mels)
                m2_loss = F.mse_loss(m2_hat, mels)
                stop_loss = F.binary_cross_entropy(stop_pred, stop)

                loss = m1_loss + m2_loss + stop_loss

                optimizer.zero_grad()
                loss.backward()

                # if hparams.tts_clip_grad_norm is not None:
                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hparams.tts_clip_grad_norm)
                if np.isnan(grad_norm.cpu()):
                    print("grad_norm was NaN!")

                optimizer.step()

                time_window.append(time.time() - start_time)
                loss_window.append(loss.item())

                step = model.get_step()
                k = step // 1000

                msg = f"| Epoch: {epoch}/{epochs} ({i}/{steps_per_epoch}) | Loss: {loss_window.average:#.4} | {1./time_window.average:#.2} steps/s | Step: {k}k | "
                stream(msg)

                if step % 10 == 0: 
                    good_logger.log_training(reduced_loss=loss.item(),
                                            reduced_mel_loss=loss.item() - stop_loss.item(),
                                            reduced_gate_loss=stop_loss.item(),
                                            grad_norm=grad_norm,
                                            learning_rate=optimizer.param_groups[0]['lr'],
                                            duration=time.perf_counter() - start,
                                            iteration=step)


                # Backup or save model as appropriate
                if backup_every != 0 and step % backup_every == 0 : 
                    backup_fpath = Path("{}/{}_{}k.pt".format(str(weights_fpath.parent), run_id, k))
                    model.save(backup_fpath, optimizer)

                if save_every != 0 and step % save_every == 0 : 
                    # Must save latest optimizer state to ensure that resuming training
                    # doesn't produce artifacts
                    model.save(weights_fpath, optimizer)

                # Evaluate model to generate samples
                epoch_eval = hparams.tts_eval_interval == -1 and i == steps_per_epoch  # If epoch is done
                step_eval = hparams.tts_eval_interval > 0 and step % hparams.tts_eval_interval == 0  # Every N steps
                if epoch_eval or step_eval:
                    for sample_idx in range(hparams.tts_eval_num_samples):
                        # At most, generate samples equal to number in the batch
                        if sample_idx + 1 <= len(texts):
                            # Remove padding from mels using frame length in metadata
                            mel_length = int(dataset.metadata[idx[sample_idx]][3])
                            mel_prediction = np_now(m2_hat[sample_idx]).T[:mel_length]
                            target_spectrogram = np_now(mels[sample_idx]).T[:mel_length]
                            attention_len = mel_length // model.r

                            eval_model(attention=np_now(attention[sample_idx][:, :attention_len]),
                                       mel_prediction=mel_prediction,
                                       target_spectrogram=target_spectrogram,
                                       input_seq=np_now(texts[sample_idx]),
                                       step=step,
                                       plot_dir=plot_dir,
                                       mel_output_dir=mel_output_dir,
                                       wav_dir=wav_dir,
                                       sample_num=sample_idx + 1,
                                       loss=loss,
                                       hparams=hparams)

                # Break out of loop to update training schedule
                if step >= max_step:
                    break

            # Add line break after every epoch
            print("")
def train(log_dir, args, hparams):
    save_dir = os.path.join(log_dir, "taco_pretrained")
    plot_dir = os.path.join(log_dir, "plots")
    wav_dir = os.path.join(log_dir, "wavs")
    mel_dir = os.path.join(log_dir, "mel-spectrograms")
    eval_dir = os.path.join(log_dir, "eval-dir")
    eval_plot_dir = os.path.join(eval_dir, "plots")
    eval_wav_dir = os.path.join(eval_dir, "wavs")
    tensorboard_dir = os.path.join(log_dir, "tacotron_events")
    meta_folder = os.path.join(log_dir, "metas")
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(plot_dir, exist_ok=True)
    os.makedirs(wav_dir, exist_ok=True)
    os.makedirs(mel_dir, exist_ok=True)
    os.makedirs(eval_dir, exist_ok=True)
    os.makedirs(eval_plot_dir, exist_ok=True)
    os.makedirs(eval_wav_dir, exist_ok=True)
    os.makedirs(tensorboard_dir, exist_ok=True)
    os.makedirs(meta_folder, exist_ok=True)

    checkpoint_fpath = os.path.join(save_dir, "tacotron_model.ckpt")
    if hparams.if_use_speaker_classifier:
        metadat_fpath = os.path.join(args.synthesizer_root,
                                     "train_augment_speaker.txt")
    else:
        metadat_fpath = os.path.join(args.synthesizer_root, "train.txt")

    log("Checkpoint path: {}".format(checkpoint_fpath))
    log("Loading training data from: {}".format(metadat_fpath))
    log("Using model: Tacotron")
    log(hparams_debug_string())

    # Start by setting a seed for repeatability
    tf.set_random_seed(hparams.tacotron_random_seed)

    # Set up data feeder
    coord = tf.train.Coordinator()
    with tf.variable_scope("datafeeder") as scope:
        feeder = Feeder(coord, metadat_fpath, hparams)

    # Set up model:
    global_step = tf.Variable(0, name="global_step", trainable=False)
    model, stats = model_train_mode(args, feeder, hparams, global_step)
    eval_model = model_test_mode(args, feeder, hparams, global_step)

    # Embeddings metadata
    char_embedding_meta = os.path.join(meta_folder, "CharacterEmbeddings.tsv")
    if not os.path.isfile(char_embedding_meta):
        with open(char_embedding_meta, "w", encoding="utf-8") as f:
            for symbol in symbols:
                if symbol == " ":
                    symbol = "\\s"  # For visual purposes, swap space with \s

                f.write("{}\n".format(symbol))

    char_embedding_meta = char_embedding_meta.replace(log_dir, "..")

    # Book keeping
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=5)

    log("Tacotron training set to a maximum of {} steps".format(
        args.tacotron_train_steps))

    # Memory allocation on the GPU as needed
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True

    # Train
    with tf.Session(config=config) as sess:
        try:
            summary_writer = tf.summary.FileWriter(tensorboard_dir, sess.graph)

            sess.run(tf.global_variables_initializer())

            # saved model restoring
            if args.restore:
                # Restore saved model if the user requested it, default = True
                try:
                    checkpoint_state = tf.train.get_checkpoint_state(save_dir)

                    if checkpoint_state and checkpoint_state.model_checkpoint_path:
                        log("Loading checkpoint {}".format(
                            checkpoint_state.model_checkpoint_path),
                            slack=True)
                        saver.restore(sess,
                                      checkpoint_state.model_checkpoint_path)

                    else:
                        log("No model to load at {}".format(save_dir),
                            slack=True)
                        saver.save(sess,
                                   checkpoint_fpath,
                                   global_step=global_step)

                except tf.errors.OutOfRangeError as e:
                    log("Cannot restore checkpoint: {}".format(e), slack=True)
            else:
                log("Starting new training!", slack=True)
                saver.save(sess, checkpoint_fpath, global_step=global_step)

            # initializing feeder
            feeder.start_threads(sess)

            # Training loop
            while not coord.should_stop() and step < args.tacotron_train_steps:
                start_time = time.time()
                step, loss, adversial_loss, opt = sess.run([
                    global_step, model.loss, model.adversial_loss,
                    model.optimize
                ])
                loss -= adversial_loss
                time_window.append(time.time() - start_time)
                loss_window.append(loss)
                message = "Step {:7d} [{:.3f} sec/step, loss={:.5f}, avg_loss={:.5f}, adv_loss={:.5f}]".format(
                    step, time_window.average, loss, loss_window.average,
                    adversial_loss)
                log(message,
                    end="\r",
                    slack=(step % args.checkpoint_interval == 0))
                print(message)

                if loss > 100 or np.isnan(loss):
                    log("Loss exploded to {:.5f} at step {}".format(
                        loss, step))
                    raise Exception("Loss exploded")

                if step % args.summary_interval == 0:
                    log("\nWriting summary at step {}".format(step))
                    summary_writer.add_summary(sess.run(stats), step)

                if step % args.eval_interval == 0:
                    # Run eval and save eval stats
                    log("\nRunning evaluation at step {}".format(step))

                    eval_losses = []
                    before_losses = []
                    after_losses = []
                    stop_token_losses = []
                    linear_losses = []
                    linear_loss = None
                    adversial_losses = []

                    if hparams.predict_linear:
                        for i in tqdm(range(feeder.test_steps)):
                            eloss, before_loss, after_loss, stop_token_loss, linear_loss, mel_p, \
       mel_t, t_len, align, lin_p, lin_t = sess.run(
                                [
                                    eval_model.tower_loss[0], eval_model.tower_before_loss[0],
                                    eval_model.tower_after_loss[0],
                                    eval_model.tower_stop_token_loss[0],
                                    eval_model.tower_linear_loss[0],
                                    eval_model.tower_mel_outputs[0][0],
                                    eval_model.tower_mel_targets[0][0],
                                    eval_model.tower_targets_lengths[0][0],
                                    eval_model.tower_alignments[0][0],
                                    eval_model.tower_linear_outputs[0][0],
                                    eval_model.tower_linear_targets[0][0],
                                ])
                            eval_losses.append(eloss)
                            before_losses.append(before_loss)
                            after_losses.append(after_loss)
                            stop_token_losses.append(stop_token_loss)
                            linear_losses.append(linear_loss)
                        linear_loss = sum(linear_losses) / len(linear_losses)

                        wav = audio.inv_linear_spectrogram(lin_p.T, hparams)
                        audio.save_wav(
                            wav,
                            os.path.join(
                                eval_wav_dir,
                                "step-{}-eval-wave-from-linear.wav".format(
                                    step)),
                            sr=hparams.sample_rate)

                    else:
                        for i in tqdm(range(feeder.test_steps)):
                            eloss, before_loss, after_loss, stop_token_loss, adversial_loss, mel_p, mel_t, t_len,\
       align = sess.run(
                                [
                                    eval_model.tower_loss[0], eval_model.tower_before_loss[0],
                                    eval_model.tower_after_loss[0],
                                    eval_model.tower_stop_token_loss[0],
                                    eval_model.tower_adversial_loss[0],
                                    eval_model.tower_mel_outputs[0][0],
                                    eval_model.tower_mel_targets[0][0],
                                    eval_model.tower_targets_lengths[0][0],
                                    eval_model.tower_alignments[0][0]
                                ])
                            eval_losses.append(eloss)
                            before_losses.append(before_loss)
                            after_losses.append(after_loss)
                            stop_token_losses.append(stop_token_loss)
                            adversial_losses.append(adversial_loss)

                    eval_loss = sum(eval_losses) / len(eval_losses)
                    before_loss = sum(before_losses) / len(before_losses)
                    after_loss = sum(after_losses) / len(after_losses)
                    stop_token_loss = sum(stop_token_losses) / len(
                        stop_token_losses)
                    adversial_loss = sum(adversial_losses) / len(
                        adversial_losses)

                    log("Saving eval log to {}..".format(eval_dir))
                    # Save some log to monitor model improvement on same unseen sequence
                    wav = audio.inv_mel_spectrogram(mel_p.T, hparams)
                    audio.save_wav(
                        wav,
                        os.path.join(
                            eval_wav_dir,
                            "step-{}-eval-wave-from-mel.wav".format(step)),
                        sr=hparams.sample_rate)

                    plot.plot_alignment(
                        align,
                        os.path.join(eval_plot_dir,
                                     "step-{}-eval-align.png".format(step)),
                        title="{}, {}, step={}, loss={:.5f}".format(
                            "Tacotron", time_string(), step, eval_loss),
                        max_len=t_len // hparams.outputs_per_step)
                    plot.plot_spectrogram(
                        mel_p,
                        os.path.join(
                            eval_plot_dir, "step-{"
                            "}-eval-mel-spectrogram.png".format(step)),
                        title="{}, {}, step={}, loss={:.5f}".format(
                            "Tacotron", time_string(), step, eval_loss),
                        target_spectrogram=mel_t,
                        max_len=t_len)

                    if hparams.predict_linear:
                        plot.plot_spectrogram(
                            lin_p,
                            os.path.join(
                                eval_plot_dir,
                                "step-{}-eval-linear-spectrogram.png".format(
                                    step)),
                            title="{}, {}, step={}, loss={:.5f}".format(
                                "Tacotron", time_string(), step, eval_loss),
                            target_spectrogram=lin_t,
                            max_len=t_len,
                            auto_aspect=True)

                    log("Eval loss for global step {}: {:.3f}".format(
                        step, eval_loss))
                    log("Writing eval summary!")
                    add_eval_stats(summary_writer, step, linear_loss,
                                   before_loss, after_loss, stop_token_loss,
                                   adversial_loss, eval_loss)

                if step % args.checkpoint_interval == 0 or step == args.tacotron_train_steps or \
                        step == 300:
                    # Save model and current global step
                    saver.save(sess, checkpoint_fpath, global_step=global_step)

                    log("\nSaving alignment, Mel-Spectrograms and griffin-lim inverted waveform.."
                        )
                    input_seq, mel_prediction, alignment, target, target_length = sess.run(
                        [
                            model.tower_inputs[0][0],
                            model.tower_mel_outputs[0][0],
                            model.tower_alignments[0][0],
                            model.tower_mel_targets[0][0],
                            model.tower_targets_lengths[0][0],
                        ])

                    # save predicted mel spectrogram to disk (debug)
                    mel_filename = "mel-prediction-step-{}.npy".format(step)
                    np.save(os.path.join(mel_dir, mel_filename),
                            mel_prediction.T,
                            allow_pickle=False)

                    # save griffin lim inverted wav for debug (mel -> wav)
                    wav = audio.inv_mel_spectrogram(mel_prediction.T, hparams)
                    audio.save_wav(
                        wav,
                        os.path.join(wav_dir,
                                     "step-{}-wave-from-mel.wav".format(step)),
                        sr=hparams.sample_rate)

                    # save alignment plot to disk (control purposes)
                    plot.plot_alignment(
                        alignment,
                        os.path.join(plot_dir,
                                     "step-{}-align.png".format(step)),
                        title="{}, {}, step={}, loss={:.5f}".format(
                            "Tacotron", time_string(), step, loss),
                        max_len=target_length // hparams.outputs_per_step)
                    # save real and predicted mel-spectrogram plot to disk (control purposes)
                    plot.plot_spectrogram(
                        mel_prediction,
                        os.path.join(
                            plot_dir,
                            "step-{}-mel-spectrogram.png".format(step)),
                        title="{}, {}, step={}, loss={:.5f}".format(
                            "Tacotron", time_string(), step, loss),
                        target_spectrogram=target,
                        max_len=target_length)
                    #log("Input at step {}: {}".format(step, sequence_to_text(input_seq)))

                if step % args.embedding_interval == 0 or step == args.tacotron_train_steps or step == 1:
                    # Get current checkpoint state
                    checkpoint_state = tf.train.get_checkpoint_state(save_dir)

                    # Update Projector
                    #log("\nSaving Model Character Embeddings visualization..")
                    #add_embedding_stats(summary_writer, [model.embedding_table.name],
                    #                    [char_embedding_meta],
                    #                    checkpoint_state.model_checkpoint_path)
                    #log("Tacotron Character embeddings have been updated on tensorboard!")

            log("Tacotron training complete after {} global steps!".format(
                args.tacotron_train_steps),
                slack=True)
            return save_dir

        except Exception as e:
            log("Exiting due to exception: {}".format(e), slack=True)
            traceback.print_exc()
            coord.request_stop(e)
Exemple #3
0
def train(log_dir, args, hparams):
    save_dir = os.path.join(log_dir, "taco_pretrained")
    plot_dir = os.path.join(log_dir, "plots")
    wav_dir = os.path.join(log_dir, "wavs")
    mel_dir = os.path.join(log_dir, "mel-spectrograms")
    eval_dir = os.path.join(log_dir, "eval-dir")
    eval_plot_dir = os.path.join(eval_dir, "plots")
    eval_wav_dir = os.path.join(eval_dir, "wavs")
    tensorboard_dir = os.path.join(log_dir, "tacotron_events")
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(plot_dir, exist_ok=True)
    os.makedirs(wav_dir, exist_ok=True)
    os.makedirs(mel_dir, exist_ok=True)
    os.makedirs(eval_dir, exist_ok=True)
    os.makedirs(eval_plot_dir, exist_ok=True)
    os.makedirs(eval_wav_dir, exist_ok=True)
    os.makedirs(tensorboard_dir, exist_ok=True)

    checkpoint_fpath = os.path.join(save_dir, "tacotron_model.ckpt")

    log("Checkpoint path: {}".format(checkpoint_fpath))
    log("Using model: Tacotron")
    log(hparams_debug_string())

    # Start by setting a seed for repeatability
    tf.set_random_seed(hparams.tacotron_random_seed)

    # Set up data feeder
    coord = tf.train.Coordinator()
    with tf.variable_scope("datafeeder") as scope:
        feeder = Feeder(coord, hparams)

    # Set up model:
    global_step = tf.Variable(0, name="global_step", trainable=False)
    model, stats = model_train_mode(args, feeder, hparams, global_step)
    #eval_model = model_test_mode(args, feeder, hparams, global_step)

    # Book keeping
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=2)

    log("Tacotron training set to a maximum of {} steps".format(
        args.tacotron_train_steps))

    # Memory allocation on the GPU as needed
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True

    # Train
    with tf.Session(config=config) as sess:
        try:
            summary_writer = tf.summary.FileWriter(tensorboard_dir, sess.graph)

            sess.run(tf.global_variables_initializer())

            # saved model restoring
            if args.restore:
                # Restore saved model if the user requested it, default = True
                try:
                    checkpoint_state = tf.train.get_checkpoint_state(save_dir)

                    if checkpoint_state and checkpoint_state.model_checkpoint_path:
                        log("Loading checkpoint {}".format(
                            checkpoint_state.model_checkpoint_path),
                            slack=True)
                        saver.restore(sess,
                                      checkpoint_state.model_checkpoint_path)

                    else:
                        log("No model to load at {}".format(save_dir),
                            slack=True)
                        saver.save(sess,
                                   checkpoint_fpath,
                                   global_step=global_step)

                except tf.errors.OutOfRangeError as e:
                    log("Cannot restore checkpoint: {}".format(e), slack=True)
            else:
                log("Starting new training!", slack=True)
                saver.save(sess, checkpoint_fpath, global_step=global_step)

            # initializing feeder
            feeder.start_threads(sess)
            print("Feeder is intialized and model is ready to train.......")

            # Training loop
            while not coord.should_stop() and step < args.tacotron_train_steps:
                start_time = time.time()
                step, loss, opt = sess.run(
                    [global_step, model.loss, model.optimize])
                time_window.append(time.time() - start_time)
                loss_window.append(loss)
                message = "Step {:7d} [{:.3f} sec/step, loss={:.5f}, avg_loss={:.5f}]".format(
                    step, time_window.average, loss, loss_window.average)
                log(message,
                    end="\r",
                    slack=(step % args.checkpoint_interval == 0))
                print(message)

                if loss > 100 or np.isnan(loss):
                    log("Loss exploded to {:.5f} at step {}".format(
                        loss, step))
                    raise Exception("Loss exploded")

                if step % args.summary_interval == 0:
                    log("\nWriting summary at step {}".format(step))
                    summary_writer.add_summary(sess.run(stats), step)

                if step % args.eval_interval == 0:
                    pass


                if step % args.checkpoint_interval == 0 or step == args.tacotron_train_steps or \
                        step == 300:
                    # Save model and current global step
                    saver.save(sess, checkpoint_fpath, global_step=global_step)

                    log("\nSaving alignment, Mel-Spectrograms and griffin-lim inverted waveform.."
                        )
                    input_seq, mel_prediction, alignment, target, target_length = sess.run(
                        [
                            model.tower_inputs[0][0],
                            model.tower_mel_outputs[0][0],
                            model.tower_alignments[0][0],
                            model.tower_mel_targets[0][0],
                            model.tower_targets_lengths[0][0],
                        ])

                    # save predicted mel spectrogram to disk (debug)
                    mel_filename = "mel-prediction-step-{}.npy".format(step)
                    np.save(os.path.join(mel_dir, mel_filename),
                            mel_prediction.T,
                            allow_pickle=False)

                    # save griffin lim inverted wav for debug (mel -> wav)
                    wav = audio.inv_mel_spectrogram(mel_prediction.T, hparams)
                    audio.save_wav(
                        wav,
                        os.path.join(wav_dir,
                                     "step-{}-wave-from-mel.wav".format(step)),
                        sr=hparams.sample_rate)

                    # save alignment plot to disk (control purposes)
                    plot.plot_alignment(
                        alignment,
                        os.path.join(plot_dir,
                                     "step-{}-align.png".format(step)),
                        title="{}, {}, step={}, loss={:.5f}".format(
                            "Tacotron", time_string(), step, loss),
                        max_len=target_length // hparams.outputs_per_step)
                    # save real and predicted mel-spectrogram plot to disk (control purposes)
                    plot.plot_spectrogram(
                        mel_prediction,
                        os.path.join(
                            plot_dir,
                            "step-{}-mel-spectrogram.png".format(step)),
                        title="{}, {}, step={}, loss={:.5f}".format(
                            "Tacotron", time_string(), step, loss),
                        target_spectrogram=target,
                        max_len=target_length)

                if step % args.embedding_interval == 0 or step == args.tacotron_train_steps or step == 1:
                    # Get current checkpoint state
                    checkpoint_state = tf.train.get_checkpoint_state(save_dir)

            log("Tacotron training complete after {} global steps!".format(
                args.tacotron_train_steps),
                slack=True)
            return save_dir

        except Exception as e:
            log("Exiting due to exception: {}".format(e), slack=True)
            traceback.print_exc()
            coord.request_stop(e)