def train(args): input_photo = tf.placeholder( tf.float32, [args.batch_size, args.patch_size, args.patch_size, 3]) input_superpixel = tf.placeholder( tf.float32, [args.batch_size, args.patch_size, args.patch_size, 3]) input_cartoon = tf.placeholder( tf.float32, [args.batch_size, args.patch_size, args.patch_size, 3]) # output=>fake picture output = network.unet_generator(input_photo) # output = guided_filter(input_photo, output, r=1) blur_fake = guided_filter(output, output, r=5, eps=2e-1) blur_cartoon = guided_filter(input_cartoon, input_cartoon, r=5, eps=2e-1) gray_fake, gray_cartoon = utils.color_shift(output, input_cartoon) d_loss_gray, g_loss_gray = loss.lsgan_loss(network.disc_sn, gray_cartoon, gray_fake, scale=1, patch=True, name='disc_gray') d_loss_blur, g_loss_blur = loss.lsgan_loss(network.disc_sn, blur_cartoon, blur_fake, scale=1, patch=True, name='disc_blur') vgg_model = loss.Vgg19('vgg19_no_fc.npy') vgg_photo = vgg_model.build_conv4_4(input_photo) vgg_output = vgg_model.build_conv4_4(output) vgg_superpixel = vgg_model.build_conv4_4(input_superpixel) h, w, c = vgg_photo.get_shape().as_list()[1:] photo_loss = tf.reduce_mean( tf.losses.absolute_difference(vgg_photo, vgg_output)) / (h * w * c) superpixel_loss = tf.reduce_mean(tf.losses.absolute_difference\ (vgg_superpixel, vgg_output))/(h*w*c) recon_loss = photo_loss + superpixel_loss tv_loss = loss.total_variation_loss(output) g_loss_total = 1e4 * tv_loss + 1e-1 * g_loss_blur + g_loss_gray + 2e2 * recon_loss d_loss_total = d_loss_blur + d_loss_gray all_vars = tf.trainable_variables() gene_vars = [var for var in all_vars if 'gene' in var.name] disc_vars = [var for var in all_vars if 'disc' in var.name] tf.summary.scalar('tv_loss', tv_loss) tf.summary.scalar('photo_loss', photo_loss) tf.summary.scalar('superpixel_loss', superpixel_loss) tf.summary.scalar('recon_loss', recon_loss) tf.summary.scalar('d_loss_gray', d_loss_gray) tf.summary.scalar('g_loss_gray', g_loss_gray) tf.summary.scalar('d_loss_blur', d_loss_blur) tf.summary.scalar('g_loss_blur', g_loss_blur) tf.summary.scalar('d_loss_total', d_loss_total) tf.summary.scalar('g_loss_total', g_loss_total) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): g_optim = tf.train.AdamOptimizer(args.adv_train_lr, beta1=0.5, beta2=0.99)\ .minimize(g_loss_total, var_list=gene_vars) d_optim = tf.train.AdamOptimizer(args.adv_train_lr, beta1=0.5, beta2=0.99)\ .minimize(d_loss_total, var_list=disc_vars) ''' config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) ''' gpu_options = tf.GPUOptions( per_process_gpu_memory_fraction=args.gpu_fraction) sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) train_writer = tf.summary.FileWriter(args.save_dir + '/train_log') summary_op = tf.summary.merge_all() saver = tf.train.Saver(var_list=gene_vars, max_to_keep=20) with tf.device('/device:GPU:0'): sess.run(tf.global_variables_initializer()) saver.restore(sess, tf.train.latest_checkpoint('pretrain/saved_models')) face_photo_dir = 'dataset/photo_face' face_photo_list = utils.load_image_list(face_photo_dir) scenery_photo_dir = 'dataset/photo_scenery' scenery_photo_list = utils.load_image_list(scenery_photo_dir) face_cartoon_dir = 'dataset/cartoon_face' face_cartoon_list = utils.load_image_list(face_cartoon_dir) scenery_cartoon_dir = 'dataset/cartoon_scenery' scenery_cartoon_list = utils.load_image_list(scenery_cartoon_dir) for total_iter in tqdm(range(args.total_iter)): if np.mod(total_iter, 5) == 0: photo_batch = utils.next_batch(face_photo_list, args.batch_size) cartoon_batch = utils.next_batch(face_cartoon_list, args.batch_size) else: photo_batch = utils.next_batch(scenery_photo_list, args.batch_size) cartoon_batch = utils.next_batch(scenery_cartoon_list, args.batch_size) inter_out = sess.run(output, feed_dict={ input_photo: photo_batch, input_superpixel: photo_batch, input_cartoon: cartoon_batch }) ''' adaptive coloring has to be applied with the clip_by_value in the last layer of generator network, which is not very stable. to stabiliy reproduce our results, please use power=1.0 and comment the clip_by_value function in the network.py first If this works, then try to use adaptive color with clip_by_value. ''' if args.use_enhance: superpixel_batch = utils.selective_adacolor(inter_out, power=1.2) else: superpixel_batch = utils.simple_superpixel(inter_out, seg_num=200) _, g_loss, r_loss = sess.run( [g_optim, g_loss_total, recon_loss], feed_dict={ input_photo: photo_batch, input_superpixel: superpixel_batch, input_cartoon: cartoon_batch }) _, d_loss, train_info = sess.run( [d_optim, d_loss_total, summary_op], feed_dict={ input_photo: photo_batch, input_superpixel: superpixel_batch, input_cartoon: cartoon_batch }) train_writer.add_summary(train_info, total_iter) if np.mod(total_iter + 1, 50) == 0: print('Iter: {}, d_loss: {}, g_loss: {}, recon_loss: {}'.\ format(total_iter, d_loss, g_loss, r_loss)) if np.mod(total_iter + 1, 500) == 0: saver.save(sess, args.save_dir + '/saved_models/model', write_meta_graph=False, global_step=total_iter) photo_face = utils.next_batch(face_photo_list, args.batch_size) cartoon_face = utils.next_batch(face_cartoon_list, args.batch_size) photo_scenery = utils.next_batch(scenery_photo_list, args.batch_size) cartoon_scenery = utils.next_batch(scenery_cartoon_list, args.batch_size) result_face = sess.run(output, feed_dict={ input_photo: photo_face, input_superpixel: photo_face, input_cartoon: cartoon_face }) result_scenery = sess.run(output, feed_dict={ input_photo: photo_scenery, input_superpixel: photo_scenery, input_cartoon: cartoon_scenery }) utils.write_batch_image( result_face, args.save_dir + '/images', str(total_iter) + '_face_result.jpg', 4) utils.write_batch_image( photo_face, args.save_dir + '/images', str(total_iter) + '_face_photo.jpg', 4) utils.write_batch_image( result_scenery, args.save_dir + '/images', str(total_iter) + '_scenery_result.jpg', 4) utils.write_batch_image( photo_scenery, args.save_dir + '/images', str(total_iter) + '_scenery_photo.jpg', 4)
def train(args): input_photo = tf.placeholder( tf.float32, [args.batch_size, args.patch_size, args.patch_size, 3]) output = network.unet_generator(input_photo) recon_loss = tf.reduce_mean( tf.losses.absolute_difference(input_photo, output)) tf.summary.scalar('recon_loss', recon_loss) summary_op = tf.summary.merge_all() all_vars = tf.trainable_variables() gene_vars = [var for var in all_vars if 'gene' in var.name] update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optim = tf.train.AdamOptimizer(args.adv_train_lr, beta1=0.5, beta2=0.99)\ .minimize(recon_loss, var_list=gene_vars) ''' config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) ''' gpu_options = tf.GPUOptions( per_process_gpu_memory_fraction=args.gpu_fraction) sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) saver = tf.train.Saver(var_list=gene_vars, max_to_keep=20) summary_writer = tf.summary.FileWriter(args.save_dir + 'save_models/model', tf.get_default_graph()) with tf.device('/device:GPU:0'): sess.run(tf.global_variables_initializer()) face_photo_dir = 'dataset/photo_face' face_photo_list = utils.load_image_list(face_photo_dir) scenery_photo_dir = 'dataset/photo_scenery' scenery_photo_list = utils.load_image_list(scenery_photo_dir) for total_iter in tqdm(range(args.total_iter)): if np.mod(total_iter, 5) == 0: photo_batch = utils.next_batch(face_photo_list, args.batch_size) else: photo_batch = utils.next_batch(scenery_photo_list, args.batch_size) _, r_loss, summary_str = sess.run( [optim, recon_loss, summary_op], feed_dict={input_photo: photo_batch}) summary_writer.add_summary(summary_str, global_step=total_iter) if np.mod(total_iter + 1, 50) == 0: print('pretrain, iter: {}, recon_loss: {}'.format( total_iter, r_loss)) if np.mod(total_iter + 1, 500) == 0: saver.save(sess, args.save_dir + 'save_models/model', write_meta_graph=False, global_step=total_iter) photo_face = utils.next_batch(face_photo_list, args.batch_size) photo_scenery = utils.next_batch(scenery_photo_list, args.batch_size) result_face = sess.run(output, feed_dict={input_photo: photo_face}) result_scenery = sess.run( output, feed_dict={input_photo: photo_scenery}) utils.write_batch_image( result_face, args.save_dir + '/images', str(total_iter) + '_face_result.jpg', 4) utils.write_batch_image( photo_face, args.save_dir + '/images', str(total_iter) + '_face_photo.jpg', 4) utils.write_batch_image( result_scenery, args.save_dir + '/images', str(total_iter) + '_scenery_result.jpg', 4) utils.write_batch_image( photo_scenery, args.save_dir + '/images', str(total_iter) + '_scenery_photo.jpg', 4)