Beispiel #1
0
def train():
    # input_audio = tf.placeholder(
    #     tf.float32, shape=[None, 12, 35, 1], name='input_audio')
    # input_face = tf.placeholder(
    #     tf.float32, shape=[None, 112, 112, 3], name='input_face')
    '''use tfrecords'''
    face_batch, audio_batch, identity5_batch = read_and_decode(
        args.tfrecords, BATCH_SIZE, 5)
    face_batch = tf.cast(face_batch, dtype=tf.float32)
    audio_batch = tf.cast(audio_batch, dtype=tf.float32)
    identity5_batch = tf.cast(identity5_batch, dtype=tf.float32)

    speech2vid = speech2vid_inference.Speech2Vid()
    prediction = speech2vid.inference(audio_batch, identity5_batch, BATCH_SIZE)
    train_loss = speech2vid.compute_loss(prediction, face_batch)

    global_step = tf.Variable(0, trainable=False)
    train_op = tf.train.GradientDescentOptimizer(
        LEARNING_RATE_BASE).minimize(train_loss, global_step=global_step)

    summary_op = tf.summary.merge_all()
    sess = tf.Session()
    train_writer = tf.summary.FileWriter(LOG_SAVE_PATH, sess.graph)

    saver = tf.train.Saver()
    sess.run(tf.global_variables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    try:
        for step in np.arange(TRAINING_STEPS):
            if coord.should_stop():
                break
            _, training_loss, step = sess.run(
                [train_op, train_loss, global_step])

            if step % 50 == 0:
                print 'Step {}, train loss = {}'.format(step, train_loss)
                summary_str = sess.run(summary_op)
                train_writer.add_summary(summary_str, step)

            if step % 2000 == 0 or (step+1) == TRAINING_STEPS:
                saver.save(sess, os.path.join(MODEL_SAVE_PATH,
                                              MODEL_NAME), global_step=global_step)
                print 'Model {} saved'.format(step)
    except tf.errors.OutOfRangeError:
        print 'Done training, epoch limit reached'
    finally:
        coord.request_stop()

    coord.join(threads)
    sess.close()

    '''use placeholder'''
Beispiel #2
0
def train():
    '''use tfrecords'''
    face_batch, audio_batch, identity5_batch = read_and_decode(
        args.tfrecords, BATCH_SIZE, args.idnum)
    face_batch = tf.cast(face_batch, dtype=tf.float32)
    audio_batch = tf.cast(audio_batch, dtype=tf.float32)
    identity5_batch = tf.cast(identity5_batch, dtype=tf.float32)

    speech2vid = speech2vid_inference_finetune.Speech2Vid(
        USE_AUDIO, USE_BN, USE_LIP, USE_DECODER, USE_FACEFC, USE_AUDIOFC,
        USE_FACE, USE_XAVIER, AVOID_LAYERS_LIST)
    audio_encoder_output = speech2vid.audio_encoder(audio_batch)
    identity_encoder_output, x7_face, x4_face, x3_face = speech2vid.identity_encoder(
        identity5_batch, args.idnum)
    prediction = speech2vid.image_decoder(audio_encoder_output,
                                          identity_encoder_output, x7_face,
                                          x4_face, x3_face)
    # prediction = speech2vid.inference(audio_batch, identity5_batch)
    train_loss, loss0, loss1, loss2, loss3, loss4, loss5 = speech2vid.loss_l1(
        prediction, face_batch)

    bn_var_list = [
        var for var in tf.trainable_variables() if 'bn' in var.op.name
    ]
    audio_var_list = [
        var for var in tf.trainable_variables()
        if 'audio' in var.op.name and var not in bn_var_list
    ]
    identity_var_list = [
        var for var in tf.trainable_variables()
        if str(var.op.name).split('/')[0][-4:] == 'face'
        and var not in bn_var_list
    ]
    lip_var_list = [
        var for var in tf.trainable_variables()
        if 'face_lip' in var.op.name and var not in bn_var_list
    ]
    decoder_var_list = [
        var for var in tf.trainable_variables() if var not in audio_var_list +
        identity_var_list + lip_var_list + bn_var_list
    ]

    global_step = tf.Variable(0, trainable=False)
    '''exponential_decay lr'''
    # identity_learning_rate = tf.train.exponential_decay(
    #     IDENTITY_LEARNING_RATE_BASE, global_step, 5000, LEARNING_RATE_DECAY, staircase=False)
    # audio_learning_rate = tf.train.exponential_decay(
    #     AUDIO_LEARNING_RATE_BASE, global_step, 5000, LEARNING_RATE_DECAY, staircase=False)
    # bn_learning_rate = tf.train.exponential_decay(
    #     BN_LEARNING_RATE_BASE, global_step, 5000, LEARNING_RATE_DECAY, staircase=False)
    # lip_learning_rate = tf.train.exponential_decay(
    #     LIP_LEARNING_RATE_BASE, global_step, 5000, LEARNING_RATE_DECAY, staircase=False)
    # decoder_learning_rate = tf.train.exponential_decay(
    #     DECODER_LEARNING_RATE_BASE, global_step, 5000, LEARNING_RATE_DECAY, staircase=False)
    '''constant lr'''
    identity_learning_rate = BASIC_LEARNING_RATE * IDENTITY_LEARNING_RATE_BASE
    audio_learning_rate = BASIC_LEARNING_RATE * AUDIO_LEARNING_RATE_BASE
    bn_learning_rate = BASIC_LEARNING_RATE * BN_LEARNING_RATE_BASE
    lip_learning_rate = BASIC_LEARNING_RATE * LIP_LEARNING_RATE_BASE
    decoder_learning_rate = BASIC_LEARNING_RATE * DECODER_LEARNING_RATE_BASE
    '''SGD'''
    # identity_optimizer = tf.train.GradientDescentOptimizer(identity_learning_rate)
    # audio_optimizer = tf.train.GradientDescentOptimizer(audio_learning_rate)
    # bn_optimizer = tf.train.GradientDescentOptimizer(bn_learning_rate)
    # lip_optimizer = tf.train.GradientDescentOptimizer(lip_learning_rate)
    # decoder_optimizer = tf.train.GradientDescentOptimizer(decoder_learning_rate)
    '''Momentum'''
    identity_optimizer = tf.train.MomentumOptimizer(identity_learning_rate,
                                                    MOMENTUM)
    audio_optimizer = tf.train.MomentumOptimizer(audio_learning_rate, MOMENTUM)
    bn_optimizer = tf.train.MomentumOptimizer(bn_learning_rate, MOMENTUM)
    lip_optimizer = tf.train.MomentumOptimizer(lip_learning_rate, MOMENTUM)
    decoder_optimizer = tf.train.MomentumOptimizer(decoder_learning_rate,
                                                   MOMENTUM)
    '''Seperate learning rate option 1'''
    identity_train_op = identity_optimizer.minimize(train_loss,
                                                    global_step=global_step,
                                                    var_list=identity_var_list)
    bn_train_op = bn_optimizer.minimize(train_loss,
                                        global_step=global_step,
                                        var_list=bn_var_list)
    lip_train_op = lip_optimizer.minimize(train_loss,
                                          global_step=global_step,
                                          var_list=lip_var_list)
    decoder_train_op = decoder_optimizer.minimize(train_loss,
                                                  global_step=global_step,
                                                  var_list=decoder_var_list)
    audio_train_op = audio_optimizer.minimize(train_loss,
                                              global_step=global_step,
                                              var_list=audio_var_list)
    train_op = tf.group(identity_train_op, audio_train_op, bn_train_op,
                        lip_train_op, decoder_train_op)
    '''Seperate learning rate option 2'''
    # grads = tf.gradients(train_loss, identity_var_list+audio_var_list)
    # identity_grad = grads[:len(identity_var_list)]
    # audio_grad = grads[len(identity_var_list):]
    # identity_train_op = identity_optimizer.apply_gradients(
    #     zip(identity_grad, identity_var_list))
    # audio_train_op = audio_optimizer.apply_gradients(
    #     zip(audio_grad, audio_var_list))
    # train_op = tf.group(identity_train_op, audio_train_op)
    '''Only one learning rate'''
    # optimizer = tf.train.GradientDescentOptimizer(identity_learning_rate)
    # train_op = optimizer.minimize(train_loss, global_step=global_step)

    saver = tf.train.Saver()

    tf.summary.scalar("loss", train_loss)
    tf.summary.image("prediction", prediction)
    tf.summary.image("face_gt", face_batch)
    tf.summary.image("audio", audio_batch)

    summary_op = tf.summary.merge_all()

    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        if args.ckpt:
            ckpt = tf.train.get_checkpoint_state(args.ckpt)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                print '{} loaded'.format(ckpt.model_checkpoint_path)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        train_writer = tf.summary.FileWriter(LOG_SAVE_PATH, sess.graph)
        try:
            if not os.path.exists(MODEL_SAVE_PATH):
                os.makedirs(MODEL_SAVE_PATH)
            saver.save(sess,
                       os.path.join(MODEL_SAVE_PATH, MODEL_NAME),
                       global_step=global_step)
            print 'Model {} saved at {}'.format(0, MODEL_SAVE_PATH)

        except KeyboardInterrupt:
            logging.info('Interrupted')
            coord.request_stop()
        except Exception as e:
            traceback.print_exc()
            coord.request_stop(e)
        finally:
            coord.request_stop()
            coord.join(threads)
def train():
    '''use tfrecords'''
    face_batch, audio_batch, identity5_batch = read_and_decode(
        args.tfrecords, BATCH_SIZE, 5)
    face_batch = tf.cast(face_batch, dtype=tf.float32)
    audio_batch = tf.cast(audio_batch, dtype=tf.float32)
    identity5_batch = tf.cast(identity5_batch, dtype=tf.float32)

    # for i in xrange(NUM_GPUS):
    #     with tf.device('/gpu:%d' % i):

    speech2vid = speech2vid_inference.Speech2Vid()
    prediction = speech2vid.inference(audio_batch, identity5_batch, BATCH_SIZE)
    train_loss = speech2vid.compute_loss(prediction, face_batch)

    global_step = tf.Variable(0, trainable=False)
    train_op = tf.train.GradientDescentOptimizer(LEARNING_RATE_BASE).minimize(
        train_loss, global_step=global_step)

    summary_op = tf.summary.merge_all()
    # train_writer = tf.summary.FileWriter(LOG_SAVE_PATH, graph)
    saver = tf.train.Saver()

    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            start_time = time.time()
            for step in np.arange(TRAINING_STEPS):
                if coord.should_stop():
                    break

                _, training_loss, step = sess.run(
                    [train_op, train_loss, global_step])

                if step % 500 == 0:
                    end_time = time.time()
                    elapsed_time = end_time - start_time
                    print 'Step: {}, Train loss: {}, Elapsed time: {}'.format(
                        step, training_loss, elapsed_time)
                    summary_str = sess.run(summary_op)
                    # train_writer.add_summary(summary_str, step)
                    start_time = time.time()

                if step % 10000 == 0 or (step + 1) == TRAINING_STEPS:
                    saver.save(sess,
                               os.path.join(MODEL_SAVE_PATH, MODEL_NAME),
                               global_step=global_step)
                    print 'Model {} saved'.format(step)
        except KeyboardInterrupt:
            logging.info('Interrupted')
            coord.request_stop()
        except Exception as e:
            traceback.print_exc()
            coord.request_stop(e)
        finally:
            saver.save(sess,
                       os.path.join(MODEL_SAVE_PATH, MODEL_NAME),
                       global_step=global_step)
            print 'Model {} saved'.format(step)
            coord.request_stop()
            coord.join(threads)
Beispiel #4
0
# import numpy as np
# import cv2
# a = np.load(
#     '/Users/lls/Documents/face/data/lrw1018/lipread_mp4/MIGHT/test/MIGHT_00001_2.npz')
# face_gt = a['face_gt']
# mfcc_gt = a['mfcc_gt']
# identity1 = a['identity1']
# identity5 = a['identity5']
# print face_gt.shape
# print mfcc_gt.shape
# # cv2.imshow('a', face_gt)
# # cv2.waitKey(0)

from build_data_utils import read_and_decode
import tensorflow as tf
face_batch, audio_batch, identity5_batch = read_and_decode(
    '../../../data/lrw1018/lipread_mp4/MIGHT/test/test.tfrecords', 1, 5)
# print face_batch.shape
# print audio_batch.shape
# print identity5_batch.shape
# face_batch = tf.cast(face_batch, dtype=tf.float32)
audio_batch = tf.cast(audio_batch, dtype=tf.float32)
# identity5_batch = tf.cast(identity5_batch, dtype=tf.float32)
# print face_batch.shape
print audio_batch.shape
# print identity5_batch.shape

with tf.Session() as sess:
    sess.run(audio_batch[:, 1:3, 1:3, :])