def _get_estimator_spec(mode,
                        gan_model,
                        loss_fn,
                        get_eval_metric_ops_fn,
                        generator_optimizer,
                        discriminator_optimizer,
                        get_hooks_fn=None):
  """Get the EstimatorSpec for the current mode."""
  if mode == model_fn_lib.ModeKeys.PREDICT:
    estimator_spec = model_fn_lib.EstimatorSpec(
        mode=mode, predictions=gan_model.generated_data)
  else:
    gan_loss = loss_fn(gan_model)
    if mode == model_fn_lib.ModeKeys.EVAL:
      estimator_spec = _get_eval_estimator_spec(gan_model, gan_loss,
                                                get_eval_metric_ops_fn)
    else:  # model_fn_lib.ModeKeys.TRAIN:
      gopt = (
          generator_optimizer()
          if callable(generator_optimizer) else generator_optimizer)
      dopt = (
          discriminator_optimizer()
          if callable(discriminator_optimizer) else discriminator_optimizer)
      get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks()
      estimator_spec = _get_train_estimator_spec(gan_model, gan_loss, gopt,
                                                 dopt, get_hooks_fn)

  return estimator_spec
Exemplo n.º 2
0
def _get_estimator_spec(mode,
                        gan_model,
                        loss_fn,
                        get_eval_metric_ops_fn,
                        generator_optimizer,
                        discriminator_optimizer,
                        get_hooks_fn=None):
    """Get the EstimatorSpec for the current mode."""
    if mode == model_fn_lib.ModeKeys.PREDICT:
        estimator_spec = model_fn_lib.EstimatorSpec(
            mode=mode, predictions=gan_model.generated_data)
    else:
        gan_loss = loss_fn(gan_model)
        if mode == model_fn_lib.ModeKeys.EVAL:
            estimator_spec = _get_eval_estimator_spec(gan_model, gan_loss,
                                                      get_eval_metric_ops_fn)
        else:  # model_fn_lib.ModeKeys.TRAIN:
            gopt = (generator_optimizer()
                    if callable(generator_optimizer) else generator_optimizer)
            dopt = (discriminator_optimizer()
                    if callable(discriminator_optimizer) else
                    discriminator_optimizer)
            get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks(
            )
            estimator_spec = _get_train_estimator_spec(gan_model, gan_loss,
                                                       gopt, dopt,
                                                       get_hooks_fn)

    return estimator_spec
Exemplo n.º 3
0
def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer,
             discriminator_optimizer, use_loss_summaries=True,
             get_hooks_fn=tfgan_train.get_sequential_train_hooks(),
             name=None):
  """Creates a `GANHead`.

  Args:
    generator_loss_fn: A TFGAN loss function for the generator. Takes a
      `GANModel` and returns a scalar.
    discriminator_loss_fn: Same as `generator_loss_fn`, but for the
      discriminator.
    generator_optimizer: The optimizer for generator updates.
    discriminator_optimizer: Same as `generator_optimizer`, but for the
      discriminator updates.
    use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
        If `None`, uses defaults.
    get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list
        of hooks.
    name: name of the head. If provided, summary and metrics keys will be
      suffixed by `"/" + name`.

  Returns:
    An instance of `GANHead`.
  """
  return GANHead(generator_loss_fn=generator_loss_fn,
                 discriminator_loss_fn=discriminator_loss_fn,
                 generator_optimizer=generator_optimizer,
                 discriminator_optimizer=discriminator_optimizer,
                 use_loss_summaries=use_loss_summaries,
                 get_hooks_fn=get_hooks_fn,
                 name=name)
Exemplo n.º 4
0
def _get_estimator_spec(
    mode, gan_model, generator_loss_fn, discriminator_loss_fn,
    get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer,
    get_hooks_fn=None, use_loss_summaries=True, is_chief=True):
  """Get the EstimatorSpec for the current mode."""
  if mode == model_fn_lib.ModeKeys.PREDICT:
    estimator_spec = model_fn_lib.EstimatorSpec(
        mode=mode, predictions=gan_model.generated_data)
  else:
    gan_loss = tfgan_tuples.GANLoss(
        generator_loss=generator_loss_fn(
            gan_model, add_summaries=use_loss_summaries),
        discriminator_loss=discriminator_loss_fn(
            gan_model, add_summaries=use_loss_summaries))
    if mode == model_fn_lib.ModeKeys.EVAL:
      estimator_spec = _get_eval_estimator_spec(
          gan_model, gan_loss, get_eval_metric_ops_fn)
    else:  # model_fn_lib.ModeKeys.TRAIN:
      gopt = (generator_optimizer() if callable(generator_optimizer) else
              generator_optimizer)
      dopt = (discriminator_optimizer() if callable(discriminator_optimizer)
              else discriminator_optimizer)
      get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks()
      estimator_spec = _get_train_estimator_spec(
          gan_model, gan_loss, gopt, dopt, get_hooks_fn, is_chief=is_chief)

  return estimator_spec
Exemplo n.º 5
0
def gan_head(generator_loss_fn,
             discriminator_loss_fn,
             generator_optimizer,
             discriminator_optimizer,
             use_loss_summaries=True,
             get_hooks_fn=tfgan_train.get_sequential_train_hooks(),
             name=None):
    """Creates a `GANHead`.

  Args:
    generator_loss_fn: A TFGAN loss function for the generator. Takes a
      `GANModel` and returns a scalar.
    discriminator_loss_fn: Same as `generator_loss_fn`, but for the
      discriminator.
    generator_optimizer: The optimizer for generator updates.
    discriminator_optimizer: Same as `generator_optimizer`, but for the
      discriminator updates.
    use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
        If `None`, uses defaults.
    get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list
        of hooks.
    name: name of the head. If provided, summary and metrics keys will be
      suffixed by `"/" + name`.

  Returns:
    An instance of `GANHead`.
  """
    return GANHead(generator_loss_fn=generator_loss_fn,
                   discriminator_loss_fn=discriminator_loss_fn,
                   generator_optimizer=generator_optimizer,
                   discriminator_optimizer=discriminator_optimizer,
                   use_loss_summaries=use_loss_summaries,
                   get_hooks_fn=get_hooks_fn,
                   name=name)
Exemplo n.º 6
0
    def test_train_hooks_exist_in_get_hooks_fn(self, create_gan_model_fn):
        model = create_gan_model_fn()
        loss = train.gan_loss(model)

        g_opt = get_sync_optimizer()
        d_opt = get_sync_optimizer()
        train_ops = train.gan_train_ops(model,
                                        loss,
                                        g_opt,
                                        d_opt,
                                        summarize_gradients=True,
                                        colocate_gradients_with_ops=True)

        sequential_train_hooks = train.get_sequential_train_hooks()(train_ops)
        self.assertLen(sequential_train_hooks, 4)
        sync_opts = [
            hook._sync_optimizer
            for hook in sequential_train_hooks if isinstance(
                hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)
        ]
        self.assertLen(sync_opts, 2)
        self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))

        joint_train_hooks = train.get_joint_train_hooks()(train_ops)
        self.assertLen(joint_train_hooks, 5)
        sync_opts = [
            hook._sync_optimizer for hook in joint_train_hooks if isinstance(
                hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)
        ]
        self.assertLen(sync_opts, 2)
        self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))
Exemplo n.º 7
0
def _get_estimator_spec(
    mode, gan_model, generator_loss_fn, discriminator_loss_fn,
    get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer,
    get_hooks_fn=None, use_loss_summaries=True, is_chief=True):
  """Get the EstimatorSpec for the current mode."""
  if mode == model_fn_lib.ModeKeys.PREDICT:
    estimator_spec = model_fn_lib.EstimatorSpec(
        mode=mode, predictions=gan_model.generated_data)
  else:
    gan_loss = tfgan_tuples.GANLoss(
        generator_loss=generator_loss_fn(
            gan_model, add_summaries=use_loss_summaries),
        discriminator_loss=discriminator_loss_fn(
            gan_model, add_summaries=use_loss_summaries))
    if mode == model_fn_lib.ModeKeys.EVAL:
      estimator_spec = _get_eval_estimator_spec(
          gan_model, gan_loss, get_eval_metric_ops_fn)
    else:  # model_fn_lib.ModeKeys.TRAIN:
      if callable(generator_optimizer):
        generator_optimizer = generator_optimizer()
      if callable(discriminator_optimizer):
        discriminator_optimizer = discriminator_optimizer()
      get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks()
      estimator_spec = _get_train_estimator_spec(
          gan_model, gan_loss, generator_optimizer, discriminator_optimizer,
          get_hooks_fn, is_chief=is_chief)

  return estimator_spec
Exemplo n.º 8
0
  def test_train_hooks_exist_in_get_hooks_fn(self, create_gan_model_fn):
    model = create_gan_model_fn()
    loss = train.gan_loss(model)

    g_opt = get_sync_optimizer()
    d_opt = get_sync_optimizer()
    train_ops = train.gan_train_ops(
        model,
        loss,
        g_opt,
        d_opt,
        summarize_gradients=True,
        colocate_gradients_with_ops=True)

    sequential_train_hooks = train.get_sequential_train_hooks()(train_ops)
    self.assertLen(sequential_train_hooks, 4)
    sync_opts = [
        hook._sync_optimizer for hook in sequential_train_hooks if
        isinstance(hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)]
    self.assertLen(sync_opts, 2)
    self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))

    joint_train_hooks = train.get_joint_train_hooks()(train_ops)
    self.assertLen(joint_train_hooks, 5)
    sync_opts = [
        hook._sync_optimizer for hook in joint_train_hooks if
        isinstance(hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)]
    self.assertLen(sync_opts, 2)
    self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))
Exemplo n.º 9
0
    def __init__(self,
                 generator_loss_fn,
                 discriminator_loss_fn,
                 generator_optimizer,
                 discriminator_optimizer,
                 use_loss_summaries=True,
                 get_hooks_fn=None,
                 get_eval_metric_ops_fn=None,
                 name=None):
        """`Head` for GAN training.

    Args:
      generator_loss_fn: A TFGAN loss function for the generator. Takes a
        `GANModel` and returns a scalar.
      discriminator_loss_fn: Same as `generator_loss_fn`, but for the
      discriminator.
      generator_optimizer: The optimizer for generator updates.
      discriminator_optimizer: Same as `generator_optimizer`, but for the
        discriminator updates.
      use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
        If `None`, uses defaults.
      get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a
        list of hooks. Defaults to `train.get_sequential_train_hooks()`
      get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a
        dict of metric results keyed by name. The output of this function is
        passed into `tf.estimator.EstimatorSpec` during evaluation.
      name: name of the head. If provided, summary and metrics keys will be
        suffixed by `"/" + name`.
    """

        if not callable(generator_loss_fn):
            raise TypeError('generator_loss_fn must be callable.')
        if not callable(discriminator_loss_fn):
            raise TypeError('discriminator_loss_fn must be callable.')
        if use_loss_summaries not in [True, False, None]:
            raise ValueError('use_loss_summaries must be True, False or None.')
        if get_hooks_fn is not None and not callable(get_hooks_fn):
            raise TypeError('get_hooks_fn must be callable.')
        if name is not None and not isinstance(name, str):
            raise TypeError('name must be string.')

        if get_hooks_fn is None:
            get_hooks_fn = tfgan_train.get_sequential_train_hooks()

        if use_loss_summaries in [True, False]:
            generator_loss_fn = functools.partial(
                generator_loss_fn, add_summaries=use_loss_summaries)
            discriminator_loss_fn = functools.partial(
                discriminator_loss_fn, add_summaries=use_loss_summaries)
        self._generator_loss_fn = generator_loss_fn
        self._discriminator_loss_fn = discriminator_loss_fn
        self._generator_optimizer = generator_optimizer
        self._discriminator_optimizer = discriminator_optimizer
        self._get_hooks_fn = get_hooks_fn
        self._get_eval_metric_ops_fn = get_eval_metric_ops_fn
        self._name = name
Exemplo n.º 10
0
  def __init__(self, generator_loss_fn, discriminator_loss_fn,
               generator_optimizer, discriminator_optimizer,
               use_loss_summaries=True,
               get_hooks_fn=None,
               get_eval_metric_ops_fn=None,
               name=None):
    """`Head` for GAN training.

    Args:
      generator_loss_fn: A TFGAN loss function for the generator. Takes a
        `GANModel` and returns a scalar.
      discriminator_loss_fn: Same as `generator_loss_fn`, but for the
      discriminator.
      generator_optimizer: The optimizer for generator updates.
      discriminator_optimizer: Same as `generator_optimizer`, but for the
        discriminator updates.
      use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
        If `None`, uses defaults.
      get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a
        list of hooks. Defaults to `train.get_sequential_train_hooks()`
      get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a
        dict of metric results keyed by name. The output of this function is
        passed into `tf.estimator.EstimatorSpec` during evaluation.
      name: name of the head. If provided, summary and metrics keys will be
        suffixed by `"/" + name`.
    """

    if not callable(generator_loss_fn):
      raise TypeError('generator_loss_fn must be callable.')
    if not callable(discriminator_loss_fn):
      raise TypeError('discriminator_loss_fn must be callable.')
    if use_loss_summaries not in [True, False, None]:
      raise ValueError('use_loss_summaries must be True, False or None.')
    if get_hooks_fn is not None and not callable(get_hooks_fn):
      raise TypeError('get_hooks_fn must be callable.')
    if name is not None and not isinstance(name, str):
      raise TypeError('name must be string.')

    if get_hooks_fn is None:
      get_hooks_fn = tfgan_train.get_sequential_train_hooks()

    if use_loss_summaries in [True, False]:
      generator_loss_fn = functools.partial(
          generator_loss_fn, add_summaries=use_loss_summaries)
      discriminator_loss_fn = functools.partial(
          discriminator_loss_fn, add_summaries=use_loss_summaries)
    self._generator_loss_fn = generator_loss_fn
    self._discriminator_loss_fn = discriminator_loss_fn
    self._generator_optimizer = generator_optimizer
    self._discriminator_optimizer = discriminator_optimizer
    self._get_hooks_fn = get_hooks_fn
    self._get_eval_metric_ops_fn = get_eval_metric_ops_fn
    self._name = name
Exemplo n.º 11
0
    def __init__(self,
                 generator_loss_fn,
                 discriminator_loss_fn,
                 generator_optimizer,
                 discriminator_optimizer,
                 use_loss_summaries=True,
                 get_hooks_fn=None,
                 name=None):
        """`Head` for GAN training.

    Args:
      generator_loss_fn: A TFGAN loss function for the generator. Takes a
        `GANModel` and returns a scalar.
      discriminator_loss_fn: Same as `generator_loss_fn`, but for the
      discriminator.
      generator_optimizer: The optimizer for generator updates.
      discriminator_optimizer: Same as `generator_optimizer`, but for the
        discriminator updates.
      use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
        If `None`, uses defaults.
      get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list
        of hooks. Defaults to `train.get_sequential_train_hooks()`
      name: name of the head. If provided, summary and metrics keys will be
        suffixed by `"/" + name`.
    """
        if get_hooks_fn is None:
            get_hooks_fn = tfgan_train.get_sequential_train_hooks()
        # TODO(joelshor): Validate inputs.

        if use_loss_summaries in [True, False]:
            generator_loss_fn = functools.partial(
                generator_loss_fn, add_summaries=use_loss_summaries)
            discriminator_loss_fn = functools.partial(
                discriminator_loss_fn, add_summaries=use_loss_summaries)
        self._generator_loss_fn = generator_loss_fn
        self._discriminator_loss_fn = discriminator_loss_fn
        self._generator_optimizer = generator_optimizer
        self._discriminator_optimizer = discriminator_optimizer
        self._get_hooks_fn = get_hooks_fn
        self._name = name
Exemplo n.º 12
0
  def __init__(self, generator_loss_fn, discriminator_loss_fn,
               generator_optimizer, discriminator_optimizer,
               use_loss_summaries=True,
               get_hooks_fn=None,
               name=None):
    """`Head` for GAN training.

    Args:
      generator_loss_fn: A TFGAN loss function for the generator. Takes a
        `GANModel` and returns a scalar.
      discriminator_loss_fn: Same as `generator_loss_fn`, but for the
      discriminator.
      generator_optimizer: The optimizer for generator updates.
      discriminator_optimizer: Same as `generator_optimizer`, but for the
        discriminator updates.
      use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
        If `None`, uses defaults.
      get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list
        of hooks. Defaults to `train.get_sequential_train_hooks()`
      name: name of the head. If provided, summary and metrics keys will be
        suffixed by `"/" + name`.
    """
    if get_hooks_fn is None:
      get_hooks_fn = tfgan_train.get_sequential_train_hooks()
    # TODO(joelshor): Validate inputs.

    if use_loss_summaries in [True, False]:
      generator_loss_fn = functools.partial(
          generator_loss_fn, add_summaries=use_loss_summaries)
      discriminator_loss_fn = functools.partial(
          discriminator_loss_fn, add_summaries=use_loss_summaries)
    self._generator_loss_fn = generator_loss_fn
    self._discriminator_loss_fn = discriminator_loss_fn
    self._generator_optimizer = generator_optimizer
    self._discriminator_optimizer = discriminator_optimizer
    self._get_hooks_fn = get_hooks_fn
Exemplo n.º 13
0
def model_fn(features, labels, mode, params):
    is_chief = not tf.get_variable_scope().reuse

    batch_size = tf.shape(labels)[0]
    noise = tf.random_normal([batch_size, FLAGS.emb_dim])
    noise = tf.nn.l2_normalize(noise, axis=1)
    gan_model = tfgan.gan_model(generator_fn=generator,
                                discriminator_fn=discriminator,
                                real_data=features[:, 1:],
                                generator_inputs=(noise, labels - 1),
                                check_shapes=False)
    if is_chief:
        for variable in tf.trainable_variables():
            tf.summary.histogram(variable.op.name, variable)
        tf.summary.histogram('logits/gen_logits',
                             gan_model.discriminator_gen_outputs[0])
        tf.summary.histogram('logits/real_logits',
                             gan_model.discriminator_real_outputs[0])

    def gen_loss_fn(gan_model, add_summaries):
        return 0

    def dis_loss_fn(gan_model, add_summaries):
        discriminator_real_outputs = gan_model.discriminator_real_outputs
        discriminator_gen_outputs = gan_model.discriminator_gen_outputs
        real_logits = tf.boolean_mask(discriminator_real_outputs[0],
                                      discriminator_real_outputs[1])
        gen_logits = tf.boolean_mask(discriminator_gen_outputs[0],
                                     discriminator_gen_outputs[1])
        return modified_discriminator_loss(real_logits,
                                           gen_logits,
                                           add_summaries=add_summaries)

    with tf.name_scope('losses'):
        gan_loss = tfgan.gan_loss(
            gan_model,
            generator_loss_fn=gen_loss_fn,
            discriminator_loss_fn=dis_loss_fn,
            gradient_penalty_weight=10 if FLAGS.wass else 0,
            add_summaries=is_chief)
        if is_chief:
            tfgan.eval.add_regularization_loss_summaries(gan_model)
    gan_loss = rl_loss(gan_model, gan_loss, add_summaries=is_chief)
    loss = gan_loss.generator_loss + gan_loss.discriminator_loss

    with tf.name_scope('train'):
        gen_opt = tf.train.AdamOptimizer(params.gen_lr, 0.5)
        dis_opt = tf.train.AdamOptimizer(params.dis_lr, 0.5)
        if params.multi_gpu:
            gen_opt = tf.contrib.estimator.TowerOptimizer(gen_opt)
            dis_opt = tf.contrib.estimator.TowerOptimizer(dis_opt)
        train_ops = tfgan.gan_train_ops(
            gan_model,
            gan_loss,
            generator_optimizer=gen_opt,
            discriminator_optimizer=dis_opt,
            transform_grads_fn=transform_grads_fn,
            summarize_gradients=is_chief,
            check_for_unused_update_ops=True,
            aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
        train_op = train_ops.global_step_inc_op
        train_hooks = get_sequential_train_hooks()(train_ops)

    if is_chief:
        with open('data/word_counts.txt', 'r') as f:
            dic = list(f)
            dic = [i.split()[0] for i in dic]
            dic.append('<unk>')
            dic = tf.convert_to_tensor(dic)
        sentence = crop_sentence(gan_model.generated_data[0][0], FLAGS.end_id)
        sentence = tf.gather(dic, sentence)
        real = crop_sentence(gan_model.real_data[0], FLAGS.end_id)
        real = tf.gather(dic, real)
        train_hooks.append(
            tf.train.LoggingTensorHook({
                'fake': sentence,
                'real': real
            },
                                       every_n_iter=100))
        tf.summary.text('fake', sentence)

    gen_var = tf.trainable_variables('Generator')
    dis_var = []
    dis_var.extend(tf.trainable_variables('Discriminator/rnn'))
    dis_var.extend(tf.trainable_variables('Discriminator/embedding'))
    saver = tf.train.Saver(gen_var + dis_var)

    def init_fn(scaffold, session):
        saver.restore(session, FLAGS.sae_ckpt)
        pass

    scaffold = tf.train.Scaffold(init_fn=init_fn)

    return tf.estimator.EstimatorSpec(mode=mode,
                                      loss=loss,
                                      train_op=train_op,
                                      scaffold=scaffold,
                                      training_hooks=train_hooks)
def model_fn(features, labels, mode, params):
    """The full unsupervised captioning model."""
    is_chief = not tf.get_variable_scope().reuse

    with slim.arg_scope(inception_v4.inception_v4_arg_scope()):
        net, _ = inception_v4.inception_v4(features['im'],
                                           None,
                                           is_training=False)
    net = tf.squeeze(net, [1, 2])
    inc_saver = tf.train.Saver(tf.global_variables('InceptionV4'))

    gan_model = tfgan.gan_model(generator_fn=generator,
                                discriminator_fn=discriminator,
                                real_data=labels['sentence'][:, 1:],
                                generator_inputs=(net, labels['len'] - 1),
                                check_shapes=False)

    if is_chief:
        for variable in tf.trainable_variables():
            tf.summary.histogram(variable.op.name, variable)
        tf.summary.histogram('logits/gen_logits',
                             gan_model.discriminator_gen_outputs[0])
        tf.summary.histogram('logits/real_logits',
                             gan_model.discriminator_real_outputs[0])

    def gen_loss_fn(gan_model, add_summaries):
        return 0

    def dis_loss_fn(gan_model, add_summaries):
        discriminator_real_outputs = gan_model.discriminator_real_outputs
        discriminator_gen_outputs = gan_model.discriminator_gen_outputs
        real_logits = tf.boolean_mask(discriminator_real_outputs[0],
                                      discriminator_real_outputs[1])
        gen_logits = tf.boolean_mask(discriminator_gen_outputs[0],
                                     discriminator_gen_outputs[1])
        return modified_discriminator_loss(real_logits,
                                           gen_logits,
                                           add_summaries=add_summaries)

    with tf.name_scope('losses'):
        pool_fn = functools.partial(tfgan.features.tensor_pool,
                                    pool_size=FLAGS.pool_size)
        gan_loss = tfgan.gan_loss(
            gan_model,
            generator_loss_fn=gen_loss_fn,
            discriminator_loss_fn=dis_loss_fn,
            gradient_penalty_weight=10 if FLAGS.wass else 0,
            tensor_pool_fn=pool_fn if FLAGS.use_pool else None,
            add_summaries=is_chief)
        if is_chief:
            tfgan.eval.add_regularization_loss_summaries(gan_model)
    gan_loss = rl_loss(gan_model,
                       gan_loss,
                       features['classes'],
                       features['scores'],
                       features['num'],
                       add_summaries=is_chief)
    sen_ae_loss = sentence_ae(gan_model, features, labels, is_chief)
    loss = gan_loss.generator_loss + gan_loss.discriminator_loss + sen_ae_loss
    gan_loss = gan_loss._replace(generator_loss=gan_loss.generator_loss +
                                 sen_ae_loss)

    with tf.name_scope('train'):
        gen_opt = tf.train.AdamOptimizer(params.gen_lr, 0.5)
        dis_opt = tf.train.AdamOptimizer(params.dis_lr, 0.5)
        if params.multi_gpu:
            gen_opt = tf.contrib.estimator.TowerOptimizer(gen_opt)
            dis_opt = tf.contrib.estimator.TowerOptimizer(dis_opt)
        train_ops = tfgan.gan_train_ops(
            gan_model,
            gan_loss,
            generator_optimizer=gen_opt,
            discriminator_optimizer=dis_opt,
            transform_grads_fn=transform_grads_fn,
            summarize_gradients=is_chief,
            check_for_unused_update_ops=not FLAGS.use_pool,
            aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
        train_op = train_ops.global_step_inc_op
        train_hooks = get_sequential_train_hooks()(train_ops)

    # Summary the generated caption on the fly.
    if is_chief:
        with open('data/word_counts.txt', 'r') as f:
            dic = list(f)
            dic = [i.split()[0] for i in dic]
            dic.append('<unk>')
            dic = tf.convert_to_tensor(dic)
        sentence = crop_sentence(gan_model.generated_data[0][0], FLAGS.end_id)
        sentence = tf.gather(dic, sentence)
        real = crop_sentence(gan_model.real_data[0], FLAGS.end_id)
        real = tf.gather(dic, real)
        train_hooks.append(
            tf.train.LoggingTensorHook({
                'fake': sentence,
                'real': real
            },
                                       every_n_iter=100))
        tf.summary.text('fake', sentence)
        tf.summary.image('im', features['im'][None, 0])

    gen_saver = tf.train.Saver(tf.trainable_variables('Generator'))
    dis_var = []
    dis_var.extend(tf.trainable_variables('Discriminator/rnn'))
    dis_var.extend(tf.trainable_variables('Discriminator/embedding'))
    dis_var.extend(tf.trainable_variables('Discriminator/fc'))
    dis_saver = tf.train.Saver(dis_var)

    def init_fn(scaffold, session):
        inc_saver.restore(session, FLAGS.inc_ckpt)
        if FLAGS.imcap_ckpt:
            gen_saver.restore(session, FLAGS.imcap_ckpt)
        if FLAGS.sae_ckpt:
            dis_saver.restore(session, FLAGS.sae_ckpt)

    scaffold = tf.train.Scaffold(init_fn=init_fn)

    return tf.estimator.EstimatorSpec(mode=mode,
                                      loss=loss,
                                      train_op=train_op,
                                      scaffold=scaffold,
                                      training_hooks=train_hooks)