def main(args=None): print(args) tf.reset_default_graph() """ Read dataset parser """ flags.network_name = args[0].split('/')[-1].split('.')[0].split( 'main_')[-1] flags.logs_dir = './logs_' + flags.network_name dataset_parser = GANParser(flags=flags) """ Transform data to TFRecord format (Only do once.) """ if False: dataset_parser.load_paths(is_jpg=True, load_val=True) dataset_parser.data2record(name='{}_train.tfrecords'.format( dataset_parser.dataset_name), set_type='train', test_num=None) dataset_parser.data2record(name='{}_val.tfrecords'.format( dataset_parser.dataset_name), set_type='val', test_num=None) # coco_parser.data2record_test(name='coco_stuff2017_test-dev_all_label.tfrecords', is_dev=True, test_num=None) # coco_parser.data2record_test(name='coco_stuff2017_test_all_label.tfrecords', is_dev=False, test_num=None) return """ Build Graph """ with tf.Graph().as_default(): """ Input (TFRecord) """ with tf.name_scope('TFRecord'): # DatasetA training_a_dataset = dataset_parser.tfrecord_get_dataset( name='{}_trainA.tfrecords'.format(dataset_parser.dataset_name), batch_size=flags.batch_size, shuffle_size=None) val_a_dataset = dataset_parser.tfrecord_get_dataset( name='{}_valA.tfrecords'.format(dataset_parser.dataset_name), batch_size=flags.batch_size, need_flip=(flags.mode == 'train')) # DatasetB training_b_dataset = dataset_parser.tfrecord_get_dataset( name='{}_trainB.tfrecords'.format(dataset_parser.dataset_name), batch_size=flags.batch_size, shuffle_size=None) val_b_dataset = dataset_parser.tfrecord_get_dataset( name='{}_valB.tfrecords'.format(dataset_parser.dataset_name), batch_size=flags.batch_size, is_label=True, need_flip=(flags.mode == 'train')) # A feed-able iterator with tf.name_scope('RealA'): handle_a = tf.placeholder(tf.string, shape=[]) iterator_a = tf.contrib.data.Iterator.from_string_handle( handle_a, training_a_dataset.output_types, training_a_dataset.output_shapes) real_a, real_a_name, real_a_shape = iterator_a.get_next() with tf.name_scope('RealB'): handle_b = tf.placeholder(tf.string, shape=[]) iterator_b = tf.contrib.data.Iterator.from_string_handle( handle_b, training_b_dataset.output_types, training_b_dataset.output_shapes) real_b, real_b_name, real_b_shape = iterator_b.get_next() with tf.name_scope('InitialA_op'): training_a_iterator = training_a_dataset.make_initializable_iterator( ) validation_a_iterator = val_a_dataset.make_initializable_iterator( ) with tf.name_scope('InitialB_op'): training_b_iterator = training_b_dataset.make_initializable_iterator( ) validation_b_iterator = val_b_dataset.make_initializable_iterator( ) """ Network (Computes predictions from the inference model) """ with tf.name_scope('Network'): # Input global_step = tf.Variable(0, trainable=False, name='global_step', dtype=tf.int32) global_step_update_op = tf.assign_add(global_step, 1, name='global_step_update_op') mean_rgb = tf.constant((123.68, 116.78, 103.94), dtype=tf.float32) fake_b_pool = tf.placeholder(tf.float32, shape=[ None, flags.image_height, flags.image_width, flags.c_in_dim ], name='fake_B_pool') image_linear_shape = tf.constant( flags.image_height * flags.image_width * flags.c_in_dim, dtype=tf.int32, name='image_linear_shape') # A -> B ''' with tf.name_scope('Generator'): with slim.arg_scope(vgg.vgg_arg_scope()): net, end_points = vgg.vgg_16(real_a - mean_rgb, num_classes=1, is_training=True, spatial_squeeze=False) print(net) return with tf.variable_scope('Generator_A2B'): pred = tf.layers.conv2d(tf.nn.relu(net), 1, 1, 1) pred_upscale = tf.image.resize_bilinear(pred, (flags.image_height, flags.image_width), name='up_scale') segment_a = tf.nn.sigmoid(pred_upscale, name='segment_a') # sigmoid cross entropy Loss with tf.name_scope('loss_gen_a2b'): loss_gen_a2b = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( logits=pred_upscale, labels=real_b/255.0, name='sigmoid'), name='mean') ''' # A -> B with tf.name_scope('Generator'): with slim.arg_scope(resnet_v1.resnet_arg_scope()): net, end_points = resnet_v1.resnet_v1_50(real_a - mean_rgb, num_classes=None, is_training=True, global_pool=False, output_stride=8) with tf.variable_scope('Generator_A2B'): d1 = deconv2d(net, 256, 3, 2, name='g_d1_dc') d1 = tf.nn.relu(instance_normalization(d1, 'g_d1_bn')) d2 = deconv2d(d1, 128, 3, 2, name='g_d2_dc') d2 = tf.nn.relu(instance_normalization(d2, 'g_d2_bn')) d3 = deconv2d(d2, 64, 3, 2, name='g_d3_dc') d3 = tf.nn.relu(instance_normalization(d3, 'g_d3_bn')) d3 = tf.pad(d3, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT") logits_a = conv2d(d3, 1, 7, 1, padding='VALID', name='g_pred_c') # A -> B adjusted_a = high_light(real_a, name='high_light') # adjusted_a = tf.zeros_like(real_a, tf.float32, name='mask', optimize=True) # logits_a = generator_resnet(real_a, flags, False, name="Generator_A2B") # adjusted_a = tf.layers.average_pooling2d(real_a, 11, strides=1, padding='same', name='adjusted_a') segment_a = tf.nn.tanh(logits_a, name='segment_a') logits_a_ori = tf.image.resize_bilinear( logits_a, (real_a_shape[0][0], real_b_shape[0][1]), name='logits_a_ori') segment_a_ori = tf.nn.tanh(logits_a_ori, name='segment_a_ori') with tf.variable_scope('Fake_B'): foreground = tf.multiply(real_a, segment_a, name='foreground') background = tf.multiply(adjusted_a, (1 - segment_a), name='background') fake_b_logits = tf.add(foreground, background, name='fake_b_logits') fake_b = tf.clip_by_value(fake_b_logits, 0, 255, name='fake_b') # fake_b_f = tf.reshape(fake_b, [-1, image_linear_shape], name='fake_b_f') fake_b_pool_f = tf.reshape(fake_b_pool, [-1, image_linear_shape], name='fake_b_pool_f') real_b_f = tf.reshape(real_b, [-1, image_linear_shape], name='real_b_f') dis_fake_b = discriminator_se_wgangp(fake_b_f, flags, reuse=False, name="Discriminator_B") dis_fake_b_pool = discriminator_se_wgangp(fake_b_pool_f, flags, reuse=True, name="Discriminator_B") dis_real_b = discriminator_se_wgangp(real_b_f, flags, reuse=True, name="Discriminator_B") # WGAN Loss with tf.name_scope('loss_gen_a2b'): loss_gen_a2b = -tf.reduce_mean(dis_fake_b) with tf.name_scope('loss_dis_b'): loss_dis_b_adv_real = -tf.reduce_mean(dis_real_b) loss_dis_b_adv_fake = tf.reduce_mean(dis_fake_b_pool) loss_dis_b = tf.reduce_mean(dis_fake_b_pool) - tf.reduce_mean( dis_real_b) with tf.name_scope('wgan-gp'): alpha = tf.random_uniform(shape=[flags.batch_size, 1], minval=0., maxval=1.) differences = fake_b_pool_f - real_b_f interpolates = real_b_f + (alpha * differences) gradients = tf.gradients( discriminator_se_wgangp(interpolates, flags, reuse=True, name="Discriminator_B"), [interpolates])[0] slopes = tf.sqrt( tf.reduce_sum(tf.square(gradients), reduction_indices=[1])) gradient_penalty = tf.reduce_mean((slopes - 1.)**2) loss_dis_b += flags.lambda_gp * gradient_penalty # Optimizer trainable_var_resnet = tf.get_collection( key=tf.GraphKeys.TRAINABLE_VARIABLES, scope='resnet_v1_50') trainable_var_gen_a2b = tf.get_collection( key=tf.GraphKeys.TRAINABLE_VARIABLES, scope='Generator_A2B') + trainable_var_resnet # slim.model_analyzer.analyze_vars(trainable_var_gen_a2b, print_info=True) # trainable_var_gen_a2b = tf.get_collection( # key=tf.GraphKeys.TRAINABLE_VARIABLES, scope='Generator_A2B') trainable_var_dis_b = tf.get_collection( key=tf.GraphKeys.TRAINABLE_VARIABLES, scope='Discriminator_B') with tf.name_scope('learning_rate_decay'): decay = tf.maximum( 0., 1. - (tf.cast(global_step, tf.float32) / flags.training_iter), name='decay') learning_rate = tf.multiply(flags.learning_rate, decay, name='learning_rate') train_op_gen_a2b = train_op(loss_gen_a2b, learning_rate, flags, trainable_var_gen_a2b, name='gen_a2b') train_op_dis_b = train_op(loss_dis_b, learning_rate, flags, trainable_var_dis_b, name='dis_b') saver = tf.train.Saver(max_to_keep=2) # Graph Logs with tf.name_scope('GEN_a2b'): tf.summary.scalar("loss/gen_a2b/all", loss_gen_a2b) with tf.name_scope('DIS_b'): tf.summary.scalar("loss/dis_b/all", loss_dis_b) tf.summary.scalar("loss/dis_b/adv_real", loss_dis_b_adv_real) tf.summary.scalar("loss/dis_b/adv_fake", loss_dis_b_adv_fake) summary_op = tf.summary.merge_all() """ Session """ tfconfig = tf.ConfigProto(allow_soft_placement=True) tfconfig.gpu_options.allow_growth = True with tf.Session(config=tfconfig) as sess: with tf.name_scope('Initial'): ckpt = tf.train.get_checkpoint_state( dataset_parser.checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: print("Model restored: {}".format( ckpt.model_checkpoint_path)) saver.restore(sess, ckpt.model_checkpoint_path) else: print("No Model found.") init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess.run(init_op) init_fn = slim.assign_from_checkpoint_fn( './pretrained/resnet_v1_50.ckpt', slim.get_model_variables('resnet_v1_50')) init_fn(sess) summary_writer = tf.summary.FileWriter(dataset_parser.logs_dir, sess.graph) """ Training Mode """ if flags.mode == 'train': print('Training mode! Batch size:{:d}'.format( flags.batch_size)) with tf.variable_scope('Input_port'): training_a_handle = sess.run( training_a_iterator.string_handle()) training_b_handle = sess.run( training_b_iterator.string_handle()) # val_a_handle = sess.run(validation_a_iterator.string_handle()) # val_b_handle = sess.run(validation_b_iterator.string_handle()) image_pool_a, image_pool_b = ImagePool( flags.pool_size), ImagePool(flags.pool_size) print('Start Training!') start_time = time.time() sess.run([ training_a_iterator.initializer, training_b_iterator.initializer ]) feed_dict_train = { handle_a: training_a_handle, handle_b: training_b_handle } # feed_dict_valid = {is_training: False} global_step_sess = sess.run(global_step) while global_step_sess < flags.training_iter: try: # Update gen_A2B, gen_B2A _, fake_b_sess = sess.run([train_op_gen_a2b, fake_b], feed_dict=feed_dict_train) # _, loss_gen_a2b_sess = sess.run([train_op_gen_a2b, loss_gen_a2b], feed_dict=feed_dict_train) # Update dis_B, dis_A fake_b_pool_query = image_pool_b.query(fake_b_sess) _ = sess.run(train_op_dis_b, feed_dict={ fake_b_pool: fake_b_pool_query, handle_b: training_b_handle }) sess.run(global_step_update_op) global_step_sess, learning_rate_sess = sess.run( [global_step, learning_rate]) print( 'global step:[{:d}/{:d}], learning rate:{:f}, time:{:4.4f}' .format(global_step_sess, flags.training_iter, learning_rate_sess, time.time() - start_time)) # Logging the events if global_step_sess % flags.log_freq == 1: print('Logging the events') summary_op_sess = sess.run(summary_op, feed_dict={ handle_a: training_a_handle, handle_b: training_b_handle, fake_b_pool: fake_b_pool_query }) summary_writer.add_summary(summary_op_sess, global_step_sess) # summary_writer.flush() # Observe training situation (For debugging.) if flags.debug and global_step_sess % flags.observe_freq == 1: real_a_sess, real_b_sess, adjusted_a_sess, segment_a_sess, fake_b_sess, \ real_a_name_sess, real_b_name_sess = \ sess.run([real_a, real_b, adjusted_a, segment_a, fake_b, real_a_name, real_b_name], feed_dict={handle_a: training_a_handle, handle_b: training_b_handle}) print('Logging training images.') dataset_parser.visualize_data( real_a=real_a_sess, real_b=real_b_sess, adjusted_a=adjusted_a_sess, segment_a=segment_a_sess, fake_b=fake_b_sess, shape=(1, 1), global_step=global_step_sess, logs_dir=dataset_parser.logs_image_train_dir, real_a_name=real_a_name_sess[0].decode(), real_b_name=real_b_name_sess[0].decode()) """ Saving the checkpoint """ if global_step_sess % flags.save_freq == 0: print('Saving model...') saver.save(sess, dataset_parser.checkpoint_dir + '/model.ckpt', global_step=global_step_sess) except tf.errors.OutOfRangeError: print( '----------------One epochs finished!----------------' ) sess.run([ training_a_iterator.initializer, training_b_iterator.initializer ]) elif flags.mode == 'test': from PIL import Image import scipy.ndimage.filters import scipy.io as sio import numpy as np print('Start Testing!') ''' with tf.variable_scope('Input_port'): val_a_handle = sess.run(validation_a_iterator.string_handle()) val_b_handle = sess.run(validation_b_iterator.string_handle()) sess.run([validation_a_iterator.initializer, validation_b_iterator.initializer]) ''' with tf.variable_scope('Input_port'): val_a_handle = sess.run( validation_a_iterator.string_handle()) val_b_handle = sess.run( validation_b_iterator.string_handle()) sess.run([ validation_a_iterator.initializer, validation_b_iterator.initializer ]) feed_dict_test = { handle_a: val_a_handle, handle_b: val_b_handle } image_idx = 0 while True: try: segment_a_ori_sess, real_a_name_sess, real_b_sess, real_a_sess, fake_b_sess = \ sess.run([segment_a_ori, real_a_name, real_b, real_a, fake_b], feed_dict=feed_dict_test) segment_a_np = (np.squeeze(segment_a_ori_sess) + 1.0) * 127.5 binary_a = np.zeros_like(segment_a_np, dtype=np.uint8) # binary_a[segment_a_np > 127.5] = 255 binary_mean = np.mean(segment_a_np) binary_a_high = np.mean( segment_a_np[segment_a_np > binary_mean]) binary_a_low = np.mean( segment_a_np[segment_a_np < binary_mean]) binary_a_ave = (binary_a_high + binary_a_low) / 2.0 segment_a_np_blur = scipy.ndimage.filters.gaussian_filter( segment_a_np, sigma=3) binary_a[segment_a_np_blur > binary_a_ave] = 255 # sio.savemat('{}/{}.mat'.format( # dataset_parser.logs_mat_output_dir, real_a_name_sess[0].decode()), # {'pred': segment_a_np, 'binary': binary_a}) # ----------------------------------------------------------------------------- if image_idx % 1 == 0: real_a_sess = np.squeeze(real_a_sess) x_png = Image.fromarray( real_a_sess.astype(np.uint8)) x_png.save('{}/{}_0_img.png'.format( dataset_parser.logs_image_val_dir, real_a_name_sess[0].decode()), format='PNG') x_png = Image.fromarray( segment_a_np.astype(np.uint8)) x_png.save('{}/{}_1_pred.png'.format( dataset_parser.logs_image_val_dir, real_a_name_sess[0].decode()), format='PNG') x_png = Image.fromarray(binary_a.astype(np.uint8)) x_png.save('{}/{}_2_binary.png'.format( dataset_parser.logs_image_val_dir, real_a_name_sess[0].decode()), format='PNG') fake_b_sess = np.squeeze(fake_b_sess) x_png = Image.fromarray( fake_b_sess.astype(np.uint8)) x_png.save('{}/{}_3_fake.png'.format( dataset_parser.logs_image_val_dir, real_a_name_sess[0].decode()), format='PNG') real_b_sess = np.squeeze(real_b_sess) x_png = Image.fromarray( real_b_sess.astype(np.uint8)) x_png.save('{}/{}_4_gt.png'.format( dataset_parser.logs_image_val_dir, real_a_name_sess[0].decode()), format='PNG') print(image_idx) image_idx += 1 except tf.errors.OutOfRangeError: print( '----------------One epochs finished!----------------' ) break
def main(args=None): print(args) tf.reset_default_graph() """ Read dataset parser """ flags.network_name = args[0].split('/')[-1].split('.')[0].split( 'main_')[-1] flags.logs_dir = './logs_' + flags.network_name dataset_parser = GANParser(flags=flags) """ Transform data to TFRecord format (Only do once.) """ if False: dataset_parser.load_paths(is_jpg=True, load_val=True) dataset_parser.data2record(name='{}_train.tfrecords'.format( dataset_parser.dataset_name), set_type='train', test_num=None) dataset_parser.data2record(name='{}_val.tfrecords'.format( dataset_parser.dataset_name), set_type='val', test_num=None) # coco_parser.data2record_test(name='coco_stuff2017_test-dev_all_label.tfrecords', is_dev=True, test_num=None) # coco_parser.data2record_test(name='coco_stuff2017_test_all_label.tfrecords', is_dev=False, test_num=None) return """ Input (TFRecord) """ with tf.variable_scope('TFRecord'): # DatasetA training_a_dataset = dataset_parser.tfrecord_get_dataset( name='{}_trainA.tfrecords'.format(dataset_parser.dataset_name), batch_size=flags.batch_size, shuffle_size=None) val_a_dataset = dataset_parser.tfrecord_get_dataset( name='{}_valA.tfrecords'.format(dataset_parser.dataset_name), batch_size=flags.batch_size) # DatasetB training_b_dataset = dataset_parser.tfrecord_get_dataset( name='{}_trainB.tfrecords'.format(dataset_parser.dataset_name), batch_size=flags.batch_size, shuffle_size=None) val_b_dataset = dataset_parser.tfrecord_get_dataset( name='{}_valB.tfrecords'.format(dataset_parser.dataset_name), batch_size=flags.batch_size) # A feed-able iterator with tf.variable_scope('RealA'): handle_a = tf.placeholder(tf.string, shape=[]) iterator_a = tf.contrib.data.Iterator.from_string_handle( handle_a, training_a_dataset.output_types, training_a_dataset.output_shapes) real_a = iterator_a.get_next() with tf.variable_scope('RealB'): handle_b = tf.placeholder(tf.string, shape=[]) iterator_b = tf.contrib.data.Iterator.from_string_handle( handle_b, training_b_dataset.output_types, training_b_dataset.output_shapes) real_b = iterator_b.get_next() with tf.variable_scope('InitialA_op'): training_a_iterator = training_a_dataset.make_initializable_iterator( ) validation_a_iterator = val_a_dataset.make_initializable_iterator() with tf.variable_scope('InitialB_op'): training_b_iterator = training_b_dataset.make_initializable_iterator( ) validation_b_iterator = val_b_dataset.make_initializable_iterator() """ Network (Computes predictions from the inference model) """ with tf.variable_scope('Network'): # Input global_step = tf.Variable(0, trainable=False, name='global_step', dtype=tf.int32) global_step_update_op = tf.assign_add(global_step, 1, name='global_step_update_op') # learning_rate = tf.placeholder(tf.float32, name="learning_rate") # is_training = tf.placeholder(tf.bool, name="is_training") # drop_probability = tf.placeholder(tf.float32, name="drop_probability") # fake_a_pool = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='fake_A_pool') fake_b_pool = tf.placeholder(tf.float32, shape=[ None, flags.image_height, flags.image_width, flags.c_in_dim ], name='fake_B_pool') image_linear_shape = tf.constant(flags.image_height * flags.image_width * flags.c_in_dim, dtype=tf.int32, name='image_linear_shape') real_a_test = tf.placeholder(tf.float32, shape=[ None, flags.image_height, flags.image_width, flags.c_in_dim ], name='real_a_test') # A -> B adjusted_a = tf.zeros_like(real_a, tf.float32, name='mask', optimize=True) # adjusted_a = high_light(real_a, name='high_light') # adjusted_a = gaussian_blur(real_a, name='gaussian_blur') segment_a = generator_combine(real_a, flags, False, name="Generator_A2B") segment_a_test = generator_combine(real_a_test, flags, True, name="Generator_A2B") with tf.variable_scope('Fake_B'): foreground = tf.multiply(real_a, segment_a, name='foreground') background = tf.multiply(adjusted_a, (1 - segment_a), name='background') fake_b = tf.add(foreground, background, name='fake_b') # fake_b_f = tf.reshape(fake_b, [-1, image_linear_shape], name='fake_b_f') fake_b_pool_f = tf.reshape(fake_b_pool, [-1, image_linear_shape], name='fake_b_pool_f') real_b_f = tf.reshape(real_b, [-1, image_linear_shape], name='real_b_f') dis_fake_b = discriminator_rule(fake_b_f, flags, reuse=False, name="Discriminator_B") dis_fake_b_pool = discriminator_rule(fake_b_pool_f, flags, reuse=True, name="Discriminator_B") dis_real_b = discriminator_rule(real_b_f, flags, reuse=True, name="Discriminator_B") # WGAN Loss with tf.name_scope('loss_gen_a2b'): loss_gen_a2b = -tf.reduce_mean(dis_fake_b) with tf.name_scope('loss_dis_b'): loss_dis_b_adv_real = -tf.reduce_mean(dis_real_b) loss_dis_b_adv_fake = tf.reduce_mean(dis_fake_b_pool) loss_dis_b = tf.reduce_mean(dis_fake_b_pool) - tf.reduce_mean( dis_real_b) with tf.name_scope('wgan-gp'): alpha = tf.random_uniform(shape=[flags.batch_size, 1], minval=0., maxval=1.) differences = fake_b_pool_f - real_b_f interpolates = real_b_f + (alpha * differences) gradients = tf.gradients( discriminator_rule(interpolates, flags, reuse=True, name="Discriminator_B"), [interpolates])[0] slopes = tf.sqrt( tf.reduce_sum(tf.square(gradients), reduction_indices=[1])) gradient_penalty = tf.reduce_mean((slopes - 1.)**2) loss_dis_b += flags.lambda_gp * gradient_penalty # Optimizer trainable_var_gen_a2b = tf.get_collection( key=tf.GraphKeys.TRAINABLE_VARIABLES, scope='Network/Generator_A2B') trainable_var_dis_b = tf.get_collection( key=tf.GraphKeys.TRAINABLE_VARIABLES, scope='Network/Discriminator_B') with tf.name_scope('learning_rate_decay'): decay = tf.maximum(0., 1. - (tf.cast(global_step, tf.float32) / (flags.num_epochs * flags.data_size)), name='decay') learning_rate = tf.multiply(flags.learning_rate, decay, name='learning_rate') train_op_gen_a2b = train_op(loss_gen_a2b, learning_rate, flags, trainable_var_gen_a2b, name='gen_a2b') train_op_dis_b = train_op(loss_dis_b, learning_rate, flags, trainable_var_dis_b, name='dis_b') saver = tf.train.Saver(max_to_keep=2) # Graph Logs with tf.variable_scope('GEN_a2b'): tf.summary.scalar("loss/gen_a2b/all", loss_gen_a2b) with tf.variable_scope('DIS_b'): tf.summary.scalar("loss/dis_b/all", loss_dis_b) tf.summary.scalar("loss/dis_b/adv_real", loss_dis_b_adv_real) tf.summary.scalar("loss/dis_b/adv_fake", loss_dis_b_adv_fake) summary_op = tf.summary.merge_all() # Graph Logs """ Session """ tfconfig = tf.ConfigProto(allow_soft_placement=True) tfconfig.gpu_options.allow_growth = True with tf.Session(config=tfconfig) as sess: with tf.variable_scope('Initial'): ckpt = tf.train.get_checkpoint_state(dataset_parser.checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: print("Model restored: {}".format(ckpt.model_checkpoint_path)) saver.restore(sess, ckpt.model_checkpoint_path) else: print("No Model found.") init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess.run(init_op) summary_writer = tf.summary.FileWriter(dataset_parser.logs_dir, sess.graph) """ Training Mode """ if flags.mode == 'train': print('Training mode! Batch size:{:d}'.format(flags.batch_size)) with tf.variable_scope('Input_port'): training_a_handle = sess.run( training_a_iterator.string_handle()) val_a_handle = sess.run(validation_a_iterator.string_handle()) training_b_handle = sess.run( training_b_iterator.string_handle()) val_b_handle = sess.run(validation_b_iterator.string_handle()) image_pool_a, image_pool_b = ImagePool( flags.pool_size), ImagePool(flags.pool_size) print('Start Training!') start_time = time.time() for epoch in range(flags.num_epochs): sess.run([ training_a_iterator.initializer, training_b_iterator.initializer ]) feed_dict_train = { handle_a: training_a_handle, handle_b: training_b_handle } # feed_dict_valid = {is_training: False} while True: try: # Update gen_A2B, gen_B2A _, fake_b_sess, = sess.run([train_op_gen_a2b, fake_b], feed_dict=feed_dict_train) # Update dis_B, dis_A fake_b_pool_query = image_pool_b.query(fake_b_sess) _ = sess.run(train_op_dis_b, feed_dict={ fake_b_pool: fake_b_pool_query, handle_b: training_b_handle }) sess.run(global_step_update_op) global_step_sess, learning_rate_sess = sess.run( [global_step, learning_rate]) print( 'epoch:[{:d}/{:d}], global step:{:d}, learning rate:{:f}, time:{:4.4f}' .format(epoch, flags.num_epochs, global_step_sess, learning_rate_sess, time.time() - start_time)) # Logging the events if global_step_sess % flags.log_freq == 1: print('Logging the events') summary_op_sess = sess.run(summary_op, feed_dict={ handle_a: training_a_handle, handle_b: training_b_handle, fake_b_pool: fake_b_pool_query }) summary_writer.add_summary(summary_op_sess, global_step_sess) # summary_writer.flush() # Observe training situation (For debugging.) if flags.debug and global_step_sess % flags.observe_freq == 1: real_a_sess, real_b_sess, adjusted_a_sess, segment_a_sess, fake_b_sess = \ sess.run([real_a, real_b, adjusted_a, segment_a, fake_b], feed_dict={handle_a: training_a_handle, handle_b: training_b_handle}) print('Logging training images.') dataset_parser.visualize_data( real_a=real_a_sess, real_b=real_b_sess, adjusted_a=adjusted_a_sess, segment_a=segment_a_sess, fake_b=fake_b_sess, shape=(1, 1), global_step=global_step_sess, logs_dir=dataset_parser.logs_image_train_dir) """ Saving the checkpoint """ if global_step_sess % flags.save_freq == 0: print('Saving model...') saver.save(sess, dataset_parser.checkpoint_dir + '/model.ckpt', global_step=global_step_sess) except tf.errors.OutOfRangeError: print( '----------------One epochs finished!----------------' ) break elif flags.mode == 'test': from glob import glob from PIL import Image import os dataset_dir = './dataset/msra4000' data = [] for folder in os.listdir(dataset_dir): path = os.path.join(dataset_dir, folder, "*.jpg") data.extend(glob(path)) data_len = len(data) for img_idx, img_path in enumerate(data): print('[{:d}/{:d}]'.format(img_idx, data_len)) img_name = img_path.split('/')[-1].split('.jpg')[0] test_img = Image.open(img_path) test_img_size = test_img.size test_img = test_img.resize( (flags.image_height, flags.image_width), Image.BILINEAR) x = np.array(test_img) if len(x.shape) < 3: x = np.dstack((x, x, x)) x = np.expand_dims(x, axis=0) segment_a_test_sess = sess.run(segment_a_test, feed_dict={real_a_test: x}) segment_a_test_sess = np.array(segment_a_test_sess) * 255 x_png = Image.fromarray( np.squeeze(segment_a_test_sess).astype(np.uint8)) x_png = x_png.resize(test_img_size, Image.BILINEAR) x_png.save('{}/{}.png'.format( dataset_parser.logs_image_val_dir, img_name), format='PNG')
def main(args=None): print(args) tf.reset_default_graph() """ Read dataset parser """ flags.network_name = args[0].split('/')[-1].split('.')[0].split( 'main_')[-1] flags.logs_dir = './logs_' + flags.network_name dataset_parser = GANParser(dataset_dir=flags.dataset_dir, flags=flags) """ Transform data to TFRecord format (Only do once.) """ if False: dataset_parser.load_paths(is_jpg=True, load_val=True) dataset_parser.data2record(name='{}_train.tfrecords'.format( dataset_parser.dataset_name), set_type='train', test_num=None) dataset_parser.data2record(name='{}_val.tfrecords'.format( dataset_parser.dataset_name), set_type='val', test_num=None) # coco_parser.data2record_test(name='coco_stuff2017_test-dev_all_label.tfrecords', is_dev=True, test_num=None) # coco_parser.data2record_test(name='coco_stuff2017_test_all_label.tfrecords', is_dev=False, test_num=None) return """ Input (TFRecord) """ with tf.variable_scope('TFRecord'): # DatasetA training_a_dataset = dataset_parser.tfrecord_get_dataset( name='{}_trainA.tfrecords'.format(dataset_parser.dataset_name), batch_size=flags.batch_size, shuffle_size=None) val_a_dataset = dataset_parser.tfrecord_get_dataset( name='{}_valA.tfrecords'.format(dataset_parser.dataset_name), batch_size=flags.batch_size) # DatasetB training_b_dataset = dataset_parser.tfrecord_get_dataset( name='{}_trainB.tfrecords'.format(dataset_parser.dataset_name), batch_size=flags.batch_size, shuffle_size=None) val_b_dataset = dataset_parser.tfrecord_get_dataset( name='{}_valB.tfrecords'.format(dataset_parser.dataset_name), batch_size=flags.batch_size) # A feed-able iterator with tf.variable_scope('RealA'): handle_a = tf.placeholder(tf.string, shape=[]) iterator_a = tf.contrib.data.Iterator.from_string_handle( handle_a, training_a_dataset.output_types, training_a_dataset.output_shapes) real_a = iterator_a.get_next() with tf.variable_scope('RealB'): handle_b = tf.placeholder(tf.string, shape=[]) iterator_b = tf.contrib.data.Iterator.from_string_handle( handle_b, training_b_dataset.output_types, training_b_dataset.output_shapes) real_b = iterator_b.get_next() with tf.variable_scope('InitialA_op'): training_a_iterator = training_a_dataset.make_initializable_iterator( ) validation_a_iterator = val_a_dataset.make_initializable_iterator() with tf.variable_scope('InitialB_op'): training_b_iterator = training_b_dataset.make_initializable_iterator( ) validation_b_iterator = val_b_dataset.make_initializable_iterator() """ Network (Computes predictions from the inference model) """ with tf.variable_scope('Network'): # Input global_step = tf.Variable(0, trainable=False, name='global_step') learning_rate = tf.placeholder(tf.float32, name="learning_rate") # is_training = tf.placeholder(tf.bool, name="is_training") # drop_probability = tf.placeholder(tf.float32, name="drop_probability") # fake_a_pool = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='fake_A_pool') fake_b_pool = tf.placeholder(tf.float32, shape=[None, None, None, 3], name='fake_B_pool') # A -> B adjusted_a = high_light(real_a, name='high_light') segment_a = generator_resnet_se_fixed(real_a, flags, False, name="Generator_A2B") # TODO : gp, argmax? with tf.variable_scope('Fake_B'): foreground = tf.multiply(real_a, segment_a, name='foreground') background = tf.multiply(adjusted_a, (1 - segment_a), name='background') fake_b = tf.add(foreground, background, name='fake_b') dis_fake_b = discriminator_patch142(fake_b, flags, reuse=False, name="Discriminator_B") dis_fake_b_pool = discriminator_patch142(fake_b_pool, flags, reuse=True, name="Discriminator_B") dis_real_b = discriminator_patch142(real_b, flags, reuse=True, name="Discriminator_B") ''' fake_data_flatten = tf.reshape(fake_b_pool, [-1]) real_data_flatten = tf.reshape(real_b, [-1]) alpha = tf.random_uniform(shape=[1, 1], minval=0., maxval=1.) differences = fake_data_flatten - real_data_flatten interpolates = fake_data_flatten + (alpha * differences) interpolates = tf.reshape(interpolates,[-1]) with tf.variable_scope("DIS"): # TODO: why [interpolates]? d_logits_ = discriminator( interpolates, self.conditions, reuse=True, training=self.is_training, name='A') gradients = tf.gradients(d_logits_, [interpolates])[0] gradients = tf.reshape(gradients, self.image_shape) slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients), axis=1)) gradient_penalty = tf.reduce_mean((slopes - 1.) ** 2) disc_cost += self.lambda_gp * gradient_penalty ''' with tf.variable_scope('loss_gen_a2b'): loss_gen_a2b = mae_criterion(dis_fake_b, tf.ones_like(dis_fake_b), name='adv') # loss_gen_a2b = -tf.reduce_mean(dis_fake_b) with tf.variable_scope('loss_dis_b'): loss_dis_b_adv_real = mae_criterion(dis_real_b, tf.ones_like(dis_real_b), name='adv_real') loss_dis_b_adv_fake = mae_criterion(dis_fake_b_pool, tf.zeros_like(dis_fake_b_pool), name='adv_fake') loss_dis_b = loss_dis_b_adv_real + loss_dis_b_adv_fake # loss_dis_b = tf.reduce_mean(dis_fake_b_pool) - tf.reduce_mean(dis_real_b) trainable_var_gen_a2b = tf.get_collection( key=tf.GraphKeys.TRAINABLE_VARIABLES, scope='Network/Generator_A2B') trainable_var_dis_b = tf.get_collection( key=tf.GraphKeys.TRAINABLE_VARIABLES, scope='Network/Discriminator_B') train_op_gen_a2b = train_op(loss_gen_a2b, learning_rate, flags, trainable_var_gen_a2b, name='gen_a2b') train_op_dis_b = train_op(loss_dis_b, learning_rate, flags, trainable_var_dis_b, name='dis_b') saver = tf.train.Saver(max_to_keep=2) # Graph Logs with tf.variable_scope('GEN_a2b'): tf.summary.scalar("loss/gen_a2b/all", loss_gen_a2b) with tf.variable_scope('DIS_b'): tf.summary.scalar("loss/dis_b/all", loss_dis_b) tf.summary.scalar("loss/dis_b/adv_real", loss_dis_b_adv_real) tf.summary.scalar("loss/dis_b/adv_fake", loss_dis_b_adv_fake) summary_op = tf.summary.merge_all() # Graph Logs """ Session """ tfconfig = tf.ConfigProto(allow_soft_placement=True) tfconfig.gpu_options.allow_growth = True with tf.Session(config=tfconfig) as sess: with tf.variable_scope('Initial'): ckpt = tf.train.get_checkpoint_state(dataset_parser.checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: print("Model restored: {}".format(ckpt.model_checkpoint_path)) saver.restore(sess, ckpt.model_checkpoint_path) else: print("No Model found.") init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess.run(init_op) summary_writer = tf.summary.FileWriter(dataset_parser.logs_dir, sess.graph) """ Training Mode """ if flags.mode == 'train': print('Training mode! Batch size:{:d}'.format(flags.batch_size)) with tf.variable_scope('Input_port'): training_a_handle = sess.run( training_a_iterator.string_handle()) val_a_handle = sess.run(validation_a_iterator.string_handle()) training_b_handle = sess.run( training_b_iterator.string_handle()) val_b_handle = sess.run(validation_b_iterator.string_handle()) image_pool_a, image_pool_b = ImagePool( flags.pool_size), ImagePool(flags.pool_size) print('Start Training!') start_time = time.time() global_step_sess = sess.run(global_step) for epoch in range(flags.num_epochs): sess.run([ training_a_iterator.initializer, training_b_iterator.initializer ]) learning_rate_sess = flags.learning_rate if epoch < flags.num_epochs_decay\ else flags.learning_rate*(flags.num_epochs-epoch)/(flags.num_epochs-flags.num_epochs_decay) feed_dict_train = { learning_rate: learning_rate_sess, handle_a: training_a_handle, handle_b: training_b_handle } # feed_dict_valid = {is_training: False} while True: try: print( 'epoch:[{:d}/{:d}], global step:{:d}, learning rate:{:f}, time:{:4.4f}' .format(epoch, flags.num_epochs, global_step_sess, learning_rate_sess, time.time() - start_time)) # Update gen_A2B, gen_B2A _, fake_b_sess, = sess.run([train_op_gen_a2b, fake_b], feed_dict=feed_dict_train) # Update dis_B, dis_A fake_b_pool_query = image_pool_b.query(fake_b_sess) _ = sess.run(train_op_dis_b, feed_dict={ learning_rate: learning_rate_sess, fake_b_pool: fake_b_pool_query, handle_b: training_b_handle }) global_step_sess += 1 sess.run(tf.assign(global_step, global_step_sess)) # Logging the events if global_step_sess % flags.log_freq == 1: print('Logging the events') summary_op_sess = sess.run(summary_op, feed_dict={ handle_a: training_a_handle, handle_b: training_b_handle, fake_b_pool: fake_b_pool_query }) summary_writer.add_summary(summary_op_sess, global_step_sess) # summary_writer.flush() # Observe training situation (For debugging.) if flags.debug and global_step_sess % flags.observe_freq == 1: real_a_sess, real_b_sess, adjusted_a_sess, segment_a_sess, fake_b_sess = \ sess.run([real_a, real_b, adjusted_a, segment_a, fake_b], feed_dict={handle_a: training_a_handle, handle_b: training_b_handle}) print('Logging training images.') dataset_parser.visualize_data( real_a=real_a_sess, real_b=real_b_sess, adjusted_a=adjusted_a_sess, segment_a=segment_a_sess, fake_b=fake_b_sess, shape=(1, 1), global_step=global_step_sess, logs_dir=dataset_parser.logs_image_train_dir) """ Saving the checkpoint """ if global_step_sess % flags.save_freq == 0: print('Saving model...') saver.save(sess, dataset_parser.checkpoint_dir + '/model.ckpt', global_step=global_step) except tf.errors.OutOfRangeError: print( '----------------One epochs finished!----------------' ) break