def main(_): if not tf.gfile.Exists(FLAGS.train_log_dir): tf.gfile.MakeDirs(FLAGS.train_log_dir) with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): with tf.name_scope('inputs'): images_x, images_y = data_provider.provide_custom_data( [ FLAGS.image_set_x_file_pattern, FLAGS.image_set_y_file_pattern ], batch_size=FLAGS.batch_size, patch_size=FLAGS.patch_size) # Set batch size for summaries. images_x.set_shape([FLAGS.batch_size, None, None, None]) images_y.set_shape([FLAGS.batch_size, None, None, None]) # Define CycleGAN model. cyclegan_model = _define_model(images_x, images_y) # Define CycleGAN loss. cyclegan_loss = tfgan.cyclegan_loss( cyclegan_model, cycle_consistency_loss_weight=FLAGS.cycle_consistency_loss_weight, tensor_pool_fn=tfgan.features.tensor_pool) # Define CycleGAN train ops. train_ops = _define_train_ops(cyclegan_model, cyclegan_loss) # Training train_steps = tfgan.GANTrainSteps(1, 1) status_message = tf.string_join([ 'Starting train step: ', tf.as_string(tf.train.get_or_create_global_step()) ], name='status_message') if not FLAGS.max_number_of_steps: return tfgan.gan_train( train_ops, FLAGS.train_log_dir, get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps), hooks=[ tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps), tf.train.LoggingTensorHook([status_message], every_n_iter=10) ], master=FLAGS.master, is_chief=FLAGS.task == 0)
def main(_): if not tf.gfile.Exists(FLAGS.train_log_dir): tf.gfile.MakeDirs(FLAGS.train_log_dir) with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): with tf.name_scope('inputs'): images_x, images_y = data_provider.provide_custom_data( [FLAGS.image_set_x_file_pattern, FLAGS.image_set_y_file_pattern], batch_size=FLAGS.batch_size, patch_size=FLAGS.patch_size) # Set batch size for summaries. images_x.set_shape([FLAGS.batch_size, None, None, None]) images_y.set_shape([FLAGS.batch_size, None, None, None]) # Define CycleGAN model. cyclegan_model = _define_model(images_x, images_y) # Define CycleGAN loss. cyclegan_loss = tfgan.cyclegan_loss( cyclegan_model, cycle_consistency_loss_weight=FLAGS.cycle_consistency_loss_weight, tensor_pool_fn=tfgan.features.tensor_pool) # Define CycleGAN train ops. train_ops = _define_train_ops(cyclegan_model, cyclegan_loss) # Training train_steps = tfgan.GANTrainSteps(1, 1) status_message = tf.string_join( [ 'Starting train step: ', tf.as_string(tf.train.get_or_create_global_step()) ], name='status_message') if not FLAGS.max_number_of_steps: return tfgan.gan_train( train_ops, FLAGS.train_log_dir, get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps), hooks=[ tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps), tf.train.LoggingTensorHook([status_message], every_n_iter=10) ], master=FLAGS.master, is_chief=FLAGS.task == 0)
def test_custom_data_provider(self): file_pattern = os.path.join(self.testdata_dir, '*.jpg') batch_size = 3 patch_size = 8 images_list = data_provider.provide_custom_data( [file_pattern, file_pattern], batch_size=batch_size, patch_size=patch_size) for images in images_list: self.assertListEqual([None, patch_size, patch_size, 3], images.shape.as_list()) self.assertEqual(tf.float32, images.dtype) with self.test_session(use_gpu=True) as sess: sess.run(tf.local_variables_initializer()) sess.run(tf.tables_initializer()) images_out_list = sess.run(images_list) for images_out in images_out_list: self.assertTupleEqual((batch_size, patch_size, patch_size, 3), images_out.shape) self.assertTrue(np.all(np.abs(images_out) <= 1.0))
def provide_data(image_file_patterns, batch_size, patch_size): """Data provider wrapper on for the data_provider in gan/cyclegan. Args: image_file_patterns: A list of file pattern globs. batch_size: Python int. Batch size. patch_size: Python int. The patch size to extract. Returns: List of `Tensor` of shape (N, H, W, C) representing the images. List of `Tensor` of shape (N, num_domains) representing the labels. """ images = data_provider.provide_custom_data( image_file_patterns, batch_size=batch_size, patch_size=patch_size) num_domains = len(images) labels = [tf.one_hot([idx] * batch_size, num_domains) for idx in range(num_domains)] return images, labels
def provide_data(image_file_patterns, batch_size, patch_size): """Data provider wrapper on for the data_provider in gan/cyclegan. Args: image_file_patterns: A list of file pattern globs. batch_size: Python int. Batch size. patch_size: Python int. The patch size to extract. Returns: List of `Tensor` of shape (N, H, W, C) representing the images. List of `Tensor` of shape (N, num_domains) representing the labels. """ images = data_provider.provide_custom_data(image_file_patterns, batch_size=batch_size, patch_size=patch_size) num_domains = len(images) labels = [ tf.one_hot([idx] * batch_size, num_domains) for idx in range(num_domains) ] return images, labels
def main(_): #safe编程,这里大家要注意一下,就是我们写的代码一定要避免出现这种OSbug,这里就是先判断文件夹是否存在,如果不存在就重现创建一下。 if not tf.gfile.Exists(gConfig['train_log_dir']): tf.gfile.MakeDirs(gConfig['train_log_dir']) with tf.device(tf.train.replica_device_setter(gConfig['ps_tasks'])): with tf.name_scope('inputs'): images_x, images_y = data_provider.provide_custom_data( [ gConfig['image_set_x_file_pattern'], gConfig['image_set_y_file_pattern'] ], batch_size=gConfig['batch_size'], patch_size=gConfig['patch_size']) # Set batch size for summaries. """ set_shape:转变维度,将images_x转换成我们所需要的维度。 tf.contrib.gan.cyclegan_loss: tf.contrib.gan.features.tensor_pool: tf.contrib.gan.GANTrainSteps: tf.train.StopAtStepHook: tf.train.LoggingTensorHook: tf.train.get_sequential_train_hooks: """ images_x.set_shape([gConfig['batch_size'], None, None, None]) images_y.set_shape([gConfig['batch_size'], None, None, None]) # Define CycleGAN model. cyclegan_model = _define_model(images_x, images_y) # Define CycleGAN loss. cyclegan_loss = tfgan.cyclegan_loss( cyclegan_model, cycle_consistency_loss_weight=gConfig[ 'cycle_consistency_loss_weight'], tensor_pool_fn=tfgan.features.tensor_pool) # Define CycleGAN train ops. train_ops = _define_train_ops(cyclegan_model, cyclegan_loss) # Training train_steps = tfgan.GANTrainSteps(1, 1) status_message = tf.string_join([ 'Starting train step: ', tf.as_string(tf.train.get_or_create_global_step()) ], name='status_message') if not gConfig['max_number_of_steps']: return tfgan.gan_train( train_ops, gConfig['train_log_dir'], get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps), hooks=[ tf.train.StopAtStepHook( num_steps=gConfig['max_number_of_steps']), tf.train.LoggingTensorHook([status_message], every_n_iter=10) ], master=gConfig['master'], is_chief=gConfig['task'] == 0)