def my_train(): with tf.Graph().as_default(): sess = tf.Session(config=config) model = FaceAging(sess=sess, lr=FLAGS.learning_rate, keep_prob=1., model_num=FLAGS.model_index, batch_size=FLAGS.batch_size, age_loss_weight=FLAGS.age_loss_weight, gan_loss_weight=FLAGS.gan_loss_weight, fea_loss_weight=FLAGS.fea_loss_weight, tv_loss_weight=FLAGS.tv_loss_weight) imgs = tf.placeholder( tf.float32, [FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 3]) true_label_features_128 = tf.placeholder( tf.float32, [FLAGS.batch_size, 128, 128, FLAGS.age_groups]) true_label_features_64 = tf.placeholder( tf.float32, [FLAGS.batch_size, 64, 64, FLAGS.age_groups]) false_label_features_64 = tf.placeholder( tf.float32, [FLAGS.batch_size, 64, 64, FLAGS.age_groups]) age_label = tf.placeholder(tf.int32, [FLAGS.batch_size]) source_img_227, source_img_128, face_label = load_source_batch3( FLAGS.source_file, FLAGS.root_folder, FLAGS.batch_size) model.train_age_lsgan_transfer(source_img_227, source_img_128, imgs, true_label_features_128, true_label_features_64, false_label_features_64, FLAGS.fea_layer_name, age_label) ge_samples = model.generate_images(imgs, true_label_features_128, reuse=True, mode='train') # Create a saver. model.saver = tf.train.Saver(model.save_d_vars + model.save_g_vars, max_to_keep=200) model.alexnet_saver = tf.train.Saver(model.alexnet_vars) model.age_saver = tf.train.Saver(model.age_vars) d_error = model.d_loss / model.gan_loss_weight g_error = model.g_loss / model.gan_loss_weight fea_error = model.fea_loss / model.fea_loss_weight age_error = model.age_loss / model.age_loss_weight # Start running operations on the Graph. sess.run(tf.global_variables_initializer()) tf.train.start_queue_runners(sess) model.alexnet_saver.restore(sess, FLAGS.alexnet_pretrained_model) model.age_saver.restore(sess, FLAGS.age_pretrained_model) if model.load(FLAGS.checkpoint_dir, model.saver): print(" [*] Load SUCCESS") else: print(" [!] Load failed...") print("{} Start training...") # Loop over max_steps for step in range(FLAGS.max_steps): images, t_label_features_128, t_label_features_64, f_label_features_64, age_labels = \ train_generator.next_target_batch_transfer2() dict = { imgs: images, true_label_features_128: t_label_features_128, true_label_features_64: t_label_features_64, false_label_features_64: f_label_features_64, age_label: age_labels } for i in range(d_iter): _, d_loss = sess.run([model.d_optim, d_error], feed_dict=dict) for i in range(g_iter): _, g_loss, fea_loss, age_loss = sess.run( [model.g_optim, g_error, fea_error, age_error], feed_dict=dict) format_str = ( '%s: step %d, d_loss = %.3f, g_loss = %.3f, fea_loss=%.3f, age_loss=%.3f' ) print(format_str % (datetime.now(), step, d_loss, g_loss, fea_loss, age_loss)) # Save the model checkpoint periodically. if step % SAVE_INTERVAL == SAVE_INTERVAL - 1 or ( step + 1) == FLAGS.max_steps: checkpoint_path = os.path.join(FLAGS.checkpoint_dir) model.save(checkpoint_path, step, 'acgan') if step % VAL_INTERVAL == VAL_INTERVAL - 1: if not os.path.exists(FLAGS.sample_dir): os.makedirs(FLAGS.sample_dir) path = os.path.join(FLAGS.sample_dir, str(step)) if not os.path.exists(path): os.makedirs(path) source = sess.run(source_img_128) save_source(source, [4, 8], os.path.join(path, 'source.jpg')) for j in range(train_generator.n_classes): true_label_fea = train_generator.label_features_128[j] dict = { imgs: source, true_label_features_128: true_label_fea } samples = sess.run(ge_samples, feed_dict=dict) save_images(samples, [4, 8], './{}/test_{:01d}.jpg'.format(path, j))
def my_train(): with tf.Graph().as_default(): #기본 작업 : 플레이스 홀더(변수) 선언 및 하이퍼 파라미터 전달 ! sess = tf.Session(config=config) model = FaceAging(sess=sess, lr=FLAGS.learning_rate, keep_prob=1., model_num=FLAGS.model_index, batch_size=FLAGS.batch_size, age_loss_weight=FLAGS.age_loss_weight, gan_loss_weight=FLAGS.gan_loss_weight, fea_loss_weight=FLAGS.fea_loss_weight, tv_loss_weight=FLAGS.tv_loss_weight) imgs = tf.placeholder( tf.float32, [FLAGS.batch_size, FLAGS.image_size, FLAGS.image_size, 3 ]) # (_,128,128,3) true_label_features_128 = tf.placeholder( tf.float32, [FLAGS.batch_size, 128, 128, FLAGS.age_groups]) true_label_features_64 = tf.placeholder( tf.float32, [FLAGS.batch_size, 64, 64, FLAGS.age_groups]) false_label_features_64 = tf.placeholder( tf.float32, [FLAGS.batch_size, 64, 64, FLAGS.age_groups]) age_label = tf.placeholder(tf.int32, [FLAGS.batch_size]) # _배치 1회! 저장 source_input.py --------------------------------------------------------------------- source_img_227, source_img_128, face_label = load_source_batch3( FLAGS.source_file, FLAGS.root_folder, FLAGS.batch_size) # -------------------------------------------------------------------------------------------------- # with tf.Session(): # print(face_label.eval()) # source_img_227, source_img_128,및 placeholder 전달 model.train_age_lsgan_transfer(source_img_227, source_img_128, imgs, true_label_features_128, true_label_features_64, false_label_features_64, FLAGS.fea_layer_name, age_label) # age_label 이아니라 face_label 사용해야 하는 거 아닌가 ????????????? # placeholder 전달 ,ge_samples는 중간 과정 체크용 사진 생성 ge_samples = model.generate_images(imgs, true_label_features_128, reuse=True, mode='train') # Create a saver. model.saver = tf.train.Saver(model.save_d_vars + model.save_g_vars, max_to_keep=200) model.alexnet_saver = tf.train.Saver(model.alexnet_vars) model.age_saver = tf.train.Saver(model.age_vars) d_error = model.d_loss / model.gan_loss_weight g_error = model.g_loss / model.gan_loss_weight fea_error = model.fea_loss / model.fea_loss_weight age_error = model.age_loss / model.age_loss_weight # Start running operations on the Graph. sess.run(tf.global_variables_initializer()) tf.train.start_queue_runners(sess) model.alexnet_saver.restore(sess, FLAGS.alexnet_pretrained_model) model.age_saver.restore(sess, FLAGS.age_pretrained_model) if model.load(FLAGS.checkpoint_dir, model.saver): print(" [*] Load SUCCESS") else: print(" [!] Load failed...") print("{} Start training...") # tensorboard --logdir=./logs writer = tf.summary.FileWriter("./logs", sess.graph) ''' img : variable that contains img ''' # Loop over max_steps for step in range(FLAGS.max_steps): #200000 #------------------------------------------------------------------------------------------------------- #data_generator.py 배치 200000번 저장!: images, t_label_features_128, t_label_features_64, f_label_features_64, age_labels:실제 데이터 배치 학습반복 images, t_label_features_128, t_label_features_64, f_label_features_64, age_labels = \ train_generator.next_target_batch_transfer2() dict = { imgs: images, true_label_features_128: t_label_features_128, #condition for resnet generator true_label_features_64: t_label_features_64, #condition for discriminator false_label_features_64: f_label_features_64, #condition for discriminator age_label: age_labels } print(len(images)) print("train discriminator------------------------") # import ipdb # ipdb.set_trace() for i in range(d_iter): # 1 _, d_loss = sess.run( [model.d_optim, d_error], feed_dict=dict) # dict가 위 placeholder 에 모두 feed? print("train generator--------------------------") for i in range(g_iter): # 1 _, g_loss, fea_loss, age_loss = sess.run( [model.g_optim, g_error, fea_error, age_error], feed_dict=dict) format_str = ( '%s: step %d, d_loss = %.3f, g_loss = %.3f, fea_loss=%.3f, age_loss=%.3f' ) print(format_str % (datetime.now(), step, d_loss, g_loss, fea_loss, age_loss)) #--------------------------------------------------------------------------------------------------- # Save the model checkpoint periodically if step % SAVE_INTERVAL == SAVE_INTERVAL - 1 or ( step + 1) == FLAGS.max_steps: checkpoint_path = os.path.join(FLAGS.checkpoint_dir) model.save(checkpoint_path, step, 'acgan') if step % VAL_INTERVAL == VAL_INTERVAL - 1: if not os.path.exists(FLAGS.sample_dir): os.makedirs(FLAGS.sample_dir) path = os.path.join(FLAGS.sample_dir, str(step)) if not os.path.exists(path): os.makedirs(path) source = sess.run(source_img_128) save_source(source, [4, 8], os.path.join(path, 'source.jpg')) for j in range(train_generator.n_classes): true_label_fea = train_generator.label_features_128[j] dict = { imgs: source, true_label_features_128: true_label_fea } samples = sess.run(ge_samples, feed_dict=dict) save_images(samples, [4, 8], './{}/test_{:01d}.jpg'.format(path, j))