def test_get_estimator_spec(self, mode, joint_train): with ops.Graph().as_default(): self._gan_model = get_dummy_gan_model() spec = estimator._get_estimator_spec( mode, self._gan_model, generator_loss_fn=losses.wasserstein_generator_loss, discriminator_loss_fn=losses.wasserstein_discriminator_loss, get_eval_metric_ops_fn=get_metrics, generator_optimizer=self._generator_optimizer, discriminator_optimizer=self._discriminator_optimizer, joint_train=joint_train, is_on_tpu=FLAGS.use_tpu, gan_train_steps=tfgan_tuples.GANTrainSteps(1, 1)) self.assertIsInstance(spec, tpu_estimator.TPUEstimatorSpec) self.assertEqual(mode, spec.mode) if mode == model_fn_lib.ModeKeys.PREDICT: self.assertEqual({'generated_data': self._gan_model.generated_data}, spec.predictions) elif mode == model_fn_lib.ModeKeys.TRAIN: self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar self.assertIsNotNone(spec.train_op) self.assertIsNotNone(spec.training_hooks) elif mode == model_fn_lib.ModeKeys.EVAL: self.assertEqual(self._gan_model.generated_data, spec.predictions) self.assertShapeEqual(np.array(0), spec.loss) # must be a scalar self.assertIsNotNone(spec.eval_metrics)
def get_joint_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)): """Returns a hooks function for sequential GAN training. When using these train hooks, IT IS RECOMMENDED TO USE `use_locking=True` ON ALL OPTIMIZERS TO AVOID RACE CONDITIONS. The order of steps taken is: 1) Combined generator and discriminator steps 2) Generator only steps, if any remain 3) Discriminator only steps, if any remain **NOTE**: Unlike `get_sequential_train_hooks`, this method performs updates for the generator and discriminator simultaneously whenever possible. This reduces the number of `tf.Session` calls, and can also change the training semantics. To illustrate the difference look at the following example: `train_steps=namedtuples.GANTrainSteps(3, 5)` will cause `get_sequential_train_hooks` to make 8 session calls: 1) 3 generator steps 2) 5 discriminator steps In contrast, `get_joint_train_steps` will make 5 session calls: 1) 3 generator + discriminator steps 2) 2 discriminator steps Args: train_steps: A `GANTrainSteps` tuple that determines how many generator and discriminator training steps to take. Returns: A function that takes a GANTrainOps tuple and returns a list of hooks. """ g_steps = train_steps.generator_train_steps d_steps = train_steps.discriminator_train_steps # Get the number of each type of step that should be run. num_d_and_g_steps = min(g_steps, d_steps) num_g_steps = g_steps - num_d_and_g_steps num_d_steps = d_steps - num_d_and_g_steps def get_hooks(train_ops): g_op = train_ops.generator_train_op d_op = train_ops.discriminator_train_op joint_hook = RunTrainOpsHook([g_op, d_op], num_d_and_g_steps) g_hook = RunTrainOpsHook(g_op, num_g_steps) d_hook = RunTrainOpsHook(d_op, num_d_steps) return [joint_hook, g_hook, d_hook] return get_hooks
def _test_multiple_steps_helper(self, get_hooks_fn_fn): train_ops = self._gan_train_ops(generator_add=10, discriminator_add=100) train_steps = namedtuples.GANTrainSteps(generator_train_steps=3, discriminator_train_steps=4) final_step = train.gan_train( train_ops, get_hooks_fn=get_hooks_fn_fn(train_steps), logdir='', hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=1)]) self.assertTrue(np.isscalar(final_step)) self.assertEqual(1 + 3 * 10 + 4 * 100, final_step)
def test_supervisor_run_gan_model_train_ops_multiple_steps(self): step = training_util.create_global_step() train_ops = namedtuples.GANTrainOps( generator_train_op=constant_op.constant(3.0), discriminator_train_op=constant_op.constant(2.0), global_step_inc_op=step.assign_add(1)) train_steps = namedtuples.GANTrainSteps(generator_train_steps=3, discriminator_train_steps=4) final_loss = slim_learning.train( train_op=train_ops, logdir='', global_step=step, number_of_steps=1, train_step_fn=train.get_sequential_train_steps(train_steps)) self.assertTrue(np.isscalar(final_loss)) self.assertEqual(17.0, final_loss)
def get_sequential_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)): """Returns a hooks function for sequential GAN training. Args: train_steps: A `GANTrainSteps` tuple that determines how many generator and discriminator training steps to take. Returns: A function that takes a GANTrainOps tuple and returns a list of hooks. """ def get_hooks(train_ops): generator_hook = RunTrainOpsHook(train_ops.generator_train_op, train_steps.generator_train_steps) discriminator_hook = RunTrainOpsHook(train_ops.discriminator_train_op, train_steps.discriminator_train_steps) return [generator_hook, discriminator_hook] return get_hooks
def get_sequential_train_steps(train_steps=namedtuples.GANTrainSteps(1, 1)): """Returns a thin wrapper around slim.learning.train_step, for GANs. This function is to provide support for the Supervisor. For new code, please use `MonitoredSession` and `get_sequential_train_hooks`. Args: train_steps: A `GANTrainSteps` tuple that determines how many generator and discriminator training steps to take. Returns: A function that can be used for `train_step_fn` for GANs. """ def sequential_train_steps(sess, train_ops, global_step, train_step_kwargs): """A thin wrapper around slim.learning.train_step, for GANs. Args: sess: A Tensorflow session. train_ops: A GANTrainOps tuple of train ops to run. global_step: The global step. train_step_kwargs: Dictionary controlling `train_step` behavior. Returns: A scalar final loss and a bool whether or not the train loop should stop. """ # Only run `should_stop` at the end, if required. Make a local copy of # `train_step_kwargs`, if necessary, so as not to modify the caller's # dictionary. should_stop_op, train_kwargs = None, train_step_kwargs if 'should_stop' in train_step_kwargs: should_stop_op = train_step_kwargs['should_stop'] train_kwargs = train_step_kwargs.copy() del train_kwargs['should_stop'] # Run generator training steps. gen_loss = 0 for _ in range(train_steps.generator_train_steps): cur_gen_loss, _ = slim_learning.train_step( sess, train_ops.generator_train_op, global_step, train_kwargs) gen_loss += cur_gen_loss # Run discriminator training steps. dis_loss = 0 for _ in range(train_steps.discriminator_train_steps): cur_dis_loss, _ = slim_learning.train_step( sess, train_ops.discriminator_train_op, global_step, train_kwargs) dis_loss += cur_dis_loss sess.run(train_ops.global_step_inc_op) # Run the `should_stop` op after the global step has been incremented, so # that the `should_stop` aligns with the proper `global_step` count. if should_stop_op is not None: should_stop = sess.run(should_stop_op) else: should_stop = False return gen_loss + dis_loss, should_stop return sequential_train_steps
def start_train(): conf.is_training = True train_input = data_provider.get_stage_I_train_input_fn() condition, real_image = train_input() gan_model, gan_loss = get_model_and_loss(condition, real_image) gan_train_ops = tfgan.gan_train_ops( model=gan_model, loss=gan_loss, generator_optimizer=generator_optimizer, discriminator_optimizer=discriminator_optimizer) # generator : discrimination = 1:5 train_setp_fn = tfgan.get_sequential_train_steps( namedtuples.GANTrainSteps(1, 10)) with tf.Session() as sess: # get_saver saver = tf.train.Saver() if not tf.train.get_checkpoint_state(conf.stageI_model_path): init_op = tf.global_variables_initializer() sess.run(init_op) else: saver.restore(sess, tf.train.latest_checkpoint(conf.stageI_model_path)) train_writer = tf.summary.FileWriter(conf.stageI_model_path, sess.graph) merged = tf.summary.merge_all() step = sess.run(global_step) with slim.queues.QueueRunners(sess): for _ in range(conf.training_steps): # test data # data = sess.run(real_image) # data = visualize_data(data) # img = Image.fromarray(data, 'RGB') # img.show() # data = sess.run(gan_model.generator_inputs) # print(data) # step = step + 1 cur_loss, _ = train_setp_fn(sess, gan_train_ops, global_step, {}) tf.summary.scalar("loss", cur_loss) if step % 50 == 0: sumary = sess.run(merged) train_writer.add_summary(sumary, step) # save var if step % 200 == 0: saver.save(sess, conf.stageI_model_path, global_step) # visualize data if step % 1000 == 0: gen_data = sess.run(gan_model.generated_data) datas = visualize_data(gen_data) scipy.misc.toimage(datas).save('image/{}.jpg'.format(step))
status_message = tf.string_join([ 'Starting train step: ', tf.as_string(tf.train.get_or_create_global_step()) ], name='status_message') """tfgan.gan_train( train_ops, hooks=[tf.train.StopAtStepHook(num_steps = max_number_of_steps), tf.train.LoggingTensorHook([status_message], every_n_iter=10)], logdir=train_log_dir, get_hooks_fn=tfgan.get_joint_train_hooks())""" #train_step_fn = tfgan.get_sequential_train_steps() train_step_fn = lib.get_sequential_train_steps(\ namedtuples.GANTrainSteps(num_g_steps, num_d_steps)) global_step = tf.train.get_or_create_global_step() for var in tf.trainable_variables(): tf.summary.histogram(var.name, var) merged_summary_op = tf.summary.merge_all() loss_values = [] with tf.Session() as sess: sess.run(tf.global_variables_initializer()) summary_writer = tf.summary.FileWriter(logs_path, \ graph = tf.get_default_graph())
def start_train(): stageI_train_input = data_provider.get_stage_I_train_input_fn() condition, real_image = stageI_train_input() stageI_gan_model, _ = StageI.get_model_and_loss(condition, real_image) conf.is_training = True need_to_init = False condition, real_image = data_provider.get_stage_II_train_input_fn()() with tf.Session() as sess: # get_saver saver = tf.train.Saver() if tf.train.get_checkpoint_state(conf.stageII_model_path): saver.restore(sess, tf.train.latest_checkpoint(conf.stageII_model_path)) else: if not tf.train.get_checkpoint_state(conf.stageI_model_path): raise FileNotFoundError("StageI model not found!") else: saver.restore( sess, tf.train.latest_checkpoint(conf.stageI_model_path)) sI_var = tf.global_variables() need_to_init = True tf.assign(global_step, 0) with tf.variable_scope('Generator', reuse=True): gen_img = stageI_gan_model.generator_fn(condition) # StageI不参与训练 param = tf.get_collection_ref(tf.GraphKeys.UPDATE_OPS) del param[:] gen_input = {"gen_img": gen_img, "caption": condition["caption"]} stageII_gan_model, gan_loss = get_model_and_loss(gen_input, real_image) gan_train_ops = tfgan.gan_train_ops( model=stageII_gan_model, loss=gan_loss, generator_optimizer=generator_optimizer, discriminator_optimizer=discriminator_optimizer) if need_to_init: var_to_init = [x for x in tf.global_variables() if x not in sI_var] sess.run(tf.initialize_variables(var_to_init)) train_setp_fn = tfgan.get_sequential_train_steps( namedtuples.GANTrainSteps(1, 10)) train_writer = tf.summary.FileWriter(conf.stageII_model_path, sess.graph) merged = tf.summary.merge_all() step = sess.run(global_step) with slim.queues.QueueRunners(sess): for _ in range(conf.training_steps): # test data data = sess.run(real_image) data = visualize_data(data) img = Image.fromarray(data, 'RGB') img.show() data = sess.run(stageII_gan_model.generator_inputs) print(data) # step = step + 1 cur_loss, _ = train_setp_fn(sess, gan_train_ops, global_step, {}) tf.summary.scalar("loss", cur_loss) if step % 50 == 0: sumary = sess.run(merged) train_writer.add_summary(sumary, step) # save var if step % 200 == 0: saver.save(sess, conf.stageI_model_path, global_step) # visualize data if step % 1000 == 0: gen_data = sess.run(gan_model.generated_data) datas = visualize_data(gen_data) scipy.misc.toimage(datas).save('image/{}.jpg'.format(step))
def __init__( self, # Arguments to construct the `model_fn`. generator_fn=None, discriminator_fn=None, generator_loss_fn=None, discriminator_loss_fn=None, generator_optimizer=None, discriminator_optimizer=None, get_eval_metric_ops_fn=None, add_summaries=None, joint_train=False, gan_train_steps=tfgan_tuples.GANTrainSteps(1, 1), # TPUEstimator options. model_dir=None, config=None, params=None, use_tpu=True, train_batch_size=None, eval_batch_size=None, predict_batch_size=None, batch_axis=None, eval_on_tpu=True, export_to_tpu=True, warm_start_from=None): """Initializes a TPUGANEstimator instance. Args: generator_fn: A python function that takes a Tensor, Tensor list, or Tensor dictionary as inputs and returns the outputs of the GAN generator. See `TFGAN` for more details and examples. Additionally, if it has an argument called `mode`, the Estimator's `mode` will be passed in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch normalization. discriminator_fn: A python function that takes the output of `generator_fn` or real data in the GAN setup, and `generator_inputs`. Outputs a Tensor in the range [-inf, inf]. See `TFGAN` for more details and examples. generator_loss_fn: The loss function on the generator. Takes a `GANModel` tuple. discriminator_loss_fn: The loss function on the discriminator. Takes a `GANModel` tuple. generator_optimizer: The optimizer for generator updates, or a function that takes no arguments and returns an optimizer. This function will be called when the default graph is the `GANEstimator`'s graph, so utilities like `tf.contrib.framework.get_or_create_global_step` will work. discriminator_optimizer: Same as `generator_optimizer`, but for the discriminator updates. get_eval_metric_ops_fn: A function that takes a list of arguments and returns a dict of metric results keyed by name. The output of this function is passed into `tf.estimator.EstimatorSpec` during evaluation. The arguments must be: * generator_inputs * generated_data * real_data * discriminator_real_outputs * discriminator_gen_outputs add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`. This is ignored for jobs that run on TPU, such as the train job if `use_tpu` is `True` or the eval job if `eval_on_tpu` is `True`. joint_train: A Python boolean. If `True`, jointly train the generator and the discriminator. If `False`, sequentially train them. See `train.py` in TFGAN for more details on the differences between the two GAN training methods. gan_train_steps: A `tfgan.GANTrainSteps` named tuple describing the ratio of generator to discriminator steps. For now, only supports 1:1 training. model_dir: Same as `TPUEstimator`: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model. If `None`, the model_dir in `config` will be used if set. If both are set, they must be same. If both are `None`, a temporary directory will be used. config: Same as `TPUEstimator`: An `tpu_config.RunConfig` configuration object. Cannot be `None`. params: Same as `TPUEstimator`: An optional `dict` of hyper parameters that will be passed into `input_fn` and `model_fn`. Keys are names of parameters, values are basic python types. There are reserved keys for `TPUEstimator`, including 'batch_size'. use_tpu: Same as `TPUEstimator`: A bool indicating whether TPU support is enabled. Currently, TPU training and evaluation respect this bit, but eval_on_tpu can override execution of eval. See below. Predict still happens on CPU. train_batch_size: Same as `TPUEstimator`: An int representing the global training batch size. TPUEstimator transforms this global batch size to a per-shard batch size, as params['batch_size'], when calling `input_fn` and `model_fn`. Cannot be `None` if `use_tpu` is `True`. Must be divisible by total number of replicas. eval_batch_size: Same as `TPUEstimator`: An int representing evaluation batch size. Must be divisible by total number of replicas. predict_batch_size: Same as `TPUEstimator`: An int representing the prediction batch size. Must be divisible by total number of replicas. batch_axis: Same as `TPUEstimator`: A python tuple of int values describing how each tensor produced by the Estimator `input_fn` should be split across the TPU compute shards. For example, if your input_fn produced (images, labels) where the images tensor is in `HWCN` format, your shard dimensions would be [3, 0], where 3 corresponds to the `N` dimension of your images Tensor, and 0 corresponds to the dimension along which to split the labels to match up with the corresponding images. If None is supplied, and per_host_input_for_training is True, batches will be sharded based on the major dimension. If tpu_config.per_host_input_for_training is False or `PER_HOST_V2`, batch_axis is ignored. eval_on_tpu: Same as `TPUEstimator`: If False, evaluation runs on CPU or GPU. In this case, the model_fn must return `EstimatorSpec` when called with `mode` as `EVAL`. export_to_tpu: Same as `TPUEstimator`: If True, `export_savedmodel()` exports a metagraph for serving on TPU besides the one on CPU. warm_start_from: Same as `TPUEstimator`: Optional string filepath to a checkpoint or SavedModel to warm-start from, or a `tf.estimator.WarmStartSettings` object to fully configure warm-starting. If the string filepath is provided instead of a `WarmStartSettings`, then all variables are warm-started, and it is assumed that vocabularies and Tensor names are unchanged. Raises: ValueError: If loss functions aren't callable. ValueError: If `gan_train_steps` isn't a `tfgan_tuples.GANTrainSteps` tuple. ValueError: If `gan_train_steps` isn't 1:1 training. """ if not callable(generator_loss_fn): raise ValueError('generator_loss_fn must be callable.') if not callable(discriminator_loss_fn): raise ValueError('discriminator_loss_fn must be callable.') if not isinstance(gan_train_steps, tfgan_tuples.GANTrainSteps): raise ValueError( '`gan_train_steps` must be `tfgan_tuples.GANTrainSteps`. Instead, ' 'was type: %s' % type(gan_train_steps)) if (gan_train_steps.generator_train_steps != 1 or gan_train_steps.discriminator_train_steps != 1): raise ValueError('Estimator currently only supports 1:1 training.') if use_tpu: generator_optimizer = _maybe_make_cross_shard_optimizer( generator_optimizer) discriminator_optimizer = _maybe_make_cross_shard_optimizer( discriminator_optimizer) def _model_fn(features, labels, mode, params): """GANEstimator model function.""" del params # unused if mode not in [ model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL, model_fn_lib.ModeKeys.PREDICT ]: raise ValueError('Mode not recognized: %s' % mode) real_data = labels # rename inputs for clarity generator_inputs = features # rename inputs for clarity # Make GANModel, which encapsulates the GAN model architectures. # TODO(joelshor): Switch TF-GAN over to TPU-compatible summaries, then # remove `add_summaries` logic below. is_on_tpu = _is_on_tpu(mode, use_tpu, eval_on_tpu) gan_model = gan_estimator_lib._get_gan_model( # pylint:disable=protected-access mode, generator_fn, discriminator_fn, real_data, generator_inputs, add_summaries=None if is_on_tpu else add_summaries) # Make the TPUEstimatorSpec, which incorporates the GANModel, losses, eval # metrics, and optimizers (if required). estimator_spec = _get_estimator_spec( mode, gan_model, generator_loss_fn, discriminator_loss_fn, get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, joint_train, is_on_tpu, gan_train_steps) assert isinstance(estimator_spec, tpu_estimator.TPUEstimatorSpec) return estimator_spec super(TPUGANEstimator, self).__init__(model_fn=_model_fn, model_dir=model_dir, config=config, params=params, use_tpu=use_tpu, train_batch_size=train_batch_size, eval_batch_size=eval_batch_size, predict_batch_size=predict_batch_size, batch_axis=batch_axis, eval_on_tpu=eval_on_tpu, export_to_tpu=export_to_tpu, warm_start_from=warm_start_from)