def train(): '''use tfrecords''' face_batch, audio_batch, identity5_batch, epoch_now = read_and_decode_TFRecordDataset( args.tfrecords, BATCH_SIZE, args.idnum, EPOCH_NUM) speech2vid = speech2vid_inference_finetune.Speech2Vid( USE_AUDIO, USE_BN, USE_LIP, USE_DECODER, USE_FACEFC, USE_AUDIOFC, USE_FACE, USE_XAVIER, AVOID_LAYERS_LIST, args.idnum) audio_encoder_output = speech2vid.audio_encoder(audio_batch) identity_encoder_output, x7_face, x4_face, x3_face = speech2vid.identity_encoder( identity5_batch, args.idnum) prediction = speech2vid.image_decoder( audio_encoder_output, identity_encoder_output, x7_face, x4_face, x3_face) # prediction = speech2vid.inference(audio_batch, identity5_batch) train_loss, loss0, loss1, loss2, loss3, loss4, loss5 = speech2vid.loss_l1( prediction, face_batch) bn_var_list = [var for var in tf.trainable_variables() if 'bn' in var.op.name] audio_var_list = [var for var in tf.trainable_variables() if 'audio' in var.op.name and var not in bn_var_list] identity_var_list = [ var for var in tf.trainable_variables() if str(var.op.name).split('/')[0][-4:] == 'face' and var not in bn_var_list] lip_var_list = [ var for var in tf.trainable_variables() if 'face_lip' in var.op.name and var not in bn_var_list] decoder_var_list = [var for var in tf.trainable_variables( ) if var not in audio_var_list + identity_var_list + lip_var_list + bn_var_list] global_step = tf.Variable(0, trainable=False) '''exponential_decay lr''' # identity_learning_rate = tf.train.exponential_decay( # IDENTITY_LEARNING_RATE_BASE, global_step, 5000, LEARNING_RATE_DECAY, staircase=False) # audio_learning_rate = tf.train.exponential_decay( # AUDIO_LEARNING_RATE_BASE, global_step, 5000, LEARNING_RATE_DECAY, staircase=False) # bn_learning_rate = tf.train.exponential_decay( # BN_LEARNING_RATE_BASE, global_step, 5000, LEARNING_RATE_DECAY, staircase=False) # lip_learning_rate = tf.train.exponential_decay( # LIP_LEARNING_RATE_BASE, global_step, 5000, LEARNING_RATE_DECAY, staircase=False) # decoder_learning_rate = tf.train.exponential_decay( # DECODER_LEARNING_RATE_BASE, global_step, 5000, LEARNING_RATE_DECAY, staircase=False) '''constant lr''' identity_learning_rate = BASIC_LEARNING_RATE*IDENTITY_LEARNING_RATE_BASE audio_learning_rate = BASIC_LEARNING_RATE*AUDIO_LEARNING_RATE_BASE bn_learning_rate = BASIC_LEARNING_RATE*BN_LEARNING_RATE_BASE lip_learning_rate = BASIC_LEARNING_RATE*LIP_LEARNING_RATE_BASE decoder_learning_rate = BASIC_LEARNING_RATE*DECODER_LEARNING_RATE_BASE '''SGD''' # identity_optimizer = tf.train.GradientDescentOptimizer(identity_learning_rate) # audio_optimizer = tf.train.GradientDescentOptimizer(audio_learning_rate) # bn_optimizer = tf.train.GradientDescentOptimizer(bn_learning_rate) # lip_optimizer = tf.train.GradientDescentOptimizer(lip_learning_rate) # decoder_optimizer = tf.train.GradientDescentOptimizer(decoder_learning_rate) '''Momentum''' # identity_optimizer = tf.train.MomentumOptimizer( # identity_learning_rate, MOMENTUM) # audio_optimizer = tf.train.MomentumOptimizer(audio_learning_rate, MOMENTUM) # bn_optimizer = tf.train.MomentumOptimizer(bn_learning_rate, MOMENTUM) # lip_optimizer = tf.train.MomentumOptimizer(lip_learning_rate, MOMENTUM) # decoder_optimizer = tf.train.MomentumOptimizer( # decoder_learning_rate, MOMENTUM) '''Adam''' identity_optimizer = tf.train.AdamOptimizer( learning_rate=identity_learning_rate) audio_optimizer = tf.train.AdamOptimizer(learning_rate=audio_learning_rate) bn_optimizer = tf.train.AdamOptimizer(learning_rate=bn_learning_rate) lip_optimizer = tf.train.AdamOptimizer(learning_rate=lip_learning_rate) decoder_optimizer = tf.train.AdamOptimizer( learning_rate=decoder_learning_rate) '''Seperate learning rate option 1''' identity_train_op = identity_optimizer.minimize( train_loss, global_step=global_step, var_list=identity_var_list) bn_train_op = bn_optimizer.minimize( train_loss, global_step=global_step, var_list=bn_var_list) lip_train_op = lip_optimizer.minimize( train_loss, global_step=global_step, var_list=lip_var_list) decoder_train_op = decoder_optimizer.minimize( train_loss, global_step=global_step, var_list=decoder_var_list) audio_train_op = audio_optimizer.minimize( train_loss, global_step=global_step, var_list=audio_var_list) train_op = tf.group(identity_train_op, audio_train_op, bn_train_op, lip_train_op, decoder_train_op) '''Only train decoder''' # decoder_train_op = decoder_optimizer.minimize( # train_loss, global_step=global_step, var_list=decoder_var_list) # train_op = tf.group(decoder_train_op) '''Seperate learning rate option 2''' # grads = tf.gradients(train_loss, bn_var_list+audio_var_list + # identity_var_list+lip_var_list+decoder_var_list) # bn_grad = grads[:len(bn_var_list)] # audio_grad = grads[len(bn_var_list):len(bn_var_list + audio_var_list)] # identity_grad = grads[len(bn_var_list + audio_var_list) :len(bn_var_list + audio_var_list + identity_var_list)] # lip_grad = grads[len(bn_var_list + audio_var_list + identity_var_list) :len(bn_var_list + audio_var_list + identity_var_list + lip_var_list)] # decoder_grad = grads[len( # bn_var_list + audio_var_list + identity_var_list + lip_var_list):] # identity_train_op = identity_optimizer.apply_gradients( # zip(identity_grad, identity_var_list), global_step=global_step) # bn_train_op = bn_optimizer.apply_gradients( # zip(bn_grad, bn_var_list), global_step=global_step) # lip_train_op = lip_optimizer.apply_gradients( # zip(lip_grad, lip_var_list), global_step=global_step) # decoder_train_op = decoder_optimizer.apply_gradients( # zip(decoder_grad, decoder_var_list), global_step=global_step) # audio_train_op = audio_optimizer.apply_gradients( # zip(audio_grad, audio_var_list), global_step=global_step) # train_op = tf.group(identity_train_op, audio_train_op, # bn_train_op, lip_train_op, decoder_train_op) '''Only one learning rate''' # optimizer = tf.train.GradientDescentOptimizer(identity_learning_rate) # train_op = optimizer.minimize(train_loss, global_step=global_step) saver = tf.train.Saver(max_to_keep=0) tf.summary.scalar("loss", train_loss) tf.summary.image("face_gt", face_batch) tf.summary.image("audio", audio_batch) tf.summary.image("prediction", prediction) summary_op = tf.summary.merge_all() with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) if args.ckpt: ckpt = tf.train.get_checkpoint_state(args.ckpt) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) print '{} loaded'.format(ckpt.model_checkpoint_path) # coord = tf.train.Coordinator() # threads = tf.train.start_queue_runners(sess=sess, coord=coord) train_writer = tf.summary.FileWriter(LOG_SAVE_PATH, sess.graph) early_stop_loss_list = [] try: start_time = time.time() for step in np.arange(TRAINING_STEPS): # if coord.should_stop(): # break # _, training_loss, step, summary, identity_learning_rate_, audio_learning_rate_, bn_learning_rate_, lip_learning_rate_, decoder_learning_rate_, loss0_, loss1_, loss2_, loss3_, loss4_, loss5_, audio_encoder_output_, identity_encoder_output_, x7_face_, base_lr_ = sess.run( # [train_op, train_loss, global_step, summary_op, identity_optimizer._learning_rate, audio_optimizer._learning_rate, bn_optimizer._learning_rate, lip_optimizer._learning_rate, decoder_optimizer._learning_rate, loss0, loss1, loss2, loss3, loss4, loss5, audio_encoder_output, identity_encoder_output, x7_face, BASIC_LEARNING_RATE]) '''When using Adam''' _, training_loss, step, summary, loss0_, loss1_, loss2_, loss3_, loss4_, loss5_, audio_encoder_output_, identity_encoder_output_, x7_face_, base_lr_, epoch_now_ = sess.run( [train_op, train_loss, global_step, summary_op, loss0, loss1, loss2, loss3, loss4, loss5, audio_encoder_output, identity_encoder_output, x7_face, BASIC_LEARNING_RATE, epoch_now]) train_writer.add_summary(summary, step) # print 'x7_face_', np.mean(x7_face_) # print 'audio_encoder_output', np.max(audio_encoder_output_), np.min( # audio_encoder_output_), np.mean(audio_encoder_output_) # print 'identity_encoder_output', np.max(identity_encoder_output_), np.min( # identity_encoder_output_), np.mean(identity_encoder_output_) if step % 1 == 0: end_time = time.time() elapsed_time = end_time - start_time # print '{} Step: {} Total loss: {}\tLoss0: {}\tLoss1: {}\tLoss2: {}\tLoss3: {}\tLoss4: {}\tLoss5: {}\tTime: {}\tBASE Lr: {} ID Lr: {} AU Lr: {} BN Lr: {} LIP Lr: {} DE Lr: {}'.format( # name, step, round(training_loss, 5), round(loss0_, 5), round(loss1_, 5), round(loss2_, 5), round(loss3_, 5), round(loss4_, 5), round(loss5_, 5), round(elapsed_time, 2), round(base_lr_, 10), round(identity_learning_rate_, 10), round(audio_learning_rate_, 10), round(bn_learning_rate_, 10), round(lip_learning_rate_, 10), round(decoder_learning_rate_, 10)) '''When using Adam''' print '{} Adam {} Epoch: {} Step: {} Total loss: {}\tLoss0: {}\tLoss1: {}\tLoss2: {}\tLoss3: {}\tLoss4: {}\tLoss5: {}\tTime: {}\tBASE Lr: {}'.format(datetime.now().strftime("%m/%d %H:%M:%S"), name, epoch_now_[0], step, round(training_loss, 5), round(loss0_, 5), round(loss1_, 5), round(loss2_, 5), round(loss3_, 5), round(loss4_, 5), round(loss5_, 5), round(elapsed_time, 2), round(base_lr_, 10)) start_time = time.time() if step % 1000 == 0 or (step+1) == TRAINING_STEPS: if not os.path.exists(MODEL_SAVE_PATH): os.makedirs(MODEL_SAVE_PATH) saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step) print 'Model {} saved at {}'.format( step, MODEL_SAVE_PATH) '''Early stopping''' # 1. test on validation set, record loss and the minimum loss # 2. if loss on validation set is larger than the minimum loss for 10 times, stop training except KeyboardInterrupt: logging.info('Interrupted') # coord.request_stop() except Exception as e: traceback.print_exc() # coord.request_stop(e) finally: if not os.path.exists(MODEL_SAVE_PATH): os.makedirs(MODEL_SAVE_PATH) saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step) print 'Model {} saved at {}'.format(step, MODEL_SAVE_PATH)
def train(): '''use tfrecords''' face_batch, audio_batch, identity5_batch = read_and_decode( args.tfrecords, BATCH_SIZE, args.idnum) face_batch = tf.cast(face_batch, dtype=tf.float32) audio_batch = tf.cast(audio_batch, dtype=tf.float32) identity5_batch = tf.cast(identity5_batch, dtype=tf.float32) speech2vid = speech2vid_inference_finetune.Speech2Vid( USE_AUDIO, USE_BN, USE_LIP, USE_DECODER, USE_FACEFC, USE_AUDIOFC, USE_FACE, USE_XAVIER, AVOID_LAYERS_LIST) audio_encoder_output = speech2vid.audio_encoder(audio_batch) identity_encoder_output, x7_face, x4_face, x3_face = speech2vid.identity_encoder( identity5_batch, args.idnum) prediction = speech2vid.image_decoder(audio_encoder_output, identity_encoder_output, x7_face, x4_face, x3_face) # prediction = speech2vid.inference(audio_batch, identity5_batch) train_loss, loss0, loss1, loss2, loss3, loss4, loss5 = speech2vid.loss_l1( prediction, face_batch) bn_var_list = [ var for var in tf.trainable_variables() if 'bn' in var.op.name ] audio_var_list = [ var for var in tf.trainable_variables() if 'audio' in var.op.name and var not in bn_var_list ] identity_var_list = [ var for var in tf.trainable_variables() if str(var.op.name).split('/')[0][-4:] == 'face' and var not in bn_var_list ] lip_var_list = [ var for var in tf.trainable_variables() if 'face_lip' in var.op.name and var not in bn_var_list ] decoder_var_list = [ var for var in tf.trainable_variables() if var not in audio_var_list + identity_var_list + lip_var_list + bn_var_list ] global_step = tf.Variable(0, trainable=False) '''exponential_decay lr''' # identity_learning_rate = tf.train.exponential_decay( # IDENTITY_LEARNING_RATE_BASE, global_step, 5000, LEARNING_RATE_DECAY, staircase=False) # audio_learning_rate = tf.train.exponential_decay( # AUDIO_LEARNING_RATE_BASE, global_step, 5000, LEARNING_RATE_DECAY, staircase=False) # bn_learning_rate = tf.train.exponential_decay( # BN_LEARNING_RATE_BASE, global_step, 5000, LEARNING_RATE_DECAY, staircase=False) # lip_learning_rate = tf.train.exponential_decay( # LIP_LEARNING_RATE_BASE, global_step, 5000, LEARNING_RATE_DECAY, staircase=False) # decoder_learning_rate = tf.train.exponential_decay( # DECODER_LEARNING_RATE_BASE, global_step, 5000, LEARNING_RATE_DECAY, staircase=False) '''constant lr''' identity_learning_rate = BASIC_LEARNING_RATE * IDENTITY_LEARNING_RATE_BASE audio_learning_rate = BASIC_LEARNING_RATE * AUDIO_LEARNING_RATE_BASE bn_learning_rate = BASIC_LEARNING_RATE * BN_LEARNING_RATE_BASE lip_learning_rate = BASIC_LEARNING_RATE * LIP_LEARNING_RATE_BASE decoder_learning_rate = BASIC_LEARNING_RATE * DECODER_LEARNING_RATE_BASE '''SGD''' # identity_optimizer = tf.train.GradientDescentOptimizer(identity_learning_rate) # audio_optimizer = tf.train.GradientDescentOptimizer(audio_learning_rate) # bn_optimizer = tf.train.GradientDescentOptimizer(bn_learning_rate) # lip_optimizer = tf.train.GradientDescentOptimizer(lip_learning_rate) # decoder_optimizer = tf.train.GradientDescentOptimizer(decoder_learning_rate) '''Momentum''' identity_optimizer = tf.train.MomentumOptimizer(identity_learning_rate, MOMENTUM) audio_optimizer = tf.train.MomentumOptimizer(audio_learning_rate, MOMENTUM) bn_optimizer = tf.train.MomentumOptimizer(bn_learning_rate, MOMENTUM) lip_optimizer = tf.train.MomentumOptimizer(lip_learning_rate, MOMENTUM) decoder_optimizer = tf.train.MomentumOptimizer(decoder_learning_rate, MOMENTUM) '''Seperate learning rate option 1''' identity_train_op = identity_optimizer.minimize(train_loss, global_step=global_step, var_list=identity_var_list) bn_train_op = bn_optimizer.minimize(train_loss, global_step=global_step, var_list=bn_var_list) lip_train_op = lip_optimizer.minimize(train_loss, global_step=global_step, var_list=lip_var_list) decoder_train_op = decoder_optimizer.minimize(train_loss, global_step=global_step, var_list=decoder_var_list) audio_train_op = audio_optimizer.minimize(train_loss, global_step=global_step, var_list=audio_var_list) train_op = tf.group(identity_train_op, audio_train_op, bn_train_op, lip_train_op, decoder_train_op) '''Seperate learning rate option 2''' # grads = tf.gradients(train_loss, identity_var_list+audio_var_list) # identity_grad = grads[:len(identity_var_list)] # audio_grad = grads[len(identity_var_list):] # identity_train_op = identity_optimizer.apply_gradients( # zip(identity_grad, identity_var_list)) # audio_train_op = audio_optimizer.apply_gradients( # zip(audio_grad, audio_var_list)) # train_op = tf.group(identity_train_op, audio_train_op) '''Only one learning rate''' # optimizer = tf.train.GradientDescentOptimizer(identity_learning_rate) # train_op = optimizer.minimize(train_loss, global_step=global_step) saver = tf.train.Saver() tf.summary.scalar("loss", train_loss) tf.summary.image("prediction", prediction) tf.summary.image("face_gt", face_batch) tf.summary.image("audio", audio_batch) summary_op = tf.summary.merge_all() with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) if args.ckpt: ckpt = tf.train.get_checkpoint_state(args.ckpt) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) print '{} loaded'.format(ckpt.model_checkpoint_path) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) train_writer = tf.summary.FileWriter(LOG_SAVE_PATH, sess.graph) try: if not os.path.exists(MODEL_SAVE_PATH): os.makedirs(MODEL_SAVE_PATH) saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step) print 'Model {} saved at {}'.format(0, MODEL_SAVE_PATH) except KeyboardInterrupt: logging.info('Interrupted') coord.request_stop() except Exception as e: traceback.print_exc() coord.request_stop(e) finally: coord.request_stop() coord.join(threads)
def create_model(): groundtruth_batch, audio_batch, identity_batch, epoch_now = read_and_decode_TFRecordDataset( args.tfrecords, BATCH_SIZE, args.idnum, EPOCH_NUM) if args.mode == 'deconv': import speech2vid_inference_deconv speech2vid = speech2vid_inference_deconv.Speech2Vid( USE_AUDIO, USE_BN, USE_LIP, USE_DECODER, USE_FACEFC, USE_AUDIOFC, USE_FACE, USE_XAVIER, AVOID_LAYERS_LIST, args.idnum) elif args.mode == 'upsample': import speech2vid_inference_finetune speech2vid = speech2vid_inference_finetune.Speech2Vid( USE_AUDIO, USE_BN, USE_LIP, USE_DECODER, USE_FACEFC, USE_AUDIOFC, USE_FACE, USE_XAVIER, AVOID_LAYERS_LIST, args.idnum) elif args.mode == 'gan': import speech2vid_inference_gan speech2vid = speech2vid_inference_gan.Speech2Vid( USE_AUDIO, USE_BN, USE_LIP, USE_DECODER, USE_FACEFC, USE_AUDIOFC, USE_FACE, USE_XAVIER, AVOID_LAYERS_LIST, args.idnum, IS_TRAINING) with tf.variable_scope('generator'): print '============ BUILD G ============' audio_encoder_output = speech2vid.audio_encoder(audio_batch) if args.idnum == 1: tf.summary.image("identity1", identity_batch) elif args.idnum == 5: tf.summary.image("identity1", tf.split(identity_batch, 5, -1)[0]) tf.summary.image("identity2", tf.split(identity_batch, 5, -1)[1]) tf.summary.image("identity3", tf.split(identity_batch, 5, -1)[2]) tf.summary.image("identity4", tf.split(identity_batch, 5, -1)[3]) tf.summary.image("identity5", tf.split(identity_batch, 5, -1)[4]) identity_encoder_output, x7_face, x4_face, x3_face = speech2vid.identity_encoder( identity_batch, args.idnum) prediction = speech2vid.image_decoder( audio_encoder_output, identity_encoder_output, x7_face, x4_face, x3_face) with tf.name_scope('groundtruth_discriminator'): with tf.variable_scope('discriminator'): discriminator_real = speech2vid.discriminator(groundtruth_batch) with tf.name_scope('prediction_discriminator'): with tf.variable_scope('discriminator', reuse=True): discriminator_pred = speech2vid.discriminator(prediction) with tf.name_scope('discriminator_loss'): discrim_loss = tf.reduce_mean( - (tf.log(discriminator_real + EPS) + tf.log(1 - discriminator_pred + EPS))) with tf.name_scope("generator_loss"): gen_loss_gan = tf.reduce_mean(-tf.log(discriminator_pred + EPS)) gen_loss_l1, loss0, loss1, loss2, loss3, loss4, loss5 = speech2vid.loss_l1( prediction, groundtruth_batch) gen_loss = gen_loss_gan * GAN_WEIGHT + gen_loss_l1 * L1_WEIGHT with tf.name_scope('discriminator_train'): print '============ BUILD D ============' discrim_tvars = [var for var in tf.trainable_variables( ) if var.name.startswith("discriminator")] discrim_optim = tf.train.AdamOptimizer(0.001) discrim_grads_and_vars = discrim_optim.compute_gradients( discrim_loss, var_list=discrim_tvars) discrim_train = discrim_optim.apply_gradients(discrim_grads_and_vars) with tf.name_scope("generator_train"): with tf.control_dependencies([discrim_train]): gen_var_list = [ var for var in tf.trainable_variables() if 'generator' in var.op.name] bn_var_list = [var for var in tf.trainable_variables() if 'bn' in var.op.name] audio_var_list = [var for var in tf.trainable_variables() if 'audio' in var.op.name and var not in bn_var_list] identity_var_list = [ var for var in tf.trainable_variables() if str(var.op.name).split('/')[1][-4:] == 'face' and var not in bn_var_list] lip_var_list = [ var for var in tf.trainable_variables() if 'face_lip' in var.op.name and var not in bn_var_list] decoder_var_list = list(set(gen_var_list)-set(bn_var_list) - set(audio_var_list) - set(identity_var_list) - set(lip_var_list)) identity_learning_rate = BASIC_LEARNING_RATE * IDENTITY_LEARNING_RATE_BASE identity_optimizer = tf.train.AdamOptimizer( learning_rate=identity_learning_rate) identity_grads_and_vars = identity_optimizer.compute_gradients( gen_loss, var_list=identity_var_list) identity_train_op = identity_optimizer.apply_gradients( identity_grads_and_vars) audio_learning_rate = BASIC_LEARNING_RATE*AUDIO_LEARNING_RATE_BASE audio_optimizer = tf.train.AdamOptimizer( learning_rate=audio_learning_rate) audio_grads_and_vars = audio_optimizer.compute_gradients( gen_loss, var_list=audio_var_list) audio_train_op = audio_optimizer.apply_gradients( audio_grads_and_vars) bn_learning_rate = BASIC_LEARNING_RATE*BN_LEARNING_RATE_BASE bn_optimizer = tf.train.AdamOptimizer( learning_rate=bn_learning_rate) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): bn_grads_and_vars = bn_optimizer.compute_gradients( gen_loss, var_list=bn_var_list) bn_train_op = bn_optimizer.apply_gradients(bn_grads_and_vars) lip_learning_rate = BASIC_LEARNING_RATE*LIP_LEARNING_RATE_BASE lip_optimizer = tf.train.AdamOptimizer( learning_rate=lip_learning_rate) lip_grads_and_vars = lip_optimizer.compute_gradients( gen_loss, var_list=lip_var_list) lip_train_op = lip_optimizer.apply_gradients(lip_grads_and_vars) decoder_learning_rate = BASIC_LEARNING_RATE*DECODER_LEARNING_RATE_BASE decoder_optimizer = tf.train.AdamOptimizer( learning_rate=decoder_learning_rate) decoder_grads_and_vars = decoder_optimizer.compute_gradients( gen_loss, var_list=decoder_var_list) decoder_train_op = decoder_optimizer.apply_gradients( decoder_grads_and_vars) ema = tf.train.ExponentialMovingAverage(decay=0.99) update_losses = ema.apply([discrim_loss, gen_loss_gan, gen_loss_l1]) global_step = tf.train.get_or_create_global_step() incr_global_step = tf.assign(global_step, global_step + 1) return Model( epoch_now=epoch_now, input_audio=audio_batch, input_groundtruth=groundtruth_batch, prediction=prediction, discriminator_real=discriminator_real, discriminator_pred=discriminator_pred, discrim_grads_and_vars=discrim_grads_and_vars, identity_grads_and_vars=identity_grads_and_vars, audio_grads_and_vars=audio_grads_and_vars, bn_grads_and_vars=bn_grads_and_vars, lip_grads_and_vars=lip_grads_and_vars, decoder_grads_and_vars=decoder_grads_and_vars, discrim_loss=ema.average(discrim_loss), gen_loss_gan=ema.average(gen_loss_gan), gen_loss_l1=ema.average(gen_loss_l1), train=tf.group(update_losses, incr_global_step, identity_train_op, audio_train_op, bn_train_op, lip_train_op, decoder_train_op) )
def evaluation(): identity5_0 = read_five_images(FIVE_IMAGES_PATH, args.idnum, face_detection=args.face_detection) if args.idnum == 5: identity = tf.placeholder(tf.float32, shape=[1, 112, 112, 15]) elif args.idnum == 1: identity = tf.placeholder(tf.float32, shape=[1, 112, 112, 3]) mfcc = tf.placeholder(tf.float32, shape=[1, 12, 35, 1]) if args.matlab == 1: C = extract_mfcc_matlab(AUDIO_PATH) else: C = extract_mfcc(AUDIO_PATH) C_length = 0 save_path_list = [] for j in range(0, C.shape[1] - 34, 4): C_length = C_length + 1 count = 0 if args.mode == 'upsample': import speech2vid_inference_finetune speech2vid = speech2vid_inference_finetune.Speech2Vid( USE_AUDIO, USE_BN, USE_LIP, USE_DECODER, USE_FACEFC, USE_AUDIOFC, USE_FACE, USE_XAVIER, AVOID_LAYERS_LIST, args.idnum) elif args.mode == 'deconv': import speech2vid_inference_deconv speech2vid = speech2vid_inference_deconv.Speech2Vid( USE_AUDIO, USE_BN, USE_LIP, USE_DECODER, USE_FACEFC, USE_AUDIOFC, USE_FACE, USE_XAVIER, AVOID_LAYERS_LIST, args.idnum) audio_encoder_output = speech2vid.audio_encoder(mfcc) identity_encoder_output, x7_face, x4_face, x3_face = speech2vid.identity_encoder( identity, args.idnum) prediction = speech2vid.image_decoder(audio_encoder_output, identity_encoder_output, x7_face, x4_face, x3_face) with tf.Session(config=config) as sess: ckpt = tf.train.get_checkpoint_state(CHECKPOINT_DIR) if ckpt and ckpt.model_checkpoint_path: global_step = ckpt.model_checkpoint_path.split('/')[-1].split( '-')[-1] saver = tf.train.Saver() saver.restore(sess, ckpt.model_checkpoint_path) print('load checkpoint success, global_step is %s' % global_step) else: print 'load checkpoint fail' for j in range(0, C.shape[1] - 34, 4): start_time = time.time() mfcc_0 = C[1:, j:j + 35] mfcc_0 = np.reshape(mfcc_0, (1, 12, 35, 1)) prediction_, identity_encoder_output_, audio_encoder_output_ = sess.run( [prediction, identity_encoder_output, audio_encoder_output], feed_dict={ identity: identity5_0, mfcc: mfcc_0 }) '''display tensor value''' # # for t in ['relu7_audio/Relu']: # # for t in ['concat8/concat8', 'upsamp1_1/resampler/Resampler', 'upsamp1_2/resampler/Resampler', 'upsamp1_3/resampler/Resampler', 'upsamp2/resampler/Resampler']: # for t in ['upsamp1_2/ResizeNearestNeighbor']: # # for t in ['concat8/concat8', 'conv8/BiasAdd', 'relu8/Relu', 'upsamp1_1/ResizeNearestNeighbor', 'conv1_1/BiasAdd','relu1_1/Relu','upsamp1_2/ResizeNearestNeighbor','conv1_2/BiasAdd','relu1_2/Relu','upsamp1_3/ResizeNearestNeighbor','conv1_3/BiasAdd','relu1_3/Relu','upsamp2/ResizeNearestNeighbor','conv2/BiasAdd','relu2/Relu','upsamp3_1/ResizeNearestNeighbor','conv3_1/BiasAdd','relu3_1/Relu','conv3_2/BiasAdd','relu3_2/Relu','upsamp4/ResizeNearestNeighbor','conv4/BiasAdd','relu4/Relu','upsamp5_1/ResizeNearestNeighbor','conv5_1/BiasAdd','relu5_1/Relu','conv5_2/BiasAdd','relu5_2/Relu','conv5_3/BiasAdd']: # tensor = sess.graph.get_tensor_by_name( # "{}:0".format(t)) # tensor_ = sess.run(tensor, feed_dict={ # identity: identity5_0, mfcc: mfcc_0}) # # print t, np.mean(tensor_) # # print tensor_ # for i in range(tensor_.shape[3]): # print t,i+1,np.mean(tensor_[:,:,:,i]),np.var(tensor_[:,:,:,i]) '''save tensor value to txt''' # tensor_name_list = [ # tensor.name for tensor in tf.get_default_graph().as_graph_def().node] # for tensor_name in tensor_name_list: # print tensor_name # if 'Place' not in tensor_name and 'save' not in tensor_name and 'Pad' not in tensor_name and 'Conv2D' not in tensor_name: # tensor = sess.graph.get_tensor_by_name( # "{}:0".format(tensor_name)) # tensor_ = sess.run(tensor) # print type(tensor_) # txt_path = 'tmp/{}.npy'.format( # tensor_name.replace('/', '-')) # np.save(txt_path, tensor_) # break prediction_Image = np.squeeze(prediction_) prediction_Image = 255 * prediction_Image / np.max( prediction_Image) prediction_Image = Image.fromarray( prediction_Image.astype(np.uint8)) '''put text''' text = str(count) font = ImageFont.truetype( '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', 30) draw = ImageDraw.Draw(prediction_Image) draw.text((10, 10), text=text, font=font) # prediction.show() if not os.path.exists(OUTPUT_DIR): os.makedirs(OUTPUT_DIR) count = count + 1 save_path = ('%s/%03d.jpg' % (OUTPUT_DIR, count)) save_path_list.append(save_path) prediction_Image.save(save_path) print 'Finish {}/{}, Audio output mean {}, Identity output mean {}, Prediction mean {}, Elapsed time {}'.format( count, C_length, np.mean(audio_encoder_output_), np.mean(identity_encoder_output_), np.mean(prediction_), time.time() - start_time) if args.mp4 == 1: save_video(OUTPUT_DIR, WAV_TO_GENERATE_MP4_PATH, '{}-{}.mp4'.format(OUTPUT_DIR, global_step))
def train(): '''use tfrecords''' face_batch, audio_batch, identity5_batch = read_and_decode( args.tfrecords, BATCH_SIZE, 5) face_batch = tf.cast(face_batch, dtype=tf.float32) audio_batch = tf.cast(audio_batch, dtype=tf.float32) identity5_batch = tf.cast(identity5_batch, dtype=tf.float32) # for i in xrange(NUM_GPUS): # with tf.device('/gpu:%d' % i): speech2vid = speech2vid_inference_finetune.Speech2Vid() prediction = speech2vid.inference(audio_batch, identity5_batch, BATCH_SIZE) train_loss = speech2vid.compute_loss(prediction, face_batch) global_step = tf.Variable(0, trainable=False) train_op = tf.train.GradientDescentOptimizer(LEARNING_RATE_BASE).minimize( train_loss, global_step=global_step) summary_op = tf.summary.merge_all() # train_writer = tf.summary.FileWriter(LOG_SAVE_PATH, graph) saver = tf.train.Saver() with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) try: start_time = time.time() for step in np.arange(TRAINING_STEPS): if coord.should_stop(): break _, training_loss, step = sess.run( [train_op, train_loss, global_step]) if step % 500 == 0: end_time = time.time() elapsed_time = end_time - start_time print 'Step: {}, Train loss: {}, Elapsed time: {}'.format( step, training_loss, elapsed_time) summary_str = sess.run(summary_op) # train_writer.add_summary(summary_str, step) start_time = time.time() if step % 10000 == 0 or (step + 1) == TRAINING_STEPS: saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step) print 'Model {} saved'.format(step) except KeyboardInterrupt: logging.info('Interrupted') coord.request_stop() except Exception as e: traceback.print_exc() coord.request_stop(e) finally: saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step) print 'Model {} saved'.format(step) coord.request_stop() coord.join(threads)