def train(): # input_audio = tf.placeholder( # tf.float32, shape=[None, 12, 35, 1], name='input_audio') # input_face = tf.placeholder( # tf.float32, shape=[None, 112, 112, 3], name='input_face') '''use tfrecords''' face_batch, audio_batch, identity5_batch = read_and_decode( args.tfrecords, BATCH_SIZE, 5) face_batch = tf.cast(face_batch, dtype=tf.float32) audio_batch = tf.cast(audio_batch, dtype=tf.float32) identity5_batch = tf.cast(identity5_batch, dtype=tf.float32) speech2vid = speech2vid_inference.Speech2Vid() prediction = speech2vid.inference(audio_batch, identity5_batch, BATCH_SIZE) train_loss = speech2vid.compute_loss(prediction, face_batch) global_step = tf.Variable(0, trainable=False) train_op = tf.train.GradientDescentOptimizer( LEARNING_RATE_BASE).minimize(train_loss, global_step=global_step) summary_op = tf.summary.merge_all() sess = tf.Session() train_writer = tf.summary.FileWriter(LOG_SAVE_PATH, sess.graph) saver = tf.train.Saver() sess.run(tf.global_variables_initializer()) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: for step in np.arange(TRAINING_STEPS): if coord.should_stop(): break _, training_loss, step = sess.run( [train_op, train_loss, global_step]) if step % 50 == 0: print 'Step {}, train loss = {}'.format(step, train_loss) summary_str = sess.run(summary_op) train_writer.add_summary(summary_str, step) if step % 2000 == 0 or (step+1) == TRAINING_STEPS: saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step) print 'Model {} saved'.format(step) except tf.errors.OutOfRangeError: print 'Done training, epoch limit reached' finally: coord.request_stop() coord.join(threads) sess.close() '''use placeholder'''
def train(): '''use tfrecords''' face_batch, audio_batch, identity5_batch = read_and_decode( args.tfrecords, BATCH_SIZE, args.idnum) face_batch = tf.cast(face_batch, dtype=tf.float32) audio_batch = tf.cast(audio_batch, dtype=tf.float32) identity5_batch = tf.cast(identity5_batch, dtype=tf.float32) speech2vid = speech2vid_inference_finetune.Speech2Vid( USE_AUDIO, USE_BN, USE_LIP, USE_DECODER, USE_FACEFC, USE_AUDIOFC, USE_FACE, USE_XAVIER, AVOID_LAYERS_LIST) audio_encoder_output = speech2vid.audio_encoder(audio_batch) identity_encoder_output, x7_face, x4_face, x3_face = speech2vid.identity_encoder( identity5_batch, args.idnum) prediction = speech2vid.image_decoder(audio_encoder_output, identity_encoder_output, x7_face, x4_face, x3_face) # prediction = speech2vid.inference(audio_batch, identity5_batch) train_loss, loss0, loss1, loss2, loss3, loss4, loss5 = speech2vid.loss_l1( prediction, face_batch) bn_var_list = [ var for var in tf.trainable_variables() if 'bn' in var.op.name ] audio_var_list = [ var for var in tf.trainable_variables() if 'audio' in var.op.name and var not in bn_var_list ] identity_var_list = [ var for var in tf.trainable_variables() if str(var.op.name).split('/')[0][-4:] == 'face' and var not in bn_var_list ] lip_var_list = [ var for var in tf.trainable_variables() if 'face_lip' in var.op.name and var not in bn_var_list ] decoder_var_list = [ var for var in tf.trainable_variables() if var not in audio_var_list + identity_var_list + lip_var_list + bn_var_list ] global_step = tf.Variable(0, trainable=False) '''exponential_decay lr''' # identity_learning_rate = tf.train.exponential_decay( # IDENTITY_LEARNING_RATE_BASE, global_step, 5000, LEARNING_RATE_DECAY, staircase=False) # audio_learning_rate = tf.train.exponential_decay( # AUDIO_LEARNING_RATE_BASE, global_step, 5000, LEARNING_RATE_DECAY, staircase=False) # bn_learning_rate = tf.train.exponential_decay( # BN_LEARNING_RATE_BASE, global_step, 5000, LEARNING_RATE_DECAY, staircase=False) # lip_learning_rate = tf.train.exponential_decay( # LIP_LEARNING_RATE_BASE, global_step, 5000, LEARNING_RATE_DECAY, staircase=False) # decoder_learning_rate = tf.train.exponential_decay( # DECODER_LEARNING_RATE_BASE, global_step, 5000, LEARNING_RATE_DECAY, staircase=False) '''constant lr''' identity_learning_rate = BASIC_LEARNING_RATE * IDENTITY_LEARNING_RATE_BASE audio_learning_rate = BASIC_LEARNING_RATE * AUDIO_LEARNING_RATE_BASE bn_learning_rate = BASIC_LEARNING_RATE * BN_LEARNING_RATE_BASE lip_learning_rate = BASIC_LEARNING_RATE * LIP_LEARNING_RATE_BASE decoder_learning_rate = BASIC_LEARNING_RATE * DECODER_LEARNING_RATE_BASE '''SGD''' # identity_optimizer = tf.train.GradientDescentOptimizer(identity_learning_rate) # audio_optimizer = tf.train.GradientDescentOptimizer(audio_learning_rate) # bn_optimizer = tf.train.GradientDescentOptimizer(bn_learning_rate) # lip_optimizer = tf.train.GradientDescentOptimizer(lip_learning_rate) # decoder_optimizer = tf.train.GradientDescentOptimizer(decoder_learning_rate) '''Momentum''' identity_optimizer = tf.train.MomentumOptimizer(identity_learning_rate, MOMENTUM) audio_optimizer = tf.train.MomentumOptimizer(audio_learning_rate, MOMENTUM) bn_optimizer = tf.train.MomentumOptimizer(bn_learning_rate, MOMENTUM) lip_optimizer = tf.train.MomentumOptimizer(lip_learning_rate, MOMENTUM) decoder_optimizer = tf.train.MomentumOptimizer(decoder_learning_rate, MOMENTUM) '''Seperate learning rate option 1''' identity_train_op = identity_optimizer.minimize(train_loss, global_step=global_step, var_list=identity_var_list) bn_train_op = bn_optimizer.minimize(train_loss, global_step=global_step, var_list=bn_var_list) lip_train_op = lip_optimizer.minimize(train_loss, global_step=global_step, var_list=lip_var_list) decoder_train_op = decoder_optimizer.minimize(train_loss, global_step=global_step, var_list=decoder_var_list) audio_train_op = audio_optimizer.minimize(train_loss, global_step=global_step, var_list=audio_var_list) train_op = tf.group(identity_train_op, audio_train_op, bn_train_op, lip_train_op, decoder_train_op) '''Seperate learning rate option 2''' # grads = tf.gradients(train_loss, identity_var_list+audio_var_list) # identity_grad = grads[:len(identity_var_list)] # audio_grad = grads[len(identity_var_list):] # identity_train_op = identity_optimizer.apply_gradients( # zip(identity_grad, identity_var_list)) # audio_train_op = audio_optimizer.apply_gradients( # zip(audio_grad, audio_var_list)) # train_op = tf.group(identity_train_op, audio_train_op) '''Only one learning rate''' # optimizer = tf.train.GradientDescentOptimizer(identity_learning_rate) # train_op = optimizer.minimize(train_loss, global_step=global_step) saver = tf.train.Saver() tf.summary.scalar("loss", train_loss) tf.summary.image("prediction", prediction) tf.summary.image("face_gt", face_batch) tf.summary.image("audio", audio_batch) summary_op = tf.summary.merge_all() with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) if args.ckpt: ckpt = tf.train.get_checkpoint_state(args.ckpt) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) print '{} loaded'.format(ckpt.model_checkpoint_path) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) train_writer = tf.summary.FileWriter(LOG_SAVE_PATH, sess.graph) try: if not os.path.exists(MODEL_SAVE_PATH): os.makedirs(MODEL_SAVE_PATH) saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step) print 'Model {} saved at {}'.format(0, MODEL_SAVE_PATH) except KeyboardInterrupt: logging.info('Interrupted') coord.request_stop() except Exception as e: traceback.print_exc() coord.request_stop(e) finally: coord.request_stop() coord.join(threads)
def train(): '''use tfrecords''' face_batch, audio_batch, identity5_batch = read_and_decode( args.tfrecords, BATCH_SIZE, 5) face_batch = tf.cast(face_batch, dtype=tf.float32) audio_batch = tf.cast(audio_batch, dtype=tf.float32) identity5_batch = tf.cast(identity5_batch, dtype=tf.float32) # for i in xrange(NUM_GPUS): # with tf.device('/gpu:%d' % i): speech2vid = speech2vid_inference.Speech2Vid() prediction = speech2vid.inference(audio_batch, identity5_batch, BATCH_SIZE) train_loss = speech2vid.compute_loss(prediction, face_batch) global_step = tf.Variable(0, trainable=False) train_op = tf.train.GradientDescentOptimizer(LEARNING_RATE_BASE).minimize( train_loss, global_step=global_step) summary_op = tf.summary.merge_all() # train_writer = tf.summary.FileWriter(LOG_SAVE_PATH, graph) saver = tf.train.Saver() with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: start_time = time.time() for step in np.arange(TRAINING_STEPS): if coord.should_stop(): break _, training_loss, step = sess.run( [train_op, train_loss, global_step]) if step % 500 == 0: end_time = time.time() elapsed_time = end_time - start_time print 'Step: {}, Train loss: {}, Elapsed time: {}'.format( step, training_loss, elapsed_time) summary_str = sess.run(summary_op) # train_writer.add_summary(summary_str, step) start_time = time.time() if step % 10000 == 0 or (step + 1) == TRAINING_STEPS: saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step) print 'Model {} saved'.format(step) except KeyboardInterrupt: logging.info('Interrupted') coord.request_stop() except Exception as e: traceback.print_exc() coord.request_stop(e) finally: saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step) print 'Model {} saved'.format(step) coord.request_stop() coord.join(threads)
# import numpy as np # import cv2 # a = np.load( # '/Users/lls/Documents/face/data/lrw1018/lipread_mp4/MIGHT/test/MIGHT_00001_2.npz') # face_gt = a['face_gt'] # mfcc_gt = a['mfcc_gt'] # identity1 = a['identity1'] # identity5 = a['identity5'] # print face_gt.shape # print mfcc_gt.shape # # cv2.imshow('a', face_gt) # # cv2.waitKey(0) from build_data_utils import read_and_decode import tensorflow as tf face_batch, audio_batch, identity5_batch = read_and_decode( '../../../data/lrw1018/lipread_mp4/MIGHT/test/test.tfrecords', 1, 5) # print face_batch.shape # print audio_batch.shape # print identity5_batch.shape # face_batch = tf.cast(face_batch, dtype=tf.float32) audio_batch = tf.cast(audio_batch, dtype=tf.float32) # identity5_batch = tf.cast(identity5_batch, dtype=tf.float32) # print face_batch.shape print audio_batch.shape # print identity5_batch.shape with tf.Session() as sess: sess.run(audio_batch[:, 1:3, 1:3, :])