コード例 #1
0
ファイル: synthesize.py プロジェクト: iPRoBe-lab/DeepTalk
def run_eval(args, checkpoint_path, output_dir, hparams, sentences):
    eval_dir = os.path.join(output_dir, "eval")
    log_dir = os.path.join(output_dir, "logs-eval")
    
    #Create output path if it doesn"t exist
    os.makedirs(eval_dir, exist_ok=True)
    os.makedirs(log_dir, exist_ok=True)
    os.makedirs(os.path.join(log_dir, "wavs"), exist_ok=True)
    os.makedirs(os.path.join(log_dir, "plots"), exist_ok=True)
    
    log(hparams_debug_string())
    synth = Tacotron2(checkpoint_path, hparams)
    
    #Set inputs batch wise
    sentences = [sentences[i: i+hparams.tacotron_synthesis_batch_size] for i 
                 in range(0, len(sentences), hparams.tacotron_synthesis_batch_size)]
    
    log("Starting Synthesis")
    with open(os.path.join(eval_dir, "map.txt"), "w") as file:
        for i, texts in enumerate(tqdm(sentences)):
            start = time.time()
            basenames = ["batch_{}_sentence_{}".format(i, j) for j in range(len(texts))]
            mel_filenames, speaker_ids = synth.synthesize(texts, basenames, eval_dir, log_dir, None)
            
            for elems in zip(texts, mel_filenames, speaker_ids):
                file.write("|".join([str(x) for x in elems]) + "\n")
    log("synthesized mel spectrograms at {}".format(eval_dir))
    return eval_dir
コード例 #2
0
ファイル: synthesize_ph.py プロジェクト: peter-yh-wu/dl
def run_synthesis(in_dir, out_dir, model_dir, hparams):
    synth_dir = os.path.join(out_dir, "mels_gta")
    os.makedirs(synth_dir, exist_ok=True)
    metadata_filename = os.path.join(in_dir, "train_ph.txt")
    print(hparams_debug_string())

    # Load the model in memory
    weights_dir = os.path.join(model_dir, "taco_pretrained")
    checkpoint_fpath = tf.train.get_checkpoint_state(
        weights_dir).model_checkpoint_path
    synth = Tacotron2(checkpoint_fpath, hparams, gta=True)

    # Load the metadata
    with open(metadata_filename, encoding="utf-8") as f:
        metadata = [line.strip().split("|") for line in f]
        frame_shift_ms = hparams.hop_size / hparams.sample_rate
        hours = sum([int(x[4]) for x in metadata]) * frame_shift_ms / 3600
        print("Loaded metadata for {} examples ({:.2f} hours)".format(
            len(metadata), hours))

    #Set inputs batch wise
    metadata = [
        metadata[i:i + hparams.tacotron_synthesis_batch_size]
        for i in range(0, len(metadata), hparams.tacotron_synthesis_batch_size)
    ]
    metadata = metadata[:-1]

    print("Starting Synthesis")
    mel_dir = os.path.join(in_dir, "stft")
    embed_dir = os.path.join(in_dir, "speaker_emb")
    text_embed_dir = os.path.join(in_dir, "devise")
    meta_out_fpath = os.path.join(out_dir, "synthesized.txt")
    with open(meta_out_fpath, "w") as file:
        for i, meta in enumerate(tqdm(metadata)):
            texts = [m[5] for m in meta]
            mel_filenames = [mel_dir + m[1] for m in meta]
            embed_filenames = [
                os.path.join(embed_dir, m[2].replace('embed', 'mbed'))
                for m in meta
            ]
            text_embed_filenames = [
                os.path.join(text_embed_dir, m[2]) for m in meta
            ]

            basenames = [
                os.path.basename(m).replace(".npy", "").replace("mel-", "")
                for m in mel_filenames
            ]
            synth.synthesize(texts, basenames, synth_dir, None, mel_filenames,
                             embed_filenames, text_embed_filenames)

            for elems in meta:
                file.write("|".join([str(x) for x in elems]) + "\n")

    print("Synthesized mel spectrograms at {}".format(synth_dir))
    return meta_out_fpath
コード例 #3
0
def run_synthesis(in_dir, out_dir, model_dir, hparams):
    synth_dir = os.path.join(out_dir, "mels_gta")
    os.makedirs(synth_dir, exist_ok=True)
    metadata_filename = os.path.join(in_dir, "train.txt")
    print(hparams_debug_string())

    # Load the model in memory
    weights_dir = os.path.join(model_dir, "taco_pretrained")
    checkpoint_fpath = tf.train.get_checkpoint_state(
        weights_dir).model_checkpoint_path
    checkpoint_fpath = checkpoint_fpath.replace(
        '/ssd_scratch/cvit/rudra/SV2TTS/', '')
    checkpoint_fpath = checkpoint_fpath.replace('logs-', '')
    synth = Tacotron2(checkpoint_fpath, hparams, gta=True)

    # Load the metadata
    with open(metadata_filename, encoding="utf-8") as f:
        metadata = [line.strip().split("|") for line in f][:149736]
        frame_shift_ms = hparams.hop_size / hparams.sample_rate
        hours = sum([int(x[4]) for x in metadata]) * frame_shift_ms / 3600
        print("Loaded metadata for {} examples ({:.2f} hours)".format(
            len(metadata), hours))

    #Set inputs batch wise
    metadata = [
        metadata[i:i + hparams.tacotron_synthesis_batch_size]
        for i in range(0, len(metadata), hparams.tacotron_synthesis_batch_size)
    ]
    # TODO: come on big boy, fix this
    # Quick and dirty fix to make sure that all batches have the same size
    metadata = metadata[:-1]

    print("Starting Synthesis")
    mel_dir = os.path.join(in_dir, "mels")
    embed_dir = os.path.join(in_dir, "embeds")
    meta_out_fpath = os.path.join(out_dir, "synthesized.txt")
    with open(meta_out_fpath, "w") as file:
        for i, meta in enumerate(tqdm(metadata)):
            texts = [m[5] for m in meta]
            mel_filenames = [os.path.join(mel_dir, m[1]) for m in meta]
            embed_filenames = [os.path.join(embed_dir, m[2]) for m in meta]
            basenames = [
                os.path.basename(m).replace(".npy", "").replace("mel-", "")
                for m in mel_filenames
            ]
            synth.synthesize(texts, basenames, synth_dir, None, mel_filenames,
                             embed_filenames)

            for elems in meta:
                file.write("|".join([str(x) for x in elems]) + "\n")

    print("Synthesized mel spectrograms at {}".format(synth_dir))
    return meta_out_fpath
コード例 #4
0
def run_live(args, checkpoint_path, hparams):
    #Log to Terminal without keeping any records in files
    log(hparams_debug_string())
    synth = Synthesizer()
    synth.load(checkpoint_path, hparams)

    #Generate fast greeting message
    greetings = "Hello, Welcome to the Live testing tool. Please type a message and I will try " \
                "to read it!"
    log(greetings)
    generate_fast(synth, greetings)

    #Interaction loop
    while True:
        try:
            text = input()
            generate_fast(synth, text)

        except KeyboardInterrupt:
            leave = "Thank you for testing our features. see you soon."
            log(leave)
            generate_fast(synth, leave)
            sleep(2)
            break
コード例 #5
0
def train(log_dir, args, hparams):
    save_dir = os.path.join(log_dir, "taco_pretrained")
    plot_dir = os.path.join(log_dir, "plots")
    wav_dir = os.path.join(log_dir, "wavs")
    mel_dir = os.path.join(log_dir, "mel-spectrograms")
    eval_dir = os.path.join(log_dir, "eval-dir")
    eval_plot_dir = os.path.join(eval_dir, "plots")
    eval_wav_dir = os.path.join(eval_dir, "wavs")
    tensorboard_dir = os.path.join(log_dir, "tacotron_events")
    meta_folder = os.path.join(log_dir, "metas")
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(plot_dir, exist_ok=True)
    os.makedirs(wav_dir, exist_ok=True)
    os.makedirs(mel_dir, exist_ok=True)
    os.makedirs(eval_dir, exist_ok=True)
    os.makedirs(eval_plot_dir, exist_ok=True)
    os.makedirs(eval_wav_dir, exist_ok=True)
    os.makedirs(tensorboard_dir, exist_ok=True)
    os.makedirs(meta_folder, exist_ok=True)

    checkpoint_fpath = os.path.join(save_dir, "tacotron_model.ckpt")
    if hparams.if_use_speaker_classifier:
        metadat_fpath = os.path.join(args.synthesizer_root,
                                     "train_augment_speaker.txt")
    else:
        metadat_fpath = os.path.join(args.synthesizer_root, "train.txt")

    log("Checkpoint path: {}".format(checkpoint_fpath))
    log("Loading training data from: {}".format(metadat_fpath))
    log("Using model: Tacotron")
    log(hparams_debug_string())

    # Start by setting a seed for repeatability
    tf.set_random_seed(hparams.tacotron_random_seed)

    # Set up data feeder
    coord = tf.train.Coordinator()
    with tf.variable_scope("datafeeder") as scope:
        feeder = Feeder(coord, metadat_fpath, hparams)

    # Set up model:
    global_step = tf.Variable(0, name="global_step", trainable=False)
    model, stats = model_train_mode(args, feeder, hparams, global_step)
    eval_model = model_test_mode(args, feeder, hparams, global_step)

    # Embeddings metadata
    char_embedding_meta = os.path.join(meta_folder, "CharacterEmbeddings.tsv")
    if not os.path.isfile(char_embedding_meta):
        with open(char_embedding_meta, "w", encoding="utf-8") as f:
            for symbol in symbols:
                if symbol == " ":
                    symbol = "\\s"  # For visual purposes, swap space with \s

                f.write("{}\n".format(symbol))

    char_embedding_meta = char_embedding_meta.replace(log_dir, "..")

    # Book keeping
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=5)

    log("Tacotron training set to a maximum of {} steps".format(
        args.tacotron_train_steps))

    # Memory allocation on the GPU as needed
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True

    # Train
    with tf.Session(config=config) as sess:
        try:
            summary_writer = tf.summary.FileWriter(tensorboard_dir, sess.graph)

            sess.run(tf.global_variables_initializer())

            # saved model restoring
            if args.restore:
                # Restore saved model if the user requested it, default = True
                try:
                    checkpoint_state = tf.train.get_checkpoint_state(save_dir)

                    if checkpoint_state and checkpoint_state.model_checkpoint_path:
                        log("Loading checkpoint {}".format(
                            checkpoint_state.model_checkpoint_path),
                            slack=True)
                        saver.restore(sess,
                                      checkpoint_state.model_checkpoint_path)

                    else:
                        log("No model to load at {}".format(save_dir),
                            slack=True)
                        saver.save(sess,
                                   checkpoint_fpath,
                                   global_step=global_step)

                except tf.errors.OutOfRangeError as e:
                    log("Cannot restore checkpoint: {}".format(e), slack=True)
            else:
                log("Starting new training!", slack=True)
                saver.save(sess, checkpoint_fpath, global_step=global_step)

            # initializing feeder
            feeder.start_threads(sess)

            # Training loop
            while not coord.should_stop() and step < args.tacotron_train_steps:
                start_time = time.time()
                step, loss, adversial_loss, opt = sess.run([
                    global_step, model.loss, model.adversial_loss,
                    model.optimize
                ])
                loss -= adversial_loss
                time_window.append(time.time() - start_time)
                loss_window.append(loss)
                message = "Step {:7d} [{:.3f} sec/step, loss={:.5f}, avg_loss={:.5f}, adv_loss={:.5f}]".format(
                    step, time_window.average, loss, loss_window.average,
                    adversial_loss)
                log(message,
                    end="\r",
                    slack=(step % args.checkpoint_interval == 0))
                print(message)

                if loss > 100 or np.isnan(loss):
                    log("Loss exploded to {:.5f} at step {}".format(
                        loss, step))
                    raise Exception("Loss exploded")

                if step % args.summary_interval == 0:
                    log("\nWriting summary at step {}".format(step))
                    summary_writer.add_summary(sess.run(stats), step)

                if step % args.eval_interval == 0:
                    # Run eval and save eval stats
                    log("\nRunning evaluation at step {}".format(step))

                    eval_losses = []
                    before_losses = []
                    after_losses = []
                    stop_token_losses = []
                    linear_losses = []
                    linear_loss = None
                    adversial_losses = []

                    if hparams.predict_linear:
                        for i in tqdm(range(feeder.test_steps)):
                            eloss, before_loss, after_loss, stop_token_loss, linear_loss, mel_p, \
       mel_t, t_len, align, lin_p, lin_t = sess.run(
                                [
                                    eval_model.tower_loss[0], eval_model.tower_before_loss[0],
                                    eval_model.tower_after_loss[0],
                                    eval_model.tower_stop_token_loss[0],
                                    eval_model.tower_linear_loss[0],
                                    eval_model.tower_mel_outputs[0][0],
                                    eval_model.tower_mel_targets[0][0],
                                    eval_model.tower_targets_lengths[0][0],
                                    eval_model.tower_alignments[0][0],
                                    eval_model.tower_linear_outputs[0][0],
                                    eval_model.tower_linear_targets[0][0],
                                ])
                            eval_losses.append(eloss)
                            before_losses.append(before_loss)
                            after_losses.append(after_loss)
                            stop_token_losses.append(stop_token_loss)
                            linear_losses.append(linear_loss)
                        linear_loss = sum(linear_losses) / len(linear_losses)

                        wav = audio.inv_linear_spectrogram(lin_p.T, hparams)
                        audio.save_wav(
                            wav,
                            os.path.join(
                                eval_wav_dir,
                                "step-{}-eval-wave-from-linear.wav".format(
                                    step)),
                            sr=hparams.sample_rate)

                    else:
                        for i in tqdm(range(feeder.test_steps)):
                            eloss, before_loss, after_loss, stop_token_loss, adversial_loss, mel_p, mel_t, t_len,\
       align = sess.run(
                                [
                                    eval_model.tower_loss[0], eval_model.tower_before_loss[0],
                                    eval_model.tower_after_loss[0],
                                    eval_model.tower_stop_token_loss[0],
                                    eval_model.tower_adversial_loss[0],
                                    eval_model.tower_mel_outputs[0][0],
                                    eval_model.tower_mel_targets[0][0],
                                    eval_model.tower_targets_lengths[0][0],
                                    eval_model.tower_alignments[0][0]
                                ])
                            eval_losses.append(eloss)
                            before_losses.append(before_loss)
                            after_losses.append(after_loss)
                            stop_token_losses.append(stop_token_loss)
                            adversial_losses.append(adversial_loss)

                    eval_loss = sum(eval_losses) / len(eval_losses)
                    before_loss = sum(before_losses) / len(before_losses)
                    after_loss = sum(after_losses) / len(after_losses)
                    stop_token_loss = sum(stop_token_losses) / len(
                        stop_token_losses)
                    adversial_loss = sum(adversial_losses) / len(
                        adversial_losses)

                    log("Saving eval log to {}..".format(eval_dir))
                    # Save some log to monitor model improvement on same unseen sequence
                    wav = audio.inv_mel_spectrogram(mel_p.T, hparams)
                    audio.save_wav(
                        wav,
                        os.path.join(
                            eval_wav_dir,
                            "step-{}-eval-wave-from-mel.wav".format(step)),
                        sr=hparams.sample_rate)

                    plot.plot_alignment(
                        align,
                        os.path.join(eval_plot_dir,
                                     "step-{}-eval-align.png".format(step)),
                        title="{}, {}, step={}, loss={:.5f}".format(
                            "Tacotron", time_string(), step, eval_loss),
                        max_len=t_len // hparams.outputs_per_step)
                    plot.plot_spectrogram(
                        mel_p,
                        os.path.join(
                            eval_plot_dir, "step-{"
                            "}-eval-mel-spectrogram.png".format(step)),
                        title="{}, {}, step={}, loss={:.5f}".format(
                            "Tacotron", time_string(), step, eval_loss),
                        target_spectrogram=mel_t,
                        max_len=t_len)

                    if hparams.predict_linear:
                        plot.plot_spectrogram(
                            lin_p,
                            os.path.join(
                                eval_plot_dir,
                                "step-{}-eval-linear-spectrogram.png".format(
                                    step)),
                            title="{}, {}, step={}, loss={:.5f}".format(
                                "Tacotron", time_string(), step, eval_loss),
                            target_spectrogram=lin_t,
                            max_len=t_len,
                            auto_aspect=True)

                    log("Eval loss for global step {}: {:.3f}".format(
                        step, eval_loss))
                    log("Writing eval summary!")
                    add_eval_stats(summary_writer, step, linear_loss,
                                   before_loss, after_loss, stop_token_loss,
                                   adversial_loss, eval_loss)

                if step % args.checkpoint_interval == 0 or step == args.tacotron_train_steps or \
                        step == 300:
                    # Save model and current global step
                    saver.save(sess, checkpoint_fpath, global_step=global_step)

                    log("\nSaving alignment, Mel-Spectrograms and griffin-lim inverted waveform.."
                        )
                    input_seq, mel_prediction, alignment, target, target_length = sess.run(
                        [
                            model.tower_inputs[0][0],
                            model.tower_mel_outputs[0][0],
                            model.tower_alignments[0][0],
                            model.tower_mel_targets[0][0],
                            model.tower_targets_lengths[0][0],
                        ])

                    # save predicted mel spectrogram to disk (debug)
                    mel_filename = "mel-prediction-step-{}.npy".format(step)
                    np.save(os.path.join(mel_dir, mel_filename),
                            mel_prediction.T,
                            allow_pickle=False)

                    # save griffin lim inverted wav for debug (mel -> wav)
                    wav = audio.inv_mel_spectrogram(mel_prediction.T, hparams)
                    audio.save_wav(
                        wav,
                        os.path.join(wav_dir,
                                     "step-{}-wave-from-mel.wav".format(step)),
                        sr=hparams.sample_rate)

                    # save alignment plot to disk (control purposes)
                    plot.plot_alignment(
                        alignment,
                        os.path.join(plot_dir,
                                     "step-{}-align.png".format(step)),
                        title="{}, {}, step={}, loss={:.5f}".format(
                            "Tacotron", time_string(), step, loss),
                        max_len=target_length // hparams.outputs_per_step)
                    # save real and predicted mel-spectrogram plot to disk (control purposes)
                    plot.plot_spectrogram(
                        mel_prediction,
                        os.path.join(
                            plot_dir,
                            "step-{}-mel-spectrogram.png".format(step)),
                        title="{}, {}, step={}, loss={:.5f}".format(
                            "Tacotron", time_string(), step, loss),
                        target_spectrogram=target,
                        max_len=target_length)
                    #log("Input at step {}: {}".format(step, sequence_to_text(input_seq)))

                if step % args.embedding_interval == 0 or step == args.tacotron_train_steps or step == 1:
                    # Get current checkpoint state
                    checkpoint_state = tf.train.get_checkpoint_state(save_dir)

                    # Update Projector
                    #log("\nSaving Model Character Embeddings visualization..")
                    #add_embedding_stats(summary_writer, [model.embedding_table.name],
                    #                    [char_embedding_meta],
                    #                    checkpoint_state.model_checkpoint_path)
                    #log("Tacotron Character embeddings have been updated on tensorboard!")

            log("Tacotron training complete after {} global steps!".format(
                args.tacotron_train_steps),
                slack=True)
            return save_dir

        except Exception as e:
            log("Exiting due to exception: {}".format(e), slack=True)
            traceback.print_exc()
            coord.request_stop(e)
コード例 #6
0
def run_synthesis(in_dir, out_dir, model_dir, hparams):
    # This generates ground truth-aligned mels for vocoder training
    synth_dir = Path(out_dir).joinpath("mels_gta")
    synth_dir.mkdir(exist_ok=True)
    print(hparams_debug_string(hparams))

    # Check for GPU
    if torch.cuda.is_available():
        device = torch.device("cuda")
        if hparams.synthesis_batch_size % torch.cuda.device_count() != 0:
            raise ValueError("`hparams.synthesis_batch_size` must be evenly divisible by n_gpus!")
    else:
        device = torch.device("cpu")
    print("Synthesizer using device:", device)

    # Instantiate Tacotron model
    model = Tacotron(embed_dims=hparams.tts_embed_dims,
                     num_chars=len(symbols),
                     encoder_dims=hparams.tts_encoder_dims,
                     decoder_dims=hparams.tts_decoder_dims,
                     n_mels=hparams.num_mels,
                     fft_bins=hparams.num_mels,
                     postnet_dims=hparams.tts_postnet_dims,
                     encoder_K=hparams.tts_encoder_K,
                     lstm_dims=hparams.tts_lstm_dims,
                     postnet_K=hparams.tts_postnet_K,
                     num_highways=hparams.tts_num_highways,
                     dropout=0., # Use zero dropout for gta mels
                     stop_threshold=hparams.tts_stop_threshold,
                     speaker_embedding_size=hparams.speaker_embedding_size).to(device)

    # Load the weights
    model_dir = Path(model_dir)
    model_fpath = model_dir.joinpath(model_dir.stem).with_suffix(".pt")
    print("\nLoading weights at %s" % model_fpath)
    model.load(model_fpath)
    print("Tacotron weights loaded from step %d" % model.step)

    # Synthesize using same reduction factor as the model is currently trained
    r = np.int32(model.r)

    # Set model to eval mode (disable gradient and zoneout)
    model.eval()

    # Initialize the dataset
    in_dir = Path(in_dir)
    metadata_fpath = in_dir.joinpath("train.txt")
    mel_dir = in_dir.joinpath("mels")
    embed_dir = in_dir.joinpath("embeds")

    dataset = SynthesizerDataset(metadata_fpath, mel_dir, embed_dir, hparams)
    data_loader = DataLoader(dataset,
                             collate_fn=lambda batch: collate_synthesizer(batch, r,hparams),
                             batch_size=hparams.synthesis_batch_size,
                             num_workers=0, #Having an error(Can't pickle local object 'run_synthesis.<locals>.<lambda>') when training in Windows unless you set num_workers=0
                             shuffle=False,
                             pin_memory=True)

    # Generate GTA mels
    meta_out_fpath = Path(out_dir).joinpath("synthesized.txt")
    with open(meta_out_fpath, "w") as file:
        for i, (texts, mels, embeds, idx) in tqdm(enumerate(data_loader), total=len(data_loader)):
            texts = texts.to(device)
            mels = mels.to(device)
            embeds = embeds.to(device)

            # Parallelize model onto GPUS using workaround due to python bug
            if device.type == "cuda" and torch.cuda.device_count() > 1:
                _, mels_out,_,_ = data_parallel_workaround(model, texts, mels, embeds)
            else:
                _,mels_out, _,_ = model(texts, mels, embeds)

            for j, k in enumerate(idx):
                # Note: outputs mel-spectrogram files and target ones have same names, just different folders
                mel_filename = Path(synth_dir).joinpath(dataset.metadata[k][1])
                mel_out = mels_out[j].detach().cpu().numpy().T

                # Use the length of the ground truth mel to remove padding from the generated mels
                mel_out = mel_out[:int(dataset.metadata[k][4])]

                # Write the spectrogram to disk
                np.save(mel_filename, mel_out, allow_pickle=False)

                # Write metadata into the synthesized file
                file.write("|".join(dataset.metadata[k]))
コード例 #7
0
def train(log_dir, args, hparams):
    save_dir = os.path.join(log_dir, "taco_pretrained")
    plot_dir = os.path.join(log_dir, "plots")
    wav_dir = os.path.join(log_dir, "wavs")
    mel_dir = os.path.join(log_dir, "mel-spectrograms")
    eval_dir = os.path.join(log_dir, "eval-dir")
    eval_plot_dir = os.path.join(eval_dir, "plots")
    eval_wav_dir = os.path.join(eval_dir, "wavs")
    tensorboard_dir = os.path.join(log_dir, "tacotron_events")
    os.makedirs(save_dir, exist_ok=True)
    os.makedirs(plot_dir, exist_ok=True)
    os.makedirs(wav_dir, exist_ok=True)
    os.makedirs(mel_dir, exist_ok=True)
    os.makedirs(eval_dir, exist_ok=True)
    os.makedirs(eval_plot_dir, exist_ok=True)
    os.makedirs(eval_wav_dir, exist_ok=True)
    os.makedirs(tensorboard_dir, exist_ok=True)

    checkpoint_fpath = os.path.join(save_dir, "tacotron_model.ckpt")

    log("Checkpoint path: {}".format(checkpoint_fpath))
    log("Using model: Tacotron")
    log(hparams_debug_string())

    # Start by setting a seed for repeatability
    tf.set_random_seed(hparams.tacotron_random_seed)

    # Set up data feeder
    coord = tf.train.Coordinator()
    with tf.variable_scope("datafeeder") as scope:
        feeder = Feeder(coord, hparams)

    # Set up model:
    global_step = tf.Variable(0, name="global_step", trainable=False)
    model, stats = model_train_mode(args, feeder, hparams, global_step)
    #eval_model = model_test_mode(args, feeder, hparams, global_step)

    # Book keeping
    step = 0
    time_window = ValueWindow(100)
    loss_window = ValueWindow(100)
    saver = tf.train.Saver(max_to_keep=2)

    log("Tacotron training set to a maximum of {} steps".format(
        args.tacotron_train_steps))

    # Memory allocation on the GPU as needed
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.allow_soft_placement = True

    # Train
    with tf.Session(config=config) as sess:
        try:
            summary_writer = tf.summary.FileWriter(tensorboard_dir, sess.graph)

            sess.run(tf.global_variables_initializer())

            # saved model restoring
            if args.restore:
                # Restore saved model if the user requested it, default = True
                try:
                    checkpoint_state = tf.train.get_checkpoint_state(save_dir)

                    if checkpoint_state and checkpoint_state.model_checkpoint_path:
                        log("Loading checkpoint {}".format(
                            checkpoint_state.model_checkpoint_path),
                            slack=True)
                        saver.restore(sess,
                                      checkpoint_state.model_checkpoint_path)

                    else:
                        log("No model to load at {}".format(save_dir),
                            slack=True)
                        saver.save(sess,
                                   checkpoint_fpath,
                                   global_step=global_step)

                except tf.errors.OutOfRangeError as e:
                    log("Cannot restore checkpoint: {}".format(e), slack=True)
            else:
                log("Starting new training!", slack=True)
                saver.save(sess, checkpoint_fpath, global_step=global_step)

            # initializing feeder
            feeder.start_threads(sess)
            print("Feeder is intialized and model is ready to train.......")

            # Training loop
            while not coord.should_stop() and step < args.tacotron_train_steps:
                start_time = time.time()
                step, loss, opt = sess.run(
                    [global_step, model.loss, model.optimize])
                time_window.append(time.time() - start_time)
                loss_window.append(loss)
                message = "Step {:7d} [{:.3f} sec/step, loss={:.5f}, avg_loss={:.5f}]".format(
                    step, time_window.average, loss, loss_window.average)
                log(message,
                    end="\r",
                    slack=(step % args.checkpoint_interval == 0))
                print(message)

                if loss > 100 or np.isnan(loss):
                    log("Loss exploded to {:.5f} at step {}".format(
                        loss, step))
                    raise Exception("Loss exploded")

                if step % args.summary_interval == 0:
                    log("\nWriting summary at step {}".format(step))
                    summary_writer.add_summary(sess.run(stats), step)

                if step % args.eval_interval == 0:
                    pass


                if step % args.checkpoint_interval == 0 or step == args.tacotron_train_steps or \
                        step == 300:
                    # Save model and current global step
                    saver.save(sess, checkpoint_fpath, global_step=global_step)

                    log("\nSaving alignment, Mel-Spectrograms and griffin-lim inverted waveform.."
                        )
                    input_seq, mel_prediction, alignment, target, target_length = sess.run(
                        [
                            model.tower_inputs[0][0],
                            model.tower_mel_outputs[0][0],
                            model.tower_alignments[0][0],
                            model.tower_mel_targets[0][0],
                            model.tower_targets_lengths[0][0],
                        ])

                    # save predicted mel spectrogram to disk (debug)
                    mel_filename = "mel-prediction-step-{}.npy".format(step)
                    np.save(os.path.join(mel_dir, mel_filename),
                            mel_prediction.T,
                            allow_pickle=False)

                    # save griffin lim inverted wav for debug (mel -> wav)
                    wav = audio.inv_mel_spectrogram(mel_prediction.T, hparams)
                    audio.save_wav(
                        wav,
                        os.path.join(wav_dir,
                                     "step-{}-wave-from-mel.wav".format(step)),
                        sr=hparams.sample_rate)

                    # save alignment plot to disk (control purposes)
                    plot.plot_alignment(
                        alignment,
                        os.path.join(plot_dir,
                                     "step-{}-align.png".format(step)),
                        title="{}, {}, step={}, loss={:.5f}".format(
                            "Tacotron", time_string(), step, loss),
                        max_len=target_length // hparams.outputs_per_step)
                    # save real and predicted mel-spectrogram plot to disk (control purposes)
                    plot.plot_spectrogram(
                        mel_prediction,
                        os.path.join(
                            plot_dir,
                            "step-{}-mel-spectrogram.png".format(step)),
                        title="{}, {}, step={}, loss={:.5f}".format(
                            "Tacotron", time_string(), step, loss),
                        target_spectrogram=target,
                        max_len=target_length)

                if step % args.embedding_interval == 0 or step == args.tacotron_train_steps or step == 1:
                    # Get current checkpoint state
                    checkpoint_state = tf.train.get_checkpoint_state(save_dir)

            log("Tacotron training complete after {} global steps!".format(
                args.tacotron_train_steps),
                slack=True)
            return save_dir

        except Exception as e:
            log("Exiting due to exception: {}".format(e), slack=True)
            traceback.print_exc()
            coord.request_stop(e)