def train(model, **kwargs): """Trains progressive GAN for stage `stage_id`. Args: model: An model object having all information of progressive GAN model, e.g. the return of build_model(). **kwargs: A dictionary of 'train_log_dir': A string of root directory of training logs. 'master': Name of the TensorFlow master to use. 'task': The Task ID. This value is used when training with multiple workers to identify each worker. 'save_summaries_num_images': Save summaries in this number of images. Returns: None. """ logging.info('stage_id=%d, num_blocks=%d, num_images=%d', model.stage_id, model.num_blocks, model.num_images) scaffold = make_scaffold(model.stage_id, model.optimizer_var_list, **kwargs) tfgan.gan_train( model.gan_train_ops, logdir=make_train_sub_dir(model.stage_id, **kwargs), get_hooks_fn=tfgan.get_sequential_train_hooks(tfgan.GANTrainSteps(1, 1)), hooks=[ tf.estimator.StopAtStepHook(last_step=model.num_images), tf.estimator.LoggingTensorHook([make_status_message(model)], every_n_iter=10) ], master=kwargs['master'], is_chief=(kwargs['task'] == 0), scaffold=scaffold, save_checkpoint_secs=600, save_summaries_steps=(kwargs['save_summaries_num_images']))
def train(hparams): """Trains a StarGAN. Args: hparams: An HParams instance containing the hyperparameters for training. """ # Create the log_dir if not exist. if not tf.io.gfile.exists(hparams.train_log_dir): tf.io.gfile.makedirs(hparams.train_log_dir) # Shard the model to different parameter servers. with tf.device(tf.compat.v1.train.replica_device_setter(hparams.ps_replicas)): # Create the input dataset. with tf.compat.v1.name_scope('inputs'), tf.device('/cpu:0'): images, labels = data_provider.provide_data('train', hparams.batch_size, hparams.patch_size) # Define the model. with tf.compat.v1.name_scope('model'): model = _define_model(images, labels) # Add image summary. tfgan.eval.add_stargan_image_summaries( model, num_images=3 * hparams.batch_size, display_diffs=True) # Define the model loss. loss = tfgan.stargan_loss(model) # Define the train ops. with tf.compat.v1.name_scope('train_ops'): train_ops = _define_train_ops(model, loss, hparams.generator_lr, hparams.discriminator_lr, hparams.adam_beta1, hparams.adam_beta2, hparams.max_number_of_steps) # Define the train steps. train_steps = _define_train_step(hparams.gen_disc_step_ratio) # Define a status message. status_message = tf.strings.join([ 'Starting train step: ', tf.as_string(tf.compat.v1.train.get_or_create_global_step()) ], name='status_message') # Train the model. tfgan.gan_train( train_ops, hparams.train_log_dir, get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps), hooks=[ tf.estimator.StopAtStepHook(num_steps=hparams.max_number_of_steps), tf.estimator.LoggingTensorHook([status_message], every_n_iter=10) ], master=hparams.tf_master, is_chief=hparams.task == 0)
def train(hparams): """Trains a CycleGAN. Args: hparams: An HParams instance containing the hyperparameters for training. """ if not tf.io.gfile.exists(hparams.train_log_dir): tf.io.gfile.makedirs(hparams.train_log_dir) with tf.device( tf.compat.v1.train.replica_device_setter(hparams.ps_replicas)): with tf.compat.v1.name_scope('inputs'), tf.device('/cpu:0'): images_x, images_y = _get_data(hparams.image_set_x_file_pattern, hparams.image_set_y_file_pattern, hparams.batch_size, hparams.patch_size) # Define CycleGAN model. cyclegan_model = _define_model(images_x, images_y) # Define CycleGAN loss. cyclegan_loss = tfgan.cyclegan_loss( cyclegan_model, cycle_consistency_loss_weight=hparams. cycle_consistency_loss_weight, tensor_pool_fn=tfgan.features.tensor_pool) # Define CycleGAN train ops. train_ops = _define_train_ops(cyclegan_model, cyclegan_loss, hparams) # Training train_steps = tfgan.GANTrainSteps(1, 1) status_message = tf.strings.join([ 'Starting train step: ', tf.as_string(tf.compat.v1.train.get_or_create_global_step()) ], name='status_message') if not hparams.max_number_of_steps: return tfgan.gan_train( train_ops, hparams.train_log_dir, get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps), hooks=[ tf.estimator.StopAtStepHook( num_steps=hparams.max_number_of_steps), tf.estimator.LoggingTensorHook( {'status_message': status_message}, every_n_iter=10) ], master=hparams.master, is_chief=hparams.task == 0)
def main(_): if not tf.gfile.Exists(FLAGS.train_log_dir): tf.gfile.MakeDirs(FLAGS.train_log_dir) with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): with tf.name_scope('inputs'): initializer_hook = load_op(FLAGS.batch_size, FLAGS.max_number_of_steps) training_input_iter = initializer_hook.input_itr images_x, images_y = training_input_iter.get_next() # Set batch size for summaries. # images_x.set_shape([FLAGS.batch_size, None, None, None]) # images_y.set_shape([FLAGS.batch_size, None, None, None]) # Define CycleGAN model. cyclegan_model = _define_model(images_x, images_y) # Define CycleGAN loss. cyclegan_loss = tfgan.cyclegan_loss( cyclegan_model, cycle_consistency_loss_weight=FLAGS.cycle_consistency_loss_weight, tensor_pool_fn=tfgan.features.tensor_pool) # Define CycleGAN train ops. train_ops = _define_train_ops(cyclegan_model, cyclegan_loss) # Training train_steps = tfgan.GANTrainSteps(1, 1) status_message = tf.string_join( [ 'Starting train step: ', tf.as_string(tf.train.get_or_create_global_step()) ], name='status_message') if not FLAGS.max_number_of_steps: return tfgan.gan_train( train_ops, FLAGS.train_log_dir, save_checkpoint_secs=60*10, get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps), hooks=[ initializer_hook, tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps), tf.train.LoggingTensorHook([status_message], every_n_iter=10) ], master=FLAGS.master, is_chief=FLAGS.task == 0)
def train(model, **kwargs): """Trains progressive GAN for stage `stage_id`. Args: model: An model object having all information of progressive GAN model, e.g. the return of build_model(). **kwargs: A dictionary of 'train_root_dir': A string of root directory of training logs. 'master': Name of the TensorFlow master to use. 'task': The Task ID. This value is used when training with multiple workers to identify each worker. 'save_summaries_num_images': Save summaries in this number of images. 'debug_hook': Whether to attach the debug hook to the training session. Returns: None. """ logging.info('stage_id=%d, num_blocks=%d, num_images=%d', model.stage_id, model.num_blocks, model.num_images) scaffold = make_scaffold(model.stage_id, model.optimizer_var_list, **kwargs) logdir = make_train_sub_dir(model.stage_id, **kwargs) print('starting training, logdir: {}'.format(logdir)) hooks = [] if model.stage_train_time_limit is None: hooks.append(tf.train.StopAtStepHook(last_step=model.num_images)) hooks.append( tf.train.LoggingTensorHook([make_status_message(model)], every_n_iter=1)) hooks.append(TrainTimeHook(model.train_time, model.stage_train_time_limit)) if kwargs['debug_hook']: hooks.append(ProganDebugHook()) tfgan.gan_train(model.gan_train_ops, logdir=logdir, get_hooks_fn=tfgan.get_sequential_train_hooks( tfgan.GANTrainSteps(1, 1)), hooks=hooks, master=kwargs['master'], is_chief=(kwargs['task'] == 0), scaffold=scaffold, save_checkpoint_secs=600, save_summaries_steps=(kwargs['save_summaries_num_images']))
def test_multiple_steps(self, get_hooks_fn_fn): """Test multiple train steps.""" if tf.executing_eagerly(): # None of the usual utilities work in eager. return train_ops = self._gan_train_ops(generator_add=10, discriminator_add=100) train_steps = tfgan.GANTrainSteps( generator_train_steps=3, discriminator_train_steps=4) final_step = tfgan.gan_train( train_ops, get_hooks_fn=get_hooks_fn_fn(train_steps), logdir='', hooks=[tf.estimator.StopAtStepHook(num_steps=1)]) self.assertTrue(np.isscalar(final_step)) self.assertEqual(1 + 3 * 10 + 4 * 100, final_step)
def test_run_helper(self, create_gan_model_fn): if tf.executing_eagerly(): # None of the usual utilities work in eager. return tf.compat.v1.random.set_random_seed(1234) model = create_gan_model_fn() loss = tfgan.gan_loss(model) g_opt = tf.compat.v1.train.GradientDescentOptimizer(1.0) d_opt = tf.compat.v1.train.GradientDescentOptimizer(1.0) train_ops = tfgan.gan_train_ops(model, loss, g_opt, d_opt) final_step = tfgan.gan_train( train_ops, logdir='', hooks=[tf.estimator.StopAtStepHook(num_steps=2)]) self.assertTrue(np.isscalar(final_step)) self.assertEqual(2, final_step)
def train(hparams): """Trains an MNIST GAN. Args: hparams: An HParams instance containing the hyperparameters for training. """ if not tf.io.gfile.exists(hparams.train_log_dir): tf.io.gfile.makedirs(hparams.train_log_dir) # Force all input processing onto CPU in order to reserve the GPU for # the forward inference and back-propagation. with tf.name_scope('inputs'), tf.device('/cpu:0'): images, one_hot_labels = data_provider.provide_data( 'train', hparams.batch_size, num_parallel_calls=4) # Define the GANModel tuple. Optionally, condition the GAN on the label or # use an InfoGAN to learn a latent representation. if hparams.gan_type == 'unconditional': gan_model = tfgan.gan_model( generator_fn=networks.unconditional_generator, discriminator_fn=networks.unconditional_discriminator, real_data=images, generator_inputs=tf.random.normal( [hparams.batch_size, hparams.noise_dims])) elif hparams.gan_type == 'conditional': noise = tf.random.normal([hparams.batch_size, hparams.noise_dims]) gan_model = tfgan.gan_model( generator_fn=networks.conditional_generator, discriminator_fn=networks.conditional_discriminator, real_data=images, generator_inputs=(noise, one_hot_labels)) elif hparams.gan_type == 'infogan': cat_dim, cont_dim = 10, 2 generator_fn = functools.partial(networks.infogan_generator, categorical_dim=cat_dim) discriminator_fn = functools.partial(networks.infogan_discriminator, categorical_dim=cat_dim, continuous_dim=cont_dim) unstructured_inputs, structured_inputs = util.get_infogan_noise( hparams.batch_size, cat_dim, cont_dim, hparams.noise_dims) gan_model = tfgan.infogan_model( generator_fn=generator_fn, discriminator_fn=discriminator_fn, real_data=images, unstructured_generator_inputs=unstructured_inputs, structured_generator_inputs=structured_inputs) tfgan.eval.add_gan_model_image_summaries(gan_model, hparams.grid_size) # Get the GANLoss tuple. You can pass a custom function, use one of the # already-implemented losses from the losses library, or use the defaults. with tf.name_scope('loss'): if hparams.gan_type == 'infogan': gan_loss = tfgan.gan_loss( gan_model, generator_loss_fn=tfgan.losses.modified_generator_loss, discriminator_loss_fn=tfgan.losses.modified_discriminator_loss, mutual_information_penalty_weight=1.0, add_summaries=True) else: gan_loss = tfgan.gan_loss(gan_model, add_summaries=True) tfgan.eval.add_regularization_loss_summaries(gan_model) # Get the GANTrain ops using custom optimizers. with tf.name_scope('train'): gen_lr, dis_lr = _learning_rate(hparams.gan_type) train_ops = tfgan.gan_train_ops( gan_model, gan_loss, generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5), discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5), summarize_gradients=True, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) # Run the alternating training loop. Skip it if no steps should be taken # (used for graph construction tests). status_message = tf.strings.join([ 'Starting train step: ', tf.as_string(tf.train.get_or_create_global_step()) ], name='status_message') if hparams.max_number_of_steps == 0: return tfgan.gan_train( train_ops, hooks=[ tf.estimator.StopAtStepHook(num_steps=hparams.max_number_of_steps), tf.estimator.LoggingTensorHook([status_message], every_n_iter=10) ], logdir=hparams.train_log_dir, get_hooks_fn=tfgan.get_joint_train_hooks(), save_checkpoint_secs=60)
def train(hparams): """Trains a CycleGAN. Args: hparams: An HParams instance containing the hyperparameters for training. """ if not tf.io.gfile.exists(hparams.train_log_dir): tf.io.gfile.makedirs(hparams.train_log_dir) with open(hparams.train_log_dir + 'train_result.json', 'w') as fp: json.dump(hparams._asdict(), fp, indent=4) with tf.device(tf.compat.v1.train.replica_device_setter(hparams.ps_replicas)): with tf.compat.v1.name_scope('inputs'), tf.device('/cpu:0'): images_x, images_y = _get_data(hparams.image_set_x_file_pattern, hparams.image_set_y_file_pattern, hparams.batch_size, hparams.patch_size, hparams.tfdata_source) # Define CycleGAN model. cyclegan_model = _define_model(images_x, images_y) # Define CycleGAN loss. cyclegan_loss = tfgan.cyclegan_loss( cyclegan_model, cycle_consistency_loss_weight=hparams.cycle_consistency_loss_weight, tensor_pool_fn=tfgan.features.tensor_pool) # Define CycleGAN train ops. train_ops = _define_train_ops(cyclegan_model, cyclegan_loss, hparams) # Training train_steps = tfgan.GANTrainSteps(1, 1) status_message = tf.strings.join([ 'Starting train step: ', tf.as_string(tf.compat.v1.train.get_or_create_global_step()) ], name='status_message') if not hparams.max_number_of_steps: return additional_params = {} if hparams.save_checkpoint_steps: max_to_keep = hparams.max_number_of_steps // hparams.save_checkpoint_steps + 1 additional_params = { 'scaffold': tf.train.Scaffold(saver=tf.train.Saver(max_to_keep=max_to_keep)), 'save_checkpoint_secs': None, 'save_checkpoint_steps': hparams.save_checkpoint_steps, } tfgan.gan_train( train_ops, hparams.train_log_dir, get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps), hooks=[ tf.estimator.StopAtStepHook(num_steps=hparams.max_number_of_steps), tf.estimator.LoggingTensorHook({'status_message': status_message}, every_n_iter=10) ], master=hparams.master, is_chief=hparams.task == 0, **additional_params, )
def train(hparams): """Trains a CIFAR10 GAN. Args: hparams: An HParams instance containing the hyperparameters for training. """ if not tf.io.gfile.exists(hparams.train_log_dir): tf.io.gfile.makedirs(hparams.train_log_dir) with tf.device( tf.compat.v1.train.replica_device_setter(hparams.ps_replicas)): # Force all input processing onto CPU in order to reserve the GPU for # the forward inference and back-propagation. with tf.compat.v1.name_scope('inputs'): with tf.device('/cpu:0'): images, _ = data_provider.provide_data('train', hparams.batch_size, num_parallel_calls=4) # Define the GANModel tuple. generator_fn = networks.generator discriminator_fn = networks.discriminator generator_inputs = tf.random.normal([hparams.batch_size, 64]) gan_model = tfgan.gan_model(generator_fn, discriminator_fn, real_data=images, generator_inputs=generator_inputs) tfgan.eval.add_gan_model_image_summaries(gan_model) # Get the GANLoss tuple. Use the selected GAN loss functions. with tf.compat.v1.name_scope('loss'): gan_loss = tfgan.gan_loss(gan_model, gradient_penalty_weight=1.0, add_summaries=True) # Get the GANTrain ops using the custom optimizers and optional # discriminator weight clipping. with tf.compat.v1.name_scope('train'): gen_opt, dis_opt = _get_optimizers(hparams) train_ops = tfgan.gan_train_ops(gan_model, gan_loss, generator_optimizer=gen_opt, discriminator_optimizer=dis_opt, summarize_gradients=True) # Run the alternating training loop. Skip it if no steps should be taken # (used for graph construction tests). status_message = tf.strings.join([ 'Starting train step: ', tf.as_string(tf.compat.v1.train.get_or_create_global_step()) ], name='status_message') if hparams.max_number_of_steps == 0: return tfgan.gan_train(train_ops, hooks=([ tf.estimator.StopAtStepHook( num_steps=hparams.max_number_of_steps), tf.estimator.LoggingTensorHook([status_message], every_n_iter=10) ]), logdir=hparams.train_log_dir, master=hparams.master, is_chief=hparams.task == 0)