def Discriminator(inputs,targets): traindir = os.path.join(logdir, 'G2\\pix2pix_D') if tf.gfile.Exists(traindir): tf.gfile.DeleteRecursively(traindir) tf.gfile.MakeDirs(traindir) fiber_output,fiber_input = inputs encoder, label = targets with tf.variable_scope('Generator'): with tf.variable_scope('G2'): generated_data = pix2pix_G(fiber_input) * circle(FLAGS.input_size,FLAGS.input_size) with tf.variable_scope('Discriminator',reuse=tf.AUTO_REUSE): discriminator_gen_outputs = pix2pix_D(tf.concat((generated_data,fiber_input),-1)) discriminator_real_outputs = pix2pix_D(tf.concat((label, fiber_input), -1)) with tf.name_scope('Train_summary'): reshaped_label = get_summary_image(label,FLAGS.grid_size) reshaped_fiber_input = get_summary_image(fiber_input,FLAGS.grid_size) reshaped_generated_data = get_summary_image(generated_data,FLAGS.grid_size) tf.summary.image('Fiber_Label', reshaped_label) tf.summary.image('Fiber_Input', reshaped_fiber_input) tf.summary.image('Generated_Data', reshaped_generated_data) with tf.name_scope('Train_Loss'): predict_real = discriminator_real_outputs predict_fake = discriminator_gen_outputs discrim_real_loss = tf.reduce_mean(tf.abs(1-predict_real)) discrim_gen_loss = tf.reduce_mean(tf.abs(-1-predict_fake)) discrim_loss = discrim_real_loss + discrim_gen_loss total_loss = discrim_loss + tf.losses.get_regularization_loss() total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.') tf.summary.scalar('Total_loss',total_loss) tf.summary.scalar('discrim_loss', discrim_loss) tf.summary.scalar('discrim_real_loss',discrim_real_loss) tf.summary.scalar('discrim_gen_loss',discrim_gen_loss) with tf.name_scope('Train_OP'): tf.summary.scalar('predict_real', tf.reduce_mean(predict_real)) tf.summary.scalar('predict_fake', tf.reduce_mean(predict_fake)) tf.summary.scalar('discrim_lr', get_lr(FLAGS.discriminator_lr,decay_steps=5000)) train_op = slim.learning.create_train_op(total_loss, get_optimizer(get_lr(FLAGS.discriminator_lr,decay_steps=5000)), update_ops =tf.get_collection(tf.GraphKeys.UPDATE_OPS), variables_to_train= tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Discriminator') ) slim.learning.train(train_op, traindir, number_of_steps =FLAGS.max_iter, log_every_n_steps=FLAGS.log_n_steps, init_fn=get_init_fn('E:\GitHub\MMFI\log\\G2\\pix2pix_G', inclusion_scope=['Generator/G2']), save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs = FLAGS.save_interval_secs)
def Generator_2(inputs,targets): traindir = os.path.join(logdir, 'G2\\pix2pix_G') if tf.gfile.Exists(traindir): tf.gfile.DeleteRecursively(traindir) tf.gfile.MakeDirs(traindir) fiber_output,fiber_input = inputs encoder, label = targets with tf.variable_scope('Generator'): with tf.variable_scope('G2'): generated_data = pix2pix_G(fiber_input) * circle(FLAGS.input_size,FLAGS.input_size) with tf.name_scope('Train_summary'): reshaped_fiber_input = get_summary_image(fiber_input,FLAGS.grid_size) reshaped_label = get_summary_image(label,FLAGS.grid_size) reshaped_generated_data = get_summary_image(generated_data,FLAGS.grid_size) tf.summary.image('Fiber_Input', reshaped_fiber_input) tf.summary.image('Fiber_Label', reshaped_label) tf.summary.image('Generated_Data', reshaped_generated_data) with tf.name_scope('g2_loss'): G2_loss = combine_loss(generated_data, label, add_summary=True) with tf.name_scope('Train_Loss'): reg_loss = tf.losses.get_regularization_loss() total_loss = G2_loss + reg_loss total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.') tf.summary.scalar('Regularization_loss',reg_loss) tf.summary.scalar('G2_loss', G2_loss) tf.summary.scalar('Total_loss',total_loss) lr = get_lr(FLAGS.generator_lr) optimizer = get_optimizer(lr) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) train_op = slim.learning.create_train_op(total_loss, optimizer, update_ops =update_ops, variables_to_train= tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='Generator/G2') ) with tf.name_scope('Train_ops'): psnr = tf.reduce_mean(tf.image.psnr(generated_data, label, max_val=1.0)) ssim = tf.reduce_mean(tf.image.ssim(generated_data, label, max_val=1.0)) corr = correlation(generated_data, label) tf.summary.scalar('PSNR', psnr) tf.summary.scalar('SSIM', ssim) tf.summary.scalar('Relation', corr) tf.summary.scalar('Learning_rate', lr) slim.learning.train(train_op, traindir, number_of_steps =FLAGS.max_iter, log_every_n_steps=FLAGS.log_n_steps, # init_fn=get_init_fn('E:\GitHub\MMFI\log\\G2\\pix2pix_G'), save_summaries_secs=FLAGS.save_summaries_secs, save_interval_secs = FLAGS.save_interval_secs)
def main(_): tf.logging.set_verbosity(tf.logging.INFO) with tf.Graph().as_default(): logdir = 'E:\GitHub\MMFI\log\GG12\\CNN' evaldir = os.path.join(logdir, 'eval') if not tf.gfile.Exists(evaldir): # tf.gfile.DeleteRecursively(evaldir) tf.gfile.MakeDirs(evaldir) with tf.name_scope('inputs'): fiber_output, fiber_input, encoder, label = data_loader.read_inputs('valid.txt', False) with tf.variable_scope('Generator'): with tf.variable_scope('G1'): generated_input = pix2pix_G(fiber_output, is_training=False) \ * circle(FLAGS.input_size,FLAGS.input_size) with tf.variable_scope('G2'): generated_data = pix2pix_G(generated_input,is_training=False)\ * circle(FLAGS.input_size,FLAGS.input_size) with tf.name_scope('Valid_summary'): reshaped_fiber_input = get_summary_image(fiber_input, FLAGS.grid_size) reshaped_label = get_summary_image(label, FLAGS.grid_size) reshaped_generated_input = get_summary_image(generated_input, FLAGS.grid_size) reshaped_generated_data = get_summary_image(generated_data, FLAGS.grid_size) tf.summary.image('Input_Fiber', reshaped_fiber_input) tf.summary.image('Input_Generator', reshaped_generated_input) tf.summary.image('Data_Real', reshaped_label) tf.summary.image('Data_Generator', reshaped_generated_data) with tf.name_scope('Valid_op'): psnr = tf.reduce_mean(tf.image.psnr(generated_data, label, max_val=1.0)) ssim = tf.reduce_mean(tf.image.ssim(generated_data, label, max_val=1.0)) corr = correlation(generated_data, label) # inception_score = get_inception_score(generated_data) tf.summary.scalar('PSNR', psnr) tf.summary.scalar('SSIM', ssim) tf.summary.scalar('Relation', corr) grate = tf.ones([1,FLAGS.grid_size*FLAGS.input_size,10,1],dtype=tf.float32) reshaped_images = tf.concat((reshaped_generated_input, grate, reshaped_fiber_input, grate, reshaped_label, grate, reshaped_generated_data, grate), 2) uint8_images = tf.cast(reshaped_images*255, tf.uint8) image_write_ops = tf.write_file('%s/%s' % (evaldir, 'Generator_is_training_False.png'), tf.image.encode_png(uint8_images[0])) status_message = tf.string_join([' PSNR: ', tf.as_string(psnr), ' ', ' SSIM: ', tf.as_string(ssim), ' ', ' Correlation: ', tf.as_string(corr)], name='status_message') checkpoint_path = tf.train.latest_checkpoint(logdir) tf.logging.info('Evaluating %s' % checkpoint_path) tf.contrib.training.evaluate_once( checkpoint_path, hooks=[tf.contrib.training.SummaryAtEndHook(evaldir), tf.contrib.training.StopAfterNEvalsHook(50), tf.train.LoggingTensorHook([status_message],every_n_iter=5)], eval_ops=image_write_ops)