def _get_estimator_spec(mode, gan_model, loss_fn, get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, get_hooks_fn=None): """Get the EstimatorSpec for the current mode.""" if mode == model_fn_lib.ModeKeys.PREDICT: estimator_spec = model_fn_lib.EstimatorSpec( mode=mode, predictions=gan_model.generated_data) else: gan_loss = loss_fn(gan_model) if mode == model_fn_lib.ModeKeys.EVAL: estimator_spec = _get_eval_estimator_spec(gan_model, gan_loss, get_eval_metric_ops_fn) else: # model_fn_lib.ModeKeys.TRAIN: gopt = ( generator_optimizer() if callable(generator_optimizer) else generator_optimizer) dopt = ( discriminator_optimizer() if callable(discriminator_optimizer) else discriminator_optimizer) get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks() estimator_spec = _get_train_estimator_spec(gan_model, gan_loss, gopt, dopt, get_hooks_fn) return estimator_spec
def _get_estimator_spec(mode, gan_model, loss_fn, get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, get_hooks_fn=None): """Get the EstimatorSpec for the current mode.""" if mode == model_fn_lib.ModeKeys.PREDICT: estimator_spec = model_fn_lib.EstimatorSpec( mode=mode, predictions=gan_model.generated_data) else: gan_loss = loss_fn(gan_model) if mode == model_fn_lib.ModeKeys.EVAL: estimator_spec = _get_eval_estimator_spec(gan_model, gan_loss, get_eval_metric_ops_fn) else: # model_fn_lib.ModeKeys.TRAIN: gopt = (generator_optimizer() if callable(generator_optimizer) else generator_optimizer) dopt = (discriminator_optimizer() if callable(discriminator_optimizer) else discriminator_optimizer) get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks( ) estimator_spec = _get_train_estimator_spec(gan_model, gan_loss, gopt, dopt, get_hooks_fn) return estimator_spec
def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer, use_loss_summaries=True, get_hooks_fn=tfgan_train.get_sequential_train_hooks(), name=None): """Creates a `GANHead`. Args: generator_loss_fn: A TFGAN loss function for the generator. Takes a `GANModel` and returns a scalar. discriminator_loss_fn: Same as `generator_loss_fn`, but for the discriminator. generator_optimizer: The optimizer for generator updates. discriminator_optimizer: Same as `generator_optimizer`, but for the discriminator updates. use_loss_summaries: If `True`, add loss summaries. If `False`, does not. If `None`, uses defaults. get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list of hooks. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Returns: An instance of `GANHead`. """ return GANHead(generator_loss_fn=generator_loss_fn, discriminator_loss_fn=discriminator_loss_fn, generator_optimizer=generator_optimizer, discriminator_optimizer=discriminator_optimizer, use_loss_summaries=use_loss_summaries, get_hooks_fn=get_hooks_fn, name=name)
def _get_estimator_spec( mode, gan_model, generator_loss_fn, discriminator_loss_fn, get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, get_hooks_fn=None, use_loss_summaries=True, is_chief=True): """Get the EstimatorSpec for the current mode.""" if mode == model_fn_lib.ModeKeys.PREDICT: estimator_spec = model_fn_lib.EstimatorSpec( mode=mode, predictions=gan_model.generated_data) else: gan_loss = tfgan_tuples.GANLoss( generator_loss=generator_loss_fn( gan_model, add_summaries=use_loss_summaries), discriminator_loss=discriminator_loss_fn( gan_model, add_summaries=use_loss_summaries)) if mode == model_fn_lib.ModeKeys.EVAL: estimator_spec = _get_eval_estimator_spec( gan_model, gan_loss, get_eval_metric_ops_fn) else: # model_fn_lib.ModeKeys.TRAIN: gopt = (generator_optimizer() if callable(generator_optimizer) else generator_optimizer) dopt = (discriminator_optimizer() if callable(discriminator_optimizer) else discriminator_optimizer) get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks() estimator_spec = _get_train_estimator_spec( gan_model, gan_loss, gopt, dopt, get_hooks_fn, is_chief=is_chief) return estimator_spec
def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer, use_loss_summaries=True, get_hooks_fn=tfgan_train.get_sequential_train_hooks(), name=None): """Creates a `GANHead`. Args: generator_loss_fn: A TFGAN loss function for the generator. Takes a `GANModel` and returns a scalar. discriminator_loss_fn: Same as `generator_loss_fn`, but for the discriminator. generator_optimizer: The optimizer for generator updates. discriminator_optimizer: Same as `generator_optimizer`, but for the discriminator updates. use_loss_summaries: If `True`, add loss summaries. If `False`, does not. If `None`, uses defaults. get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list of hooks. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Returns: An instance of `GANHead`. """ return GANHead(generator_loss_fn=generator_loss_fn, discriminator_loss_fn=discriminator_loss_fn, generator_optimizer=generator_optimizer, discriminator_optimizer=discriminator_optimizer, use_loss_summaries=use_loss_summaries, get_hooks_fn=get_hooks_fn, name=name)
def test_train_hooks_exist_in_get_hooks_fn(self, create_gan_model_fn): model = create_gan_model_fn() loss = train.gan_loss(model) g_opt = get_sync_optimizer() d_opt = get_sync_optimizer() train_ops = train.gan_train_ops(model, loss, g_opt, d_opt, summarize_gradients=True, colocate_gradients_with_ops=True) sequential_train_hooks = train.get_sequential_train_hooks()(train_ops) self.assertLen(sequential_train_hooks, 4) sync_opts = [ hook._sync_optimizer for hook in sequential_train_hooks if isinstance( hook, sync_replicas_optimizer._SyncReplicasOptimizerHook) ] self.assertLen(sync_opts, 2) self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt))) joint_train_hooks = train.get_joint_train_hooks()(train_ops) self.assertLen(joint_train_hooks, 5) sync_opts = [ hook._sync_optimizer for hook in joint_train_hooks if isinstance( hook, sync_replicas_optimizer._SyncReplicasOptimizerHook) ] self.assertLen(sync_opts, 2) self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))
def _get_estimator_spec( mode, gan_model, generator_loss_fn, discriminator_loss_fn, get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, get_hooks_fn=None, use_loss_summaries=True, is_chief=True): """Get the EstimatorSpec for the current mode.""" if mode == model_fn_lib.ModeKeys.PREDICT: estimator_spec = model_fn_lib.EstimatorSpec( mode=mode, predictions=gan_model.generated_data) else: gan_loss = tfgan_tuples.GANLoss( generator_loss=generator_loss_fn( gan_model, add_summaries=use_loss_summaries), discriminator_loss=discriminator_loss_fn( gan_model, add_summaries=use_loss_summaries)) if mode == model_fn_lib.ModeKeys.EVAL: estimator_spec = _get_eval_estimator_spec( gan_model, gan_loss, get_eval_metric_ops_fn) else: # model_fn_lib.ModeKeys.TRAIN: if callable(generator_optimizer): generator_optimizer = generator_optimizer() if callable(discriminator_optimizer): discriminator_optimizer = discriminator_optimizer() get_hooks_fn = get_hooks_fn or tfgan_train.get_sequential_train_hooks() estimator_spec = _get_train_estimator_spec( gan_model, gan_loss, generator_optimizer, discriminator_optimizer, get_hooks_fn, is_chief=is_chief) return estimator_spec
def test_train_hooks_exist_in_get_hooks_fn(self, create_gan_model_fn): model = create_gan_model_fn() loss = train.gan_loss(model) g_opt = get_sync_optimizer() d_opt = get_sync_optimizer() train_ops = train.gan_train_ops( model, loss, g_opt, d_opt, summarize_gradients=True, colocate_gradients_with_ops=True) sequential_train_hooks = train.get_sequential_train_hooks()(train_ops) self.assertLen(sequential_train_hooks, 4) sync_opts = [ hook._sync_optimizer for hook in sequential_train_hooks if isinstance(hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)] self.assertLen(sync_opts, 2) self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt))) joint_train_hooks = train.get_joint_train_hooks()(train_ops) self.assertLen(joint_train_hooks, 5) sync_opts = [ hook._sync_optimizer for hook in joint_train_hooks if isinstance(hook, sync_replicas_optimizer._SyncReplicasOptimizerHook)] self.assertLen(sync_opts, 2) self.assertSetEqual(frozenset(sync_opts), frozenset((g_opt, d_opt)))
def __init__(self, generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer, use_loss_summaries=True, get_hooks_fn=None, get_eval_metric_ops_fn=None, name=None): """`Head` for GAN training. Args: generator_loss_fn: A TFGAN loss function for the generator. Takes a `GANModel` and returns a scalar. discriminator_loss_fn: Same as `generator_loss_fn`, but for the discriminator. generator_optimizer: The optimizer for generator updates. discriminator_optimizer: Same as `generator_optimizer`, but for the discriminator updates. use_loss_summaries: If `True`, add loss summaries. If `False`, does not. If `None`, uses defaults. get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a list of hooks. Defaults to `train.get_sequential_train_hooks()` get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a dict of metric results keyed by name. The output of this function is passed into `tf.estimator.EstimatorSpec` during evaluation. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. """ if not callable(generator_loss_fn): raise TypeError('generator_loss_fn must be callable.') if not callable(discriminator_loss_fn): raise TypeError('discriminator_loss_fn must be callable.') if use_loss_summaries not in [True, False, None]: raise ValueError('use_loss_summaries must be True, False or None.') if get_hooks_fn is not None and not callable(get_hooks_fn): raise TypeError('get_hooks_fn must be callable.') if name is not None and not isinstance(name, str): raise TypeError('name must be string.') if get_hooks_fn is None: get_hooks_fn = tfgan_train.get_sequential_train_hooks() if use_loss_summaries in [True, False]: generator_loss_fn = functools.partial( generator_loss_fn, add_summaries=use_loss_summaries) discriminator_loss_fn = functools.partial( discriminator_loss_fn, add_summaries=use_loss_summaries) self._generator_loss_fn = generator_loss_fn self._discriminator_loss_fn = discriminator_loss_fn self._generator_optimizer = generator_optimizer self._discriminator_optimizer = discriminator_optimizer self._get_hooks_fn = get_hooks_fn self._get_eval_metric_ops_fn = get_eval_metric_ops_fn self._name = name
def __init__(self, generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer, use_loss_summaries=True, get_hooks_fn=None, get_eval_metric_ops_fn=None, name=None): """`Head` for GAN training. Args: generator_loss_fn: A TFGAN loss function for the generator. Takes a `GANModel` and returns a scalar. discriminator_loss_fn: Same as `generator_loss_fn`, but for the discriminator. generator_optimizer: The optimizer for generator updates. discriminator_optimizer: Same as `generator_optimizer`, but for the discriminator updates. use_loss_summaries: If `True`, add loss summaries. If `False`, does not. If `None`, uses defaults. get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a list of hooks. Defaults to `train.get_sequential_train_hooks()` get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a dict of metric results keyed by name. The output of this function is passed into `tf.estimator.EstimatorSpec` during evaluation. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. """ if not callable(generator_loss_fn): raise TypeError('generator_loss_fn must be callable.') if not callable(discriminator_loss_fn): raise TypeError('discriminator_loss_fn must be callable.') if use_loss_summaries not in [True, False, None]: raise ValueError('use_loss_summaries must be True, False or None.') if get_hooks_fn is not None and not callable(get_hooks_fn): raise TypeError('get_hooks_fn must be callable.') if name is not None and not isinstance(name, str): raise TypeError('name must be string.') if get_hooks_fn is None: get_hooks_fn = tfgan_train.get_sequential_train_hooks() if use_loss_summaries in [True, False]: generator_loss_fn = functools.partial( generator_loss_fn, add_summaries=use_loss_summaries) discriminator_loss_fn = functools.partial( discriminator_loss_fn, add_summaries=use_loss_summaries) self._generator_loss_fn = generator_loss_fn self._discriminator_loss_fn = discriminator_loss_fn self._generator_optimizer = generator_optimizer self._discriminator_optimizer = discriminator_optimizer self._get_hooks_fn = get_hooks_fn self._get_eval_metric_ops_fn = get_eval_metric_ops_fn self._name = name
def __init__(self, generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer, use_loss_summaries=True, get_hooks_fn=None, name=None): """`Head` for GAN training. Args: generator_loss_fn: A TFGAN loss function for the generator. Takes a `GANModel` and returns a scalar. discriminator_loss_fn: Same as `generator_loss_fn`, but for the discriminator. generator_optimizer: The optimizer for generator updates. discriminator_optimizer: Same as `generator_optimizer`, but for the discriminator updates. use_loss_summaries: If `True`, add loss summaries. If `False`, does not. If `None`, uses defaults. get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list of hooks. Defaults to `train.get_sequential_train_hooks()` name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. """ if get_hooks_fn is None: get_hooks_fn = tfgan_train.get_sequential_train_hooks() # TODO(joelshor): Validate inputs. if use_loss_summaries in [True, False]: generator_loss_fn = functools.partial( generator_loss_fn, add_summaries=use_loss_summaries) discriminator_loss_fn = functools.partial( discriminator_loss_fn, add_summaries=use_loss_summaries) self._generator_loss_fn = generator_loss_fn self._discriminator_loss_fn = discriminator_loss_fn self._generator_optimizer = generator_optimizer self._discriminator_optimizer = discriminator_optimizer self._get_hooks_fn = get_hooks_fn self._name = name
def __init__(self, generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer, use_loss_summaries=True, get_hooks_fn=None, name=None): """`Head` for GAN training. Args: generator_loss_fn: A TFGAN loss function for the generator. Takes a `GANModel` and returns a scalar. discriminator_loss_fn: Same as `generator_loss_fn`, but for the discriminator. generator_optimizer: The optimizer for generator updates. discriminator_optimizer: Same as `generator_optimizer`, but for the discriminator updates. use_loss_summaries: If `True`, add loss summaries. If `False`, does not. If `None`, uses defaults. get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list of hooks. Defaults to `train.get_sequential_train_hooks()` name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. """ if get_hooks_fn is None: get_hooks_fn = tfgan_train.get_sequential_train_hooks() # TODO(joelshor): Validate inputs. if use_loss_summaries in [True, False]: generator_loss_fn = functools.partial( generator_loss_fn, add_summaries=use_loss_summaries) discriminator_loss_fn = functools.partial( discriminator_loss_fn, add_summaries=use_loss_summaries) self._generator_loss_fn = generator_loss_fn self._discriminator_loss_fn = discriminator_loss_fn self._generator_optimizer = generator_optimizer self._discriminator_optimizer = discriminator_optimizer self._get_hooks_fn = get_hooks_fn
def model_fn(features, labels, mode, params): is_chief = not tf.get_variable_scope().reuse batch_size = tf.shape(labels)[0] noise = tf.random_normal([batch_size, FLAGS.emb_dim]) noise = tf.nn.l2_normalize(noise, axis=1) gan_model = tfgan.gan_model(generator_fn=generator, discriminator_fn=discriminator, real_data=features[:, 1:], generator_inputs=(noise, labels - 1), check_shapes=False) if is_chief: for variable in tf.trainable_variables(): tf.summary.histogram(variable.op.name, variable) tf.summary.histogram('logits/gen_logits', gan_model.discriminator_gen_outputs[0]) tf.summary.histogram('logits/real_logits', gan_model.discriminator_real_outputs[0]) def gen_loss_fn(gan_model, add_summaries): return 0 def dis_loss_fn(gan_model, add_summaries): discriminator_real_outputs = gan_model.discriminator_real_outputs discriminator_gen_outputs = gan_model.discriminator_gen_outputs real_logits = tf.boolean_mask(discriminator_real_outputs[0], discriminator_real_outputs[1]) gen_logits = tf.boolean_mask(discriminator_gen_outputs[0], discriminator_gen_outputs[1]) return modified_discriminator_loss(real_logits, gen_logits, add_summaries=add_summaries) with tf.name_scope('losses'): gan_loss = tfgan.gan_loss( gan_model, generator_loss_fn=gen_loss_fn, discriminator_loss_fn=dis_loss_fn, gradient_penalty_weight=10 if FLAGS.wass else 0, add_summaries=is_chief) if is_chief: tfgan.eval.add_regularization_loss_summaries(gan_model) gan_loss = rl_loss(gan_model, gan_loss, add_summaries=is_chief) loss = gan_loss.generator_loss + gan_loss.discriminator_loss with tf.name_scope('train'): gen_opt = tf.train.AdamOptimizer(params.gen_lr, 0.5) dis_opt = tf.train.AdamOptimizer(params.dis_lr, 0.5) if params.multi_gpu: gen_opt = tf.contrib.estimator.TowerOptimizer(gen_opt) dis_opt = tf.contrib.estimator.TowerOptimizer(dis_opt) train_ops = tfgan.gan_train_ops( gan_model, gan_loss, generator_optimizer=gen_opt, discriminator_optimizer=dis_opt, transform_grads_fn=transform_grads_fn, summarize_gradients=is_chief, check_for_unused_update_ops=True, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) train_op = train_ops.global_step_inc_op train_hooks = get_sequential_train_hooks()(train_ops) if is_chief: with open('data/word_counts.txt', 'r') as f: dic = list(f) dic = [i.split()[0] for i in dic] dic.append('<unk>') dic = tf.convert_to_tensor(dic) sentence = crop_sentence(gan_model.generated_data[0][0], FLAGS.end_id) sentence = tf.gather(dic, sentence) real = crop_sentence(gan_model.real_data[0], FLAGS.end_id) real = tf.gather(dic, real) train_hooks.append( tf.train.LoggingTensorHook({ 'fake': sentence, 'real': real }, every_n_iter=100)) tf.summary.text('fake', sentence) gen_var = tf.trainable_variables('Generator') dis_var = [] dis_var.extend(tf.trainable_variables('Discriminator/rnn')) dis_var.extend(tf.trainable_variables('Discriminator/embedding')) saver = tf.train.Saver(gen_var + dis_var) def init_fn(scaffold, session): saver.restore(session, FLAGS.sae_ckpt) pass scaffold = tf.train.Scaffold(init_fn=init_fn) return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold=scaffold, training_hooks=train_hooks)
def model_fn(features, labels, mode, params): """The full unsupervised captioning model.""" is_chief = not tf.get_variable_scope().reuse with slim.arg_scope(inception_v4.inception_v4_arg_scope()): net, _ = inception_v4.inception_v4(features['im'], None, is_training=False) net = tf.squeeze(net, [1, 2]) inc_saver = tf.train.Saver(tf.global_variables('InceptionV4')) gan_model = tfgan.gan_model(generator_fn=generator, discriminator_fn=discriminator, real_data=labels['sentence'][:, 1:], generator_inputs=(net, labels['len'] - 1), check_shapes=False) if is_chief: for variable in tf.trainable_variables(): tf.summary.histogram(variable.op.name, variable) tf.summary.histogram('logits/gen_logits', gan_model.discriminator_gen_outputs[0]) tf.summary.histogram('logits/real_logits', gan_model.discriminator_real_outputs[0]) def gen_loss_fn(gan_model, add_summaries): return 0 def dis_loss_fn(gan_model, add_summaries): discriminator_real_outputs = gan_model.discriminator_real_outputs discriminator_gen_outputs = gan_model.discriminator_gen_outputs real_logits = tf.boolean_mask(discriminator_real_outputs[0], discriminator_real_outputs[1]) gen_logits = tf.boolean_mask(discriminator_gen_outputs[0], discriminator_gen_outputs[1]) return modified_discriminator_loss(real_logits, gen_logits, add_summaries=add_summaries) with tf.name_scope('losses'): pool_fn = functools.partial(tfgan.features.tensor_pool, pool_size=FLAGS.pool_size) gan_loss = tfgan.gan_loss( gan_model, generator_loss_fn=gen_loss_fn, discriminator_loss_fn=dis_loss_fn, gradient_penalty_weight=10 if FLAGS.wass else 0, tensor_pool_fn=pool_fn if FLAGS.use_pool else None, add_summaries=is_chief) if is_chief: tfgan.eval.add_regularization_loss_summaries(gan_model) gan_loss = rl_loss(gan_model, gan_loss, features['classes'], features['scores'], features['num'], add_summaries=is_chief) sen_ae_loss = sentence_ae(gan_model, features, labels, is_chief) loss = gan_loss.generator_loss + gan_loss.discriminator_loss + sen_ae_loss gan_loss = gan_loss._replace(generator_loss=gan_loss.generator_loss + sen_ae_loss) with tf.name_scope('train'): gen_opt = tf.train.AdamOptimizer(params.gen_lr, 0.5) dis_opt = tf.train.AdamOptimizer(params.dis_lr, 0.5) if params.multi_gpu: gen_opt = tf.contrib.estimator.TowerOptimizer(gen_opt) dis_opt = tf.contrib.estimator.TowerOptimizer(dis_opt) train_ops = tfgan.gan_train_ops( gan_model, gan_loss, generator_optimizer=gen_opt, discriminator_optimizer=dis_opt, transform_grads_fn=transform_grads_fn, summarize_gradients=is_chief, check_for_unused_update_ops=not FLAGS.use_pool, aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N) train_op = train_ops.global_step_inc_op train_hooks = get_sequential_train_hooks()(train_ops) # Summary the generated caption on the fly. if is_chief: with open('data/word_counts.txt', 'r') as f: dic = list(f) dic = [i.split()[0] for i in dic] dic.append('<unk>') dic = tf.convert_to_tensor(dic) sentence = crop_sentence(gan_model.generated_data[0][0], FLAGS.end_id) sentence = tf.gather(dic, sentence) real = crop_sentence(gan_model.real_data[0], FLAGS.end_id) real = tf.gather(dic, real) train_hooks.append( tf.train.LoggingTensorHook({ 'fake': sentence, 'real': real }, every_n_iter=100)) tf.summary.text('fake', sentence) tf.summary.image('im', features['im'][None, 0]) gen_saver = tf.train.Saver(tf.trainable_variables('Generator')) dis_var = [] dis_var.extend(tf.trainable_variables('Discriminator/rnn')) dis_var.extend(tf.trainable_variables('Discriminator/embedding')) dis_var.extend(tf.trainable_variables('Discriminator/fc')) dis_saver = tf.train.Saver(dis_var) def init_fn(scaffold, session): inc_saver.restore(session, FLAGS.inc_ckpt) if FLAGS.imcap_ckpt: gen_saver.restore(session, FLAGS.imcap_ckpt) if FLAGS.sae_ckpt: dis_saver.restore(session, FLAGS.sae_ckpt) scaffold = tf.train.Scaffold(init_fn=init_fn) return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op, scaffold=scaffold, training_hooks=train_hooks)