コード例 #1
0
  def test_get_estimator_spec(self, mode, joint_train):
    with ops.Graph().as_default():
      self._gan_model = get_dummy_gan_model()
      spec = estimator._get_estimator_spec(
          mode,
          self._gan_model,
          generator_loss_fn=losses.wasserstein_generator_loss,
          discriminator_loss_fn=losses.wasserstein_discriminator_loss,
          get_eval_metric_ops_fn=get_metrics,
          generator_optimizer=self._generator_optimizer,
          discriminator_optimizer=self._discriminator_optimizer,
          joint_train=joint_train,
          is_on_tpu=FLAGS.use_tpu,
          gan_train_steps=tfgan_tuples.GANTrainSteps(1, 1))

    self.assertIsInstance(spec, tpu_estimator.TPUEstimatorSpec)
    self.assertEqual(mode, spec.mode)
    if mode == model_fn_lib.ModeKeys.PREDICT:
      self.assertEqual({'generated_data': self._gan_model.generated_data},
                       spec.predictions)
    elif mode == model_fn_lib.ModeKeys.TRAIN:
      self.assertShapeEqual(np.array(0), spec.loss)  # must be a scalar
      self.assertIsNotNone(spec.train_op)
      self.assertIsNotNone(spec.training_hooks)
    elif mode == model_fn_lib.ModeKeys.EVAL:
      self.assertEqual(self._gan_model.generated_data, spec.predictions)
      self.assertShapeEqual(np.array(0), spec.loss)  # must be a scalar
      self.assertIsNotNone(spec.eval_metrics)
コード例 #2
0
def get_joint_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)):
    """Returns a hooks function for sequential GAN training.

  When using these train hooks, IT IS RECOMMENDED TO USE `use_locking=True` ON
  ALL OPTIMIZERS TO AVOID RACE CONDITIONS.

  The order of steps taken is:
  1) Combined generator and discriminator steps
  2) Generator only steps, if any remain
  3) Discriminator only steps, if any remain

  **NOTE**: Unlike `get_sequential_train_hooks`, this method performs updates
  for the generator and discriminator simultaneously whenever possible. This
  reduces the number of `tf.Session` calls, and can also change the training
  semantics.

  To illustrate the difference look at the following example:

  `train_steps=namedtuples.GANTrainSteps(3, 5)` will cause
  `get_sequential_train_hooks` to make 8 session calls:
    1) 3 generator steps
    2) 5 discriminator steps

  In contrast, `get_joint_train_steps` will make 5 session calls:
  1) 3 generator + discriminator steps
  2) 2 discriminator steps

  Args:
    train_steps: A `GANTrainSteps` tuple that determines how many generator
      and discriminator training steps to take.

  Returns:
    A function that takes a GANTrainOps tuple and returns a list of hooks.
  """
    g_steps = train_steps.generator_train_steps
    d_steps = train_steps.discriminator_train_steps
    # Get the number of each type of step that should be run.
    num_d_and_g_steps = min(g_steps, d_steps)
    num_g_steps = g_steps - num_d_and_g_steps
    num_d_steps = d_steps - num_d_and_g_steps

    def get_hooks(train_ops):
        g_op = train_ops.generator_train_op
        d_op = train_ops.discriminator_train_op

        joint_hook = RunTrainOpsHook([g_op, d_op], num_d_and_g_steps)
        g_hook = RunTrainOpsHook(g_op, num_g_steps)
        d_hook = RunTrainOpsHook(d_op, num_d_steps)

        return [joint_hook, g_hook, d_hook]

    return get_hooks
コード例 #3
0
ファイル: train_test.py プロジェクト: zmk520/tensorflow
    def _test_multiple_steps_helper(self, get_hooks_fn_fn):
        train_ops = self._gan_train_ops(generator_add=10,
                                        discriminator_add=100)
        train_steps = namedtuples.GANTrainSteps(generator_train_steps=3,
                                                discriminator_train_steps=4)
        final_step = train.gan_train(
            train_ops,
            get_hooks_fn=get_hooks_fn_fn(train_steps),
            logdir='',
            hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=1)])

        self.assertTrue(np.isscalar(final_step))
        self.assertEqual(1 + 3 * 10 + 4 * 100, final_step)
コード例 #4
0
ファイル: train_test.py プロジェクト: zmk520/tensorflow
    def test_supervisor_run_gan_model_train_ops_multiple_steps(self):
        step = training_util.create_global_step()
        train_ops = namedtuples.GANTrainOps(
            generator_train_op=constant_op.constant(3.0),
            discriminator_train_op=constant_op.constant(2.0),
            global_step_inc_op=step.assign_add(1))
        train_steps = namedtuples.GANTrainSteps(generator_train_steps=3,
                                                discriminator_train_steps=4)

        final_loss = slim_learning.train(
            train_op=train_ops,
            logdir='',
            global_step=step,
            number_of_steps=1,
            train_step_fn=train.get_sequential_train_steps(train_steps))
        self.assertTrue(np.isscalar(final_loss))
        self.assertEqual(17.0, final_loss)
コード例 #5
0
def get_sequential_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)):
  """Returns a hooks function for sequential GAN training.

  Args:
    train_steps: A `GANTrainSteps` tuple that determines how many generator
      and discriminator training steps to take.

  Returns:
    A function that takes a GANTrainOps tuple and returns a list of hooks.
  """
  def get_hooks(train_ops):
    generator_hook = RunTrainOpsHook(train_ops.generator_train_op,
                                     train_steps.generator_train_steps)
    discriminator_hook = RunTrainOpsHook(train_ops.discriminator_train_op,
                                         train_steps.discriminator_train_steps)
    return [generator_hook, discriminator_hook]
  return get_hooks
コード例 #6
0
def get_sequential_train_steps(train_steps=namedtuples.GANTrainSteps(1, 1)):
    """Returns a thin wrapper around slim.learning.train_step, for GANs.

  This function is to provide support for the Supervisor. For new code, please
  use `MonitoredSession` and `get_sequential_train_hooks`.

  Args:
    train_steps: A `GANTrainSteps` tuple that determines how many generator
      and discriminator training steps to take.

  Returns:
    A function that can be used for `train_step_fn` for GANs.
  """
    def sequential_train_steps(sess, train_ops, global_step,
                               train_step_kwargs):
        """A thin wrapper around slim.learning.train_step, for GANs.

    Args:
      sess: A Tensorflow session.
      train_ops: A GANTrainOps tuple of train ops to run.
      global_step: The global step.
      train_step_kwargs: Dictionary controlling `train_step` behavior.

    Returns:
      A scalar final loss and a bool whether or not the train loop should stop.
    """
        # Only run `should_stop` at the end, if required. Make a local copy of
        # `train_step_kwargs`, if necessary, so as not to modify the caller's
        # dictionary.
        should_stop_op, train_kwargs = None, train_step_kwargs
        if 'should_stop' in train_step_kwargs:
            should_stop_op = train_step_kwargs['should_stop']
            train_kwargs = train_step_kwargs.copy()
            del train_kwargs['should_stop']

        # Run generator training steps.
        gen_loss = 0
        for _ in range(train_steps.generator_train_steps):
            cur_gen_loss, _ = slim_learning.train_step(
                sess, train_ops.generator_train_op, global_step, train_kwargs)
            gen_loss += cur_gen_loss

        # Run discriminator training steps.
        dis_loss = 0
        for _ in range(train_steps.discriminator_train_steps):
            cur_dis_loss, _ = slim_learning.train_step(
                sess, train_ops.discriminator_train_op, global_step,
                train_kwargs)
            dis_loss += cur_dis_loss

        sess.run(train_ops.global_step_inc_op)

        # Run the `should_stop` op after the global step has been incremented, so
        # that the `should_stop` aligns with the proper `global_step` count.
        if should_stop_op is not None:
            should_stop = sess.run(should_stop_op)
        else:
            should_stop = False

        return gen_loss + dis_loss, should_stop

    return sequential_train_steps
コード例 #7
0
def start_train():
    conf.is_training = True
    train_input = data_provider.get_stage_I_train_input_fn()
    condition, real_image = train_input()

    gan_model, gan_loss = get_model_and_loss(condition, real_image)

    gan_train_ops = tfgan.gan_train_ops(
        model=gan_model,
        loss=gan_loss,
        generator_optimizer=generator_optimizer,
        discriminator_optimizer=discriminator_optimizer)

    # generator : discrimination = 1:5
    train_setp_fn = tfgan.get_sequential_train_steps(
        namedtuples.GANTrainSteps(1, 10))

    with tf.Session() as sess:
        # get_saver
        saver = tf.train.Saver()

        if not tf.train.get_checkpoint_state(conf.stageI_model_path):
            init_op = tf.global_variables_initializer()
            sess.run(init_op)
        else:
            saver.restore(sess,
                          tf.train.latest_checkpoint(conf.stageI_model_path))

        train_writer = tf.summary.FileWriter(conf.stageI_model_path,
                                             sess.graph)
        merged = tf.summary.merge_all()
        step = sess.run(global_step)

        with slim.queues.QueueRunners(sess):
            for _ in range(conf.training_steps):
                # test data
                # data = sess.run(real_image)
                # data = visualize_data(data)
                # img = Image.fromarray(data, 'RGB')
                # img.show()
                # data = sess.run(gan_model.generator_inputs)
                # print(data)
                #
                step = step + 1

                cur_loss, _ = train_setp_fn(sess, gan_train_ops, global_step,
                                            {})
                tf.summary.scalar("loss", cur_loss)
                if step % 50 == 0:
                    sumary = sess.run(merged)
                    train_writer.add_summary(sumary, step)

                # save var
                if step % 200 == 0:
                    saver.save(sess, conf.stageI_model_path, global_step)

                # visualize data
                if step % 1000 == 0:
                    gen_data = sess.run(gan_model.generated_data)
                    datas = visualize_data(gen_data)
                    scipy.misc.toimage(datas).save('image/{}.jpg'.format(step))
コード例 #8
0
    status_message = tf.string_join([
        'Starting train step: ',
        tf.as_string(tf.train.get_or_create_global_step())
    ],
                                    name='status_message')
    """tfgan.gan_train(
        train_ops,
        hooks=[tf.train.StopAtStepHook(num_steps = max_number_of_steps),
            tf.train.LoggingTensorHook([status_message], every_n_iter=10)],
        logdir=train_log_dir,
        get_hooks_fn=tfgan.get_joint_train_hooks())"""

#train_step_fn = tfgan.get_sequential_train_steps()
train_step_fn = lib.get_sequential_train_steps(\
    namedtuples.GANTrainSteps(num_g_steps, num_d_steps))
global_step = tf.train.get_or_create_global_step()

for var in tf.trainable_variables():
    tf.summary.histogram(var.name, var)

merged_summary_op = tf.summary.merge_all()

loss_values = []

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    summary_writer = tf.summary.FileWriter(logs_path, \
        graph = tf.get_default_graph())
コード例 #9
0
ファイル: StageII.py プロジェクト: sssste/StackGAN
def start_train():
    stageI_train_input = data_provider.get_stage_I_train_input_fn()
    condition, real_image = stageI_train_input()
    stageI_gan_model, _ = StageI.get_model_and_loss(condition, real_image)
    conf.is_training = True
    need_to_init = False

    condition, real_image = data_provider.get_stage_II_train_input_fn()()

    with tf.Session() as sess:
        # get_saver
        saver = tf.train.Saver()

        if tf.train.get_checkpoint_state(conf.stageII_model_path):
            saver.restore(sess,
                          tf.train.latest_checkpoint(conf.stageII_model_path))
        else:
            if not tf.train.get_checkpoint_state(conf.stageI_model_path):
                raise FileNotFoundError("StageI model not found!")
            else:
                saver.restore(
                    sess, tf.train.latest_checkpoint(conf.stageI_model_path))
                sI_var = tf.global_variables()
                need_to_init = True
                tf.assign(global_step, 0)

        with tf.variable_scope('Generator', reuse=True):
            gen_img = stageI_gan_model.generator_fn(condition)

        # StageI不参与训练

        param = tf.get_collection_ref(tf.GraphKeys.UPDATE_OPS)
        del param[:]

        gen_input = {"gen_img": gen_img, "caption": condition["caption"]}

        stageII_gan_model, gan_loss = get_model_and_loss(gen_input, real_image)

        gan_train_ops = tfgan.gan_train_ops(
            model=stageII_gan_model,
            loss=gan_loss,
            generator_optimizer=generator_optimizer,
            discriminator_optimizer=discriminator_optimizer)

        if need_to_init:
            var_to_init = [x for x in tf.global_variables() if x not in sI_var]
            sess.run(tf.initialize_variables(var_to_init))

        train_setp_fn = tfgan.get_sequential_train_steps(
            namedtuples.GANTrainSteps(1, 10))

        train_writer = tf.summary.FileWriter(conf.stageII_model_path,
                                             sess.graph)
        merged = tf.summary.merge_all()
        step = sess.run(global_step)

        with slim.queues.QueueRunners(sess):
            for _ in range(conf.training_steps):
                # test data
                data = sess.run(real_image)
                data = visualize_data(data)
                img = Image.fromarray(data, 'RGB')
                img.show()
                data = sess.run(stageII_gan_model.generator_inputs)
                print(data)
                #
                step = step + 1

                cur_loss, _ = train_setp_fn(sess, gan_train_ops, global_step,
                                            {})
                tf.summary.scalar("loss", cur_loss)
                if step % 50 == 0:
                    sumary = sess.run(merged)
                    train_writer.add_summary(sumary, step)

                # save var
                if step % 200 == 0:
                    saver.save(sess, conf.stageI_model_path, global_step)

                # visualize data
                if step % 1000 == 0:
                    gen_data = sess.run(gan_model.generated_data)
                    datas = visualize_data(gen_data)
                    scipy.misc.toimage(datas).save('image/{}.jpg'.format(step))
コード例 #10
0
    def __init__(
            self,
            # Arguments to construct the `model_fn`.
            generator_fn=None,
            discriminator_fn=None,
            generator_loss_fn=None,
            discriminator_loss_fn=None,
            generator_optimizer=None,
            discriminator_optimizer=None,
            get_eval_metric_ops_fn=None,
            add_summaries=None,
            joint_train=False,
            gan_train_steps=tfgan_tuples.GANTrainSteps(1, 1),
            # TPUEstimator options.
            model_dir=None,
            config=None,
            params=None,
            use_tpu=True,
            train_batch_size=None,
            eval_batch_size=None,
            predict_batch_size=None,
            batch_axis=None,
            eval_on_tpu=True,
            export_to_tpu=True,
            warm_start_from=None):
        """Initializes a TPUGANEstimator instance.

    Args:
      generator_fn: A python function that takes a Tensor, Tensor list, or
        Tensor dictionary as inputs and returns the outputs of the GAN
        generator. See `TFGAN` for more details and examples. Additionally, if
        it has an argument called `mode`, the Estimator's `mode` will be passed
        in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch
        normalization.
      discriminator_fn: A python function that takes the output of
        `generator_fn` or real data in the GAN setup, and `generator_inputs`.
        Outputs a Tensor in the range [-inf, inf]. See `TFGAN` for more details
        and examples.
      generator_loss_fn: The loss function on the generator. Takes a `GANModel`
        tuple.
      discriminator_loss_fn: The loss function on the discriminator. Takes a
        `GANModel` tuple.
      generator_optimizer: The optimizer for generator updates, or a function
        that takes no arguments and returns an optimizer. This function will
        be called when the default graph is the `GANEstimator`'s graph, so
        utilities like `tf.contrib.framework.get_or_create_global_step` will
        work.
      discriminator_optimizer: Same as `generator_optimizer`, but for the
        discriminator updates.
      get_eval_metric_ops_fn: A function that takes a list of arguments and
        returns a dict of metric results keyed by name. The output of this
        function is passed into `tf.estimator.EstimatorSpec` during evaluation.
        The arguments must be:
            * generator_inputs
            * generated_data
            * real_data
            * discriminator_real_outputs
            * discriminator_gen_outputs
      add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`.
        This is ignored for jobs that run on TPU, such as the train job if
        `use_tpu` is `True` or the eval job if `eval_on_tpu` is `True`.
      joint_train: A Python boolean. If `True`, jointly train the generator and
        the discriminator. If `False`, sequentially train them. See `train.py`
        in TFGAN for more details on the differences between the two GAN
        training methods.
      gan_train_steps: A `tfgan.GANTrainSteps` named tuple describing the ratio
        of generator to discriminator steps. For now, only supports 1:1
        training.
      model_dir: Same as `TPUEstimator`: Directory to save model parameters,
        graph and etc. This can also be used to load checkpoints from the
        directory into a estimator to continue training a previously saved
        model. If `None`, the model_dir in `config` will be used if set. If both
        are set, they must be same. If both are `None`, a temporary directory
        will be used.
      config: Same as `TPUEstimator`: An `tpu_config.RunConfig` configuration
        object. Cannot be `None`.
      params: Same as `TPUEstimator`: An optional `dict` of hyper parameters
        that will be passed into `input_fn` and `model_fn`.  Keys are names of
        parameters, values are basic python types. There are reserved keys for
        `TPUEstimator`, including 'batch_size'.
      use_tpu: Same as `TPUEstimator`: A bool indicating whether TPU support is
        enabled. Currently, TPU training and evaluation respect this bit, but
        eval_on_tpu can override execution of eval. See below. Predict still
        happens on CPU.
      train_batch_size: Same as `TPUEstimator`: An int representing the global
        training batch size. TPUEstimator transforms this global batch size to a
        per-shard batch size, as params['batch_size'], when calling `input_fn`
        and `model_fn`. Cannot be `None` if `use_tpu` is `True`. Must be
        divisible by total number of replicas.
      eval_batch_size: Same as `TPUEstimator`: An int representing evaluation
        batch size. Must be divisible by total number of replicas.
      predict_batch_size: Same as `TPUEstimator`: An int representing the
        prediction batch size. Must be divisible by total number of replicas.
      batch_axis: Same as `TPUEstimator`: A python tuple of int values
        describing how each tensor produced by the Estimator `input_fn` should
        be split across the TPU compute shards. For example, if your input_fn
        produced (images, labels) where the images tensor is in `HWCN` format,
        your shard dimensions would be [3, 0], where 3 corresponds to the `N`
        dimension of your images Tensor, and 0 corresponds to the dimension
        along which to split the labels to match up with the corresponding
        images. If None is supplied, and per_host_input_for_training is True,
        batches will be sharded based on the major dimension. If
        tpu_config.per_host_input_for_training is False or `PER_HOST_V2`,
        batch_axis is ignored.
      eval_on_tpu: Same as `TPUEstimator`: If False, evaluation runs on CPU or
        GPU. In this case, the model_fn must return `EstimatorSpec` when called
        with `mode` as `EVAL`.
      export_to_tpu: Same as `TPUEstimator`: If True, `export_savedmodel()`
        exports a metagraph for serving on TPU besides the one on CPU.
      warm_start_from: Same as `TPUEstimator`: Optional string filepath to a
        checkpoint or SavedModel to warm-start from, or a
        `tf.estimator.WarmStartSettings` object to fully configure
        warm-starting.  If the string filepath is provided instead of a
        `WarmStartSettings`, then all variables are warm-started, and it is
        assumed that vocabularies and Tensor names are unchanged.

    Raises:
      ValueError: If loss functions aren't callable.
      ValueError: If `gan_train_steps` isn't a `tfgan_tuples.GANTrainSteps`
        tuple.
      ValueError: If `gan_train_steps` isn't 1:1 training.
    """
        if not callable(generator_loss_fn):
            raise ValueError('generator_loss_fn must be callable.')
        if not callable(discriminator_loss_fn):
            raise ValueError('discriminator_loss_fn must be callable.')
        if not isinstance(gan_train_steps, tfgan_tuples.GANTrainSteps):
            raise ValueError(
                '`gan_train_steps` must be `tfgan_tuples.GANTrainSteps`. Instead, '
                'was type: %s' % type(gan_train_steps))
        if (gan_train_steps.generator_train_steps != 1
                or gan_train_steps.discriminator_train_steps != 1):
            raise ValueError('Estimator currently only supports 1:1 training.')

        if use_tpu:
            generator_optimizer = _maybe_make_cross_shard_optimizer(
                generator_optimizer)
            discriminator_optimizer = _maybe_make_cross_shard_optimizer(
                discriminator_optimizer)

        def _model_fn(features, labels, mode, params):
            """GANEstimator model function."""
            del params  # unused
            if mode not in [
                    model_fn_lib.ModeKeys.TRAIN, model_fn_lib.ModeKeys.EVAL,
                    model_fn_lib.ModeKeys.PREDICT
            ]:
                raise ValueError('Mode not recognized: %s' % mode)
            real_data = labels  # rename inputs for clarity
            generator_inputs = features  # rename inputs for clarity

            # Make GANModel, which encapsulates the GAN model architectures.
            # TODO(joelshor): Switch TF-GAN over to TPU-compatible summaries, then
            # remove `add_summaries` logic below.
            is_on_tpu = _is_on_tpu(mode, use_tpu, eval_on_tpu)
            gan_model = gan_estimator_lib._get_gan_model(  # pylint:disable=protected-access
                mode,
                generator_fn,
                discriminator_fn,
                real_data,
                generator_inputs,
                add_summaries=None if is_on_tpu else add_summaries)

            # Make the TPUEstimatorSpec, which incorporates the GANModel, losses, eval
            # metrics, and optimizers (if required).
            estimator_spec = _get_estimator_spec(
                mode, gan_model, generator_loss_fn, discriminator_loss_fn,
                get_eval_metric_ops_fn, generator_optimizer,
                discriminator_optimizer, joint_train, is_on_tpu,
                gan_train_steps)
            assert isinstance(estimator_spec, tpu_estimator.TPUEstimatorSpec)
            return estimator_spec

        super(TPUGANEstimator,
              self).__init__(model_fn=_model_fn,
                             model_dir=model_dir,
                             config=config,
                             params=params,
                             use_tpu=use_tpu,
                             train_batch_size=train_batch_size,
                             eval_batch_size=eval_batch_size,
                             predict_batch_size=predict_batch_size,
                             batch_axis=batch_axis,
                             eval_on_tpu=eval_on_tpu,
                             export_to_tpu=export_to_tpu,
                             warm_start_from=warm_start_from)