Beispiel #1
0
def my_train():
    with tf.Graph().as_default():
        sess = tf.Session(config=config)
        model = FaceAging(sess=sess,
                          lr=FLAGS.learning_rate,
                          keep_prob=1.,
                          model_num=FLAGS.model_index,
                          batch_size=FLAGS.batch_size,
                          age_loss_weight=FLAGS.age_loss_weight,
                          gan_loss_weight=FLAGS.gan_loss_weight,
                          fea_loss_weight=FLAGS.fea_loss_weight,
                          tv_loss_weight=FLAGS.tv_loss_weight)

        imgs = tf.placeholder(
            tf.float32,
            [FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 3])
        true_label_features_128 = tf.placeholder(
            tf.float32, [FLAGS.batch_size, 128, 128, FLAGS.age_groups])
        true_label_features_64 = tf.placeholder(
            tf.float32, [FLAGS.batch_size, 64, 64, FLAGS.age_groups])
        false_label_features_64 = tf.placeholder(
            tf.float32, [FLAGS.batch_size, 64, 64, FLAGS.age_groups])
        age_label = tf.placeholder(tf.int32, [FLAGS.batch_size])

        source_img_227, source_img_128, face_label = load_source_batch3(
            FLAGS.source_file, FLAGS.root_folder, FLAGS.batch_size)

        model.train_age_lsgan_transfer(source_img_227, source_img_128, imgs,
                                       true_label_features_128,
                                       true_label_features_64,
                                       false_label_features_64,
                                       FLAGS.fea_layer_name, age_label)

        ge_samples = model.generate_images(imgs,
                                           true_label_features_128,
                                           reuse=True,
                                           mode='train')

        # Create a saver.
        model.saver = tf.train.Saver(model.save_d_vars + model.save_g_vars,
                                     max_to_keep=200)
        model.alexnet_saver = tf.train.Saver(model.alexnet_vars)
        model.age_saver = tf.train.Saver(model.age_vars)

        d_error = model.d_loss / model.gan_loss_weight
        g_error = model.g_loss / model.gan_loss_weight
        fea_error = model.fea_loss / model.fea_loss_weight
        age_error = model.age_loss / model.age_loss_weight

        # Start running operations on the Graph.
        sess.run(tf.global_variables_initializer())
        tf.train.start_queue_runners(sess)

        model.alexnet_saver.restore(sess, FLAGS.alexnet_pretrained_model)
        model.age_saver.restore(sess, FLAGS.age_pretrained_model)

        if model.load(FLAGS.checkpoint_dir, model.saver):
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        print("{} Start training...")

        # Loop over max_steps
        for step in range(FLAGS.max_steps):
            images, t_label_features_128, t_label_features_64, f_label_features_64, age_labels = \
                train_generator.next_target_batch_transfer2()
            dict = {
                imgs: images,
                true_label_features_128: t_label_features_128,
                true_label_features_64: t_label_features_64,
                false_label_features_64: f_label_features_64,
                age_label: age_labels
            }
            for i in range(d_iter):
                _, d_loss = sess.run([model.d_optim, d_error], feed_dict=dict)

            for i in range(g_iter):
                _, g_loss, fea_loss, age_loss = sess.run(
                    [model.g_optim, g_error, fea_error, age_error],
                    feed_dict=dict)

            format_str = (
                '%s: step %d, d_loss = %.3f, g_loss = %.3f, fea_loss=%.3f, age_loss=%.3f'
            )
            print(format_str %
                  (datetime.now(), step, d_loss, g_loss, fea_loss, age_loss))

            # Save the model checkpoint periodically.
            if step % SAVE_INTERVAL == SAVE_INTERVAL - 1 or (
                    step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.checkpoint_dir)
                model.save(checkpoint_path, step, 'acgan')

            if step % VAL_INTERVAL == VAL_INTERVAL - 1:
                if not os.path.exists(FLAGS.sample_dir):
                    os.makedirs(FLAGS.sample_dir)
                path = os.path.join(FLAGS.sample_dir, str(step))
                if not os.path.exists(path):
                    os.makedirs(path)

                source = sess.run(source_img_128)
                save_source(source, [4, 8], os.path.join(path, 'source.jpg'))
                for j in range(train_generator.n_classes):
                    true_label_fea = train_generator.label_features_128[j]
                    dict = {
                        imgs: source,
                        true_label_features_128: true_label_fea
                    }
                    samples = sess.run(ge_samples, feed_dict=dict)
                    save_images(samples, [4, 8],
                                './{}/test_{:01d}.jpg'.format(path, j))
Beispiel #2
0
def my_train():
    with tf.Graph().as_default():
        #기본 작업 : 플레이스 홀더(변수) 선언 및 하이퍼 파라미터 전달 !
        sess = tf.Session(config=config)
        model = FaceAging(sess=sess,
                          lr=FLAGS.learning_rate,
                          keep_prob=1.,
                          model_num=FLAGS.model_index,
                          batch_size=FLAGS.batch_size,
                          age_loss_weight=FLAGS.age_loss_weight,
                          gan_loss_weight=FLAGS.gan_loss_weight,
                          fea_loss_weight=FLAGS.fea_loss_weight,
                          tv_loss_weight=FLAGS.tv_loss_weight)

        imgs = tf.placeholder(
            tf.float32,
            [FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 3
             ])  # (_,128,128,3)
        true_label_features_128 = tf.placeholder(
            tf.float32, [FLAGS.batch_size, 128, 128, FLAGS.age_groups])
        true_label_features_64 = tf.placeholder(
            tf.float32, [FLAGS.batch_size, 64, 64, FLAGS.age_groups])
        false_label_features_64 = tf.placeholder(
            tf.float32, [FLAGS.batch_size, 64, 64, FLAGS.age_groups])
        age_label = tf.placeholder(tf.int32, [FLAGS.batch_size])

        # _배치 1회!  저장 source_input.py ---------------------------------------------------------------------
        source_img_227, source_img_128, face_label = load_source_batch3(
            FLAGS.source_file, FLAGS.root_folder, FLAGS.batch_size)
        # --------------------------------------------------------------------------------------------------
        #        with tf.Session():
        #           print(face_label.eval())

        # source_img_227, source_img_128,및 placeholder 전달
        model.train_age_lsgan_transfer(source_img_227, source_img_128, imgs,
                                       true_label_features_128,
                                       true_label_features_64,
                                       false_label_features_64,
                                       FLAGS.fea_layer_name, age_label)
        # age_label 이아니라 face_label 사용해야 하는 거 아닌가 ?????????????

        # placeholder 전달 ,ge_samples는 중간 과정 체크용 사진 생성
        ge_samples = model.generate_images(imgs,
                                           true_label_features_128,
                                           reuse=True,
                                           mode='train')

        # Create a saver.
        model.saver = tf.train.Saver(model.save_d_vars + model.save_g_vars,
                                     max_to_keep=200)
        model.alexnet_saver = tf.train.Saver(model.alexnet_vars)
        model.age_saver = tf.train.Saver(model.age_vars)

        d_error = model.d_loss / model.gan_loss_weight
        g_error = model.g_loss / model.gan_loss_weight
        fea_error = model.fea_loss / model.fea_loss_weight
        age_error = model.age_loss / model.age_loss_weight

        # Start running operations on the Graph.
        sess.run(tf.global_variables_initializer())
        tf.train.start_queue_runners(sess)

        model.alexnet_saver.restore(sess, FLAGS.alexnet_pretrained_model)
        model.age_saver.restore(sess, FLAGS.age_pretrained_model)

        if model.load(FLAGS.checkpoint_dir, model.saver):
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        print("{} Start training...")

        # tensorboard --logdir=./logs
        writer = tf.summary.FileWriter("./logs", sess.graph)
        '''
        img : variable that contains img
        
        '''

        # Loop over max_steps
        for step in range(FLAGS.max_steps):  #200000
            #-------------------------------------------------------------------------------------------------------
            #data_generator.py 배치 200000번 저장!: images, t_label_features_128, t_label_features_64, f_label_features_64, age_labels:실제 데이터 배치 학습반복
            images, t_label_features_128, t_label_features_64, f_label_features_64, age_labels = \
                train_generator.next_target_batch_transfer2()
            dict = {
                imgs: images,
                true_label_features_128:
                t_label_features_128,  #condition for resnet generator
                true_label_features_64:
                t_label_features_64,  #condition for discriminator
                false_label_features_64:
                f_label_features_64,  #condition for discriminator
                age_label: age_labels
            }

            print(len(images))
            print("train discriminator------------------------")
            #          import ipdb
            #         ipdb.set_trace()
            for i in range(d_iter):  # 1
                _, d_loss = sess.run(
                    [model.d_optim, d_error],
                    feed_dict=dict)  # dict가 위 placeholder 에 모두 feed?
            print("train generator--------------------------")
            for i in range(g_iter):  # 1
                _, g_loss, fea_loss, age_loss = sess.run(
                    [model.g_optim, g_error, fea_error, age_error],
                    feed_dict=dict)

            format_str = (
                '%s: step %d, d_loss = %.3f, g_loss = %.3f, fea_loss=%.3f, age_loss=%.3f'
            )
            print(format_str %
                  (datetime.now(), step, d_loss, g_loss, fea_loss, age_loss))

            #---------------------------------------------------------------------------------------------------

            # Save the model checkpoint periodically
            if step % SAVE_INTERVAL == SAVE_INTERVAL - 1 or (
                    step + 1) == FLAGS.max_steps:
                checkpoint_path = os.path.join(FLAGS.checkpoint_dir)
                model.save(checkpoint_path, step, 'acgan')

            if step % VAL_INTERVAL == VAL_INTERVAL - 1:
                if not os.path.exists(FLAGS.sample_dir):
                    os.makedirs(FLAGS.sample_dir)
                path = os.path.join(FLAGS.sample_dir, str(step))
                if not os.path.exists(path):
                    os.makedirs(path)

                source = sess.run(source_img_128)
                save_source(source, [4, 8], os.path.join(path, 'source.jpg'))
                for j in range(train_generator.n_classes):
                    true_label_fea = train_generator.label_features_128[j]
                    dict = {
                        imgs: source,
                        true_label_features_128: true_label_fea
                    }
                    samples = sess.run(ge_samples, feed_dict=dict)
                    save_images(samples, [4, 8],
                                './{}/test_{:01d}.jpg'.format(path, j))