예제 #1
0
def train():
    '''use tfrecords'''
    face_batch, audio_batch, identity5_batch, epoch_now = read_and_decode_TFRecordDataset(
        args.tfrecords, BATCH_SIZE, args.idnum, EPOCH_NUM)

    speech2vid = speech2vid_inference_finetune.Speech2Vid(
        USE_AUDIO, USE_BN, USE_LIP, USE_DECODER, USE_FACEFC, USE_AUDIOFC, USE_FACE, USE_XAVIER, AVOID_LAYERS_LIST, args.idnum)
    audio_encoder_output = speech2vid.audio_encoder(audio_batch)
    identity_encoder_output, x7_face, x4_face, x3_face = speech2vid.identity_encoder(
        identity5_batch, args.idnum)
    prediction = speech2vid.image_decoder(
        audio_encoder_output, identity_encoder_output, x7_face, x4_face, x3_face)
    # prediction = speech2vid.inference(audio_batch, identity5_batch)
    train_loss, loss0, loss1, loss2, loss3, loss4, loss5 = speech2vid.loss_l1(
        prediction, face_batch)

    bn_var_list = [var for var in tf.trainable_variables()
                   if 'bn' in var.op.name]
    audio_var_list = [var for var in tf.trainable_variables()
                      if 'audio' in var.op.name and var not in bn_var_list]
    identity_var_list = [
        var for var in tf.trainable_variables() if str(var.op.name).split('/')[0][-4:] == 'face' and var not in bn_var_list]
    lip_var_list = [
        var for var in tf.trainable_variables() if 'face_lip' in var.op.name and var not in bn_var_list]
    decoder_var_list = [var for var in tf.trainable_variables(
    ) if var not in audio_var_list + identity_var_list + lip_var_list + bn_var_list]

    global_step = tf.Variable(0, trainable=False)
    '''exponential_decay lr'''
    # identity_learning_rate = tf.train.exponential_decay(
    #     IDENTITY_LEARNING_RATE_BASE, global_step, 5000, LEARNING_RATE_DECAY, staircase=False)
    # audio_learning_rate = tf.train.exponential_decay(
    #     AUDIO_LEARNING_RATE_BASE, global_step, 5000, LEARNING_RATE_DECAY, staircase=False)
    # bn_learning_rate = tf.train.exponential_decay(
    #     BN_LEARNING_RATE_BASE, global_step, 5000, LEARNING_RATE_DECAY, staircase=False)
    # lip_learning_rate = tf.train.exponential_decay(
    #     LIP_LEARNING_RATE_BASE, global_step, 5000, LEARNING_RATE_DECAY, staircase=False)
    # decoder_learning_rate = tf.train.exponential_decay(
    #     DECODER_LEARNING_RATE_BASE, global_step, 5000, LEARNING_RATE_DECAY, staircase=False)
    '''constant lr'''
    identity_learning_rate = BASIC_LEARNING_RATE*IDENTITY_LEARNING_RATE_BASE
    audio_learning_rate = BASIC_LEARNING_RATE*AUDIO_LEARNING_RATE_BASE
    bn_learning_rate = BASIC_LEARNING_RATE*BN_LEARNING_RATE_BASE
    lip_learning_rate = BASIC_LEARNING_RATE*LIP_LEARNING_RATE_BASE
    decoder_learning_rate = BASIC_LEARNING_RATE*DECODER_LEARNING_RATE_BASE
    '''SGD'''
    # identity_optimizer = tf.train.GradientDescentOptimizer(identity_learning_rate)
    # audio_optimizer = tf.train.GradientDescentOptimizer(audio_learning_rate)
    # bn_optimizer = tf.train.GradientDescentOptimizer(bn_learning_rate)
    # lip_optimizer = tf.train.GradientDescentOptimizer(lip_learning_rate)
    # decoder_optimizer = tf.train.GradientDescentOptimizer(decoder_learning_rate)
    '''Momentum'''
    # identity_optimizer = tf.train.MomentumOptimizer(
    #     identity_learning_rate, MOMENTUM)
    # audio_optimizer = tf.train.MomentumOptimizer(audio_learning_rate, MOMENTUM)
    # bn_optimizer = tf.train.MomentumOptimizer(bn_learning_rate, MOMENTUM)
    # lip_optimizer = tf.train.MomentumOptimizer(lip_learning_rate, MOMENTUM)
    # decoder_optimizer = tf.train.MomentumOptimizer(
    #     decoder_learning_rate, MOMENTUM)
    '''Adam'''
    identity_optimizer = tf.train.AdamOptimizer(
        learning_rate=identity_learning_rate)
    audio_optimizer = tf.train.AdamOptimizer(learning_rate=audio_learning_rate)
    bn_optimizer = tf.train.AdamOptimizer(learning_rate=bn_learning_rate)
    lip_optimizer = tf.train.AdamOptimizer(learning_rate=lip_learning_rate)
    decoder_optimizer = tf.train.AdamOptimizer(
        learning_rate=decoder_learning_rate)
    '''Seperate learning rate option 1'''
    identity_train_op = identity_optimizer.minimize(
        train_loss, global_step=global_step, var_list=identity_var_list)
    bn_train_op = bn_optimizer.minimize(
        train_loss, global_step=global_step, var_list=bn_var_list)
    lip_train_op = lip_optimizer.minimize(
        train_loss, global_step=global_step, var_list=lip_var_list)
    decoder_train_op = decoder_optimizer.minimize(
        train_loss, global_step=global_step, var_list=decoder_var_list)
    audio_train_op = audio_optimizer.minimize(
        train_loss, global_step=global_step, var_list=audio_var_list)
    train_op = tf.group(identity_train_op, audio_train_op,
                        bn_train_op, lip_train_op, decoder_train_op)
    '''Only train decoder'''
    # decoder_train_op = decoder_optimizer.minimize(
    #     train_loss, global_step=global_step, var_list=decoder_var_list)
    # train_op = tf.group(decoder_train_op)
    '''Seperate learning rate option 2'''
    # grads = tf.gradients(train_loss, bn_var_list+audio_var_list +
    #                      identity_var_list+lip_var_list+decoder_var_list)
    # bn_grad = grads[:len(bn_var_list)]
    # audio_grad = grads[len(bn_var_list):len(bn_var_list + audio_var_list)]
    # identity_grad = grads[len(bn_var_list + audio_var_list)                          :len(bn_var_list + audio_var_list + identity_var_list)]
    # lip_grad = grads[len(bn_var_list + audio_var_list + identity_var_list)                     :len(bn_var_list + audio_var_list + identity_var_list + lip_var_list)]
    # decoder_grad = grads[len(
    #     bn_var_list + audio_var_list + identity_var_list + lip_var_list):]
    # identity_train_op = identity_optimizer.apply_gradients(
    #     zip(identity_grad, identity_var_list), global_step=global_step)
    # bn_train_op = bn_optimizer.apply_gradients(
    #     zip(bn_grad, bn_var_list), global_step=global_step)
    # lip_train_op = lip_optimizer.apply_gradients(
    #     zip(lip_grad, lip_var_list), global_step=global_step)
    # decoder_train_op = decoder_optimizer.apply_gradients(
    #     zip(decoder_grad, decoder_var_list), global_step=global_step)
    # audio_train_op = audio_optimizer.apply_gradients(
    #     zip(audio_grad, audio_var_list), global_step=global_step)
    # train_op = tf.group(identity_train_op, audio_train_op,
    #                     bn_train_op, lip_train_op, decoder_train_op)
    '''Only one learning rate'''
    # optimizer = tf.train.GradientDescentOptimizer(identity_learning_rate)
    # train_op = optimizer.minimize(train_loss, global_step=global_step)

    saver = tf.train.Saver(max_to_keep=0)

    tf.summary.scalar("loss", train_loss)
    tf.summary.image("face_gt",  face_batch)
    tf.summary.image("audio",  audio_batch)
    tf.summary.image("prediction",  prediction)

    summary_op = tf.summary.merge_all()

    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        if args.ckpt:
            ckpt = tf.train.get_checkpoint_state(args.ckpt)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                print '{} loaded'.format(ckpt.model_checkpoint_path)

        # coord = tf.train.Coordinator()
        # threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        train_writer = tf.summary.FileWriter(LOG_SAVE_PATH, sess.graph)
        early_stop_loss_list = []
        try:
            start_time = time.time()
            for step in np.arange(TRAINING_STEPS):
                # if coord.should_stop():
                #     break

                # _,  training_loss, step, summary, identity_learning_rate_, audio_learning_rate_, bn_learning_rate_, lip_learning_rate_, decoder_learning_rate_, loss0_, loss1_, loss2_, loss3_, loss4_, loss5_, audio_encoder_output_, identity_encoder_output_, x7_face_, base_lr_ = sess.run(
                #     [train_op,  train_loss, global_step, summary_op, identity_optimizer._learning_rate, audio_optimizer._learning_rate, bn_optimizer._learning_rate, lip_optimizer._learning_rate, decoder_optimizer._learning_rate, loss0, loss1, loss2, loss3, loss4, loss5, audio_encoder_output, identity_encoder_output, x7_face, BASIC_LEARNING_RATE])

                '''When using Adam'''
                _,  training_loss, step, summary, loss0_, loss1_, loss2_, loss3_, loss4_, loss5_, audio_encoder_output_, identity_encoder_output_, x7_face_, base_lr_, epoch_now_ = sess.run(
                    [train_op,  train_loss, global_step, summary_op, loss0, loss1, loss2, loss3, loss4, loss5, audio_encoder_output, identity_encoder_output, x7_face, BASIC_LEARNING_RATE, epoch_now])

                train_writer.add_summary(summary, step)
                # print 'x7_face_', np.mean(x7_face_)
                # print 'audio_encoder_output', np.max(audio_encoder_output_), np.min(
                #     audio_encoder_output_), np.mean(audio_encoder_output_)
                # print 'identity_encoder_output', np.max(identity_encoder_output_), np.min(
                #     identity_encoder_output_), np.mean(identity_encoder_output_)

                if step % 1 == 0:
                    end_time = time.time()
                    elapsed_time = end_time - start_time

                    # print '{}  Step: {}  Total loss: {}\tLoss0: {}\tLoss1: {}\tLoss2: {}\tLoss3: {}\tLoss4: {}\tLoss5: {}\tTime: {}\tBASE Lr: {}  ID Lr: {}  AU Lr: {}  BN Lr: {}  LIP Lr: {}  DE Lr: {}'.format(
                    #     name, step,  round(training_loss, 5), round(loss0_, 5), round(loss1_, 5), round(loss2_, 5), round(loss3_, 5), round(loss4_, 5), round(loss5_, 5), round(elapsed_time, 2), round(base_lr_, 10), round(identity_learning_rate_, 10), round(audio_learning_rate_, 10), round(bn_learning_rate_, 10), round(lip_learning_rate_, 10), round(decoder_learning_rate_, 10))

                    '''When using Adam'''
                    print '{}  Adam  {}  Epoch: {}  Step: {}  Total loss: {}\tLoss0: {}\tLoss1: {}\tLoss2: {}\tLoss3: {}\tLoss4: {}\tLoss5: {}\tTime: {}\tBASE Lr: {}'.format(datetime.now().strftime("%m/%d %H:%M:%S"),
                                                                                                                                                                              name, epoch_now_[0], step,  round(training_loss, 5), round(loss0_, 5), round(loss1_, 5), round(loss2_, 5), round(loss3_, 5), round(loss4_, 5), round(loss5_, 5), round(elapsed_time, 2), round(base_lr_, 10))

                    start_time = time.time()

                if step % 1000 == 0 or (step+1) == TRAINING_STEPS:
                    if not os.path.exists(MODEL_SAVE_PATH):
                        os.makedirs(MODEL_SAVE_PATH)
                    saver.save(sess, os.path.join(MODEL_SAVE_PATH,
                                                  MODEL_NAME), global_step=global_step)
                    print 'Model {} saved at {}'.format(
                        step, MODEL_SAVE_PATH)

                    '''Early stopping'''
                    # 1. test on validation set, record loss and the minimum loss
                    # 2. if loss on validation set is larger than the minimum loss for 10 times, stop training

        except KeyboardInterrupt:
            logging.info('Interrupted')
            # coord.request_stop()
        except Exception as e:
            traceback.print_exc()
            # coord.request_stop(e)
        finally:
            if not os.path.exists(MODEL_SAVE_PATH):
                os.makedirs(MODEL_SAVE_PATH)
            saver.save(sess, os.path.join(MODEL_SAVE_PATH,
                                          MODEL_NAME), global_step=global_step)
            print 'Model {} saved at {}'.format(step, MODEL_SAVE_PATH)
예제 #2
0
def train():
    '''use tfrecords'''
    face_batch, audio_batch, identity5_batch = read_and_decode(
        args.tfrecords, BATCH_SIZE, args.idnum)
    face_batch = tf.cast(face_batch, dtype=tf.float32)
    audio_batch = tf.cast(audio_batch, dtype=tf.float32)
    identity5_batch = tf.cast(identity5_batch, dtype=tf.float32)

    speech2vid = speech2vid_inference_finetune.Speech2Vid(
        USE_AUDIO, USE_BN, USE_LIP, USE_DECODER, USE_FACEFC, USE_AUDIOFC,
        USE_FACE, USE_XAVIER, AVOID_LAYERS_LIST)
    audio_encoder_output = speech2vid.audio_encoder(audio_batch)
    identity_encoder_output, x7_face, x4_face, x3_face = speech2vid.identity_encoder(
        identity5_batch, args.idnum)
    prediction = speech2vid.image_decoder(audio_encoder_output,
                                          identity_encoder_output, x7_face,
                                          x4_face, x3_face)
    # prediction = speech2vid.inference(audio_batch, identity5_batch)
    train_loss, loss0, loss1, loss2, loss3, loss4, loss5 = speech2vid.loss_l1(
        prediction, face_batch)

    bn_var_list = [
        var for var in tf.trainable_variables() if 'bn' in var.op.name
    ]
    audio_var_list = [
        var for var in tf.trainable_variables()
        if 'audio' in var.op.name and var not in bn_var_list
    ]
    identity_var_list = [
        var for var in tf.trainable_variables()
        if str(var.op.name).split('/')[0][-4:] == 'face'
        and var not in bn_var_list
    ]
    lip_var_list = [
        var for var in tf.trainable_variables()
        if 'face_lip' in var.op.name and var not in bn_var_list
    ]
    decoder_var_list = [
        var for var in tf.trainable_variables() if var not in audio_var_list +
        identity_var_list + lip_var_list + bn_var_list
    ]

    global_step = tf.Variable(0, trainable=False)
    '''exponential_decay lr'''
    # identity_learning_rate = tf.train.exponential_decay(
    #     IDENTITY_LEARNING_RATE_BASE, global_step, 5000, LEARNING_RATE_DECAY, staircase=False)
    # audio_learning_rate = tf.train.exponential_decay(
    #     AUDIO_LEARNING_RATE_BASE, global_step, 5000, LEARNING_RATE_DECAY, staircase=False)
    # bn_learning_rate = tf.train.exponential_decay(
    #     BN_LEARNING_RATE_BASE, global_step, 5000, LEARNING_RATE_DECAY, staircase=False)
    # lip_learning_rate = tf.train.exponential_decay(
    #     LIP_LEARNING_RATE_BASE, global_step, 5000, LEARNING_RATE_DECAY, staircase=False)
    # decoder_learning_rate = tf.train.exponential_decay(
    #     DECODER_LEARNING_RATE_BASE, global_step, 5000, LEARNING_RATE_DECAY, staircase=False)
    '''constant lr'''
    identity_learning_rate = BASIC_LEARNING_RATE * IDENTITY_LEARNING_RATE_BASE
    audio_learning_rate = BASIC_LEARNING_RATE * AUDIO_LEARNING_RATE_BASE
    bn_learning_rate = BASIC_LEARNING_RATE * BN_LEARNING_RATE_BASE
    lip_learning_rate = BASIC_LEARNING_RATE * LIP_LEARNING_RATE_BASE
    decoder_learning_rate = BASIC_LEARNING_RATE * DECODER_LEARNING_RATE_BASE
    '''SGD'''
    # identity_optimizer = tf.train.GradientDescentOptimizer(identity_learning_rate)
    # audio_optimizer = tf.train.GradientDescentOptimizer(audio_learning_rate)
    # bn_optimizer = tf.train.GradientDescentOptimizer(bn_learning_rate)
    # lip_optimizer = tf.train.GradientDescentOptimizer(lip_learning_rate)
    # decoder_optimizer = tf.train.GradientDescentOptimizer(decoder_learning_rate)
    '''Momentum'''
    identity_optimizer = tf.train.MomentumOptimizer(identity_learning_rate,
                                                    MOMENTUM)
    audio_optimizer = tf.train.MomentumOptimizer(audio_learning_rate, MOMENTUM)
    bn_optimizer = tf.train.MomentumOptimizer(bn_learning_rate, MOMENTUM)
    lip_optimizer = tf.train.MomentumOptimizer(lip_learning_rate, MOMENTUM)
    decoder_optimizer = tf.train.MomentumOptimizer(decoder_learning_rate,
                                                   MOMENTUM)
    '''Seperate learning rate option 1'''
    identity_train_op = identity_optimizer.minimize(train_loss,
                                                    global_step=global_step,
                                                    var_list=identity_var_list)
    bn_train_op = bn_optimizer.minimize(train_loss,
                                        global_step=global_step,
                                        var_list=bn_var_list)
    lip_train_op = lip_optimizer.minimize(train_loss,
                                          global_step=global_step,
                                          var_list=lip_var_list)
    decoder_train_op = decoder_optimizer.minimize(train_loss,
                                                  global_step=global_step,
                                                  var_list=decoder_var_list)
    audio_train_op = audio_optimizer.minimize(train_loss,
                                              global_step=global_step,
                                              var_list=audio_var_list)
    train_op = tf.group(identity_train_op, audio_train_op, bn_train_op,
                        lip_train_op, decoder_train_op)
    '''Seperate learning rate option 2'''
    # grads = tf.gradients(train_loss, identity_var_list+audio_var_list)
    # identity_grad = grads[:len(identity_var_list)]
    # audio_grad = grads[len(identity_var_list):]
    # identity_train_op = identity_optimizer.apply_gradients(
    #     zip(identity_grad, identity_var_list))
    # audio_train_op = audio_optimizer.apply_gradients(
    #     zip(audio_grad, audio_var_list))
    # train_op = tf.group(identity_train_op, audio_train_op)
    '''Only one learning rate'''
    # optimizer = tf.train.GradientDescentOptimizer(identity_learning_rate)
    # train_op = optimizer.minimize(train_loss, global_step=global_step)

    saver = tf.train.Saver()

    tf.summary.scalar("loss", train_loss)
    tf.summary.image("prediction", prediction)
    tf.summary.image("face_gt", face_batch)
    tf.summary.image("audio", audio_batch)

    summary_op = tf.summary.merge_all()

    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())
        if args.ckpt:
            ckpt = tf.train.get_checkpoint_state(args.ckpt)
            if ckpt and ckpt.model_checkpoint_path:
                saver.restore(sess, ckpt.model_checkpoint_path)
                print '{} loaded'.format(ckpt.model_checkpoint_path)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        train_writer = tf.summary.FileWriter(LOG_SAVE_PATH, sess.graph)
        try:
            if not os.path.exists(MODEL_SAVE_PATH):
                os.makedirs(MODEL_SAVE_PATH)
            saver.save(sess,
                       os.path.join(MODEL_SAVE_PATH, MODEL_NAME),
                       global_step=global_step)
            print 'Model {} saved at {}'.format(0, MODEL_SAVE_PATH)

        except KeyboardInterrupt:
            logging.info('Interrupted')
            coord.request_stop()
        except Exception as e:
            traceback.print_exc()
            coord.request_stop(e)
        finally:
            coord.request_stop()
            coord.join(threads)
예제 #3
0
def create_model():
    groundtruth_batch, audio_batch, identity_batch, epoch_now = read_and_decode_TFRecordDataset(
        args.tfrecords, BATCH_SIZE, args.idnum, EPOCH_NUM)

    if args.mode == 'deconv':
        import speech2vid_inference_deconv
        speech2vid = speech2vid_inference_deconv.Speech2Vid(
            USE_AUDIO, USE_BN, USE_LIP, USE_DECODER, USE_FACEFC, USE_AUDIOFC, USE_FACE, USE_XAVIER, AVOID_LAYERS_LIST, args.idnum)
    elif args.mode == 'upsample':
        import speech2vid_inference_finetune
        speech2vid = speech2vid_inference_finetune.Speech2Vid(
            USE_AUDIO, USE_BN, USE_LIP, USE_DECODER, USE_FACEFC, USE_AUDIOFC, USE_FACE, USE_XAVIER, AVOID_LAYERS_LIST, args.idnum)
    elif args.mode == 'gan':
        import speech2vid_inference_gan
        speech2vid = speech2vid_inference_gan.Speech2Vid(
            USE_AUDIO, USE_BN, USE_LIP, USE_DECODER, USE_FACEFC, USE_AUDIOFC, USE_FACE, USE_XAVIER, AVOID_LAYERS_LIST, args.idnum, IS_TRAINING)

    with tf.variable_scope('generator'):
        print '============ BUILD G ============'
        audio_encoder_output = speech2vid.audio_encoder(audio_batch)
        if args.idnum == 1:
            tf.summary.image("identity1", identity_batch)
        elif args.idnum == 5:
            tf.summary.image("identity1", tf.split(identity_batch, 5, -1)[0])
            tf.summary.image("identity2", tf.split(identity_batch, 5, -1)[1])
            tf.summary.image("identity3", tf.split(identity_batch, 5, -1)[2])
            tf.summary.image("identity4", tf.split(identity_batch, 5, -1)[3])
            tf.summary.image("identity5", tf.split(identity_batch, 5, -1)[4])
        identity_encoder_output, x7_face, x4_face, x3_face = speech2vid.identity_encoder(
            identity_batch, args.idnum)
        prediction = speech2vid.image_decoder(
            audio_encoder_output, identity_encoder_output, x7_face, x4_face, x3_face)

    with tf.name_scope('groundtruth_discriminator'):
        with tf.variable_scope('discriminator'):
            discriminator_real = speech2vid.discriminator(groundtruth_batch)

    with tf.name_scope('prediction_discriminator'):
        with tf.variable_scope('discriminator', reuse=True):
            discriminator_pred = speech2vid.discriminator(prediction)

    with tf.name_scope('discriminator_loss'):
        discrim_loss = tf.reduce_mean(
            - (tf.log(discriminator_real + EPS) + tf.log(1 - discriminator_pred + EPS)))

    with tf.name_scope("generator_loss"):
        gen_loss_gan = tf.reduce_mean(-tf.log(discriminator_pred + EPS))
        gen_loss_l1, loss0, loss1, loss2, loss3, loss4, loss5 = speech2vid.loss_l1(
            prediction, groundtruth_batch)
        gen_loss = gen_loss_gan * GAN_WEIGHT + gen_loss_l1 * L1_WEIGHT

    with tf.name_scope('discriminator_train'):
        print '============ BUILD D ============'
        discrim_tvars = [var for var in tf.trainable_variables(
        ) if var.name.startswith("discriminator")]
        discrim_optim = tf.train.AdamOptimizer(0.001)
        discrim_grads_and_vars = discrim_optim.compute_gradients(
            discrim_loss, var_list=discrim_tvars)
        discrim_train = discrim_optim.apply_gradients(discrim_grads_and_vars)

    with tf.name_scope("generator_train"):
        with tf.control_dependencies([discrim_train]):
            gen_var_list = [
                var for var in tf.trainable_variables() if 'generator' in var.op.name]
            bn_var_list = [var for var in tf.trainable_variables()
                           if 'bn' in var.op.name]
            audio_var_list = [var for var in tf.trainable_variables()
                              if 'audio' in var.op.name and var not in bn_var_list]
            identity_var_list = [
                var for var in tf.trainable_variables() if str(var.op.name).split('/')[1][-4:] == 'face' and var not in bn_var_list]
            lip_var_list = [
                var for var in tf.trainable_variables() if 'face_lip' in var.op.name and var not in bn_var_list]
            decoder_var_list = list(set(gen_var_list)-set(bn_var_list) -
                                    set(audio_var_list) - set(identity_var_list) - set(lip_var_list))
            identity_learning_rate = BASIC_LEARNING_RATE * IDENTITY_LEARNING_RATE_BASE
            identity_optimizer = tf.train.AdamOptimizer(
                learning_rate=identity_learning_rate)
            identity_grads_and_vars = identity_optimizer.compute_gradients(
                gen_loss, var_list=identity_var_list)
            identity_train_op = identity_optimizer.apply_gradients(
                identity_grads_and_vars)
            audio_learning_rate = BASIC_LEARNING_RATE*AUDIO_LEARNING_RATE_BASE
            audio_optimizer = tf.train.AdamOptimizer(
                learning_rate=audio_learning_rate)
            audio_grads_and_vars = audio_optimizer.compute_gradients(
                gen_loss, var_list=audio_var_list)
            audio_train_op = audio_optimizer.apply_gradients(
                audio_grads_and_vars)
            bn_learning_rate = BASIC_LEARNING_RATE*BN_LEARNING_RATE_BASE
            bn_optimizer = tf.train.AdamOptimizer(
                learning_rate=bn_learning_rate)
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                bn_grads_and_vars = bn_optimizer.compute_gradients(
                    gen_loss, var_list=bn_var_list)
                bn_train_op = bn_optimizer.apply_gradients(bn_grads_and_vars)
            lip_learning_rate = BASIC_LEARNING_RATE*LIP_LEARNING_RATE_BASE
            lip_optimizer = tf.train.AdamOptimizer(
                learning_rate=lip_learning_rate)
            lip_grads_and_vars = lip_optimizer.compute_gradients(
                gen_loss, var_list=lip_var_list)
            lip_train_op = lip_optimizer.apply_gradients(lip_grads_and_vars)
            decoder_learning_rate = BASIC_LEARNING_RATE*DECODER_LEARNING_RATE_BASE
            decoder_optimizer = tf.train.AdamOptimizer(
                learning_rate=decoder_learning_rate)
            decoder_grads_and_vars = decoder_optimizer.compute_gradients(
                gen_loss, var_list=decoder_var_list)
            decoder_train_op = decoder_optimizer.apply_gradients(
                decoder_grads_and_vars)

    ema = tf.train.ExponentialMovingAverage(decay=0.99)
    update_losses = ema.apply([discrim_loss, gen_loss_gan, gen_loss_l1])
    global_step = tf.train.get_or_create_global_step()
    incr_global_step = tf.assign(global_step, global_step + 1)

    return Model(
        epoch_now=epoch_now,
        input_audio=audio_batch,
        input_groundtruth=groundtruth_batch,
        prediction=prediction,
        discriminator_real=discriminator_real,
        discriminator_pred=discriminator_pred,

        discrim_grads_and_vars=discrim_grads_and_vars,
        identity_grads_and_vars=identity_grads_and_vars,
        audio_grads_and_vars=audio_grads_and_vars,
        bn_grads_and_vars=bn_grads_and_vars,
        lip_grads_and_vars=lip_grads_and_vars,
        decoder_grads_and_vars=decoder_grads_and_vars,

        discrim_loss=ema.average(discrim_loss),
        gen_loss_gan=ema.average(gen_loss_gan),
        gen_loss_l1=ema.average(gen_loss_l1),

        train=tf.group(update_losses, incr_global_step, identity_train_op,
                       audio_train_op, bn_train_op, lip_train_op, decoder_train_op)

    )
예제 #4
0
def evaluation():
    identity5_0 = read_five_images(FIVE_IMAGES_PATH,
                                   args.idnum,
                                   face_detection=args.face_detection)
    if args.idnum == 5:
        identity = tf.placeholder(tf.float32, shape=[1, 112, 112, 15])
    elif args.idnum == 1:
        identity = tf.placeholder(tf.float32, shape=[1, 112, 112, 3])
    mfcc = tf.placeholder(tf.float32, shape=[1, 12, 35, 1])
    if args.matlab == 1:
        C = extract_mfcc_matlab(AUDIO_PATH)
    else:
        C = extract_mfcc(AUDIO_PATH)
    C_length = 0
    save_path_list = []
    for j in range(0, C.shape[1] - 34, 4):
        C_length = C_length + 1
    count = 0
    if args.mode == 'upsample':
        import speech2vid_inference_finetune
        speech2vid = speech2vid_inference_finetune.Speech2Vid(
            USE_AUDIO, USE_BN, USE_LIP, USE_DECODER, USE_FACEFC, USE_AUDIOFC,
            USE_FACE, USE_XAVIER, AVOID_LAYERS_LIST, args.idnum)
    elif args.mode == 'deconv':
        import speech2vid_inference_deconv
        speech2vid = speech2vid_inference_deconv.Speech2Vid(
            USE_AUDIO, USE_BN, USE_LIP, USE_DECODER, USE_FACEFC, USE_AUDIOFC,
            USE_FACE, USE_XAVIER, AVOID_LAYERS_LIST, args.idnum)

    audio_encoder_output = speech2vid.audio_encoder(mfcc)
    identity_encoder_output, x7_face, x4_face, x3_face = speech2vid.identity_encoder(
        identity, args.idnum)
    prediction = speech2vid.image_decoder(audio_encoder_output,
                                          identity_encoder_output, x7_face,
                                          x4_face, x3_face)

    with tf.Session(config=config) as sess:
        ckpt = tf.train.get_checkpoint_state(CHECKPOINT_DIR)
        if ckpt and ckpt.model_checkpoint_path:
            global_step = ckpt.model_checkpoint_path.split('/')[-1].split(
                '-')[-1]
            saver = tf.train.Saver()
            saver.restore(sess, ckpt.model_checkpoint_path)
            print('load checkpoint success, global_step is %s' % global_step)
        else:
            print 'load checkpoint fail'

        for j in range(0, C.shape[1] - 34, 4):
            start_time = time.time()
            mfcc_0 = C[1:, j:j + 35]
            mfcc_0 = np.reshape(mfcc_0, (1, 12, 35, 1))
            prediction_, identity_encoder_output_, audio_encoder_output_ = sess.run(
                [prediction, identity_encoder_output, audio_encoder_output],
                feed_dict={
                    identity: identity5_0,
                    mfcc: mfcc_0
                })
            '''display tensor value'''
            # # for t in ['relu7_audio/Relu']:
            # # for t in ['concat8/concat8', 'upsamp1_1/resampler/Resampler', 'upsamp1_2/resampler/Resampler', 'upsamp1_3/resampler/Resampler', 'upsamp2/resampler/Resampler']:
            # for t in ['upsamp1_2/ResizeNearestNeighbor']:
            #     # for t in ['concat8/concat8', 'conv8/BiasAdd', 'relu8/Relu', 'upsamp1_1/ResizeNearestNeighbor', 'conv1_1/BiasAdd','relu1_1/Relu','upsamp1_2/ResizeNearestNeighbor','conv1_2/BiasAdd','relu1_2/Relu','upsamp1_3/ResizeNearestNeighbor','conv1_3/BiasAdd','relu1_3/Relu','upsamp2/ResizeNearestNeighbor','conv2/BiasAdd','relu2/Relu','upsamp3_1/ResizeNearestNeighbor','conv3_1/BiasAdd','relu3_1/Relu','conv3_2/BiasAdd','relu3_2/Relu','upsamp4/ResizeNearestNeighbor','conv4/BiasAdd','relu4/Relu','upsamp5_1/ResizeNearestNeighbor','conv5_1/BiasAdd','relu5_1/Relu','conv5_2/BiasAdd','relu5_2/Relu','conv5_3/BiasAdd']:
            #     tensor = sess.graph.get_tensor_by_name(
            #         "{}:0".format(t))
            #     tensor_ = sess.run(tensor, feed_dict={
            #         identity: identity5_0, mfcc: mfcc_0})
            #     # print t, np.mean(tensor_)
            #     # print tensor_
            #     for i in range(tensor_.shape[3]):
            #         print t,i+1,np.mean(tensor_[:,:,:,i]),np.var(tensor_[:,:,:,i])
            '''save tensor value to txt'''
            # tensor_name_list = [
            #     tensor.name for tensor in tf.get_default_graph().as_graph_def().node]
            # for tensor_name in tensor_name_list:
            #     print tensor_name
            #     if 'Place' not in tensor_name and 'save' not in tensor_name and 'Pad' not in tensor_name and 'Conv2D' not in tensor_name:
            #         tensor = sess.graph.get_tensor_by_name(
            #             "{}:0".format(tensor_name))
            #         tensor_ = sess.run(tensor)
            #         print type(tensor_)
            #         txt_path = 'tmp/{}.npy'.format(
            #             tensor_name.replace('/', '-'))
            #         np.save(txt_path, tensor_)
            # break
            prediction_Image = np.squeeze(prediction_)
            prediction_Image = 255 * prediction_Image / np.max(
                prediction_Image)
            prediction_Image = Image.fromarray(
                prediction_Image.astype(np.uint8))
            '''put text'''
            text = str(count)
            font = ImageFont.truetype(
                '/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf', 30)
            draw = ImageDraw.Draw(prediction_Image)
            draw.text((10, 10), text=text, font=font)
            # prediction.show()
            if not os.path.exists(OUTPUT_DIR):
                os.makedirs(OUTPUT_DIR)
            count = count + 1
            save_path = ('%s/%03d.jpg' % (OUTPUT_DIR, count))
            save_path_list.append(save_path)
            prediction_Image.save(save_path)
            print 'Finish {}/{}, Audio output mean {}, Identity output mean {}, Prediction mean {}, Elapsed time {}'.format(
                count, C_length, np.mean(audio_encoder_output_),
                np.mean(identity_encoder_output_), np.mean(prediction_),
                time.time() - start_time)

    if args.mp4 == 1:
        save_video(OUTPUT_DIR, WAV_TO_GENERATE_MP4_PATH,
                   '{}-{}.mp4'.format(OUTPUT_DIR, global_step))
예제 #5
0
def train():
    '''use tfrecords'''
    face_batch, audio_batch, identity5_batch = read_and_decode(
        args.tfrecords, BATCH_SIZE, 5)
    face_batch = tf.cast(face_batch, dtype=tf.float32)
    audio_batch = tf.cast(audio_batch, dtype=tf.float32)
    identity5_batch = tf.cast(identity5_batch, dtype=tf.float32)

    # for i in xrange(NUM_GPUS):
    #     with tf.device('/gpu:%d' % i):

    speech2vid = speech2vid_inference_finetune.Speech2Vid()
    prediction = speech2vid.inference(audio_batch, identity5_batch, BATCH_SIZE)
    train_loss = speech2vid.compute_loss(prediction, face_batch)

    global_step = tf.Variable(0, trainable=False)
    train_op = tf.train.GradientDescentOptimizer(LEARNING_RATE_BASE).minimize(
        train_loss, global_step=global_step)

    summary_op = tf.summary.merge_all()
    # train_writer = tf.summary.FileWriter(LOG_SAVE_PATH, graph)
    saver = tf.train.Saver()

    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            start_time = time.time()
            for step in np.arange(TRAINING_STEPS):
                if coord.should_stop():
                    break

                _, training_loss, step = sess.run(
                    [train_op, train_loss, global_step])

                if step % 500 == 0:
                    end_time = time.time()
                    elapsed_time = end_time - start_time
                    print 'Step: {}, Train loss: {},  Elapsed time: {}'.format(
                        step, training_loss, elapsed_time)
                    summary_str = sess.run(summary_op)
                    # train_writer.add_summary(summary_str, step)
                    start_time = time.time()

                if step % 10000 == 0 or (step + 1) == TRAINING_STEPS:
                    saver.save(sess,
                               os.path.join(MODEL_SAVE_PATH, MODEL_NAME),
                               global_step=global_step)
                    print 'Model {} saved'.format(step)
        except KeyboardInterrupt:
            logging.info('Interrupted')
            coord.request_stop()
        except Exception as e:
            traceback.print_exc()
            coord.request_stop(e)
        finally:
            saver.save(sess,
                       os.path.join(MODEL_SAVE_PATH, MODEL_NAME),
                       global_step=global_step)
            print 'Model {} saved'.format(step)
            coord.request_stop()
            coord.join(threads)