예제 #1
0
파일: train.py 프로젝트: qingnengli/MMFI
def Discriminator(inputs,targets):

  traindir = os.path.join(logdir, 'G2\\pix2pix_D')
  if tf.gfile.Exists(traindir):
    tf.gfile.DeleteRecursively(traindir)
  tf.gfile.MakeDirs(traindir)

  fiber_output,fiber_input = inputs
  encoder, label = targets

  with tf.variable_scope('Generator'):
    with tf.variable_scope('G2'):
      generated_data = pix2pix_G(fiber_input) * circle(FLAGS.input_size,FLAGS.input_size)

  with tf.variable_scope('Discriminator',reuse=tf.AUTO_REUSE):
      discriminator_gen_outputs = pix2pix_D(tf.concat((generated_data,fiber_input),-1))
      discriminator_real_outputs = pix2pix_D(tf.concat((label, fiber_input), -1))

  with tf.name_scope('Train_summary'):
    reshaped_label = get_summary_image(label,FLAGS.grid_size)
    reshaped_fiber_input = get_summary_image(fiber_input,FLAGS.grid_size)
    reshaped_generated_data = get_summary_image(generated_data,FLAGS.grid_size)
    tf.summary.image('Fiber_Label', reshaped_label)
    tf.summary.image('Fiber_Input', reshaped_fiber_input)
    tf.summary.image('Generated_Data', reshaped_generated_data)

  with tf.name_scope('Train_Loss'):
    predict_real = discriminator_real_outputs
    predict_fake = discriminator_gen_outputs
    discrim_real_loss = tf.reduce_mean(tf.abs(1-predict_real))
    discrim_gen_loss = tf.reduce_mean(tf.abs(-1-predict_fake))
    discrim_loss = discrim_real_loss + discrim_gen_loss
    total_loss = discrim_loss +  tf.losses.get_regularization_loss()
    total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
    tf.summary.scalar('Total_loss',total_loss)
    tf.summary.scalar('discrim_loss', discrim_loss)
    tf.summary.scalar('discrim_real_loss',discrim_real_loss)
    tf.summary.scalar('discrim_gen_loss',discrim_gen_loss)

  with tf.name_scope('Train_OP'):
    tf.summary.scalar('predict_real', tf.reduce_mean(predict_real))
    tf.summary.scalar('predict_fake', tf.reduce_mean(predict_fake))
    tf.summary.scalar('discrim_lr', get_lr(FLAGS.discriminator_lr,decay_steps=5000))

  train_op = slim.learning.create_train_op(total_loss,
                                           get_optimizer(get_lr(FLAGS.discriminator_lr,decay_steps=5000)),
                                           update_ops =tf.get_collection(tf.GraphKeys.UPDATE_OPS),
                                           variables_to_train=
                                           tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                                                          scope='Discriminator')
                                           )

  slim.learning.train(train_op, traindir,
                      number_of_steps =FLAGS.max_iter,
                      log_every_n_steps=FLAGS.log_n_steps,
                      init_fn=get_init_fn('E:\GitHub\MMFI\log\\G2\\pix2pix_G',
                                          inclusion_scope=['Generator/G2']),
                      save_summaries_secs=FLAGS.save_summaries_secs,
                      save_interval_secs = FLAGS.save_interval_secs)
예제 #2
0
파일: train.py 프로젝트: qingnengli/MMFI
def Generator_2(inputs,targets):

  traindir = os.path.join(logdir, 'G2\\pix2pix_G')
  if tf.gfile.Exists(traindir):
    tf.gfile.DeleteRecursively(traindir)
  tf.gfile.MakeDirs(traindir)

  fiber_output,fiber_input = inputs
  encoder, label = targets

  with tf.variable_scope('Generator'):
    with tf.variable_scope('G2'):
      generated_data = pix2pix_G(fiber_input) * circle(FLAGS.input_size,FLAGS.input_size)

  with tf.name_scope('Train_summary'):
    reshaped_fiber_input = get_summary_image(fiber_input,FLAGS.grid_size)
    reshaped_label = get_summary_image(label,FLAGS.grid_size)
    reshaped_generated_data = get_summary_image(generated_data,FLAGS.grid_size)
    tf.summary.image('Fiber_Input', reshaped_fiber_input)
    tf.summary.image('Fiber_Label', reshaped_label)
    tf.summary.image('Generated_Data', reshaped_generated_data)

  with tf.name_scope('g2_loss'):
    G2_loss = combine_loss(generated_data, label, add_summary=True)
  with tf.name_scope('Train_Loss'):
    reg_loss = tf.losses.get_regularization_loss()
    total_loss = G2_loss + reg_loss
    total_loss = tf.check_numerics(total_loss, 'Loss is inf or nan.')
    tf.summary.scalar('Regularization_loss',reg_loss)
    tf.summary.scalar('G2_loss', G2_loss)
    tf.summary.scalar('Total_loss',total_loss)

  lr = get_lr(FLAGS.generator_lr)
  optimizer = get_optimizer(lr)
  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  train_op = slim.learning.create_train_op(total_loss, optimizer, update_ops =update_ops,
                                           variables_to_train=
                                           tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                                             scope='Generator/G2')
                                           )

  with tf.name_scope('Train_ops'):
    psnr = tf.reduce_mean(tf.image.psnr(generated_data, label, max_val=1.0))
    ssim = tf.reduce_mean(tf.image.ssim(generated_data, label, max_val=1.0))
    corr = correlation(generated_data, label)
    tf.summary.scalar('PSNR', psnr)
    tf.summary.scalar('SSIM', ssim)
    tf.summary.scalar('Relation', corr)
    tf.summary.scalar('Learning_rate', lr)

  slim.learning.train(train_op, traindir,
                      number_of_steps =FLAGS.max_iter,
                      log_every_n_steps=FLAGS.log_n_steps,
                      # init_fn=get_init_fn('E:\GitHub\MMFI\log\\G2\\pix2pix_G'),
                      save_summaries_secs=FLAGS.save_summaries_secs,
                      save_interval_secs = FLAGS.save_interval_secs)
예제 #3
0
def main(_):
	tf.logging.set_verbosity(tf.logging.INFO)
	with tf.Graph().as_default():
		logdir = 'E:\GitHub\MMFI\log\GG12\\CNN'
		evaldir = os.path.join(logdir, 'eval')
		if not tf.gfile.Exists(evaldir):
			# tf.gfile.DeleteRecursively(evaldir)
			tf.gfile.MakeDirs(evaldir)

		with tf.name_scope('inputs'):
			fiber_output, fiber_input, encoder, label = data_loader.read_inputs('valid.txt', False)

		with tf.variable_scope('Generator'):
			with tf.variable_scope('G1'):
				generated_input = pix2pix_G(fiber_output, is_training=False) \
				                  * circle(FLAGS.input_size,FLAGS.input_size)
			with tf.variable_scope('G2'):
				generated_data = pix2pix_G(generated_input,is_training=False)\
				                 * circle(FLAGS.input_size,FLAGS.input_size)

		with tf.name_scope('Valid_summary'):
			reshaped_fiber_input = get_summary_image(fiber_input, FLAGS.grid_size)
			reshaped_label = get_summary_image(label, FLAGS.grid_size)
			reshaped_generated_input = get_summary_image(generated_input, FLAGS.grid_size)
			reshaped_generated_data = get_summary_image(generated_data, FLAGS.grid_size)
			tf.summary.image('Input_Fiber', reshaped_fiber_input)
			tf.summary.image('Input_Generator', reshaped_generated_input)
			tf.summary.image('Data_Real', reshaped_label)
			tf.summary.image('Data_Generator', reshaped_generated_data)

		with tf.name_scope('Valid_op'):
			psnr = tf.reduce_mean(tf.image.psnr(generated_data, label, max_val=1.0))
			ssim = tf.reduce_mean(tf.image.ssim(generated_data, label, max_val=1.0))
			corr = correlation(generated_data, label)
			# inception_score = get_inception_score(generated_data)

			tf.summary.scalar('PSNR', psnr)
			tf.summary.scalar('SSIM', ssim)
			tf.summary.scalar('Relation', corr)

			grate = tf.ones([1,FLAGS.grid_size*FLAGS.input_size,10,1],dtype=tf.float32)
			reshaped_images = tf.concat((reshaped_generated_input, grate,
			                             reshaped_fiber_input, grate,
			                             reshaped_label, grate,
			                             reshaped_generated_data, grate), 2)
			uint8_images = tf.cast(reshaped_images*255, tf.uint8)
			image_write_ops = tf.write_file('%s/%s' % (evaldir, 'Generator_is_training_False.png'),
			                                tf.image.encode_png(uint8_images[0]))

			status_message = tf.string_join([' PSNR: ', tf.as_string(psnr), ' ',
			                                 ' SSIM: ', tf.as_string(ssim), ' ',
			                                 ' Correlation: ', tf.as_string(corr)],
			                                name='status_message')


		checkpoint_path = tf.train.latest_checkpoint(logdir)
		tf.logging.info('Evaluating %s' % checkpoint_path)

		tf.contrib.training.evaluate_once(
			checkpoint_path,
			hooks=[tf.contrib.training.SummaryAtEndHook(evaldir),
			       tf.contrib.training.StopAfterNEvalsHook(50),
						 tf.train.LoggingTensorHook([status_message],every_n_iter=5)],
			eval_ops=image_write_ops)