def Generator_2(inputs,targets): traindir = os.path.join(logdir, 'G2\\pix2pix_G') if tf.gfile.Exists(traindir): tf.gfile.DeleteRecursively(traindir) tf.gfile.MakeDirs(traindir) fiber_output,fiber_input = inputs encoder, label = targets with tf.variable_scope('Generator'): with tf.variable_scope('G2'): generated_data = pix2pix_G(fiber_input) * circle(FLAGS.input_size,FLAGS.input_size) with tf.name_scope('Train_summary'): reshaped_fiber_input = get_summary_image(fiber_input,FLAGS.grid_size) reshaped_label = get_summary_image(label,FLAGS.grid_size) reshaped_generated_data = get_summary_image(generated_data,FLAGS.grid_size) tf.summary.image('Fiber_Input', reshaped_fiber_input) tf.summary.image('Fiber_Label', reshaped_label) tf.summary.image('Generated_Data', reshaped_generated_data) with tf.name_scope('g2_loss'): G2_loss = combine_loss(generated_data, label, add_summary=True) with tf.name_scope('Train_Loss'): reg_loss = tf.losses.get_regularization_loss() total_loss = G2_loss + reg_loss total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.') tf.summary.scalar('Regularization_loss',reg_loss) tf.summary.scalar('G2_loss', G2_loss) tf.summary.scalar('Total_loss',total_loss) lr = get_lr(FLAGS.generator_lr) optimizer = get_optimizer(lr) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) train_op = slim.learning.create_train_op(total_loss, optimizer, update_ops =update_ops, variables_to_train= tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Generator/G2') ) with tf.name_scope('Train_ops'): psnr = tf.reduce_mean(tf.image.psnr(generated_data, label, max_val=1.0)) ssim = tf.reduce_mean(tf.image.ssim(generated_data, label, max_val=1.0)) corr = correlation(generated_data, label) tf.summary.scalar('PSNR', psnr) tf.summary.scalar('SSIM', ssim) tf.summary.scalar('Relation', corr) tf.summary.scalar('Learning_rate', lr) slim.learning.train(train_op, traindir, number_of_steps =FLAGS.max_iter, log_every_n_steps=FLAGS.log_n_steps, # init_fn=get_init_fn('E:\GitHub\MMFI\log\\G2\\pix2pix_G'), save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs = FLAGS.save_interval_secs)
def main(_): tf.logging.set_verbosity(tf.logging.INFO) with tf.Graph().as_default(): logdir = 'E:\GitHub\MMFI\log\GG12\\CNN' evaldir = os.path.join(logdir, 'eval') if not tf.gfile.Exists(evaldir): # tf.gfile.DeleteRecursively(evaldir) tf.gfile.MakeDirs(evaldir) with tf.name_scope('inputs'): fiber_output, fiber_input, encoder, label = data_loader.read_inputs('valid.txt', False) with tf.variable_scope('Generator'): with tf.variable_scope('G1'): generated_input = pix2pix_G(fiber_output, is_training=False) \ * circle(FLAGS.input_size,FLAGS.input_size) with tf.variable_scope('G2'): generated_data = pix2pix_G(generated_input,is_training=False)\ * circle(FLAGS.input_size,FLAGS.input_size) with tf.name_scope('Valid_summary'): reshaped_fiber_input = get_summary_image(fiber_input, FLAGS.grid_size) reshaped_label = get_summary_image(label, FLAGS.grid_size) reshaped_generated_input = get_summary_image(generated_input, FLAGS.grid_size) reshaped_generated_data = get_summary_image(generated_data, FLAGS.grid_size) tf.summary.image('Input_Fiber', reshaped_fiber_input) tf.summary.image('Input_Generator', reshaped_generated_input) tf.summary.image('Data_Real', reshaped_label) tf.summary.image('Data_Generator', reshaped_generated_data) with tf.name_scope('Valid_op'): psnr = tf.reduce_mean(tf.image.psnr(generated_data, label, max_val=1.0)) ssim = tf.reduce_mean(tf.image.ssim(generated_data, label, max_val=1.0)) corr = correlation(generated_data, label) # inception_score = get_inception_score(generated_data) tf.summary.scalar('PSNR', psnr) tf.summary.scalar('SSIM', ssim) tf.summary.scalar('Relation', corr) grate = tf.ones([1,FLAGS.grid_size*FLAGS.input_size,10,1],dtype=tf.float32) reshaped_images = tf.concat((reshaped_generated_input, grate, reshaped_fiber_input, grate, reshaped_label, grate, reshaped_generated_data, grate), 2) uint8_images = tf.cast(reshaped_images*255, tf.uint8) image_write_ops = tf.write_file('%s/%s' % (evaldir, 'Generator_is_training_False.png'), tf.image.encode_png(uint8_images[0])) status_message = tf.string_join([' PSNR: ', tf.as_string(psnr), ' ', ' SSIM: ', tf.as_string(ssim), ' ', ' Correlation: ', tf.as_string(corr)], name='status_message') checkpoint_path = tf.train.latest_checkpoint(logdir) tf.logging.info('Evaluating %s' % checkpoint_path) tf.contrib.training.evaluate_once( checkpoint_path, hooks=[tf.contrib.training.SummaryAtEndHook(evaldir), tf.contrib.training.StopAfterNEvalsHook(50), tf.train.LoggingTensorHook([status_message],every_n_iter=5)], eval_ops=image_write_ops)
def TFGAN(inputs,targets): traindir = os.path.join(logdir, 'GG12\\PIX2PIX_MINMAX_1024') if tf.gfile.Exists(traindir): tf.gfile.DeleteRecursively(traindir) tf.gfile.MakeDirs(traindir) # Create a GANModel tuple. fiber_output, fiber_input = inputs encoder, label = targets real_data = tf.concat((label,fiber_input),-1) ####################################################################### ########################## GAN MODEL ################################# ####################################################################### gan_model = tfgan.gan_model( generator_fn=generator_fn, discriminator_fn=pix2pix_D, real_data=real_data, generator_inputs=fiber_output, generator_scope='Generator', discriminator_scope='Discriminator') ####################################################################### ########################## GAN SUMMARY ############################### ####################################################################### with tf.name_scope('Train_summary'): generated_data, generated_input = tf.split(gan_model.generated_data,2,-1) reshaped_fiber_input = get_summary_image(fiber_input, FLAGS.grid_size) reshaped_label = get_summary_image(label, FLAGS.grid_size) reshaped_generated_input = get_summary_image(generated_input, FLAGS.grid_size) reshaped_generated_data = get_summary_image(generated_data, FLAGS.grid_size) tf.summary.image('Input_Fiber', reshaped_fiber_input) tf.summary.image('Input_Generator', reshaped_generated_input) tf.summary.image('Data_Real', reshaped_label) tf.summary.image('Data_Generator', reshaped_generated_data) ####################################################################### ########################## GAN LOSS ################################# ####################################################################### with tf.name_scope('pixel_loss'): pixel_loss = combine_loss(gan_model.generated_data, gan_model.real_data, add_summary=True) with tf.name_scope('gan_loss'): gan_loss = tfgan.gan_loss( gan_model, generator_loss_fn=tfgan.losses.modified_generator_loss, discriminator_loss_fn=tfgan.losses.modified_discriminator_loss, gradient_penalty_weight=1.0, # only in wassertein_loss ) tfgan.eval.add_regularization_loss_summaries(gan_model) with tf.name_scope('Train_Loss'): gan_loss = tfgan.losses.combine_adversarial_loss( gan_loss, gan_model, pixel_loss, weight_factor=FLAGS.adversarial_loss_weight) ####################################################################### ########################## GAN OPS ################################ ####################################################################### with tf.name_scope('Train_ops'): gen_lr = get_lr(1e-5,decay_steps=5000) dis_lr = get_lr(5e-5,decay_steps=5000) train_ops = tfgan.gan_train_ops( gan_model, gan_loss, generator_optimizer=get_optimizer(gen_lr), discriminator_optimizer=get_optimizer(dis_lr), # summarize_gradients=False, # colocate_gradients_with_ops=True, # transform_grads_fn=tf.contrib.training.clip_gradient_norms_fn(1e3), # aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) ) psnr = tf.reduce_mean(tf.image.psnr(generated_data, label, max_val = 1.0)) ssim = tf.reduce_mean(tf.image.ssim(generated_data, label, max_val = 1.0)) corr = correlation(generated_data, label) tf.summary.scalar('PSNR', psnr) tf.summary.scalar('SSIM', ssim) tf.summary.scalar('Relation', corr) tf.summary.scalar('generator_lr', gen_lr) # tf.summary.scalar('discriminator_lr', dis_lr) ####################################################################### ########################## GAN TRAIN ############################## ####################################################################### train_steps = tfgan.GANTrainSteps(generator_train_steps=1, discriminator_train_steps=1) message = tf.string_join([' Train step: ', tf.as_string(tf.train.get_or_create_global_step()), ' PSNR:', tf.as_string(psnr), ' SSIM:', tf.as_string(ssim), ' Correlation:', tf.as_string(corr) ], name='status_message') tfgan.gan_train(train_ops, logdir = traindir, get_hooks_fn=tfgan.get_joint_train_hooks(train_steps), hooks=[tf.train.StopAtStepHook(num_steps=FLAGS.max_iter), tf.train.LoggingTensorHook([message], every_n_iter=FLAGS.log_n_steps), get_tfgan_init_fn('E:\GitHub\MMFI\log\\GG12\\CNN', 'Generator'), # get_tfgan_init_fn('E:\GitHub\MMFI\log\\G2\\pix2pix_D', 'Discriminator'), ], save_summaries_steps = FLAGS.save_summaries_steps*2, save_checkpoint_secs = FLAGS.save_interval_secs)