def freeze(ckpt_path): OPTIONS = namedtuple( 'OPTIONS', 'batch_size image_size \ gf_dim df_dim output_c_dim') options = OPTIONS._make( (args.batch_size, args.fine_size, args.ngf, args.ndf, args.output_nc)) inp_content_image = tf.placeholder(tf.float32, shape=(1, None, None, 3), name='input') out_image = generator_resnet(inp_content_image, options, name='generatorA2B') out_image = tf.identity(out_image, name='output') init_op = tf.global_variables_initializer() restore_saver = tf.train.Saver() with tf.Session() as sess: sess.run(init_op) restore_saver.restore(sess, ckpt_path) frozen_graph_def = tf.graph_util.convert_variables_to_constants( sess, sess.graph_def, output_node_names=['output']) path = os.path.dirname(ckpt_path) with open(path + '/cyclegan.pb', 'wb') as f: f.write(frozen_graph_def.SerializeToString())
def __init__(self, args): self.batch_size = args.batch_size self.image_width = args.image_width self.image_height = args.image_height self.input_c_dim = args.input_nc self.output_c_dim = args.output_nc self.L1_lambda = args.L1_lambda self.Lg_lambda = args.Lg_lambda self.dataset_dir = args.dataset_dir self.segment_class = args.segment_class self.alpha_recip = 1. / args.ratio_gan2seg if args.ratio_gan2seg > 0 else 0 self.use_pix2pix = args.use_pix2pix self.discriminator = discriminator() if args.use_resnet: self.generator = generator_resnet() else: if args.use_pix2pix: self.generator = generator_pix2pix() self.discriminator = discriminator_pix2pix() else: self.generator = generator_unet() if args.use_lsgan: self.criterionGAN = mae_criterion else: self.criterionGAN = sce_criterion # tf.keras.utils.plot_model(self.discriminator, 'multi_input_and_output_model.png', show_shapes=True) # input("") OPTIONS = namedtuple( 'OPTIONS', 'batch_size image_height image_width \ gf_dim df_dim output_c_dim is_training segment_class' ) self.options = OPTIONS._make( (args.batch_size, args.image_height, args.image_width, args.ngf, args.ndf, args.output_nc, args.phase == 'train', args.segment_class)) self._build_model(args) self.pool = ImagePool(args.max_size) #### [ADDED] CHECKPOINT MANAGER self.lr = 0.001 self.d_optim = tf.keras.optimizers.Adam(learning_rate=self.lr, beta_1=args.beta1) self.g_optim = tf.keras.optimizers.Adam(learning_rate=self.lr, beta_1=args.beta1) self.gen_ckpt = tf.train.Checkpoint(optimizer=self.g_optim, net=self.generator) self.disc_ckpt = tf.train.Checkpoint(optimizer=self.d_optim, net=self.discriminator) self.gen_ckpt_manager = tf.train.CheckpointManager( self.gen_ckpt, './checkpoint/gta/gen_ckpts', max_to_keep=3) self.disc_ckpt_manager = tf.train.CheckpointManager( self.disc_ckpt, './checkpoint/gta/disc_ckpts', max_to_keep=3)
def main(args=None): print(args) tf.reset_default_graph() """ Read dataset parser """ flags.network_name = args[0].split('/')[-1].split('.')[0].split( 'main_')[-1] flags.logs_dir = './logs_' + flags.network_name dataset_parser = GANParser(flags=flags) """ Transform data to TFRecord format (Only do once.) """ if False: dataset_parser.load_paths(is_jpg=True, load_val=True) dataset_parser.data2record(name='{}_train.tfrecords'.format( dataset_parser.dataset_name), set_type='train', test_num=None) dataset_parser.data2record(name='{}_val.tfrecords'.format( dataset_parser.dataset_name), set_type='val', test_num=None) # coco_parser.data2record_test(name='coco_stuff2017_test-dev_all_label.tfrecords', is_dev=True, test_num=None) # coco_parser.data2record_test(name='coco_stuff2017_test_all_label.tfrecords', is_dev=False, test_num=None) return """ Build Graph """ with tf.Graph().as_default(): """ Input (TFRecord) """ with tf.name_scope('TFRecord'): # DatasetA training_a_dataset = dataset_parser.tfrecord_get_dataset( name='{}_trainA.tfrecords'.format(dataset_parser.dataset_name), batch_size=flags.batch_size, shuffle_size=None) val_a_dataset = dataset_parser.tfrecord_get_dataset( name='{}_valA.tfrecords'.format(dataset_parser.dataset_name), batch_size=flags.batch_size, need_flip=(flags.mode == 'train')) # DatasetB training_b_dataset = dataset_parser.tfrecord_get_dataset( name='{}_trainB.tfrecords'.format(dataset_parser.dataset_name), batch_size=flags.batch_size, shuffle_size=None) val_b_dataset = dataset_parser.tfrecord_get_dataset( name='{}_valB.tfrecords'.format(dataset_parser.dataset_name), batch_size=flags.batch_size, is_label=True, need_flip=(flags.mode == 'train')) # A feed-able iterator with tf.name_scope('RealA'): handle_a = tf.placeholder(tf.string, shape=[]) iterator_a = tf.contrib.data.Iterator.from_string_handle( handle_a, training_a_dataset.output_types, training_a_dataset.output_shapes) real_a, real_a_name, real_a_shape = iterator_a.get_next() with tf.name_scope('RealB'): handle_b = tf.placeholder(tf.string, shape=[]) iterator_b = tf.contrib.data.Iterator.from_string_handle( handle_b, training_b_dataset.output_types, training_b_dataset.output_shapes) real_b, real_b_name, real_b_shape = iterator_b.get_next() with tf.name_scope('InitialA_op'): training_a_iterator = training_a_dataset.make_initializable_iterator( ) validation_a_iterator = val_a_dataset.make_initializable_iterator( ) with tf.name_scope('InitialB_op'): training_b_iterator = training_b_dataset.make_initializable_iterator( ) validation_b_iterator = val_b_dataset.make_initializable_iterator( ) """ Network (Computes predictions from the inference model) """ with tf.name_scope('Network'): # Input global_step = tf.Variable(0, trainable=False, name='global_step', dtype=tf.int32) global_step_update_op = tf.assign_add(global_step, 1, name='global_step_update_op') # mean_rgb = tf.constant((123.68, 116.78, 103.94), dtype=tf.float32) fake_b_pool = tf.placeholder(tf.float32, shape=[ None, flags.image_height, flags.image_width, flags.c_in_dim ], name='fake_B_pool') image_linear_shape = tf.constant( flags.image_height * flags.image_width * flags.c_in_dim, dtype=tf.int32, name='image_linear_shape') # A -> B ''' with tf.name_scope('Generator'): with slim.arg_scope(vgg.vgg_arg_scope()): net, end_points = vgg.vgg_16(real_a - mean_rgb, num_classes=1, is_training=True, spatial_squeeze=False) print(net) return with tf.variable_scope('Generator_A2B'): pred = tf.layers.conv2d(tf.nn.relu(net), 1, 1, 1) pred_upscale = tf.image.resize_bilinear(pred, (flags.image_height, flags.image_width), name='up_scale') segment_a = tf.nn.sigmoid(pred_upscale, name='segment_a') # sigmoid cross entropy Loss with tf.name_scope('loss_gen_a2b'): loss_gen_a2b = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( logits=pred_upscale, labels=real_b/255.0, name='sigmoid'), name='mean') ''' # A -> B # adjusted_a = tf.zeros_like(real_a, tf.float32, name='mask', optimize=True) # adjusted_a = high_light(real_a, name='high_light') adjusted_a = tf.layers.average_pooling2d(real_a, 11, strides=1, padding='same', name='adjusted_a') logits_a = generator_resnet(real_a, flags, False, name="Generator_A2B") segment_a = tf.nn.tanh(logits_a, name='segment_a') logits_a_ori = tf.image.resize_bilinear( logits_a, (real_a_shape[0][0], real_b_shape[0][1]), name='logits_a_ori') segment_a_ori = tf.nn.tanh(logits_a_ori, name='segment_a_ori') with tf.variable_scope('Fake_B'): foreground = tf.multiply(real_a, segment_a, name='foreground') background = tf.multiply(adjusted_a, (1 - segment_a), name='background') fake_b_logits = tf.add(foreground, background, name='fake_b_logits') fake_b = tf.clip_by_value(fake_b_logits, 0, 255, name='fake_b') # fake_b_f = tf.reshape(fake_b, [-1, image_linear_shape], name='fake_b_f') fake_b_pool_f = tf.reshape(fake_b_pool, [-1, image_linear_shape], name='fake_b_pool_f') real_b_f = tf.reshape(real_b, [-1, image_linear_shape], name='real_b_f') dis_fake_b = discriminator_se_wgangp(fake_b_f, flags, reuse=False, name="Discriminator_B") dis_fake_b_pool = discriminator_se_wgangp(fake_b_pool_f, flags, reuse=True, name="Discriminator_B") dis_real_b = discriminator_se_wgangp(real_b_f, flags, reuse=True, name="Discriminator_B") # WGAN Loss with tf.name_scope('loss_gen_a2b'): loss_gen_a2b = -tf.reduce_mean(dis_fake_b) with tf.name_scope('loss_dis_b'): loss_dis_b_adv_real = -tf.reduce_mean(dis_real_b) loss_dis_b_adv_fake = tf.reduce_mean(dis_fake_b_pool) loss_dis_b = tf.reduce_mean(dis_fake_b_pool) - tf.reduce_mean( dis_real_b) with tf.name_scope('wgan-gp'): alpha = tf.random_uniform(shape=[flags.batch_size, 1], minval=0., maxval=1.) differences = fake_b_pool_f - real_b_f interpolates = real_b_f + (alpha * differences) gradients = tf.gradients( discriminator_se_wgangp(interpolates, flags, reuse=True, name="Discriminator_B"), [interpolates])[0] slopes = tf.sqrt( tf.reduce_sum(tf.square(gradients), reduction_indices=[1])) gradient_penalty = tf.reduce_mean((slopes - 1.)**2) loss_dis_b += flags.lambda_gp * gradient_penalty # Optimizer ''' trainable_var_resnet = tf.get_collection( key=tf.GraphKeys.TRAINABLE_VARIABLES, scope='vgg_16') trainable_var_gen_a2b = tf.get_collection( key=tf.GraphKeys.TRAINABLE_VARIABLES, scope='Generator_A2B') + trainable_var_resnet slim.model_analyzer.analyze_vars(trainable_var_gen_a2b, print_info=True) ''' trainable_var_gen_a2b = tf.get_collection( key=tf.GraphKeys.TRAINABLE_VARIABLES, scope='Generator_A2B') trainable_var_dis_b = tf.get_collection( key=tf.GraphKeys.TRAINABLE_VARIABLES, scope='Discriminator_B') with tf.name_scope('learning_rate_decay'): decay = tf.maximum( 0., 1. - (tf.cast(global_step, tf.float32) / flags.training_iter), name='decay') learning_rate = tf.multiply(flags.learning_rate, decay, name='learning_rate') train_op_gen_a2b = train_op(loss_gen_a2b, learning_rate, flags, trainable_var_gen_a2b, name='gen_a2b') train_op_dis_b = train_op(loss_dis_b, learning_rate, flags, trainable_var_dis_b, name='dis_b') saver = tf.train.Saver(max_to_keep=2) # Graph Logs with tf.name_scope('GEN_a2b'): tf.summary.scalar("loss/gen_a2b/all", loss_gen_a2b) with tf.name_scope('DIS_b'): tf.summary.scalar("loss/dis_b/all", loss_dis_b) tf.summary.scalar("loss/dis_b/adv_real", loss_dis_b_adv_real) tf.summary.scalar("loss/dis_b/adv_fake", loss_dis_b_adv_fake) summary_op = tf.summary.merge_all() """ Session """ tfconfig = tf.ConfigProto(allow_soft_placement=True) tfconfig.gpu_options.allow_growth = True with tf.Session(config=tfconfig) as sess: with tf.name_scope('Initial'): ckpt = tf.train.get_checkpoint_state( dataset_parser.checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: print("Model restored: {}".format( ckpt.model_checkpoint_path)) saver.restore(sess, ckpt.model_checkpoint_path) else: print("No Model found.") init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess.run(init_op) # init_fn = slim.assign_from_checkpoint_fn('./pretrained/vgg_16.ckpt', # slim.get_model_variables('vgg_16')) # init_fn(sess) summary_writer = tf.summary.FileWriter(dataset_parser.logs_dir, sess.graph) """ Training Mode """ if flags.mode == 'train': print('Training mode! Batch size:{:d}'.format( flags.batch_size)) with tf.variable_scope('Input_port'): training_a_handle = sess.run( training_a_iterator.string_handle()) training_b_handle = sess.run( training_b_iterator.string_handle()) # val_a_handle = sess.run(validation_a_iterator.string_handle()) # val_b_handle = sess.run(validation_b_iterator.string_handle()) image_pool_a, image_pool_b = ImagePool( flags.pool_size), ImagePool(flags.pool_size) print('Start Training!') start_time = time.time() sess.run([ training_a_iterator.initializer, training_b_iterator.initializer ]) feed_dict_train = { handle_a: training_a_handle, handle_b: training_b_handle } # feed_dict_valid = {is_training: False} global_step_sess = sess.run(global_step) while global_step_sess < flags.training_iter: try: # Update gen_A2B, gen_B2A _, fake_b_sess = sess.run([train_op_gen_a2b, fake_b], feed_dict=feed_dict_train) # _, loss_gen_a2b_sess = sess.run([train_op_gen_a2b, loss_gen_a2b], feed_dict=feed_dict_train) # Update dis_B, dis_A fake_b_pool_query = image_pool_b.query(fake_b_sess) _ = sess.run(train_op_dis_b, feed_dict={ fake_b_pool: fake_b_pool_query, handle_b: training_b_handle }) sess.run(global_step_update_op) global_step_sess, learning_rate_sess = sess.run( [global_step, learning_rate]) print( 'global step:[{:d}/{:d}], learning rate:{:f}, time:{:4.4f}' .format(global_step_sess, flags.training_iter, learning_rate_sess, time.time() - start_time)) # Logging the events if global_step_sess % flags.log_freq == 1: print('Logging the events') summary_op_sess = sess.run(summary_op, feed_dict={ handle_a: training_a_handle, handle_b: training_b_handle, fake_b_pool: fake_b_pool_query }) summary_writer.add_summary(summary_op_sess, global_step_sess) # summary_writer.flush() # Observe training situation (For debugging.) if flags.debug and global_step_sess % flags.observe_freq == 1: real_a_sess, real_b_sess, adjusted_a_sess, segment_a_sess, fake_b_sess, \ real_a_name_sess, real_b_name_sess = \ sess.run([real_a, real_b, adjusted_a, segment_a, fake_b, real_a_name, real_b_name], feed_dict={handle_a: training_a_handle, handle_b: training_b_handle}) print('Logging training images.') dataset_parser.visualize_data( real_a=real_a_sess, real_b=real_b_sess, adjusted_a=adjusted_a_sess, segment_a=segment_a_sess, fake_b=fake_b_sess, shape=(1, 1), global_step=global_step_sess, logs_dir=dataset_parser.logs_image_train_dir, real_a_name=real_a_name_sess[0].decode(), real_b_name=real_b_name_sess[0].decode()) """ Saving the checkpoint """ if global_step_sess % flags.save_freq == 0: print('Saving model...') saver.save(sess, dataset_parser.checkpoint_dir + '/model.ckpt', global_step=global_step_sess) except tf.errors.OutOfRangeError: print( '----------------One epochs finished!----------------' ) sess.run([ training_a_iterator.initializer, training_b_iterator.initializer ]) elif flags.mode == 'test': from PIL import Image import scipy.ndimage.filters import scipy.io as sio import numpy as np print('Start Testing!') ''' with tf.variable_scope('Input_port'): val_a_handle = sess.run(validation_a_iterator.string_handle()) val_b_handle = sess.run(validation_b_iterator.string_handle()) sess.run([validation_a_iterator.initializer, validation_b_iterator.initializer]) ''' with tf.variable_scope('Input_port'): val_a_handle = sess.run( validation_a_iterator.string_handle()) val_b_handle = sess.run( validation_b_iterator.string_handle()) sess.run([ validation_a_iterator.initializer, validation_b_iterator.initializer ]) feed_dict_test = { handle_a: val_a_handle, handle_b: val_b_handle } image_idx = 0 while True: try: segment_a_ori_sess, real_a_name_sess, real_b_sess, real_a_sess, fake_b_sess = \ sess.run([segment_a_ori, real_a_name, real_b, real_a, fake_b], feed_dict=feed_dict_test) segment_a_np = (np.squeeze(segment_a_ori_sess) + 1.0) * 127.5 binary_a = np.zeros_like(segment_a_np, dtype=np.uint8) # binary_a[segment_a_np > 127.5] = 255 binary_mean = np.mean(segment_a_np) binary_a_high = np.mean( segment_a_np[segment_a_np > binary_mean]) binary_a_low = np.mean( segment_a_np[segment_a_np < binary_mean]) binary_a_ave = (binary_a_high + binary_a_low) / 2.0 segment_a_np_blur = scipy.ndimage.filters.gaussian_filter( segment_a_np, sigma=11) binary_a[segment_a_np_blur > binary_a_ave] = 255 sio.savemat( '{}/{}.mat'.format( dataset_parser.logs_mat_output_dir, real_a_name_sess[0].decode()), { 'pred': segment_a_np, 'binary': binary_a }) # ----------------------------------------------------------------------------- if image_idx % 1 == 0: real_a_sess = np.squeeze(real_a_sess) x_png = Image.fromarray( real_a_sess.astype(np.uint8)) x_png.save('{}/{}_0_img.png'.format( dataset_parser.logs_image_val_dir, real_a_name_sess[0].decode()), format='PNG') x_png = Image.fromarray( segment_a_np.astype(np.uint8)) x_png.save('{}/{}_1_pred.png'.format( dataset_parser.logs_image_val_dir, real_a_name_sess[0].decode()), format='PNG') x_png = Image.fromarray(binary_a.astype(np.uint8)) x_png.save('{}/{}_2_binary.png'.format( dataset_parser.logs_image_val_dir, real_a_name_sess[0].decode()), format='PNG') fake_b_sess = np.squeeze(fake_b_sess) x_png = Image.fromarray( fake_b_sess.astype(np.uint8)) x_png.save('{}/{}_3_fake.png'.format( dataset_parser.logs_image_val_dir, real_a_name_sess[0].decode()), format='PNG') real_b_sess = np.squeeze(real_b_sess) x_png = Image.fromarray( real_b_sess.astype(np.uint8)) x_png.save('{}/{}_4_gt.png'.format( dataset_parser.logs_image_val_dir, real_a_name_sess[0].decode()), format='PNG') print(image_idx) image_idx += 1 except tf.errors.OutOfRangeError: print( '----------------One epochs finished!----------------' ) break