def train(model, **kwargs): """Trains progressive GAN for stage `stage_id`. Args: model: An model object having all information of progressive GAN model, e.g. the return of build_model(). **kwargs: A dictionary of 'train_log_dir': A string of root directory of training logs. 'master': Name of the TensorFlow master to use. 'task': The Task ID. This value is used when training with multiple workers to identify each worker. 'save_summaries_num_images': Save summaries in this number of images. Returns: None. """ logging.info('stage_id=%d, num_blocks=%d, num_images=%d', model.stage_id, model.num_blocks, model.num_images) scaffold = make_scaffold(model.stage_id, model.optimizer_var_list, **kwargs) tfgan.gan_train( model.gan_train_ops, logdir=make_train_sub_dir(model.stage_id, **kwargs), get_hooks_fn=tfgan.get_sequential_train_hooks(tfgan.GANTrainSteps(1, 1)), hooks=[ tf.estimator.StopAtStepHook(last_step=model.num_images), tf.estimator.LoggingTensorHook([make_status_message(model)], every_n_iter=10) ], master=kwargs['master'], is_chief=(kwargs['task'] == 0), scaffold=scaffold, save_checkpoint_secs=600, save_summaries_steps=(kwargs['save_summaries_num_images']))
def train(hparams): """Trains a StarGAN. Args: hparams: An HParams instance containing the hyperparameters for training. """ # Create the log_dir if not exist. if not tf.io.gfile.exists(hparams.train_log_dir): tf.io.gfile.makedirs(hparams.train_log_dir) # Shard the model to different parameter servers. with tf.device(tf.compat.v1.train.replica_device_setter(hparams.ps_replicas)): # Create the input dataset. with tf.compat.v1.name_scope('inputs'), tf.device('/cpu:0'): images, labels = data_provider.provide_data('train', hparams.batch_size, hparams.patch_size) # Define the model. with tf.compat.v1.name_scope('model'): model = _define_model(images, labels) # Add image summary. tfgan.eval.add_stargan_image_summaries( model, num_images=3 * hparams.batch_size, display_diffs=True) # Define the model loss. loss = tfgan.stargan_loss(model) # Define the train ops. with tf.compat.v1.name_scope('train_ops'): train_ops = _define_train_ops(model, loss, hparams.generator_lr, hparams.discriminator_lr, hparams.adam_beta1, hparams.adam_beta2, hparams.max_number_of_steps) # Define the train steps. train_steps = _define_train_step(hparams.gen_disc_step_ratio) # Define a status message. status_message = tf.strings.join([ 'Starting train step: ', tf.as_string(tf.compat.v1.train.get_or_create_global_step()) ], name='status_message') # Train the model. tfgan.gan_train( train_ops, hparams.train_log_dir, get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps), hooks=[ tf.estimator.StopAtStepHook(num_steps=hparams.max_number_of_steps), tf.estimator.LoggingTensorHook([status_message], every_n_iter=10) ], master=hparams.tf_master, is_chief=hparams.task == 0)
def train(hparams): """Trains a CycleGAN. Args: hparams: An HParams instance containing the hyperparameters for training. """ if not tf.io.gfile.exists(hparams.train_log_dir): tf.io.gfile.makedirs(hparams.train_log_dir) with tf.device( tf.compat.v1.train.replica_device_setter(hparams.ps_replicas)): with tf.compat.v1.name_scope('inputs'), tf.device('/cpu:0'): images_x, images_y = _get_data(hparams.image_set_x_file_pattern, hparams.image_set_y_file_pattern, hparams.batch_size, hparams.patch_size) # Define CycleGAN model. cyclegan_model = _define_model(images_x, images_y) # Define CycleGAN loss. cyclegan_loss = tfgan.cyclegan_loss( cyclegan_model, cycle_consistency_loss_weight=hparams. cycle_consistency_loss_weight, tensor_pool_fn=tfgan.features.tensor_pool) # Define CycleGAN train ops. train_ops = _define_train_ops(cyclegan_model, cyclegan_loss, hparams) # Training train_steps = tfgan.GANTrainSteps(1, 1) status_message = tf.strings.join([ 'Starting train step: ', tf.as_string(tf.compat.v1.train.get_or_create_global_step()) ], name='status_message') if not hparams.max_number_of_steps: return tfgan.gan_train( train_ops, hparams.train_log_dir, get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps), hooks=[ tf.estimator.StopAtStepHook( num_steps=hparams.max_number_of_steps), tf.estimator.LoggingTensorHook( {'status_message': status_message}, every_n_iter=10) ], master=hparams.master, is_chief=hparams.task == 0)
def train(hparams, override_generator_fn=None, override_discriminator_fn=None): """Trains a StarGAN. Args: hparams: An HParams instance containing the hyperparameters for training. override_generator_fn: A generator function that overrides the default one. override_discriminator_fn: A discriminator function that overrides the default one. """ # Create directories if not exist. if not tf.io.gfile.exists(hparams.output_dir): tf.io.gfile.makedirs(hparams.output_dir) # Make sure steps integers are consistent. if hparams.max_number_of_steps % hparams.steps_per_eval != 0: raise ValueError('`max_number_of_steps` must be divisible by ' '`steps_per_eval`.') # Create optimizers. gen_opt, dis_opt = _get_optimizer(hparams.generator_lr, hparams.discriminator_lr, hparams.adam_beta1, hparams.adam_beta2) # Create estimator. stargan_estimator = tfgan.estimator.StarGANEstimator( generator_fn=override_generator_fn or network.generator, discriminator_fn=override_discriminator_fn or network.discriminator, loss_fn=tfgan.stargan_loss, generator_optimizer=gen_opt, discriminator_optimizer=dis_opt, get_hooks_fn=tfgan.get_sequential_train_hooks( _define_train_step(hparams.gen_disc_step_ratio)), add_summaries=tfgan.estimator.SummaryType.IMAGES) # Get input function for training and test images. train_input_fn = lambda: data_provider.provide_data( # pylint:disable=g-long-lambda 'train', hparams.batch_size, hparams.patch_size) test_images_np = data_provider.provide_celeba_test_set(hparams.patch_size) filename_str = os.path.join(hparams.output_dir, 'summary_image_%i.png') # Periodically train and write prediction output to disk. cur_step = 0 while cur_step < hparams.max_number_of_steps: cur_step += hparams.steps_per_eval stargan_estimator.train(train_input_fn, steps=cur_step) summary_img = _get_summary_image(stargan_estimator, test_images_np) with tf.io.gfile.GFile(filename_str % cur_step, 'w') as f: PIL.Image.fromarray( (255 * summary_img).astype(np.uint8)).save(f, 'PNG')
def main(_): if not tf.gfile.Exists(FLAGS.train_log_dir): tf.gfile.MakeDirs(FLAGS.train_log_dir) with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): with tf.name_scope('inputs'): initializer_hook = load_op(FLAGS.batch_size, FLAGS.max_number_of_steps) training_input_iter = initializer_hook.input_itr images_x, images_y = training_input_iter.get_next() # Set batch size for summaries. # images_x.set_shape([FLAGS.batch_size, None, None, None]) # images_y.set_shape([FLAGS.batch_size, None, None, None]) # Define CycleGAN model. cyclegan_model = _define_model(images_x, images_y) # Define CycleGAN loss. cyclegan_loss = tfgan.cyclegan_loss( cyclegan_model, cycle_consistency_loss_weight=FLAGS.cycle_consistency_loss_weight, tensor_pool_fn=tfgan.features.tensor_pool) # Define CycleGAN train ops. train_ops = _define_train_ops(cyclegan_model, cyclegan_loss) # Training train_steps = tfgan.GANTrainSteps(1, 1) status_message = tf.string_join( [ 'Starting train step: ', tf.as_string(tf.train.get_or_create_global_step()) ], name='status_message') if not FLAGS.max_number_of_steps: return tfgan.gan_train( train_ops, FLAGS.train_log_dir, save_checkpoint_secs=60*10, get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps), hooks=[ initializer_hook, tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps), tf.train.LoggingTensorHook([status_message], every_n_iter=10) ], master=FLAGS.master, is_chief=FLAGS.task == 0)
def train(model, **kwargs): """Trains progressive GAN for stage `stage_id`. Args: model: An model object having all information of progressive GAN model, e.g. the return of build_model(). **kwargs: A dictionary of 'train_root_dir': A string of root directory of training logs. 'master': Name of the TensorFlow master to use. 'task': The Task ID. This value is used when training with multiple workers to identify each worker. 'save_summaries_num_images': Save summaries in this number of images. 'debug_hook': Whether to attach the debug hook to the training session. Returns: None. """ logging.info('stage_id=%d, num_blocks=%d, num_images=%d', model.stage_id, model.num_blocks, model.num_images) scaffold = make_scaffold(model.stage_id, model.optimizer_var_list, **kwargs) logdir = make_train_sub_dir(model.stage_id, **kwargs) print('starting training, logdir: {}'.format(logdir)) hooks = [] if model.stage_train_time_limit is None: hooks.append(tf.train.StopAtStepHook(last_step=model.num_images)) hooks.append( tf.train.LoggingTensorHook([make_status_message(model)], every_n_iter=1)) hooks.append(TrainTimeHook(model.train_time, model.stage_train_time_limit)) if kwargs['debug_hook']: hooks.append(ProganDebugHook()) tfgan.gan_train(model.gan_train_ops, logdir=logdir, get_hooks_fn=tfgan.get_sequential_train_hooks( tfgan.GANTrainSteps(1, 1)), hooks=hooks, master=kwargs['master'], is_chief=(kwargs['task'] == 0), scaffold=scaffold, save_checkpoint_secs=600, save_summaries_steps=(kwargs['save_summaries_num_images']))
def test_train_hooks_exist_in_get_hooks_fn(self, create_gan_model_fn): if tf.executing_eagerly(): # None of the usual utilities work in eager. return model = create_gan_model_fn() loss = tfgan.gan_loss(model) g_opt = get_sync_optimizer() d_opt = get_sync_optimizer() train_ops = tfgan.gan_train_ops( model, loss, g_opt, d_opt, summarize_gradients=True, colocate_gradients_with_ops=True) sequential_train_hooks = tfgan.get_sequential_train_hooks()(train_ops) self.assertLen(sequential_train_hooks, 4) sync_opts = [ hook._sync_optimizer for hook in sequential_train_hooks if isinstance(hook, get_sync_optimizer_hook_type()) ] self.assertLen(sync_opts, 2) self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt))) joint_train_hooks = tfgan.get_joint_train_hooks()(train_ops) self.assertLen(joint_train_hooks, 5) sync_opts = [ hook._sync_optimizer for hook in joint_train_hooks if isinstance(hook, get_sync_optimizer_hook_type()) ] self.assertLen(sync_opts, 2) self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))
def train(hparams): """Trains a CycleGAN. Args: hparams: An HParams instance containing the hyperparameters for training. """ if not tf.io.gfile.exists(hparams.train_log_dir): tf.io.gfile.makedirs(hparams.train_log_dir) with open(hparams.train_log_dir + 'train_result.json', 'w') as fp: json.dump(hparams._asdict(), fp, indent=4) with tf.device(tf.compat.v1.train.replica_device_setter(hparams.ps_replicas)): with tf.compat.v1.name_scope('inputs'), tf.device('/cpu:0'): images_x, images_y = _get_data(hparams.image_set_x_file_pattern, hparams.image_set_y_file_pattern, hparams.batch_size, hparams.patch_size, hparams.tfdata_source) # Define CycleGAN model. cyclegan_model = _define_model(images_x, images_y) # Define CycleGAN loss. cyclegan_loss = tfgan.cyclegan_loss( cyclegan_model, cycle_consistency_loss_weight=hparams.cycle_consistency_loss_weight, tensor_pool_fn=tfgan.features.tensor_pool) # Define CycleGAN train ops. train_ops = _define_train_ops(cyclegan_model, cyclegan_loss, hparams) # Training train_steps = tfgan.GANTrainSteps(1, 1) status_message = tf.strings.join([ 'Starting train step: ', tf.as_string(tf.compat.v1.train.get_or_create_global_step()) ], name='status_message') if not hparams.max_number_of_steps: return additional_params = {} if hparams.save_checkpoint_steps: max_to_keep = hparams.max_number_of_steps // hparams.save_checkpoint_steps + 1 additional_params = { 'scaffold': tf.train.Scaffold(saver=tf.train.Saver(max_to_keep=max_to_keep)), 'save_checkpoint_secs': None, 'save_checkpoint_steps': hparams.save_checkpoint_steps, } tfgan.gan_train( train_ops, hparams.train_log_dir, get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps), hooks=[ tf.estimator.StopAtStepHook(num_steps=hparams.max_number_of_steps), tf.estimator.LoggingTensorHook({'status_message': status_message}, every_n_iter=10) ], master=hparams.master, is_chief=hparams.task == 0, **additional_params, )
def main(_): log_dir = FLAGS.train_log_dir if not tf.gfile.Exists(log_dir): tf.gfile.MakeDirs(log_dir) with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)): validation_iteration_count = FLAGS.validation_itr_count validation_sample_count = FLAGS.validation_sample_count loader_name = FLAGS.loader_name neighborhood = 0 loader = get_class(loader_name + '.' + loader_name)(FLAGS.path) data_set = loader.load_data(neighborhood, True) shadow_map, shadow_ratio = loader.load_shadow_map( neighborhood, data_set) with tf.name_scope('inputs'): initializer_hook = load_op(FLAGS.batch_size, FLAGS.max_number_of_steps, loader, data_set, shadow_map, shadow_ratio, FLAGS.regularization_support_rate) training_input_iter = initializer_hook.input_itr images_x, images_y = training_input_iter.get_next() # Set batch size for summaries. # images_x.set_shape([FLAGS.batch_size, None, None, None]) # images_y.set_shape([FLAGS.batch_size, None, None, None]) # Define model. gan_type = FLAGS.gan_type gan_train_wrapper_dict = { "cycle_gan": CycleGANWrapper(cycle_consistency_loss_weight=FLAGS. cycle_consistency_loss_weight, identity_loss_weight=FLAGS.identity_loss_weight, use_identity_loss=FLAGS.use_identity_loss), "gan_x2y": GANWrapper(identity_loss_weight=FLAGS.identity_loss_weight, use_identity_loss=FLAGS.use_identity_loss, swap_inputs=False), "gan_y2x": GANWrapper(identity_loss_weight=FLAGS.identity_loss_weight, use_identity_loss=FLAGS.use_identity_loss, swap_inputs=True) } wrapper = gan_train_wrapper_dict[gan_type] with tf.variable_scope('Model', reuse=tf.AUTO_REUSE): the_gan_model = wrapper.define_model(images_x, images_y) peer_validation_hook = wrapper.create_validation_hook( data_set, loader, log_dir, neighborhood, shadow_map, shadow_ratio, validation_iteration_count, validation_sample_count) the_gan_loss = wrapper.define_loss(the_gan_model) # Define CycleGAN train ops. train_ops = _define_train_ops(the_gan_model, the_gan_loss) # Training train_steps = tfgan.GANTrainSteps(1, 1) status_message = tf.string_join([ 'Starting train step: ', tf.as_string(tf.train.get_or_create_global_step()) ], name='status_message') if not FLAGS.max_number_of_steps: return gpu = tf.config.experimental.list_physical_devices('GPU') tf.config.experimental.set_memory_growth(gpu[0], True) training_scaffold = Scaffold(saver=tf.train.Saver(max_to_keep=20)) gan_train( train_ops, log_dir, scaffold=training_scaffold, save_checkpoint_steps=validation_iteration_count, get_hooks_fn=tfgan.get_sequential_train_hooks(train_steps), hooks=[ initializer_hook, peer_validation_hook, tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps), tf.train.LoggingTensorHook([status_message], every_n_iter=1000) ], master=FLAGS.master, is_chief=FLAGS.task == 0)
def train(hparams, override_generator_fn=None, override_discriminator_fn=None): """Trains a StarGAN. Args: hparams: An HParams instance containing the hyperparameters for training. override_generator_fn: A generator function that overrides the default one. override_discriminator_fn: A discriminator function that overrides the default one. """ # Create directories if not exist. if not tf.io.gfile.exists(hparams.output_dir): tf.io.gfile.makedirs(hparams.output_dir) with open(hparams.output_dir + 'train_result.json', 'w') as fp: json.dump(hparams._asdict(), fp, indent=4) # Make sure steps integers are consistent. if hparams.max_number_of_steps % hparams.steps_per_eval != 0: raise ValueError('`max_number_of_steps` must be divisible by ' '`steps_per_eval`.') # Create optimizers. gen_opt, dis_opt = _get_optimizer(hparams.generator_lr, hparams.discriminator_lr, hparams.adam_beta1, hparams.adam_beta2) # Create estimator. if hparams.cls_model and hparams.cls_checkpoint: raise Exception('Can only assign one parameter between hparams.cls_model and hparams.cls_checkpoint') if hparams.cls_model: print("[!!!!] LOAD custom classification model in discriminator.") network_discriminator = network.CustomKerasDiscriminator(hparams.cls_model + '/base_model.h5') # network_discriminator = network.custom_keras_discriminator(hparams.cls_model) # tf.keras.estimator.model_to_estimator(keras_model_path=hparams.cls_model, model_dir='/tmp/temp_checkpoint/') elif hparams.cls_checkpoint: network_discriminator = network.custom_tf_discriminator() else: network_discriminator = network.discriminator stargan_estimator = tfgan.estimator.StarGANEstimator( model_dir= hparams.output_dir + "checkpoints/", generator_fn=override_generator_fn or network.generator, discriminator_fn=override_discriminator_fn or network_discriminator, # loss_fn=tfgan.stargan_loss, loss_fn=_get_stargan_loss(reconstruction_loss_weight=hparams.reconstruction_loss_weight, self_consistency_loss_weight=hparams.self_consistency_loss_weight, classification_loss_weight=hparams.classification_loss_weight), generator_optimizer=gen_opt, discriminator_optimizer=dis_opt, get_hooks_fn=tfgan.get_sequential_train_hooks( _define_train_step(hparams.gen_disc_step_ratio)), add_summaries=tfgan.estimator.SummaryType.IMAGES, config=tf.estimator.RunConfig(save_checkpoints_steps=hparams.save_checkpoints_steps, keep_checkpoint_max=hparams.keep_checkpoint_max), cls_model=hparams.cls_model, cls_checkpoint=hparams.cls_checkpoint ) # Get input function for training and test images. if (hparams.tfdata_source): print("[**] load train dataset: tensorflow dataset: {x}".format(x=hparams.tfdata_source)) train_input_fn = lambda: data_provider.provide_data( # pylint:disable=g-long-lambda hparams.tfdata_source, hparams.batch_size, hparams.patch_size, split='train', color_labeled=hparams.use_color_labels, num_parallel_calls=None, shuffle=True, domains=tuple(hparams.tfdata_source_domains.split(",")), download=eval(hparams.download), data_dir=hparams.data_dir) if hparams.tfdata_source.startswith('cycle_gan'): test_images_np = data_provider.provide_cyclegan_test_set(hparams.tfdata_source, hparams.patch_size) num_domains = 2 elif hparams.tfdata_source == 'celeb_a': test_images_np = data_provider.provide_celeba_test_set(hparams.patch_size, download=eval(hparams.download), data_dir=hparams.data_dir) num_domains = len(test_images_np) else: test_images_np, num_domains = data_provider.provide_categorized_test_set(hparams.tfdata_source, hparams.patch_size, color_labeled=hparams.use_color_labels, download=eval(hparams.download), data_dir=hparams.data_dir) else: train_input_fn = None test_images_np = None num_domains = None raise Exception("TODO: support external data souce.") filename_str = os.path.join(hparams.output_dir, 'summary_image_%i.png') # Periodically train and write prediction output to disk. cur_step = 0 while cur_step < hparams.max_number_of_steps: cur_step += hparams.steps_per_eval print("current step: {cur_step} /{max_step}".format(cur_step=cur_step, max_step=hparams.max_number_of_steps)) stargan_estimator.train(train_input_fn, steps=cur_step) summary_img = _get_summary_image(stargan_estimator, test_images_np, num_domains) with tf.io.gfile.GFile(filename_str % cur_step, 'w') as f: # Handle single-channel images if summary_img.shape[2] == 1: summary_img = np.repeat(summary_img, 3, axis=2) PIL.Image.fromarray((255 * summary_img).astype(np.uint8)).save(f, 'PNG')