Exemplo n.º 1
0
def Generator_2(inputs,targets):

  traindir = os.path.join(logdir, 'G2\\pix2pix_G')
  if tf.gfile.Exists(traindir):
    tf.gfile.DeleteRecursively(traindir)
  tf.gfile.MakeDirs(traindir)

  fiber_output,fiber_input = inputs
  encoder, label = targets

  with tf.variable_scope('Generator'):
    with tf.variable_scope('G2'):
      generated_data = pix2pix_G(fiber_input) * circle(FLAGS.input_size,FLAGS.input_size)

  with tf.name_scope('Train_summary'):
    reshaped_fiber_input = get_summary_image(fiber_input,FLAGS.grid_size)
    reshaped_label = get_summary_image(label,FLAGS.grid_size)
    reshaped_generated_data = get_summary_image(generated_data,FLAGS.grid_size)
    tf.summary.image('Fiber_Input', reshaped_fiber_input)
    tf.summary.image('Fiber_Label', reshaped_label)
    tf.summary.image('Generated_Data', reshaped_generated_data)

  with tf.name_scope('g2_loss'):
    G2_loss = combine_loss(generated_data, label, add_summary=True)
  with tf.name_scope('Train_Loss'):
    reg_loss = tf.losses.get_regularization_loss()
    total_loss = G2_loss + reg_loss
    total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
    tf.summary.scalar('Regularization_loss',reg_loss)
    tf.summary.scalar('G2_loss', G2_loss)
    tf.summary.scalar('Total_loss',total_loss)

  lr = get_lr(FLAGS.generator_lr)
  optimizer = get_optimizer(lr)
  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  train_op = slim.learning.create_train_op(total_loss, optimizer, update_ops =update_ops,
                                           variables_to_train=
                                           tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                                             scope='Generator/G2')
                                           )

  with tf.name_scope('Train_ops'):
    psnr = tf.reduce_mean(tf.image.psnr(generated_data, label, max_val=1.0))
    ssim = tf.reduce_mean(tf.image.ssim(generated_data, label, max_val=1.0))
    corr = correlation(generated_data, label)
    tf.summary.scalar('PSNR', psnr)
    tf.summary.scalar('SSIM', ssim)
    tf.summary.scalar('Relation', corr)
    tf.summary.scalar('Learning_rate', lr)

  slim.learning.train(train_op, traindir,
                      number_of_steps =FLAGS.max_iter,
                      log_every_n_steps=FLAGS.log_n_steps,
                      # init_fn=get_init_fn('E:\GitHub\MMFI\log\\G2\\pix2pix_G'),
                      save_summaries_secs=FLAGS.save_summaries_secs,
                      save_interval_secs = FLAGS.save_interval_secs)
Exemplo n.º 2
0
def main(_):
	tf.logging.set_verbosity(tf.logging.INFO)
	with tf.Graph().as_default():
		logdir = 'E:\GitHub\MMFI\log\GG12\\CNN'
		evaldir = os.path.join(logdir, 'eval')
		if not tf.gfile.Exists(evaldir):
			# tf.gfile.DeleteRecursively(evaldir)
			tf.gfile.MakeDirs(evaldir)

		with tf.name_scope('inputs'):
			fiber_output, fiber_input, encoder, label = data_loader.read_inputs('valid.txt', False)

		with tf.variable_scope('Generator'):
			with tf.variable_scope('G1'):
				generated_input = pix2pix_G(fiber_output, is_training=False) \
				                  * circle(FLAGS.input_size,FLAGS.input_size)
			with tf.variable_scope('G2'):
				generated_data = pix2pix_G(generated_input,is_training=False)\
				                 * circle(FLAGS.input_size,FLAGS.input_size)

		with tf.name_scope('Valid_summary'):
			reshaped_fiber_input = get_summary_image(fiber_input, FLAGS.grid_size)
			reshaped_label = get_summary_image(label, FLAGS.grid_size)
			reshaped_generated_input = get_summary_image(generated_input, FLAGS.grid_size)
			reshaped_generated_data = get_summary_image(generated_data, FLAGS.grid_size)
			tf.summary.image('Input_Fiber', reshaped_fiber_input)
			tf.summary.image('Input_Generator', reshaped_generated_input)
			tf.summary.image('Data_Real', reshaped_label)
			tf.summary.image('Data_Generator', reshaped_generated_data)

		with tf.name_scope('Valid_op'):
			psnr = tf.reduce_mean(tf.image.psnr(generated_data, label, max_val=1.0))
			ssim = tf.reduce_mean(tf.image.ssim(generated_data, label, max_val=1.0))
			corr = correlation(generated_data, label)
			# inception_score = get_inception_score(generated_data)

			tf.summary.scalar('PSNR', psnr)
			tf.summary.scalar('SSIM', ssim)
			tf.summary.scalar('Relation', corr)

			grate = tf.ones([1,FLAGS.grid_size*FLAGS.input_size,10,1],dtype=tf.float32)
			reshaped_images = tf.concat((reshaped_generated_input, grate,
			                             reshaped_fiber_input, grate,
			                             reshaped_label, grate,
			                             reshaped_generated_data, grate), 2)
			uint8_images = tf.cast(reshaped_images*255, tf.uint8)
			image_write_ops = tf.write_file('%s/%s' % (evaldir, 'Generator_is_training_False.png'),
			                                tf.image.encode_png(uint8_images[0]))

			status_message = tf.string_join([' PSNR: ', tf.as_string(psnr), ' ',
			                                 ' SSIM: ', tf.as_string(ssim), ' ',
			                                 ' Correlation: ', tf.as_string(corr)],
			                                name='status_message')


		checkpoint_path = tf.train.latest_checkpoint(logdir)
		tf.logging.info('Evaluating %s' % checkpoint_path)

		tf.contrib.training.evaluate_once(
			checkpoint_path,
			hooks=[tf.contrib.training.SummaryAtEndHook(evaldir),
			       tf.contrib.training.StopAfterNEvalsHook(50),
						 tf.train.LoggingTensorHook([status_message],every_n_iter=5)],
			eval_ops=image_write_ops)
Exemplo n.º 3
0
def TFGAN(inputs,targets):

    traindir = os.path.join(logdir, 'GG12\\PIX2PIX_MINMAX_1024')
    if tf.gfile.Exists(traindir):
      tf.gfile.DeleteRecursively(traindir)
    tf.gfile.MakeDirs(traindir)

    # Create a GANModel tuple.
    fiber_output, fiber_input = inputs
    encoder, label = targets
    real_data = tf.concat((label,fiber_input),-1)
    #######################################################################
    ##########################  GAN MODEL #################################
    #######################################################################
    gan_model = tfgan.gan_model(
        generator_fn=generator_fn,
        discriminator_fn=pix2pix_D,
        real_data=real_data,
        generator_inputs=fiber_output,
        generator_scope='Generator',
        discriminator_scope='Discriminator')

    #######################################################################
    ##########################  GAN SUMMARY ###############################
    #######################################################################
    with tf.name_scope('Train_summary'):
      generated_data, generated_input = tf.split(gan_model.generated_data,2,-1)
      reshaped_fiber_input = get_summary_image(fiber_input, FLAGS.grid_size)
      reshaped_label = get_summary_image(label, FLAGS.grid_size)
      reshaped_generated_input = get_summary_image(generated_input, FLAGS.grid_size)
      reshaped_generated_data = get_summary_image(generated_data, FLAGS.grid_size)
      tf.summary.image('Input_Fiber', reshaped_fiber_input)
      tf.summary.image('Input_Generator', reshaped_generated_input)
      tf.summary.image('Data_Real', reshaped_label)
      tf.summary.image('Data_Generator', reshaped_generated_data)

    #######################################################################
    ##########################  GAN LOSS  #################################
    #######################################################################
    with tf.name_scope('pixel_loss'):
      pixel_loss = combine_loss(gan_model.generated_data,
                                gan_model.real_data,
                                add_summary=True)
    with tf.name_scope('gan_loss'):
      gan_loss = tfgan.gan_loss(
        gan_model,
        generator_loss_fn=tfgan.losses.modified_generator_loss,
        discriminator_loss_fn=tfgan.losses.modified_discriminator_loss,
        gradient_penalty_weight=1.0, # only in wassertein_loss
      )
      tfgan.eval.add_regularization_loss_summaries(gan_model)
    with tf.name_scope('Train_Loss'):
      gan_loss = tfgan.losses.combine_adversarial_loss(
          gan_loss, gan_model, pixel_loss,
          weight_factor=FLAGS.adversarial_loss_weight)

    #######################################################################
    ##########################   GAN OPS   ################################
    #######################################################################
    with tf.name_scope('Train_ops'):
      gen_lr = get_lr(1e-5,decay_steps=5000)
      dis_lr = get_lr(5e-5,decay_steps=5000)
      train_ops = tfgan.gan_train_ops(
          gan_model,  gan_loss,
          generator_optimizer=get_optimizer(gen_lr),
          discriminator_optimizer=get_optimizer(dis_lr),
          # summarize_gradients=False,
          # colocate_gradients_with_ops=True,
          # transform_grads_fn=tf.contrib.training.clip_gradient_norms_fn(1e3),
          # aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
          )
      psnr = tf.reduce_mean(tf.image.psnr(generated_data, label, max_val = 1.0))
      ssim = tf.reduce_mean(tf.image.ssim(generated_data, label, max_val = 1.0))
      corr = correlation(generated_data, label)
      tf.summary.scalar('PSNR', psnr)
      tf.summary.scalar('SSIM', ssim)
      tf.summary.scalar('Relation', corr)
      tf.summary.scalar('generator_lr', gen_lr)
      # tf.summary.scalar('discriminator_lr', dis_lr)

    #######################################################################
    ##########################   GAN TRAIN   ##############################
    #######################################################################
    train_steps = tfgan.GANTrainSteps(generator_train_steps=1, discriminator_train_steps=1)
    message = tf.string_join([' Train step: ', tf.as_string(tf.train.get_or_create_global_step()),
                              '   PSNR:', tf.as_string(psnr), '   SSIM:', tf.as_string(ssim),
                              '   Correlation:', tf.as_string(corr)
                              ], name='status_message')

    tfgan.gan_train(train_ops, logdir = traindir,  get_hooks_fn=tfgan.get_joint_train_hooks(train_steps),
                    hooks=[tf.train.StopAtStepHook(num_steps=FLAGS.max_iter),
                           tf.train.LoggingTensorHook([message], every_n_iter=FLAGS.log_n_steps),
                           get_tfgan_init_fn('E:\GitHub\MMFI\log\\GG12\\CNN', 'Generator'),
                           # get_tfgan_init_fn('E:\GitHub\MMFI\log\\G2\\pix2pix_D', 'Discriminator'),
                           ],
                    save_summaries_steps = FLAGS.save_summaries_steps*2,
                    save_checkpoint_secs = FLAGS.save_interval_secs)