Example #1
0
def main():
    print()
    parser = argparse.ArgumentParser()

    parser.add_argument('--log_dir', default=hp.logdir)
    parser.add_argument('--log_name', default=hp.logname)
    parser.add_argument('--sample_dir', default=hp.sampledir)
    parser.add_argument('--data_paths', default=hp.data)
    parser.add_argument('--load_path', default=None)
    parser.add_argument('--load_converter', default=None)
    parser.add_argument('--deltree', default=False)

    parser.add_argument('--summary_interval',
                        type=int,
                        default=hp.summary_interval)
    parser.add_argument('--test_interval', type=int, default=hp.test_interval)
    parser.add_argument('--checkpoint_interval',
                        type=int,
                        default=hp.checkpoint_interval)
    parser.add_argument('--num_iterations',
                        type=int,
                        default=hp.num_iterations)

    parser.add_argument('--debug', type=bool, default=False)

    config = parser.parse_args()

    config.log_dir = config.log_dir + '/' + config.log_name
    if not os.path.exists(config.log_dir):
        os.makedirs(config.log_dir)
    elif config.deltree:
        for the_file in os.listdir(config.log_dir):
            file_path = os.path.join(config.log_dir, the_file)
            os.unlink(file_path)
    log_path = os.path.join(config.log_dir + '/', 'train.log')
    infolog.init(log_path, "log")
    checkpoint_path = os.path.join(config.log_dir, 'model.ckpt')

    if (hp.test_only == 0):
        g = Graph(config=config, training=True, train_form=hp.train_form)
        print("Training Graph loaded")
    if hp.test_graph or (hp.test_only > 0):
        g2 = Graph(config=config, training=False, train_form=hp.train_form)
        print("Testing Graph loaded")
        if config.load_converter or (hp.test_only > 0):
            g_conv = Graph(config=config,
                           training=False,
                           train_form='Converter')
            print("Converter Graph loaded")
    if (hp.test_only == 0):
        with g.graph.as_default():
            sv = tf.train.Supervisor(logdir=config.log_dir)
            with sv.managed_session() as sess:

                #sess = tf_debug.LocalCLIDebugWrapperSession(sess)
                if config.load_path:
                    # Restore from a checkpoint if the user requested it.
                    infolog.log('Resuming from checkpoint: %s ' %
                                (tf.train.latest_checkpoint(config.log_dir)),
                                slack=True)
                    sv.saver.restore(
                        sess, tf.train.latest_checkpoint(config.log_dir))
                else:
                    infolog.log('Starting new training', slack=True)

                summary_writer = tf.summary.FileWriter(config.log_dir,
                                                       sess.graph)

                for epoch in range(1, 100000000):
                    if sv.should_stop(): break
                    losses = [0, 0, 0, 0]
                    #for step in tqdm(range(1)):
                    for step in tqdm(range(g.num_batch)):
                        #for step in range(g.num_batch):
                        if hp.train_form == 'Both':
                            if hp.include_dones:
                                gs, merged, loss, loss1, loss2, loss3, _ = sess.run(
                                    [
                                        g.global_step, g.merged, g.loss,
                                        g.loss1, g.loss2, g.loss3, g.train_op
                                    ])
                                loss_one = [loss, loss1, loss2, loss3]
                            else:
                                gs, merged, loss, loss1, loss3, _ = sess.run([
                                    g.global_step, g.merged, g.loss, g.loss1,
                                    g.loss3, g.train_op
                                ])
                                loss_one = [loss, loss1, loss3, 0]
                        elif hp.train_form == 'Encoder':
                            if hp.include_dones:
                                gs, merged, loss, loss1, loss2, _ = sess.run([
                                    g.global_step, g.merged, g.loss, g.loss1,
                                    g.loss2, g.train_op
                                ])
                                loss_one = [loss, loss1, loss2, 0]
                            else:
                                gs, merged, loss, _ = sess.run([
                                    g.global_step, g.merged, g.loss, g.train_op
                                ])
                                loss_one = [loss, 0, 0, 0]
                        else:
                            gs, merged, loss, _ = sess.run(
                                [g.global_step, g.merged, g.loss, g.train_op])
                            loss_one = [loss, 0, 0, 0]

                        losses = [x + y for x, y in zip(losses, loss_one)]

                    losses = [x / g.num_batch for x in losses]
                    print(
                        "###############################################################################"
                    )
                    if hp.train_form == 'Both':
                        if hp.include_dones:
                            infolog.log(
                                "Global Step %d (%04d): Loss = %.8f Loss1 = %.8f Loss2 = %.8f Loss3 = %.8f"
                                % (epoch, gs, losses[0], losses[1], losses[2],
                                   losses[3]))
                        else:
                            infolog.log(
                                "Global Step %d (%04d): Loss = %.8f Loss1 = %.8f Loss3 = %.8f"
                                % (epoch, gs, losses[0], losses[1], losses[2]))
                    elif hp.train_form == 'Encoder':
                        if hp.include_dones:
                            infolog.log(
                                "Global Step %d (%04d): Loss = %.8f Loss1 = %.8f Loss2 = %.8f"
                                % (epoch, gs, losses[0], losses[1], losses[2]))
                        else:
                            infolog.log("Global Step %d (%04d): Loss = %.8f" %
                                        (epoch, gs, losses[0]))
                    else:
                        infolog.log("Global Step %d (%04d): Loss = %.8f" %
                                    (epoch, gs, losses[0]))
                    print(
                        "###############################################################################"
                    )

                    if epoch % config.summary_interval == 0:
                        infolog.log('Saving summary')
                        summary_writer.add_summary(merged, gs)
                        if hp.train_form == 'Both':
                            if hp.include_dones:
                                origx, Kmel_out, Ky1, Kdone_out, Ky2, Kmag_out, Ky3 = sess.run(
                                    [
                                        g.origx, g.mel_output, g.y1,
                                        g.done_output, g.y2, g.mag_output, g.y3
                                    ])
                                plot_losses(config, Kmel_out, Ky1, Kdone_out,
                                            Ky2, Kmag_out, Ky3, gs)
                            else:
                                origx, Kmel_out, Ky1, Kmag_out, Ky3 = sess.run(
                                    [
                                        g.origx, g.mel_output, g.y1,
                                        g.mag_output, g.y3
                                    ])
                                plot_losses(config, Kmel_out, Ky1, None, None,
                                            Kmag_out, Ky3, gs)
                        elif hp.train_form == 'Encoder':
                            if hp.include_dones:
                                origx, Kmel_out, Ky1, Kdone_out, Ky2 = sess.run(
                                    [
                                        g.origx, g.mel_output, g.y1,
                                        g.done_output, g.y2
                                    ])
                                plot_losses(config, Kmel_out, Ky1, Kdone_out,
                                            Ky2, None, None, gs)
                            else:
                                origx, Kmel_out, Ky1 = sess.run(
                                    [g.origx, g.mel_output, g.y1])
                                plot_losses(config, Kmel_out, Ky1, None, None,
                                            None, None, gs)
                        else:
                            origx, Kmag_out, Ky3 = sess.run(
                                [g.origx, g.mag_output, g.y3])
                            plot_losses(config, None, None, None, None,
                                        Kmag_out, Ky3, gs)

                    if epoch % config.checkpoint_interval == 0:
                        infolog.log('Saving checkpoint to: %s-%d' %
                                    (checkpoint_path, gs))
                        sv.saver.save(sess, checkpoint_path, global_step=gs)

                    if hp.test_graph and hp.train_form != 'Converter':
                        if epoch % config.test_interval == 0:
                            infolog.log('Saving audio')
                            origx = sess.run([g.origx])
                            if not config.load_converter:
                                wavs = synthesize.synthesize_part(
                                    g2, config, gs, origx, None)
                            else:
                                wavs = synthesize.synthesize_part(
                                    g2, config, gs, origx, g_conv)
                            plot_wavs(config, wavs, gs)

                    # break
                    if gs > config.num_iterations: break
    else:
        infolog.log('Saving audio')
        gT = GraphTest(config=config)
        with gT.graph.as_default():
            svT = tf.train.Supervisor(logdir=config.log_dir)
            with svT.managed_session() as sessT:
                origx = sessT.run([gT.origx])
            if not config.load_converter:
                wavs = synthesize.synthesize_part(g2, config, 0, origx, None)
            else:
                wavs = synthesize.synthesize_part(g2, config, 0, origx, g_conv)
            plot_wavs(config, wavs, 0)

    print("Done")
Example #2
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('--log_dir', default=hp.logdir)
    parser.add_argument('--log_name', default=hp.logname)
    parser.add_argument('--sample_dir', default=hp.sampledir)
    parser.add_argument('--data_paths', default=hp.data)
    parser.add_argument('--load_path', default=None)
    parser.add_argument('--initialize_path', default=None)
    parser.add_argument('--deltree', default=False)

    parser.add_argument('--summary_interval',
                        type=int,
                        default=hp.summary_interval)
    parser.add_argument('--test_interval', type=int, default=hp.test_interval)
    parser.add_argument('--checkpoint_interval',
                        type=int,
                        default=hp.checkpoint_interval)
    parser.add_argument('--num_iterations',
                        type=int,
                        default=hp.num_iterations)

    parser.add_argument('--debug', type=bool, default=False)

    config = parser.parse_args()
    config.log_dir = config.log_dir + '/' + config.log_name
    if not os.path.exists(config.log_dir):
        os.makedirs(config.log_dir)
    elif config.deltree:
        for the_file in os.listdir(config.log_dir):
            file_path = os.path.join(config.log_dir, the_file)
            os.unlink(file_path)
    log_path = os.path.join(config.log_dir + '/', 'train.log')
    infolog.init(log_path, "log")
    checkpoint_path = os.path.join(config.log_dir, 'model.ckpt')
    g = Graph(config=config)
    print("Training Graph loaded")
    g2 = Graph(config=config, training=False)
    print("Testing Graph loaded")
    with g.graph.as_default():
        sv = tf.train.Supervisor(logdir=config.log_dir)
        with sv.managed_session() as sess:

            #sess = tf_debug.LocalCLIDebugWrapperSession(sess)
            if config.load_path:
                # Restore from a checkpoint if the user requested it.
                tf.reset_default_graph()
                restore_path = get_most_recent_checkpoint(config.load_path)
                sv.saver.restore(sess, restore_path)
                infolog.log('Resuming from checkpoint: %s ' % (restore_path),
                            slack=True)
            elif config.initialize_path:
                restore_path = get_most_recent_checkpoint(
                    config.initialize_path)
                sv.saver.restore(sess, restore_path)
                infolog.log('Initialized from checkpoint: %s ' %
                            (restore_path),
                            slack=True)
            else:
                infolog.log('Starting new training', slack=True)

            summary_writer = tf.summary.FileWriter(config.log_dir, sess.graph)

            for epoch in range(1, 100000000):
                if sv.should_stop(): break
                losses = [0, 0, 0, 0]
                #losses = [0,0,0]
                for step in tqdm(range(g.num_batch)):
                    #for step in range(g.num_batch):
                    #gs,merged,loss,loss1,loss3,alginm,_ = sess.run([g.global_step,g.merged,g.loss,g.loss1,g.loss3, g.alignments_li,g.train_op])
                    gs, merged, loss, loss1, loss2, loss3, alginm, _ = sess.run(
                        [
                            g.global_step, g.merged, g.loss, g.loss1, g.loss2,
                            g.loss3, g.alignments_li, g.train_op
                        ])
                    #infolog.log("Step %04d: Loss = %.8f Loss1 = %.8f Loss2 = %.8f Loss3 = %.8f" %(gs,loss,loss1,loss2,loss3))
                    #infolog.log("Step %04d: Loss = %.8f Loss1 = %.8f Loss3 = %.8f" %(gs,loss,loss1,loss3))
                    #loss_one = [loss,loss1,loss3]
                    loss_one = [loss, loss1, loss2, loss3]
                    losses = [x + y for x, y in zip(losses, loss_one)]

                losses = [x / g.num_batch for x in losses]
                print(
                    "###############################################################################"
                )
                infolog.log(
                    "Global Step %d (%04d): Loss = %.8f Loss1 = %.8f Loss2 = %.8f Loss3 = %.8f"
                    % (epoch, gs, losses[0], losses[1], losses[2], losses[3]))
                #infolog.log("Global Step %d (%04d): Loss = %.8f Loss1 = %.8f Loss3 = %.8f" %(epoch,gs,losses[0],losses[1],losses[2]))
                print(
                    "###############################################################################"
                )

                if epoch % config.summary_interval == 0:
                    infolog.log('Saving summary')
                    summary_writer.add_summary(merged, gs)
                    origx, Kmel_out, Ky1, Kdone_out, Ky2, Kmag_out, Ky3 = sess.run(
                        [
                            g.origx, g.mel_output, g.y1, g.done_output, g.y2,
                            g.mag_output, g.y3
                        ])
                    #origx, Kmel_out,Ky1,Kmag_out,Ky3 = sess.run([g.origx, g.mel_output,g.y1,g.mag_output,g.y3])
                    plot_losses(config, Kmel_out, Ky1, Kdone_out, Ky2,
                                Kmag_out, Ky3, gs)
                    #plot_losses(config,Kmel_out,Ky1,Kmag_out,Ky3,gs)

                if epoch % config.checkpoint_interval == 0:
                    infolog.log('Saving checkpoint to: %s-%d' %
                                (checkpoint_path, gs))
                    sv.saver.save(sess, checkpoint_path, global_step=gs)

                if epoch % config.test_interval == 0:
                    infolog.log('Saving audio')
                    origx, Kmel_out, Ky1, Kdone_out, Ky2, Kmag_out, Ky3 = sess.run(
                        [
                            g.origx, g.mel_output, g.y1, g.done_output, g.y2,
                            g.mag_output, g.y3
                        ])
                    #origx, Kmel_out,Ky1,Kmag_out,Ky3 = sess.run([g.origx, g.mel_output,g.y1,g.mag_output,g.y3])
                    wavs = synthesize.synthesize_part(g2, config, gs, origx)
                    plot_wavs(config, wavs, gs)

                # break
                if gs > config.num_iterations: break

    print("Done")