def test_supervisor_run_gan_model_train_ops_multiple_steps(self): """Test that the train ops work with the old-style supervisor.""" if tf.executing_eagerly(): # None of the usual utilities work in eager. return step = tf.compat.v1.train.create_global_step() train_ops = tfgan.GANTrainOps( generator_train_op=tf.constant(3.0), discriminator_train_op=tf.constant(2.0), global_step_inc_op=step.assign_add(1)) train_steps = tfgan.GANTrainSteps( generator_train_steps=3, discriminator_train_steps=4) number_of_steps = 1 # Typical simple Supervisor use. train_step_kwargs = {} train_step_kwargs['should_stop'] = tf.greater_equal(step, number_of_steps) train_step_fn = tfgan.get_sequential_train_steps(train_steps) sv = tf.compat.v1.train.Supervisor(logdir='', global_step=step) with sv.managed_session(master='') as sess: while not sv.should_stop(): total_loss, should_stop = train_step_fn( sess, train_ops, step, train_step_kwargs) if should_stop: sv.request_stop() break # Correctness checks. self.assertTrue(np.isscalar(total_loss)) self.assertEqual(17.0, total_loss)
def train(model, **kwargs): """Trains progressive GAN for stage `stage_id`. Args: model: An model object having all information of progressive GAN model, e.g. the return of build_model(). **kwargs: A dictionary of 'train_log_dir': A string of root directory of training logs. 'master': Name of the TensorFlow master to use. 'task': The Task ID. This value is used when training with multiple workers to identify each worker. 'save_summaries_num_images': Save summaries in this number of images. Returns: None. """ logging.info('stage_id=%d, num_blocks=%d, num_images=%d', model.stage_id, model.num_blocks, model.num_images) scaffold = make_scaffold(model.stage_id, model.optimizer_var_list, **kwargs) tfgan.gan_train( model.gan_train_ops, logdir=make_train_sub_dir(model.stage_id, **kwargs), get_hooks_fn=tfgan.get_sequential_train_hooks(tfgan.GANTrainSteps(1, 1)), hooks=[ tf.estimator.StopAtStepHook(last_step=model.num_images), tf.estimator.LoggingTensorHook([make_status_message(model)], every_n_iter=10) ], master=kwargs['master'], is_chief=(kwargs['task'] == 0), scaffold=scaffold, save_checkpoint_secs=600, save_summaries_steps=(kwargs['save_summaries_num_images']))
def _define_train_step(gen_disc_step_ratio): """Get the training step for generator and discriminator for each GAN step. Args: gen_disc_step_ratio: A python number. The ratio of generator to discriminator training steps. Returns: GANTrainSteps namedtuple representing the training step configuration. """ if gen_disc_step_ratio <= 1: discriminator_step = int(1 / gen_disc_step_ratio) return tfgan.GANTrainSteps(1, discriminator_step) else: generator_step = int(gen_disc_step_ratio) return tfgan.GANTrainSteps(generator_step, 1)
def train(hparams): """Trains a CycleGAN. Args: hparams: An HParams instance containing the hyperparameters for training. """ if not tf.io.gfile.exists(hparams.train_log_dir): tf.io.gfile.makedirs(hparams.train_log_dir) with tf.device( tf.compat.v1.train.replica_device_setter(hparams.ps_replicas)): with tf.compat.v1.name_scope('inputs'), tf.device('/cpu:0'): images_x, images_y = _get_data(hparams.image_set_x_file_pattern, hparams.image_set_y_file_pattern, hparams.batch_size, hparams.patch_size) # Define CycleGAN model. cyclegan_model = _define_model(images_x, images_y) # Define CycleGAN loss. cyclegan_loss = tfgan.cyclegan_loss( cyclegan_model, cycle_consistency_loss_weight=hparams. cycle_consistency_loss_weight, tensor_pool_fn=tfgan.features.tensor_pool) # Define CycleGAN train ops. train_ops = _define_train_ops(cyclegan_model, cyclegan_loss, hparams) # Training train_steps = tfgan.GANTrainSteps(1, 1) status_message = tf.strings.join([ 'Starting train step: ', tf.as_string(tf.compat.v1.train.get_or_create_global_step()) ], name='status_message') if not hparams.max_number_of_steps: return tfgan.gan_train( train_ops, hparams.train_log_dir, get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps), hooks=[ tf.estimator.StopAtStepHook( num_steps=hparams.max_number_of_steps), tf.estimator.LoggingTensorHook( {'status_message': status_message}, every_n_iter=10) ], master=hparams.master, is_chief=hparams.task == 0)
def test_multiple_steps(self, get_hooks_fn_fn): """Test multiple train steps.""" if tf.executing_eagerly(): # None of the usual utilities work in eager. return train_ops = self._gan_train_ops(generator_add=10, discriminator_add=100) train_steps = tfgan.GANTrainSteps( generator_train_steps=3, discriminator_train_steps=4) final_step = tfgan.gan_train( train_ops, get_hooks_fn=get_hooks_fn_fn(train_steps), logdir='', hooks=[tf.estimator.StopAtStepHook(num_steps=1)]) self.assertTrue(np.isscalar(final_step)) self.assertEqual(1 + 3 * 10 + 4 * 100, final_step)
def main(_): if not tf.gfile.Exists(FLAGS.train_log_dir): tf.gfile.MakeDirs(FLAGS.train_log_dir) with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): with tf.name_scope('inputs'): initializer_hook = load_op(FLAGS.batch_size, FLAGS.max_number_of_steps) training_input_iter = initializer_hook.input_itr images_x, images_y = training_input_iter.get_next() # Set batch size for summaries. # images_x.set_shape([FLAGS.batch_size, None, None, None]) # images_y.set_shape([FLAGS.batch_size, None, None, None]) # Define CycleGAN model. cyclegan_model = _define_model(images_x, images_y) # Define CycleGAN loss. cyclegan_loss = tfgan.cyclegan_loss( cyclegan_model, cycle_consistency_loss_weight=FLAGS.cycle_consistency_loss_weight, tensor_pool_fn=tfgan.features.tensor_pool) # Define CycleGAN train ops. train_ops = _define_train_ops(cyclegan_model, cyclegan_loss) # Training train_steps = tfgan.GANTrainSteps(1, 1) status_message = tf.string_join( [ 'Starting train step: ', tf.as_string(tf.train.get_or_create_global_step()) ], name='status_message') if not FLAGS.max_number_of_steps: return tfgan.gan_train( train_ops, FLAGS.train_log_dir, save_checkpoint_secs=60*10, get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps), hooks=[ initializer_hook, tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps), tf.train.LoggingTensorHook([status_message], every_n_iter=10) ], master=FLAGS.master, is_chief=FLAGS.task == 0)
def test_train(self, g_steps, d_steps, joint_train, expected_total_substeps, expected_g_substep_mask, expected_d_substep_mask): real_opt = tf.compat.v1.train.GradientDescentOptimizer(1e-2) gopt = TestOptimizerWrapper(real_opt, name='g_opt') dopt = TestOptimizerWrapper(real_opt, name='d_opt') est = tfgan.estimator.TPUGANEstimator( generator_fn=generator_fn, discriminator_fn=discriminator_fn, generator_loss_fn=tfgan.losses.wasserstein_generator_loss, discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss, generator_optimizer=gopt, discriminator_optimizer=dopt, gan_train_steps=tfgan.GANTrainSteps(g_steps, d_steps), joint_train=joint_train, get_eval_metric_ops_fn=get_metrics, train_batch_size=4, eval_batch_size=10, predict_batch_size=8, use_tpu=flags.FLAGS.use_tpu, config=self._config) def train_input_fn(params): data = tf.ones([params['batch_size'], 4], dtype=tf.float32) return data, data est.train(train_input_fn, steps=1) self.assertEqual(1, est.get_variable_value('global_step')) substep_counter_name = 'discriminator_train/substep_counter' if d_steps == 0: substep_counter_name = 'generator_train/substep_counter' substep_counter = est.get_variable_value(substep_counter_name) self.assertEqual(expected_total_substeps, substep_counter) if expected_g_substep_mask is not None: g_substep_mask = est.get_variable_value( 'generator_train/substep_mask') self.assertIn(g_substep_mask, expected_g_substep_mask) if expected_d_substep_mask is not None: d_substep_mask = est.get_variable_value( 'discriminator_train/substep_mask') self.assertIn(d_substep_mask, expected_d_substep_mask)
def train(model, **kwargs): """Trains progressive GAN for stage `stage_id`. Args: model: An model object having all information of progressive GAN model, e.g. the return of build_model(). **kwargs: A dictionary of 'train_root_dir': A string of root directory of training logs. 'master': Name of the TensorFlow master to use. 'task': The Task ID. This value is used when training with multiple workers to identify each worker. 'save_summaries_num_images': Save summaries in this number of images. 'debug_hook': Whether to attach the debug hook to the training session. Returns: None. """ logging.info('stage_id=%d, num_blocks=%d, num_images=%d', model.stage_id, model.num_blocks, model.num_images) scaffold = make_scaffold(model.stage_id, model.optimizer_var_list, **kwargs) logdir = make_train_sub_dir(model.stage_id, **kwargs) print('starting training, logdir: {}'.format(logdir)) hooks = [] if model.stage_train_time_limit is None: hooks.append(tf.train.StopAtStepHook(last_step=model.num_images)) hooks.append( tf.train.LoggingTensorHook([make_status_message(model)], every_n_iter=1)) hooks.append(TrainTimeHook(model.train_time, model.stage_train_time_limit)) if kwargs['debug_hook']: hooks.append(ProganDebugHook()) tfgan.gan_train(model.gan_train_ops, logdir=logdir, get_hooks_fn=tfgan.get_sequential_train_hooks( tfgan.GANTrainSteps(1, 1)), hooks=hooks, master=kwargs['master'], is_chief=(kwargs['task'] == 0), scaffold=scaffold, save_checkpoint_secs=600, save_summaries_steps=(kwargs['save_summaries_num_images']))
def test_get_train_estimator_spec(self, joint_train): with tf.Graph().as_default(): if joint_train: gan_model_fns = [get_dummy_gan_model] else: gan_model_fns = [get_dummy_gan_model, get_dummy_gan_model] spec = get_train_estimator_spec( gan_model_fns, self._loss_fns, {}, # gan_loss_kwargs self._optimizers, joint_train=joint_train, is_on_tpu=flags.FLAGS.use_tpu, gan_train_steps=tfgan.GANTrainSteps(1, 1), add_summaries=not flags.FLAGS.use_tpu) self.assertIsInstance(spec, TPUEstimatorSpec) self.assertEqual(tf.estimator.ModeKeys.TRAIN, spec.mode) self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar self.assertIsNotNone(spec.train_op) self.assertIsNotNone(spec.training_hooks)
def train(hparams): """Trains a CycleGAN. Args: hparams: An HParams instance containing the hyperparameters for training. """ if not tf.io.gfile.exists(hparams.train_log_dir): tf.io.gfile.makedirs(hparams.train_log_dir) with open(hparams.train_log_dir + 'train_result.json', 'w') as fp: json.dump(hparams._asdict(), fp, indent=4) with tf.device(tf.compat.v1.train.replica_device_setter(hparams.ps_replicas)): with tf.compat.v1.name_scope('inputs'), tf.device('/cpu:0'): images_x, images_y = _get_data(hparams.image_set_x_file_pattern, hparams.image_set_y_file_pattern, hparams.batch_size, hparams.patch_size, hparams.tfdata_source) # Define CycleGAN model. cyclegan_model = _define_model(images_x, images_y) # Define CycleGAN loss. cyclegan_loss = tfgan.cyclegan_loss( cyclegan_model, cycle_consistency_loss_weight=hparams.cycle_consistency_loss_weight, tensor_pool_fn=tfgan.features.tensor_pool) # Define CycleGAN train ops. train_ops = _define_train_ops(cyclegan_model, cyclegan_loss, hparams) # Training train_steps = tfgan.GANTrainSteps(1, 1) status_message = tf.strings.join([ 'Starting train step: ', tf.as_string(tf.compat.v1.train.get_or_create_global_step()) ], name='status_message') if not hparams.max_number_of_steps: return additional_params = {} if hparams.save_checkpoint_steps: max_to_keep = hparams.max_number_of_steps // hparams.save_checkpoint_steps + 1 additional_params = { 'scaffold': tf.train.Scaffold(saver=tf.train.Saver(max_to_keep=max_to_keep)), 'save_checkpoint_secs': None, 'save_checkpoint_steps': hparams.save_checkpoint_steps, } tfgan.gan_train( train_ops, hparams.train_log_dir, get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps), hooks=[ tf.estimator.StopAtStepHook(num_steps=hparams.max_number_of_steps), tf.estimator.LoggingTensorHook({'status_message': status_message}, every_n_iter=10) ], master=hparams.master, is_chief=hparams.task == 0, **additional_params, )
def main(_): log_dir = FLAGS.train_log_dir if not tf.gfile.Exists(log_dir): tf.gfile.MakeDirs(log_dir) with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): validation_iteration_count = FLAGS.validation_itr_count validation_sample_count = FLAGS.validation_sample_count loader_name = FLAGS.loader_name neighborhood = 0 loader = get_class(loader_name + '.' + loader_name)(FLAGS.path) data_set = loader.load_data(neighborhood, True) shadow_map, shadow_ratio = loader.load_shadow_map( neighborhood, data_set) with tf.name_scope('inputs'): initializer_hook = load_op(FLAGS.batch_size, FLAGS.max_number_of_steps, loader, data_set, shadow_map, shadow_ratio, FLAGS.regularization_support_rate) training_input_iter = initializer_hook.input_itr images_x, images_y = training_input_iter.get_next() # Set batch size for summaries. # images_x.set_shape([FLAGS.batch_size, None, None, None]) # images_y.set_shape([FLAGS.batch_size, None, None, None]) # Define model. gan_type = FLAGS.gan_type gan_train_wrapper_dict = { "cycle_gan": CycleGANWrapper(cycle_consistency_loss_weight=FLAGS. cycle_consistency_loss_weight, identity_loss_weight=FLAGS.identity_loss_weight, use_identity_loss=FLAGS.use_identity_loss), "gan_x2y": GANWrapper(identity_loss_weight=FLAGS.identity_loss_weight, use_identity_loss=FLAGS.use_identity_loss, swap_inputs=False), "gan_y2x": GANWrapper(identity_loss_weight=FLAGS.identity_loss_weight, use_identity_loss=FLAGS.use_identity_loss, swap_inputs=True) } wrapper = gan_train_wrapper_dict[gan_type] with tf.variable_scope('Model', reuse=tf.AUTO_REUSE): the_gan_model = wrapper.define_model(images_x, images_y) peer_validation_hook = wrapper.create_validation_hook( data_set, loader, log_dir, neighborhood, shadow_map, shadow_ratio, validation_iteration_count, validation_sample_count) the_gan_loss = wrapper.define_loss(the_gan_model) # Define CycleGAN train ops. train_ops = _define_train_ops(the_gan_model, the_gan_loss) # Training train_steps = tfgan.GANTrainSteps(1, 1) status_message = tf.string_join([ 'Starting train step: ', tf.as_string(tf.train.get_or_create_global_step()) ], name='status_message') if not FLAGS.max_number_of_steps: return gpu = tf.config.experimental.list_physical_devices('GPU') tf.config.experimental.set_memory_growth(gpu[0], True) training_scaffold = Scaffold(saver=tf.train.Saver(max_to_keep=20)) gan_train( train_ops, log_dir, scaffold=training_scaffold, save_checkpoint_steps=validation_iteration_count, get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps), hooks=[ initializer_hook, peer_validation_hook, tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps), tf.train.LoggingTensorHook([status_message], every_n_iter=1000) ], master=FLAGS.master, is_chief=FLAGS.task == 0)