Esempio n. 1
0
  def test_two_listeners_with_default_saver(self):
    with ops.Graph().as_default():
      global_step = variables.get_or_create_global_step()
      train_op = state_ops.assign_add(global_step, 1)
      listener1 = MockCheckpointSaverListener()
      listener2 = MockCheckpointSaverListener()
      hook = basic_session_run_hooks.CheckpointSaverHook(
          self.model_dir,
          save_steps=1,
          listeners=[listener1, listener2])
      with monitored_session.SingularMonitoredSession(
          hooks=[hook],
          checkpoint_dir=self.model_dir) as sess:
        sess.run(train_op)
        sess.run(train_op)
        global_step_val = sess.run(global_step)
      listener1_counts = listener1.get_counts()
      listener2_counts = listener2.get_counts()
    self.assertEqual(2, global_step_val)
    self.assertEqual({
        'begin': 1,
        'before_save': 2,
        'after_save': 2,
        'end': 1
    }, listener1_counts)
    self.assertEqual(listener1_counts, listener2_counts)

    with ops.Graph().as_default():
      global_step = variables.get_or_create_global_step()
      with monitored_session.SingularMonitoredSession(
          checkpoint_dir=self.model_dir) as sess2:
        global_step_saved_val = sess2.run(global_step)
    self.assertEqual(2, global_step_saved_val)
Esempio n. 2
0
 def test_listener_with_monitored_session(self):
   with ops.Graph().as_default():
     scaffold = monitored_session.Scaffold()
     global_step = variables.get_or_create_global_step()
     train_op = state_ops.assign_add(global_step, 1)
     listener = MockCheckpointSaverListener()
     hook = basic_session_run_hooks.CheckpointSaverHook(
         self.model_dir,
         save_steps=1,
         scaffold=scaffold,
         listeners=[listener])
     with monitored_session.SingularMonitoredSession(
         hooks=[hook],
         scaffold=scaffold,
         checkpoint_dir=self.model_dir) as sess:
       sess.run(train_op)
       sess.run(train_op)
       global_step_val = sess.run(global_step)
     listener_counts = listener.get_counts()
   self.assertEqual(2, global_step_val)
   self.assertEqual({
       'begin': 1,
       'before_save': 2,
       'after_save': 2,
       'end': 1
   }, listener_counts)
 def test_save_secs_calls_listeners_periodically(self):
     with self.graph.as_default():
         listener = MockCheckpointSaverListener()
         hook = basic_session_run_hooks.CheckpointSaverHook(
             self.model_dir,
             save_secs=2,
             scaffold=self.scaffold,
             listeners=[listener])
         hook.begin()
         self.scaffold.finalize()
         with session_lib.Session() as sess:
             sess.run(self.scaffold.init_op)
             mon_sess = monitored_session._HookedSession(sess, [hook])
             mon_sess.run(self.train_op)  # hook runs here
             mon_sess.run(self.train_op)
             time.sleep(2.5)
             mon_sess.run(self.train_op)  # hook runs here
             mon_sess.run(self.train_op)
             mon_sess.run(self.train_op)
             time.sleep(2.5)
             mon_sess.run(self.train_op)  # hook runs here
             mon_sess.run(
                 self.train_op)  # hook won't run here, so it does at end
             hook.end(sess)  # hook runs here
         self.assertEqual(
             {
                 'begin': 1,
                 'before_save': 4,
                 'after_save': 4,
                 'end': 1
             }, listener.get_counts())
 def test_save_steps_saves_periodically(self):
     with self.graph.as_default():
         hook = basic_session_run_hooks.CheckpointSaverHook(
             self.model_dir, save_steps=2, scaffold=self.scaffold)
         hook.begin()
         self.scaffold.finalize()
         with session_lib.Session() as sess:
             sess.run(self.scaffold.init_op)
             mon_sess = monitored_session._HookedSession(sess, [hook])
             mon_sess.run(self.train_op)
             mon_sess.run(self.train_op)
             # Not saved
             self.assertEqual(
                 1,
                 checkpoint_utils.load_variable(self.model_dir,
                                                self.global_step.name))
             mon_sess.run(self.train_op)
             # saved
             self.assertEqual(
                 3,
                 checkpoint_utils.load_variable(self.model_dir,
                                                self.global_step.name))
             mon_sess.run(self.train_op)
             # Not saved
             self.assertEqual(
                 3,
                 checkpoint_utils.load_variable(self.model_dir,
                                                self.global_step.name))
             mon_sess.run(self.train_op)
             # saved
             self.assertEqual(
                 5,
                 checkpoint_utils.load_variable(self.model_dir,
                                                self.global_step.name))
Esempio n. 5
0
  def testResumeTrainAchievesRoughlyTheSameLoss(self):
    number_of_steps = [300, 1, 5]
    logdir = os.path.join(self.get_temp_dir(), 'resume_train_same_loss')

    for i in range(len(number_of_steps)):
      with ops.Graph().as_default():
        random_seed.set_random_seed(i)
        tf_inputs = constant_op.constant(self._inputs, dtype=dtypes.float32)
        tf_labels = constant_op.constant(self._labels, dtype=dtypes.float32)

        tf_predictions = logistic_classifier(tf_inputs)
        loss_ops.log_loss(tf_predictions, tf_labels)
        total_loss = loss_ops.get_total_loss()

        optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)

        train_op = training.create_train_op(total_loss, optimizer)

        saver = saver_lib.Saver()

        loss = training.train(
            train_op,
            logdir,
            hooks=[
                basic_session_run_hooks.StopAtStepHook(
                    num_steps=number_of_steps[i]),
                basic_session_run_hooks.CheckpointSaverHook(
                    logdir, save_steps=50, saver=saver),
            ])
        self.assertIsNotNone(loss)
        self.assertLess(loss, .015)
def train_mnist():
    with tf.get_default_graph().as_default():
        input_pipe = MnistData(BATCH_SIZE)
        train_features, train_labels = input_pipe.build_train_data_tensor()

        global_step = 0

        train_op, loss = train_graph(train_features, train_labels)

        assign_ops = get_assign_ops()

        checkpoint_hook = basic_session_run_hooks.CheckpointSaverHook(
            MODEL_DIR, save_steps=TEST_FREQ)

        class _LoggerHook(tf.train.SessionRunHook):
            """logs loss and runtime."""
            def begin(self):
                self._step = global_step

            def before_run(self, run_context):
                self._start_time = time.time()
                return tf.train.SessionRunArgs(loss)  # asks for loss value.

            def after_run(self, run_context, run_values):
                duration = time.time() - self._start_time
                loss_value = run_values.results
                if self._step > 0 and self._step % 100 == 0:
                    num_examples_per_step = BATCH_SIZE
                    examples_per_sec = num_examples_per_step / duration
                    sec_per_batch = float(duration)
                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec' +
                        '; %.3f sec/batch)')
                    tf.logging.info(
                        format_str %
                        (datetime.datetime.now(), self._step, loss_value,
                         examples_per_sec, sec_per_batch))

            with tf.train.MonitoredTrainingSession(
                    checkpoint_dir=MODEL_DIR,
                    hooks=[_LoggerHook(), checkpoint_hook],
                    save_checkpoint_secs=None,
                    config=tf.ConfigProto(log_device_placement=False)) as sess:

                if INIT_DICT is not None:
                    sess.run(assign_ops)
                    dict_str = ';  '.join(
                        map(
                            lambda scope: scope + '/' + ', '.join(INIT_DICT[
                                scope].keys()) + '', INIT_DICT.keys()))
                    tf.logging.info('Instantiated tensors with assign ops: ' +
                                    dict_str)

                for i in range(TRAIN_STEPS):
                    sess.run(train_op)
Esempio n. 7
0
 def test_saves_when_saver_and_scaffold_both_missing(self):
   with self.graph.as_default():
     hook = basic_session_run_hooks.CheckpointSaverHook(
         self.model_dir, save_steps=1)
     hook.begin()
     self.scaffold.finalize()
     with session_lib.Session() as sess:
       sess.run(self.scaffold.init_op)
       mon_sess = monitored_session._HookedSession(sess, [hook])
       mon_sess.run(self.train_op)
       self.assertEqual(1,
                        checkpoint_utils.load_variable(self.model_dir,
                                                       self.global_step.name))
def MonitoredTrainingSession(
        master='',  # pylint: disable=invalid-name
        is_chief=True,
        checkpoint_dir=None,
        hooks=None,
        scaffold=None,
        config=None):
    """Creates a `MonitoredSession` for training.

  For a chief, this utility sets proper session initializer/restorer. It also
  creates hooks related to checkpoint and summary saving. For workers, this
  utility sets proper session creator which waits for the chief to
  inialize/restore.


  Args:
    master: `String` the TensorFlow master to use.
    is_chief: If `True`, it will take care of initialization and recovery the
      underlying TensorFlow session. If `False`, it will wait on a chief to
      initialize or recover the TensorFlow session.
    checkpoint_dir: A string.  Optional path to a directory where to restore
      variables.
    hooks: Optional list of `SessionRunHook` objects.
    scaffold: A `Scaffold` used for gathering or building supportive ops. If
      not specified, a default one is created. It's used to finalize the graph.
    config: `ConfigProto` proto used to configure the session.

  Returns:
    A `MonitoredSession` object.
  """
    hooks = hooks or []
    scaffold = scaffold or Scaffold()
    if not is_chief:
        session_creator = WorkerSessionCreator(scaffold=scaffold,
                                               master=master,
                                               config=config)
    else:
        session_creator = ChiefSessionCreator(scaffold=scaffold,
                                              checkpoint_dir=checkpoint_dir,
                                              master=master,
                                              config=config)
        hooks.extend([
            basic_session_run_hooks.StepCounterHook(output_dir=checkpoint_dir),
            basic_session_run_hooks.SummarySaverHook(
                scaffold=scaffold, output_dir=checkpoint_dir),
            basic_session_run_hooks.CheckpointSaverHook(checkpoint_dir,
                                                        save_secs=600,
                                                        scaffold=scaffold),
        ])

    return MonitoredSession(session_creator=session_creator, hooks=hooks)
Esempio n. 9
0
    def test_save_secs_saves_periodically(self, mock_time):
        # Let's have a realistic start time
        current_time = 1484695987.209386

        with self.graph.as_default():
            mock_time.return_value = current_time
            hook = basic_session_run_hooks.CheckpointSaverHook(
                self.model_dir, save_secs=2, scaffold=self.scaffold)
            hook.begin()
            self.scaffold.finalize()

            with session_lib.Session() as sess:
                sess.run(self.scaffold.init_op)
                mon_sess = monitored_session._HookedSession(sess, [hook])

                mock_time.return_value = current_time
                mon_sess.run(self.train_op)  # Saved.

                mock_time.return_value = current_time + 0.5
                mon_sess.run(self.train_op)  # Not saved.

                self.assertEqual(
                    1,
                    checkpoint_utils.load_variable(self.model_dir,
                                                   self.global_step.name))

                # Simulate 2.5 seconds of sleep.
                mock_time.return_value = current_time + 2.5
                mon_sess.run(self.train_op)  # Saved.

                mock_time.return_value = current_time + 2.6
                mon_sess.run(self.train_op)  # Not saved.

                mock_time.return_value = current_time + 2.7
                mon_sess.run(self.train_op)  # Not saved.

                self.assertEqual(
                    3,
                    checkpoint_utils.load_variable(self.model_dir,
                                                   self.global_step.name))

                # Simulate 7.5 more seconds of sleep (10 seconds from start.
                mock_time.return_value = current_time + 10
                mon_sess.run(self.train_op)  # Saved.
                self.assertEqual(
                    6,
                    checkpoint_utils.load_variable(self.model_dir,
                                                   self.global_step.name))
Esempio n. 10
0
  def test_save_secs_calls_listeners_periodically(self, mock_time):
    # Let's have a realistic start time
    current_time = 1484695987.209386

    with self.graph.as_default():
      mock_time.return_value = current_time
      listener = MockCheckpointSaverListener()
      hook = basic_session_run_hooks.CheckpointSaverHook(
          self.model_dir,
          save_secs=2,
          scaffold=self.scaffold,
          listeners=[listener])
      hook.begin()
      self.scaffold.finalize()
      with session_lib.Session() as sess:
        sess.run(self.scaffold.init_op)
        mon_sess = monitored_session._HookedSession(sess, [hook])

        mock_time.return_value = current_time + 0.5
        mon_sess.run(self.train_op)  # hook runs here

        mock_time.return_value = current_time + 0.5
        mon_sess.run(self.train_op)

        mock_time.return_value = current_time + 3.0
        mon_sess.run(self.train_op)  # hook runs here

        mock_time.return_value = current_time + 3.5
        mon_sess.run(self.train_op)

        mock_time.return_value = current_time + 4.0
        mon_sess.run(self.train_op)

        mock_time.return_value = current_time + 6.5
        mon_sess.run(self.train_op)  # hook runs here

        mock_time.return_value = current_time + 7.0
        mon_sess.run(self.train_op)  # hook won't run here, so it does at end

        mock_time.return_value = current_time + 7.5
        hook.end(sess)  # hook runs here
      self.assertEqual({
          'begin': 1,
          'before_save': 4,
          'after_save': 4,
          'end': 1
      }, listener.get_counts())
Esempio n. 11
0
    def __init__(self, estimator):
        """Initializes a `CheckpointInputPipelineHook`.

    Args:
      estimator: Estimator.

    Raises:
      ValueError: One of `save_steps` or `save_secs` should be set.
      ValueError: At most one of saver or scaffold should be set.
    """
        # `checkpoint_basename` is "input.ckpt" for non-distributed pipelines or
        # of the form "input_<task_type>_<task_id>.ckpt" for distributed pipelines.
        # Note: The default `checkpoint_basename` used by `CheckpointSaverHook` is
        # "model.ckpt". We intentionally choose the input pipeline checkpoint prefix
        # to be different to avoid conflicts with the model checkpoint.

        # pylint: disable=protected-access
        checkpoint_prefix = "input"
        if estimator._config.num_worker_replicas > 1:
            # Distributed setting.
            suffix = "_{}_{}".format(estimator._config.task_type,
                                     estimator._config.task_id)
            checkpoint_prefix += suffix
        # pylint: enable=protected-access

        # We use a composition paradigm instead of inheriting from
        # `CheckpointSaverHook` because `Estimator` does an `isinstance` check
        # to check whether a `CheckpointSaverHook` is already present in the list
        # of hooks and if not, adds one. Inheriting from `CheckpointSaverHook`
        # would thwart this behavior. This hook checkpoints *only the iterators*
        # and not the graph variables.
        self._checkpoint_saver_hook = basic_session_run_hooks.CheckpointSaverHook(
            estimator.model_dir,
            save_secs=estimator._config.save_checkpoints_secs,  # pylint: disable=protected-access
            save_steps=estimator._config.save_checkpoints_steps,  # pylint: disable=protected-access
            checkpoint_basename=checkpoint_prefix + ".ckpt")

        # Name for the protocol buffer file that will contain the list of most
        # recent checkpoints stored as a `CheckpointState` protocol buffer.
        # This file, kept in the same directory as the checkpoint files, is
        # automatically managed by the `Saver` to keep track of recent checkpoints.
        # The default name used by the `Saver` for this file is "checkpoint". Here
        # we use the name "checkpoint_<checkpoint_prefix>" so that in case the
        # `checkpoint_dir` is the same as the model checkpoint directory, there are
        # no conflicts during restore.
        self._latest_filename = "checkpoint_" + checkpoint_prefix
        self._first_run = True
    def test_summary_writer_defs(self):
        fake_summary_writer.FakeSummaryWriter.install()
        writer_cache.FileWriterCache.clear()
        summary_writer = writer_cache.FileWriterCache.get(self.model_dir)

        with self.graph.as_default():
            hook = basic_session_run_hooks.CheckpointSaverHook(
                self.model_dir, save_steps=2, scaffold=self.scaffold)
            hook.begin()
            self.scaffold.finalize()
            with session_lib.Session() as sess:
                sess.run(self.scaffold.init_op)
                mon_sess = monitored_session._HookedSession(sess, [hook])
                mon_sess.run(self.train_op)
            summary_writer.assert_summaries(
                test_case=self,
                expected_logdir=self.model_dir,
                expected_added_meta_graphs=[
                    meta_graph.create_meta_graph_def(
                        graph_def=self.graph.as_graph_def(add_shapes=True),
                        saver_def=self.scaffold.saver.saver_def)
                ])

        fake_summary_writer.FakeSummaryWriter.uninstall()
Esempio n. 13
0
    def train_and_evaluate(self):
        """Interleaves training and evaluation.

    The frequency of evaluation is controlled by the constructor arg
    `min_eval_frequency`. When this parameter is 0, evaluation happens
    only after training has completed. Note that evaluation cannot happen
    more frequently than checkpoints are taken. If no new snapshots are
    available when evaluation is supposed to occur, then evaluation doesn't
    happen for another `min_eval_frequency` steps (assuming a checkpoint is
    available at that point). Thus, settings `min_eval_frequency` to 1 means
    that the model will be evaluated everytime there is a new checkpoint.

    This is particular useful for a "Master" task in the cloud, whose
    responsibility it is to take checkpoints, evaluate those checkpoints,
    and write out summaries. Participating in training as the supervisor
    allows such a task to accomplish the first and last items, while
    performing evaluation allows for the second.

    Returns:
      The result of the `evaluate` call to the `Estimator` as well as the
      export results using the specified `ExportStrategy`.
    """
        # The directory to which evaluation summaries are written are determined
        # by adding a suffix to 'eval'; that suffix is the 'name' parameter to
        # the various evaluate(...) methods. By setting it to None, we force
        # the directory name to simply be 'eval'.
        eval_dir_suffix = None

        # We set every_n_steps to 1, but evaluation only occurs when a new
        # snapshot is available. If, by the time we finish evaluation
        # there is a new snapshot, then we just evaluate again. Otherwise,
        # we keep training until one becomes available.
        with _new_attr_context(self, "_train_monitors"):
            self._train_monitors = self._train_monitors or []
            config = self._estimator.config
            intermediate_export = self._checkpoint_and_export and (
                config.save_checkpoints_secs or config.save_checkpoints_steps)
            if intermediate_export:
                # Create a partially specified evaluate function with the desired
                # arguments. This will be executed by the _EvalAndExportListener,
                # which will specify the latest checkpoint path.
                eval_fn = functools.partial(self._call_evaluate,
                                            input_fn=self._eval_input_fn,
                                            steps=self._eval_steps,
                                            metrics=self._eval_metrics,
                                            hooks=self._eval_hooks)

                export_listener = _EvalAndExportListener(
                    eval_fn=eval_fn,
                    export_fn=self._maybe_export,
                    model_dir=self._estimator.model_dir)

                saver_hook = basic_session_run_hooks.CheckpointSaverHook(
                    checkpoint_dir=self._estimator.model_dir,
                    save_secs=config.save_checkpoints_secs,
                    save_steps=config.save_checkpoints_steps,
                    listeners=[export_listener])
                self._train_monitors += [saver_hook]
            else:
                if self._min_eval_frequency:
                    self._train_monitors += [
                        monitors.ValidationMonitor(
                            input_fn=self._eval_input_fn,
                            eval_steps=self._eval_steps,
                            metrics=self._eval_metrics,
                            every_n_steps=self._min_eval_frequency,
                            name=eval_dir_suffix,
                            hooks=self._eval_hooks)
                    ]
            self.train(delay_secs=0)

        # If the checkpoint_and_export flag and appropriate estimator configuration
        # parameters are set, then model evaluations and exports are done during the
        # training process. In particular, this will always occur at the end of
        # training, so we return the most recent results to avoid performing a
        # duplicate evaluation and model export.
        if intermediate_export:
            return export_listener.eval_result, export_listener.export_results
        else:
            eval_result = self._call_evaluate(input_fn=self._eval_input_fn,
                                              steps=self._eval_steps,
                                              metrics=self._eval_metrics,
                                              name=eval_dir_suffix,
                                              hooks=self._eval_hooks)
            export_results = self._maybe_export(eval_result)
            return eval_result, export_results
Esempio n. 14
0
def MonitoredTrainingSession(master='',
                             is_chief=True,
                             checkpoint_dir=None,
                             scaffold=None,
                             hooks=None,
                             chief_only_hooks=None,
                             save_checkpoint_secs=USE_DEFAULT,
                             save_summaries_steps=USE_DEFAULT,
                             save_summaries_secs=USE_DEFAULT,
                             config=None,
                             stop_grace_period_secs=120,
                             log_step_count_steps=100,
                             save_checkpoint_steps=USE_DEFAULT,
                             summary_dir=None):
    if save_summaries_steps == USE_DEFAULT and save_summaries_secs == USE_DEFAULT:
        save_summaries_steps = 100
        save_summaries_secs = None
    elif save_summaries_secs == USE_DEFAULT:
        save_summaries_secs = None
    elif save_summaries_steps == USE_DEFAULT:
        save_summaries_steps = None

    if (save_checkpoint_steps == USE_DEFAULT
            and save_checkpoint_secs == USE_DEFAULT):
        save_checkpoint_steps = None
        save_checkpoint_secs = 600
    elif save_checkpoint_secs == USE_DEFAULT:
        save_checkpoint_secs = None
    elif save_checkpoint_steps == USE_DEFAULT:
        save_checkpoint_steps = None

    scaffold = scaffold or Scaffold()

    all_hooks = []
    if is_chief and chief_only_hooks:
        all_hooks.extend(chief_only_hooks)

    session_creator = ChiefSessionCreator(scaffold=scaffold,
                                          checkpoint_dir=checkpoint_dir,
                                          master=master,
                                          config=config)

    summary_dir = summary_dir or checkpoint_dir
    if summary_dir:
        if (save_summaries_steps
                and save_summaries_steps > 0) or (save_summaries_secs
                                                  and save_summaries_secs > 0):
            all_hooks.append(
                basic_session_run_hooks.SummarySaverHook(
                    scaffold=scaffold,
                    save_steps=save_summaries_steps,
                    save_secs=save_summaries_secs,
                    output_dir=summary_dir))

    if checkpoint_dir:
        if (save_checkpoint_secs and save_checkpoint_secs > 0) or (
                save_checkpoint_steps and save_checkpoint_steps > 0):
            all_hooks.append(
                basic_session_run_hooks.CheckpointSaverHook(
                    checkpoint_dir,
                    save_steps=save_checkpoint_steps,
                    save_secs=save_checkpoint_secs,
                    scaffold=scaffold))

    if hooks:
        all_hooks.extend(hooks)

    hvd_info_rank0('all hooks {}'.format(all_hooks))
    return MonitoredSession(session_creator=session_creator,
                            hooks=all_hooks,
                            stop_grace_period_secs=stop_grace_period_secs)
Esempio n. 15
0
def train(args):
    """Train CIFAR-10 for a number of steps.

  Args:
    args: The command line arguments.
  """
    with tf.Graph().as_default(), tf.device('/cpu:0'):
        # Create the global step.
        global_step = tf.contrib.framework.create_global_step()

        # Calculate the learning rate schedule.
        num_batches_per_epoch = (cifar10.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN /
                                 args.batch_size)
        decay_steps = int(num_batches_per_epoch * cifar10.NUM_EPOCHS_PER_DECAY)

        # Decay the learning rate exponentially based on the number of steps.
        lr = tf.train.exponential_decay(cifar10.INITIAL_LEARNING_RATE,
                                        global_step,
                                        decay_steps,
                                        cifar10.LEARNING_RATE_DECAY_FACTOR,
                                        staircase=True)

        # Create an optimizer that performs gradient descent.
        opt = tf.train.GradientDescentOptimizer(lr)

        # Calculate the gradients for each model tower.
        tower_grads = []
        for i in xrange(args.num_gpus):
            with tf.device('/gpu:%d' % i):
                with tf.name_scope('%s_%d' % (cifar10.TOWER_NAME, i)) as scope:
                    # Calculate the loss for one tower of the CIFAR model. This function
                    # constructs the entire CIFAR model but shares the variables across
                    # all towers.
                    loss = tower_loss(scope, args)

                    # Reuse variables for the next tower.
                    tf.get_variable_scope().reuse_variables()

                    # Retain the summaries from the final tower.
                    summaries = tf.get_collection(tf.GraphKeys.SUMMARIES,
                                                  scope)

                    # Calculate the gradients for the batch of data on this CIFAR tower.
                    grads = opt.compute_gradients(loss)

                    # Keep track of the gradients across all towers.
                    tower_grads.append(grads)

        # We must calculate the mean of each gradient. Note that this is the
        # synchronization point across all towers.
        grads = average_gradients(tower_grads)

        # Add a summary to track the learning rate.
        summaries.append(tf.summary.scalar('learning_rate', lr))

        # Add histograms for gradients.
        for grad, var in grads:
            if grad is not None:
                summaries.append(
                    tf.summary.histogram(var.op.name + '/gradients', grad))

        # Apply the gradients to adjust the shared variables.
        apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)

        # Add histograms for trainable variables.
        for var in tf.trainable_variables():
            summaries.append(tf.summary.histogram(var.op.name, var))

        # Track the moving averages of all trainable variables.
        # To understand why the following line is necessary, see:
        # https://github.com/carpedm20/DCGAN-tensorflow/issues/59
        with tf.variable_scope(tf.get_variable_scope(), reuse=False):
            variable_averages = tf.train.ExponentialMovingAverage(
                cifar10.MOVING_AVERAGE_DECAY, global_step)
            variables_averages_op = variable_averages.apply(
                tf.trainable_variables())

        # Group all updates to into a single train op.
        train_op = tf.group(apply_gradient_op, variables_averages_op)

        # Build the summary operation from the last tower summaries.
        summary_op = tf.summary.merge(summaries)

        scaffold = monitored_session.Scaffold(summary_op=summary_op)

        # allow_soft_placement must be set to True to build towers on GPU, as some
        # of the ops do not have GPU implementations.
        session_creator = monitored_session.ChiefSessionCreator(
            scaffold,
            checkpoint_dir=args.train_dir,
            config=tf.ConfigProto(
                allow_soft_placement=True,
                log_device_placement=args.log_device_placement))

        hooks = [
            # Hook to save the model every N steps and at the end.
            basic_session_run_hooks.CheckpointSaverHook(
                args.train_dir,
                checkpoint_basename=CHECKPOINT_BASENAME,
                save_steps=args.checkpoint_interval_steps,
                scaffold=scaffold),

            # Hook to save a summary every N steps.
            basic_session_run_hooks.SummarySaverHook(
                save_steps=args.summary_interval_steps,
                output_dir=args.train_dir,
                scaffold=scaffold),

            # Hook to stop at step N.
            basic_session_run_hooks.StopAtStepHook(
                last_step=args.train_max_steps)
        ]

        # Start a new monitored session. This will automatically restart the
        # sessions if the parameter servers are preempted.
        with monitored_session.MonitoredSession(
                session_creator=session_creator, hooks=hooks) as sess:

            while not sess.should_stop():
                start_time = time.time()
                _, loss_value, global_step_value = sess.run(
                    [train_op, loss, global_step])
                duration = time.time() - start_time

                assert not np.isnan(
                    loss_value), 'Model diverged with loss = NaN'

                if global_step_value % 10 == 0:
                    num_examples_per_step = args.batch_size * args.num_gpus
                    examples_per_sec = num_examples_per_step / duration
                    sec_per_batch = duration / args.num_gpus

                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                        'sec/batch)')
                    logging.info(format_str %
                                 (datetime.now(), global_step_value,
                                  loss_value, examples_per_sec, sec_per_batch))
 def test_raise_in_none_secs_and_steps(self):
     with self.assertRaises(ValueError):
         basic_session_run_hooks.CheckpointSaverHook(self.model_dir)
 def test_raise_when_saver_and_scaffold_both_present(self):
     with self.assertRaises(ValueError):
         basic_session_run_hooks.CheckpointSaverHook(
             self.model_dir,
             saver=self.scaffold.saver,
             scaffold=self.scaffold)
Esempio n. 18
0
  def testTrainWithAlteredGradients(self):
    # Use the same learning rate but different gradient multipliers
    # to train two models. Model with equivalently larger learning
    # rate (i.e., learning_rate * gradient_multiplier) has smaller
    # training loss.
    logdir1 = os.path.join(self.get_temp_dir(), 'tmp_logs6/')
    logdir2 = os.path.join(self.get_temp_dir(), 'tmp_logs7/')

    if gfile.Exists(logdir1):
      gfile.DeleteRecursively(logdir1)
    if gfile.Exists(logdir2):
      gfile.DeleteRecursively(logdir2)

    multipliers = [1., 1000.]
    number_of_steps = 10
    losses = []
    learning_rate = 0.001

    # First, train the model with equivalently smaller learning rate.
    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      train_op = self.create_train_op(
          learning_rate=learning_rate, gradient_multiplier=multipliers[0])

      saver = saver_lib.Saver()

      loss = training.train(
          train_op,
          logdir1,
          hooks=[
              basic_session_run_hooks.StopAtStepHook(num_steps=number_of_steps),
              basic_session_run_hooks.CheckpointSaverHook(
                  logdir1, save_steps=50, saver=saver),
          ])

      losses.append(loss)
      self.assertGreater(loss, .5)

    # Second, train the model with equivalently larger learning rate.
    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      train_op = self.create_train_op(
          learning_rate=learning_rate, gradient_multiplier=multipliers[1])
      saver = saver_lib.Saver()

      loss = training.train(
          train_op,
          logdir2,
          hooks=[
              basic_session_run_hooks.StopAtStepHook(num_steps=number_of_steps),
              basic_session_run_hooks.CheckpointSaverHook(
                  logdir2, save_steps=50, saver=saver),
          ])

      losses.append(loss)
      self.assertIsNotNone(loss)
      self.assertLess(loss, .5)

    # The loss of the model trained with larger learning rate should
    # be smaller.
    self.assertGreater(losses[0], losses[1])
Esempio n. 19
0
  def testTrainWithInitFromCheckpoint(self):
    logdir1 = os.path.join(self.get_temp_dir(), 'tmp_logs1/')
    logdir2 = os.path.join(self.get_temp_dir(), 'tmp_logs2/')

    if gfile.Exists(logdir1):  # For running on jenkins.
      gfile.DeleteRecursively(logdir1)
    if gfile.Exists(logdir2):  # For running on jenkins.
      gfile.DeleteRecursively(logdir2)

    # First, train the model one step (make sure the error is high).
    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      train_op = self.create_train_op()
      saver = saver_lib.Saver()
      loss = training.train(
          train_op,
          logdir1,
          hooks=[
              basic_session_run_hooks.CheckpointSaverHook(
                  logdir1, save_steps=1, saver=saver),
              basic_session_run_hooks.StopAtStepHook(num_steps=1),
          ],
          save_checkpoint_secs=None)
      self.assertGreater(loss, .5)

    # Next, train the model to convergence.
    with ops.Graph().as_default():
      random_seed.set_random_seed(1)
      train_op = self.create_train_op()
      saver = saver_lib.Saver()
      loss = training.train(
          train_op,
          logdir1,
          hooks=[
              basic_session_run_hooks.CheckpointSaverHook(
                  logdir1, save_steps=1, saver=saver),
              basic_session_run_hooks.StopAtStepHook(num_steps=300),
          ],
          save_checkpoint_secs=None)
      self.assertIsNotNone(loss)
      self.assertLess(loss, .02)

    # Finally, advance the model a single step and validate that the loss is
    # still low.
    with ops.Graph().as_default():
      random_seed.set_random_seed(2)
      train_op = self.create_train_op()

      model_variables = variables_lib2.global_variables()
      model_path = os.path.join(logdir1, 'model.ckpt-300')

      assign_fn = variables_lib.assign_from_checkpoint_fn(model_path,
                                                          model_variables)

      def init_fn(_, session):
        assign_fn(session)

      loss = training.train(
          train_op,
          logdir2,
          scaffold=monitored_session.Scaffold(init_fn=init_fn),
          hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=1)])

      self.assertIsNotNone(loss)
      self.assertLess(loss, .02)
Esempio n. 20
0
def _monitored_train(graph,
                     output_dir,
                     train_op,
                     loss_op,
                     global_step_tensor=None,
                     init_op=None,
                     init_feed_dict=None,
                     init_fn=None,
                     log_every_steps=10,
                     supervisor_is_chief=True,
                     supervisor_master='',
                     supervisor_save_model_secs=600,
                     supervisor_save_model_steps=None,
                     keep_checkpoint_max=5,
                     supervisor_save_summaries_secs=None,
                     supervisor_save_summaries_steps=100,
                     feed_fn=None,
                     steps=None,
                     fail_on_nan_loss=True,
                     hooks=None,
                     max_steps=None):
  """Train a model via monitored_session.

  Given `graph`, a directory to write outputs to (`output_dir`), and some ops,
  run a training loop. The given `train_op` performs one step of training on the
  model. The `loss_op` represents the objective function of the training. It is
  expected to increment the `global_step_tensor`, a scalar integer tensor
  counting training steps. This function uses `Supervisor` to initialize the
  graph (from a checkpoint if one is available in `output_dir`), write summaries
  defined in the graph, and write regular checkpoints as defined by
  `supervisor_save_model_secs`.

  Training continues until `global_step_tensor` evaluates to `max_steps`, or, if
  `fail_on_nan_loss`, until `loss_op` evaluates to `NaN`. In that case the
  program is terminated with exit code 1.

  Args:
    graph: A graph to train. It is expected that this graph is not in use
      elsewhere.
    output_dir: A directory to write outputs to.
    train_op: An op that performs one training step when run.
    loss_op: A scalar loss tensor.
    global_step_tensor: A tensor representing the global step. If none is given,
      one is extracted from the graph using the same logic as in `Supervisor`.
    init_op: An op that initializes the graph. If `None`, use `Supervisor`'s
      default.
    init_feed_dict: A dictionary that maps `Tensor` objects to feed values.
      This feed dictionary will be used when `init_op` is evaluated.
    init_fn: Optional callable passed to Supervisor to initialize the model.
    log_every_steps: Output logs regularly. The logs contain timing data and the
      current loss. A `0` or negative value disables logging.
    supervisor_is_chief: Whether the current process is the chief supervisor in
      charge of restoring the model and running standard services.
    supervisor_master: The master string to use when preparing the session.
    supervisor_save_model_secs: Save checkpoints every this many seconds. Can
        not be specified with `supervisor_save_model_steps`.
    supervisor_save_model_steps: Save checkpoints every this many steps. Can not
        be specified with `supervisor_save_model_secs`.
    keep_checkpoint_max: The maximum number of recent checkpoint files to
      keep. As new files are created, older files are deleted. If None or 0,
      all checkpoint files are kept. This is simply passed as the max_to_keep
      arg to `tf.Saver` constructor.
    supervisor_save_summaries_secs: Save summaries every
      `supervisor_save_summaries_secs` seconds when training.
    supervisor_save_summaries_steps: Save summaries every
      `supervisor_save_summaries_steps` steps when training. Exactly one of
      `supervisor_save_model_steps` and `supervisor_save_model_secs` should be
      specified, and the other should be None.
    feed_fn: A function that is called every iteration to produce a `feed_dict`
      passed to `session.run` calls. Optional.
    steps: Trains for this many steps (e.g. current global step + `steps`).
    fail_on_nan_loss: If true, raise `NanLossDuringTrainingError` if `loss_op`
      evaluates to `NaN`. If false, continue training as if nothing happened.
    hooks: List of `SessionRunHook` subclass instances. Used for callbacks
      inside the training loop.
    max_steps: Number of total steps for which to train model. If `None`,
      train forever. Two calls fit(steps=100) means 200 training iterations.
      On the other hand two calls of fit(max_steps=100) means, second call
      will not do any iteration since first call did all 100 steps.

  Returns:
    The final loss value.

  Raises:
    ValueError: If `output_dir`, `train_op`, `loss_op`, or `global_step_tensor`
      is not provided. See `tf.contrib.framework.get_global_step` for how we
      look up the latter if not provided explicitly.
    NanLossDuringTrainingError: If `fail_on_nan_loss` is `True`, and loss ever
      evaluates to `NaN`.
    ValueError: If both `steps` and `max_steps` are not `None`.
  """
  if (steps is not None) and (max_steps is not None):
    raise ValueError('Can not provide both steps and max_steps.')
  if not output_dir:
    raise ValueError('Output directory should be non-empty %s.' % output_dir)
  if train_op is None:
    raise ValueError('Missing train_op.')
  if loss_op is None:
    raise ValueError('Missing loss_op.')
  if hooks is None:
    hooks = []
  if not isinstance(hooks, list):
    raise ValueError('Hooks should be a list.')
  with graph.as_default():
    global_step_tensor = contrib_variables.assert_or_get_global_step(
        graph, global_step_tensor)
  if global_step_tensor is None:
    raise ValueError('No "global_step" was provided or found in the graph.')

  if max_steps is not None:
    try:
      start_step = load_variable(output_dir, global_step_tensor.name)
      if max_steps <= start_step:
        logging.info('Skipping training since max_steps has already saved.')
        return None
    except:  # pylint: disable=bare-except
      pass

  # Adapted SessionRunHooks such as ExportMonitor depend on the
  # CheckpointSaverHook to be executed before they should be executed.
  # The `hooks` param comprises of deprecated monitor hooks
  # (such as ExportMonitor). Appending them after the basic_session_run_hooks.
  all_hooks = []
  with graph.as_default():
    all_hooks.append(basic_session_run_hooks.NanTensorHook(
        loss_op, fail_on_nan_loss=fail_on_nan_loss))
    if log_every_steps > 0:
      all_hooks.append(basic_session_run_hooks.LoggingTensorHook({
          'loss': loss_op.name,
          'step': global_step_tensor.name
      }, every_n_iter=log_every_steps))

    def make_saver():
      return tf_saver.Saver(
          sharded=True, max_to_keep=keep_checkpoint_max, defer_build=True,
          write_version=saver_pb2.SaverDef.V1)

    scaffold = monitored_session.Scaffold(
        init_op=init_op,
        init_feed_dict=init_feed_dict,
        init_fn=init_fn,
        saver=monitored_session.Scaffold.get_or_default('saver',
                                                        ops.GraphKeys.SAVERS,
                                                        make_saver))

    if not supervisor_is_chief:
      session_creator = monitored_session.WorkerSessionCreator(
          scaffold=scaffold,
          master=supervisor_master)
    else:
      session_creator = monitored_session.ChiefSessionCreator(
          scaffold=scaffold,
          checkpoint_dir=output_dir,
          master=supervisor_master)
      summary_writer = summary_io.SummaryWriterCache.get(output_dir)
      all_hooks.append(
          basic_session_run_hooks.StepCounterHook(
              summary_writer=summary_writer))
      all_hooks.append(
          basic_session_run_hooks.SummarySaverHook(
              save_secs=supervisor_save_summaries_secs,
              save_steps=supervisor_save_summaries_steps,
              summary_writer=summary_writer,
              scaffold=scaffold))
      if (supervisor_save_model_secs is not None
          or supervisor_save_model_steps is not None):
        all_hooks.append(
            basic_session_run_hooks.CheckpointSaverHook(
                output_dir,
                save_secs=supervisor_save_model_secs,
                save_steps=supervisor_save_model_steps,
                scaffold=scaffold))

    if steps is not None or max_steps is not None:
      all_hooks.append(basic_session_run_hooks.StopAtStepHook(steps, max_steps))
    all_hooks.extend(hooks)

    with monitored_session.MonitoredSession(
        session_creator=session_creator,
        hooks=all_hooks) as super_sess:
      loss = None
      while not super_sess.should_stop():
        _, loss = super_sess.run([train_op, loss_op], feed_fn() if feed_fn else
                                 None)
    summary_io.SummaryWriterCache.clear()
    return loss
Esempio n. 21
0
def PartialRestoreSession(
        master='',  # pylint: disable=invalid-name
        is_chief=True,
        checkpoint_dir=None,
        restore_var_list=None,
        scaffold=None,
        hooks=None,
        chief_only_hooks=None,
        save_checkpoint_secs=600,
        save_summaries_steps=monitored_session.USE_DEFAULT,
        save_summaries_secs=monitored_session.USE_DEFAULT,
        config=None,
        stop_grace_period_secs=120,
        log_step_count_steps=100):
    """Creates a `MonitoredSession` for training.

    Supports partial restoration from checkpoints with parameter
    `restore_var_list`, by adding `CheckpointRestorerHook`.

  For a chief, this utility sets proper session initializer/restorer. It also
  creates hooks related to checkpoint and summary saving. For workers, this
  utility sets proper session creator which waits for the chief to
  initialize/restore. Please check `tf.train.MonitoredSession` for more
  information.


  Args:
    master: `String` the TensorFlow master to use.
    is_chief: If `True`, it will take care of initialization and recovery the
      underlying TensorFlow session. If `False`, it will wait on a chief to
      initialize or recover the TensorFlow session.
    checkpoint_dir: A string.  Optional path to a directory where to restore
      variables.
    restore_var_list: a list of variables, optional, if not all variables should
      be recovered from checkpoint.
      Useful when changing network structures during training, i.e., finetuning
      a pretrained model with new layers.
    scaffold: A `Scaffold` used for gathering or building supportive ops. If
      not specified, a default one is created. It's used to finalize the graph.
    hooks: Optional list of `SessionRunHook` objects.
    chief_only_hooks: list of `SessionRunHook` objects. Activate these hooks if
      `is_chief==True`, ignore otherwise.
    save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
      using a default checkpoint saver. If `save_checkpoint_secs` is set to
      `None`, then the default checkpoint saver isn't used.
    save_summaries_steps: The frequency, in number of global steps, that the
      summaries are written to disk using a default summary saver. If both
      `save_summaries_steps` and `save_summaries_secs` are set to `None`, then
      the default summary saver isn't used. Default 100.
    save_summaries_secs: The frequency, in secs, that the summaries are written
      to disk using a default summary saver.  If both `save_summaries_steps` and
      `save_summaries_secs` are set to `None`, then the default summary saver
      isn't used. Default not enabled.
    config: an instance of `tf.ConfigProto` proto used to configure the session.
      It's the `config` argument of constructor of `tf.Session`.
    stop_grace_period_secs: Number of seconds given to threads to stop after
      `close()` has been called.
    log_step_count_steps: The frequency, in number of global steps, that the
      global step/sec is logged.

  Returns:
    A `MonitoredSession` object.
  """
    if save_summaries_steps == monitored_session.USE_DEFAULT \
            and save_summaries_secs == monitored_session.USE_DEFAULT:
        save_summaries_steps = 100
        save_summaries_secs = None
    elif save_summaries_secs == monitored_session.USE_DEFAULT:
        save_summaries_secs = None
    elif save_summaries_steps == monitored_session.USE_DEFAULT:
        save_summaries_steps = None

    scaffold = scaffold or monitored_session.Scaffold()
    if not is_chief:
        session_creator = monitored_session.WorkerSessionCreator(
            scaffold=scaffold, master=master, config=config)
        return monitored_session.MonitoredSession(
            session_creator=session_creator,
            hooks=hooks or [],
            stop_grace_period_secs=stop_grace_period_secs)

    all_hooks = []
    if chief_only_hooks:
        all_hooks.extend(chief_only_hooks)
    if restore_var_list is None:
        restore_checkpoint_dir = checkpoint_dir
    else:
        restore_checkpoint_dir = None
        all_hooks.append(
            CheckpointRestorerHook(checkpoint_dir, var_list=restore_var_list))
        all_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
        missing_vars = filter(lambda v: not (v in restore_var_list), all_vars)
        logging.warning("MonitoredTrainingSession not restoring %s",
                        missing_vars)
    session_creator = monitored_session.ChiefSessionCreator(
        scaffold=scaffold,
        checkpoint_dir=restore_checkpoint_dir,
        master=master,
        config=config)

    if checkpoint_dir:
        all_hooks.append(
            basic_session_run_hooks.StepCounterHook(
                output_dir=checkpoint_dir, every_n_steps=log_step_count_steps))

        if (save_summaries_steps
                and save_summaries_steps > 0) or (save_summaries_secs
                                                  and save_summaries_secs > 0):
            all_hooks.append(
                basic_session_run_hooks.SummarySaverHook(
                    scaffold=scaffold,
                    save_steps=save_summaries_steps,
                    save_secs=save_summaries_secs,
                    output_dir=checkpoint_dir))
        if save_checkpoint_secs and save_checkpoint_secs > 0:
            all_hooks.append(
                basic_session_run_hooks.CheckpointSaverHook(
                    checkpoint_dir,
                    save_secs=save_checkpoint_secs,
                    scaffold=scaffold))

    if hooks:
        all_hooks.extend(hooks)
    return monitored_session.MonitoredSession(
        session_creator=session_creator,
        hooks=all_hooks,
        stop_grace_period_secs=stop_grace_period_secs)
Esempio n. 22
0
def train(train_op,
          logdir,
          master='',
          is_chief=True,
          scaffold=None,
          hooks=None,
          chief_only_hooks=None,
          save_checkpoint_secs=600,
          save_summaries_steps=100,
          config=None):
    """Runs the training loop.

  Args:
    train_op: A `Tensor` that, when executed, will apply the gradients and
      return the loss value.
    logdir: The directory where the graph and checkpoints are saved.
    master: The URL of the master.
    is_chief: Specifies whether or not the training is being run by the primary
      replica during replica training.
    scaffold: An tf.train.Scaffold instance.
    hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
      training loop.
    chief_only_hooks: List of `tf.train.SessionRunHook` instances which are run
      inside the training loop for the chief trainer only.
    save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
      using a default checkpoint saver. If `save_checkpoint_secs` is set to
      `None`, then the default checkpoint saver isn't used.
    save_summaries_steps: The frequency, in number of global steps, that the
      summaries are written to disk using a default summary saver. If
      `save_summaries_steps` is set to `None`, then the default summary saver
      isn't used.
    config: An instance of `tf.ConfigProto`.

  Returns:
    the value of the loss function after training.

  Raises:
    ValueError: if `logdir` is `None` and either `save_checkpoint_secs` or
    `save_summaries_steps` are `None.
  """
    # TODO(nsilberman): move this logic into monitored_session.py
    scaffold = scaffold or monitored_session.Scaffold()

    hooks = hooks or []

    if is_chief:
        session_creator = monitored_session.ChiefSessionCreator(
            scaffold=scaffold,
            checkpoint_dir=logdir,
            master=master,
            config=config)

        if chief_only_hooks:
            hooks.extend(chief_only_hooks)

        hooks.append(
            basic_session_run_hooks.StepCounterHook(output_dir=logdir))

        if save_summaries_steps:
            if logdir is None:
                raise ValueError(
                    'logdir cannot be None when save_summaries_steps is None')
            hooks.append(
                basic_session_run_hooks.SummarySaverHook(
                    scaffold=scaffold,
                    save_steps=save_summaries_steps,
                    output_dir=logdir))

        if save_checkpoint_secs:
            if logdir is None:
                raise ValueError(
                    'logdir cannot be None when save_checkpoint_secs is None')
            hooks.append(
                basic_session_run_hooks.CheckpointSaverHook(
                    logdir, save_secs=save_checkpoint_secs, scaffold=scaffold))
    else:
        session_creator = monitored_session.WorkerSessionCreator(
            scaffold=scaffold, master=master, config=config)

    with monitored_session.MonitoredSession(session_creator=session_creator,
                                            hooks=hooks) as session:
        loss = None
        while not session.should_stop():
            loss = session.run(train_op)
    return loss
Esempio n. 23
0
    def _model_fn(features, labels, mode):
        """Function that returns predictions, training loss, and training op."""
        weights = None
        if weights_name and weights_name in features:
            weights = features.pop(weights_name)

        keys = None
        if keys_name and keys_name in features:
            keys = features.pop(keys_name)

        # If we're doing eval, optionally ignore device_assigner.
        # Also ignore device assigner if we're exporting (mode == INFER)
        dev_assn = device_assigner
        if (mode == model_fn_lib.ModeKeys.INFER
                or (local_eval and mode == model_fn_lib.ModeKeys.EVAL)):
            dev_assn = None

        graph_builder = graph_builder_class(params, device_assigner=dev_assn)
        inference = {}
        output_alternatives = None
        if (mode == model_fn_lib.ModeKeys.EVAL
                or mode == model_fn_lib.ModeKeys.INFER):
            inference[eval_metrics.INFERENCE_PROB_NAME] = (
                graph_builder.inference_graph(features))

            if params.regression:
                predictions = {
                    None: inference[eval_metrics.INFERENCE_PROB_NAME]
                }
                output_alternatives = {
                    None:
                    (constants.ProblemType.LINEAR_REGRESSION, predictions)
                }
            else:
                inference[eval_metrics.INFERENCE_PRED_NAME] = math_ops.argmax(
                    inference[eval_metrics.INFERENCE_PROB_NAME], 1)

                predictions = {
                    prediction_key.PredictionKey.PROBABILITIES:
                    inference[eval_metrics.INFERENCE_PROB_NAME],
                    prediction_key.PredictionKey.CLASSES:
                    inference[eval_metrics.INFERENCE_PRED_NAME]
                }
                output_alternatives = {
                    None: (constants.ProblemType.CLASSIFICATION, predictions)
                }

            if keys is not None:
                inference[keys_name] = keys

        # labels might be None if we're doing prediction (which brings up the
        # question of why we force everything to adhere to a single model_fn).
        loss_deps = []
        training_graph = None
        training_hooks = []
        scaffold = None
        if labels is not None and mode == model_fn_lib.ModeKeys.TRAIN:
            training_graph = control_flow_ops.group(
                graph_builder.training_graph(features,
                                             labels,
                                             input_weights=weights,
                                             num_trainers=num_trainers,
                                             trainer_id=trainer_id),
                state_ops.assign_add(contrib_framework.get_global_step(), 1))
            loss_deps.append(training_graph)
            if hasattr(graph_builder, 'finalize_training'):
                finalize_listener = EveryCheckpointPreSaveListener(
                    graph_builder.finalize_training())
                scaffold = monitored_session.Scaffold()
                training_hooks.append(
                    basic_session_run_hooks.CheckpointSaverHook(
                        model_dir,
                        save_secs=600,
                        save_steps=None,
                        scaffold=scaffold,
                        listeners=[finalize_listener]))

        training_loss = None
        if (mode == model_fn_lib.ModeKeys.EVAL
                or mode == model_fn_lib.ModeKeys.TRAIN):
            with ops.control_dependencies(loss_deps):
                training_loss = graph_builder.training_loss(features,
                                                            labels,
                                                            name=LOSS_NAME)

        # Put weights back in
        if weights is not None:
            features[weights_name] = weights

        if early_stopping_rounds:
            training_hooks.append(TensorForestLossHook(early_stopping_rounds))

        if report_feature_importances:
            training_hooks.append(
                TensorForestRunOpAtEndHook({
                    'feature_importances':
                    graph_builder.feature_importances()
                }))

        return model_fn_lib.ModelFnOps(mode=mode,
                                       predictions=inference,
                                       loss=training_loss,
                                       train_op=training_graph,
                                       training_hooks=training_hooks,
                                       scaffold=scaffold,
                                       output_alternatives=output_alternatives)
Esempio n. 24
0
def random_forest_model_fn(features, labels, mode, params, config):
    """Function that returns predictions, training loss, and training op."""
    labels_tensor = labels
    if isinstance(labels, dict) and len(labels) == 1:
        labels_tensor = labels.values()[0]

    weights_name = params["weights_name"]
    keys_name = params["keys_name"]
    num_classes = tf.identity(params['num_classes'], name='num_classes')
    params_toGraphs = tensor_forest.ForestHParams(
        num_classes=params['num_classes'],
        num_features=params['num_features'],
        num_trees=params['num_trees'],
        max_nodes=params['max_nodes'],
        regression=params['regression'],
        split_after_samples=params['split_after_samples'])
    #  注意第90行 fill()
    # https://github.com/tensorflow/tensorflow/blob/r1.2/tensorflow/contrib
    # /tensor_forest/python/tensor_forest.py
    params_toGraphs = params_toGraphs.fill()
    graph_builder_class = tensor_forest.RandomForestGraphs

    early_stopping_rounds = params["early_stopping_rounds"]
    num_trainers = 1
    trainer_id = 0
    report_feature_importances = False
    model_dir = None
    local_eval = False
    device_assigner = None
    weights = None
    if weights_name and weights_name in features:
        weights = features.pop(weights_name)

    keys = None
    if keys_name and keys_name in features:
        keys = features.pop(keys_name)

    # If we're doing eval, optionally ignore device_assigner.
    # Also ignore device assigner if we're exporting (mode == INFER)
    dev_assn = device_assigner
    if (mode == model_fn_lib.ModeKeys.INFER
            or (local_eval and mode == model_fn_lib.ModeKeys.EVAL)):
        dev_assn = None

    graph_builder = graph_builder_class(params_toGraphs,
                                        device_assigner=dev_assn)
    inference = {}
    predictions = {}
    output_alternatives = None
    # if (mode == model_fn_lib.ModeKeys.EVAL or
    #             mode == model_fn_lib.ModeKeys.INFER):
    if True:
        inference[eval_metrics.INFERENCE_PROB_NAME] = (
            graph_builder.inference_graph(features))

        if params_toGraphs.regression:
            predictions = {None: inference[eval_metrics.INFERENCE_PROB_NAME]}
            output_alternatives = {
                None: (constants.ProblemType.LINEAR_REGRESSION, predictions)
            }
        else:
            inference[eval_metrics.INFERENCE_PRED_NAME] = math_ops.argmax(
                inference[eval_metrics.INFERENCE_PROB_NAME], 1)

            predictions = {
                prediction_key.PredictionKey.PROBABILITIES:
                inference[eval_metrics.INFERENCE_PROB_NAME],
                prediction_key.PredictionKey.CLASSES:
                inference[eval_metrics.INFERENCE_PRED_NAME]
            }
            output_alternatives = {
                None: (constants.ProblemType.CLASSIFICATION, predictions)
            }

        if report_feature_importances:
            inference[eval_metrics.FEATURE_IMPORTANCE_NAME] = (
                graph_builder.feature_importances())

        if keys is not None:
            inference[keys_name] = keys

    # labels might be None if we're doing prediction (which brings up the
    # question of why we force everything to adhere to a single model_fn).
    loss_deps = []
    training_graph = None
    training_hooks = []
    scaffold = None
    if labels is not None and mode == model_fn_lib.ModeKeys.TRAIN:
        training_graph = control_flow_ops.group(
            graph_builder.training_graph(features,
                                         labels,
                                         input_weights=weights,
                                         num_trainers=num_trainers,
                                         trainer_id=trainer_id),
            state_ops.assign_add(contrib_framework.get_global_step(), 1))
        loss_deps.append(training_graph)
        if hasattr(graph_builder, 'finalize_training'):
            finalize_listener = EveryCheckpointPreSaveListener(
                graph_builder.finalize_training())
            scaffold = monitored_session.Scaffold()
            training_hooks.append(
                basic_session_run_hooks.CheckpointSaverHook(
                    model_dir,
                    save_secs=600,
                    save_steps=None,
                    scaffold=scaffold,
                    listeners=[finalize_listener]))

    training_loss = None
    if (mode == model_fn_lib.ModeKeys.EVAL
            or mode == model_fn_lib.ModeKeys.TRAIN):
        with ops.control_dependencies(loss_deps):
            training_loss = graph_builder.training_loss(
                features, labels, name='rf_training_loss')

    # 命名以传到 hook 中
    if not params['regression']:
        confusion_matrix_print = confusion_matrix(
            labels=labels_tensor,
            predictions=predictions['classes'],
            num_classes=num_classes,
        )

        confusion_matrix_print = tf.identity(confusion_matrix_print,
                                             name='confusion_matrix_print')
    else:
        confusion_matrix_print = tf.identity(0, name='confusion_matrix_print')

    regression_ornot = tf.identity(params['regression'],
                                   name='regression_ornot')
    # Put weights back in
    if weights is not None:
        features[weights_name] = weights

    if early_stopping_rounds:
        training_hooks.append(TensorForestLossHook(early_stopping_rounds))

    metrics = {}
    # metrics[metric_key.MetricKey.AUC] = metrics_lib.streaming_auc(
    #     labels=labels_tensor,
    #     predictions=inference[eval_metrics.INFERENCE_PRED_NAME]
    # )
    if not params_toGraphs.regression:
        metrics['eval_confusion_matrix'] = confusion_matrix(
            labels=labels_tensor,
            predictions=predictions['classes'],
            num_classes=params['num_classes'],
        )

    return model_fn_lib.ModelFnOps(mode=mode,
                                   predictions=inference,
                                   loss=training_loss,
                                   train_op=training_graph,
                                   training_hooks=training_hooks,
                                   scaffold=scaffold,
                                   eval_metric_ops=metrics,
                                   output_alternatives=output_alternatives)
Esempio n. 25
0
def MonitoredTrainingSession(
        master='',  # pylint: disable=invalid-name
        is_chief=True,
        checkpoint_dir=None,
        scaffold=None,
        hooks=None,
        chief_only_hooks=None,
        save_checkpoint_secs=600,
        save_summaries_steps=100,
        save_summaries_secs=None,
        config=None,
        stop_grace_period_secs=120,
        log_step_count_steps=100):
    """Creates a `MonitoredSession` for training.

  For a chief, this utility sets proper session initializer/restorer. It also
  creates hooks related to checkpoint and summary saving. For workers, this
  utility sets proper session creator which waits for the chief to
  initialize/restore.


  Args:
    master: `String` the TensorFlow master to use.
    is_chief: If `True`, it will take care of initialization and recovery the
      underlying TensorFlow session. If `False`, it will wait on a chief to
      initialize or recover the TensorFlow session.
    checkpoint_dir: A string.  Optional path to a directory where to restore
      variables.
    scaffold: A `Scaffold` used for gathering or building supportive ops. If
      not specified, a default one is created. It's used to finalize the graph.
    hooks: Optional list of `SessionRunHook` objects.
    chief_only_hooks: list of `SessionRunHook` objects. Activate these hooks if
      `is_chief==True`, ignore otherwise.
    save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
      using a default checkpoint saver. If `save_checkpoint_secs` is set to
      `None`, then the default checkpoint saver isn't used.
    save_summaries_steps: The frequency, in number of global steps, that the
      summaries are written to disk using a default summary saver. If both
      `save_summaries_steps` and `save_summaries_secs` are set to `None`, then
      the default summary saver isn't used.
    save_summaries_secs: The frequency, in secs, that the summaries are written
      to disk using a default summary saver.  If both `save_summaries_steps` and
      `save_summaries_secs` are set to `None`, then the default summary saver
      isn't used.
    config: an instance of `tf.ConfigProto` proto used to configure the session.
      It's the `config` argument of constructor of `tf.Session`.
    stop_grace_period_secs: Number of seconds given to threads to stop after
      `close()` has been called.
    log_step_count_steps: The frequency, in number of global steps, that the
      global step/sec is logged.

  Returns:
    A `MonitoredSession` object.
  """
    scaffold = scaffold or Scaffold()
    if not is_chief:
        session_creator = WorkerSessionCreator(scaffold=scaffold,
                                               master=master,
                                               config=config)
        return MonitoredSession(session_creator=session_creator,
                                hooks=hooks or [],
                                stop_grace_period_secs=stop_grace_period_secs)

    all_hooks = []
    if chief_only_hooks:
        all_hooks.extend(chief_only_hooks)
    session_creator = ChiefSessionCreator(scaffold=scaffold,
                                          checkpoint_dir=checkpoint_dir,
                                          master=master,
                                          config=config)

    if checkpoint_dir:
        all_hooks.append(
            basic_session_run_hooks.StepCounterHook(
                output_dir=checkpoint_dir, every_n_steps=log_step_count_steps))

        if (save_summaries_steps
                and save_summaries_steps > 0) or (save_summaries_secs
                                                  and save_summaries_secs > 0):
            all_hooks.append(
                basic_session_run_hooks.SummarySaverHook(
                    scaffold=scaffold,
                    save_steps=save_summaries_steps,
                    save_secs=save_summaries_secs,
                    output_dir=checkpoint_dir))
        if save_checkpoint_secs and save_checkpoint_secs > 0:
            all_hooks.append(
                basic_session_run_hooks.CheckpointSaverHook(
                    checkpoint_dir,
                    save_secs=save_checkpoint_secs,
                    scaffold=scaffold))

    if hooks:
        all_hooks.extend(hooks)
    return MonitoredSession(session_creator=session_creator,
                            hooks=all_hooks,
                            stop_grace_period_secs=stop_grace_period_secs)
Esempio n. 26
0
def train(args):
    """Train CIFAR-10 for a number of steps.

  Args:
    args: The command line arguments.
  """

    with tf.Graph().as_default():

        # Create the global step
        global_step = tf.contrib.framework.create_global_step()

        # Get images and labels for CIFAR-10.
        images, labels = cifar10.distorted_inputs(args.data_dir,
                                                  args.batch_size,
                                                  args.use_fp16)

        # Build a Graph that computes the logits predictions from the
        # inference model.
        logits = cifar10.inference(images, args.batch_size, args.use_fp16)

        # Calculate loss.
        loss = cifar10.loss(logits, labels)

        # Build a Graph that trains the model with one batch of examples and
        # updates the model parameters.
        train_op = cifar10.train(loss, global_step, args.batch_size)

        scaffold = monitored_session.Scaffold()

        session_creator = monitored_session.ChiefSessionCreator(
            scaffold,
            checkpoint_dir=args.train_dir,
            config=tf.ConfigProto(
                log_device_placement=args.log_device_placement))

        hooks = [
            # Hook to save the model every N steps and at the end.
            basic_session_run_hooks.CheckpointSaverHook(
                args.train_dir,
                checkpoint_basename=CHECKPOINT_BASENAME,
                save_steps=args.checkpoint_interval_steps,
                scaffold=scaffold),

            # Hook to save a summary every N steps.
            basic_session_run_hooks.SummarySaverHook(
                save_steps=args.summary_interval_steps,
                output_dir=args.train_dir,
                scaffold=scaffold),

            # Hook to stop at step N.
            basic_session_run_hooks.StopAtStepHook(
                last_step=args.train_max_steps)
        ]

        # Start a new monitored session. This will automatically restart the
        # sessions if the parameter servers are preempted.
        with monitored_session.MonitoredSession(
                session_creator=session_creator, hooks=hooks) as sess:

            while not sess.should_stop():

                start_time = time.time()
                _, loss_value, global_step_value = sess.run(
                    [train_op, loss, global_step])
                duration = time.time() - start_time

                assert not np.isnan(
                    loss_value), 'Model diverged with loss = NaN'

                if global_step_value % 10 == 0:
                    num_examples_per_step = args.batch_size
                    examples_per_sec = num_examples_per_step / duration
                    sec_per_batch = float(duration)

                    logging.info(
                        ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                         'sec/batch)'), datetime.now(), global_step_value,
                        loss_value, examples_per_sec, sec_per_batch)
Esempio n. 27
0
  def testTrainAllVarsHasLowerLossThanTrainSubsetOfVars(self):
    logdir = os.path.join(self.get_temp_dir(), 'tmp_logs3/')
    if gfile.Exists(logdir):  # For running on jenkins.
      gfile.DeleteRecursively(logdir)

    # First, train only the weights of the model.
    with ops.Graph().as_default():
      random_seed.set_random_seed(0)
      total_loss = self.ModelLoss()
      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
      weights = variables_lib.get_variables_by_name('weights')

      train_op = training.create_train_op(
          total_loss, optimizer, variables_to_train=weights)

      saver = saver_lib.Saver()
      loss = training.train(
          train_op,
          logdir,
          hooks=[
              basic_session_run_hooks.CheckpointSaverHook(
                  logdir, save_steps=1, saver=saver),
              basic_session_run_hooks.StopAtStepHook(num_steps=200),
          ])
      self.assertGreater(loss, .015)
      self.assertLess(loss, .05)

    # Next, train the biases of the model.
    with ops.Graph().as_default():
      random_seed.set_random_seed(1)
      total_loss = self.ModelLoss()
      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)
      biases = variables_lib.get_variables_by_name('biases')

      train_op = training.create_train_op(
          total_loss, optimizer, variables_to_train=biases)

      saver = saver_lib.Saver()
      loss = training.train(
          train_op,
          logdir,
          hooks=[
              basic_session_run_hooks.CheckpointSaverHook(
                  logdir, save_steps=1, saver=saver),
              basic_session_run_hooks.StopAtStepHook(num_steps=300),
          ])
      self.assertGreater(loss, .015)
      self.assertLess(loss, .05)

    # Finally, train both weights and bias to get lower loss.
    with ops.Graph().as_default():
      random_seed.set_random_seed(2)
      total_loss = self.ModelLoss()
      optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1.0)

      train_op = training.create_train_op(total_loss, optimizer)
      saver = saver_lib.Saver()
      loss = training.train(
          train_op,
          logdir,
          hooks=[
              basic_session_run_hooks.CheckpointSaverHook(
                  logdir, save_steps=1, saver=saver),
              basic_session_run_hooks.StopAtStepHook(num_steps=400),
          ])
      self.assertIsNotNone(loss)
      self.assertLess(loss, .015)
 def test_raise_when_saver_and_scaffold_both_missing(self):
     with self.assertRaises(ValueError):
         basic_session_run_hooks.CheckpointSaverHook(self.model_dir)
Esempio n. 29
0
    def __init__(self, estimator, external_state_policy=None):
        """Initializes a `CheckpointInputPipelineHook`.

    If the input pipeline depends on external state (e.g. seeds for
    RandomUniform) beyond the input pipeline, this hook would be unable to
    serialize and deserialize that state. If its acceptable to ignore that state
    change the external_state_policy argument to 'warn' or 'ignore'. For e.g.

    ```python
    est = tf.estimator.Estimator(model_fn)
    while True:
      est.train(
          train_input_fn,
          hooks=[tf.data.experimental.CheckpointInputPipelineHook(
              est, external_state_policy='warn')],
          steps=train_steps_per_eval)
      # Note: We do not pass the hook here.
      metrics = est.evaluate(eval_input_fn)
      if should_stop_the_training(metrics):
        break
    ```

    Args:
      estimator: Estimator.
      external_state_policy: A string that identifies how to handle input
        pipelines that depend on external state. Possible values are
        'ignore': The external state is silently ignored.
        'warn': The external state is ignored, logging a warning.
        'fail': The operation fails upon encountering external state.
        By default we set it to 'fail'.

    Raises:
      ValueError: One of `save_steps` or `save_secs` should be set.
      ValueError: At most one of saver or scaffold should be set.
      ValueError: If `external_state_policy` is not one of 'warn', 'ignore' or
        'fail'.
    """
        if external_state_policy is None:
            external_state_policy = "fail"
        self._external_state_policy = _convert_external_state_policy_to_enum(
            external_state_policy)
        # `checkpoint_basename` is "input.ckpt" for non-distributed pipelines or
        # of the form "input_<task_type>_<task_id>.ckpt" for distributed pipelines.
        # Note: The default `checkpoint_basename` used by `CheckpointSaverHook` is
        # "model.ckpt". We intentionally choose the input pipeline checkpoint prefix
        # to be different to avoid conflicts with the model checkpoint.

        # pylint: disable=protected-access
        checkpoint_prefix = "input"
        if estimator._config.num_worker_replicas > 1:
            # Distributed setting.
            suffix = "_{}_{}".format(estimator._config.task_type,
                                     estimator._config.task_id)
            checkpoint_prefix += suffix
        # pylint: enable=protected-access

        # We use a composition paradigm instead of inheriting from
        # `CheckpointSaverHook` because `Estimator` does an `isinstance` check
        # to check whether a `CheckpointSaverHook` is already present in the list
        # of hooks and if not, adds one. Inheriting from `CheckpointSaverHook`
        # would thwart this behavior. This hook checkpoints *only the iterators*
        # and not the graph variables.
        self._checkpoint_saver_hook = basic_session_run_hooks.CheckpointSaverHook(
            estimator.model_dir,
            save_secs=estimator._config.save_checkpoints_secs,  # pylint: disable=protected-access
            save_steps=estimator._config.save_checkpoints_steps,  # pylint: disable=protected-access
            checkpoint_basename=checkpoint_prefix + ".ckpt")

        # Name for the protocol buffer file that will contain the list of most
        # recent checkpoints stored as a `CheckpointState` protocol buffer.
        # This file, kept in the same directory as the checkpoint files, is
        # automatically managed by the `Saver` to keep track of recent checkpoints.
        # The default name used by the `Saver` for this file is "checkpoint". Here
        # we use the name "checkpoint_<checkpoint_prefix>" so that in case the
        # `checkpoint_dir` is the same as the model checkpoint directory, there are
        # no conflicts during restore.
        self._latest_filename = "checkpoint_" + checkpoint_prefix
Esempio n. 30
0
def train_mnist():
    """
    The main function which trains. Does basically everything in one function.
    """
    with tf.get_default_graph().as_default():
        if DATA is 'MNIST':
            input_pipe = MnistData(BATCH_SIZE)
            train_features, train_labels = input_pipe.build_train_data_tensor()
        elif DATA is 'CIFAR10':
            train_features, train_labels = cifar10_input.distorted_inputs(
                BATCH_SIZE)
        else:
            raise ValueError('DATA value not supported: ' + DATA)

        checkpoint = tf.train.get_checkpoint_state(MODEL_DIR)
        if checkpoint is None:
            tf.logging.info('No checkpoint found; training from scratch')
            global_step = 0
        else:
            # Assuming model_checkpoint_path looks something like:
            #       /my-favorite-path/cifar10_train/model.ckpt-0,
            # extract global_step from it.
            # needs to be before graph setup
            global_step = float(
                checkpoint.model_checkpoint_path.split('/')[-1].split('-')[-1])
            tf.logging.info('Continuing training from step: ' +
                            str(global_step))

        train_op, mask_update_op, pruning_obj, loss = train_graph(
            train_features, train_labels)

        assign_ops = get_assign_ops()

        checkpoint_hook = basic_session_run_hooks.CheckpointSaverHook(
            MODEL_DIR, save_steps=TEST_FREQ)

        # needs to be after graph setup
        if checkpoint is None:
            if RESTORE:
                var_list = []
                for var_str, shape in zip(RESTORE_VARS, RESTORE_SHAPES):
                    i = var_str.index('/')
                    scope = var_str[:i]
                    with tf.variable_scope(scope, reuse=True):
                        var = tf.get_variable(var_str[i + 1:], shape=shape)
                        var_list.append(var)

                saver = tf.train.Saver(var_list)
        else:
            #            cifar_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
            #            cifar_vars = cifar_vars[:-1]
            #            for var in cifar_vars:
            #                tf.logging.info(str(var))
            #            saver2 = tf.train.Saver(cifar_vars)
            saver2 = tf.train.Saver()

        class _LoggerHook(tf.train.SessionRunHook):
            """logs loss and runtime."""
            def begin(self):
                self._step = -0.5 + global_step

            def before_run(self, run_context):
                # this is a hack so that it correctly counts
                # (trainop, mask_update_op) as one step
                self._step += 0.5
                self._start_time = time.time()
                return tf.train.SessionRunArgs(loss)  # asks for loss value.

            def after_run(self, run_context, run_values):
                duration = time.time() - self._start_time
                loss_value = run_values.results
                if self._step > 0 and self._step % 100 == 0:
                    num_examples_per_step = BATCH_SIZE
                    examples_per_sec = num_examples_per_step / duration
                    sec_per_batch = float(duration)
                    format_str = (
                        '%s: step %d, loss = %.2f (%.1f examples/sec' +
                        '; %.3f sec/batch)')
                    tf.logging.info(
                        format_str %
                        (datetime.datetime.now(), self._step, loss_value,
                         examples_per_sec, sec_per_batch))

        with tf.train.MonitoredTrainingSession(
                checkpoint_dir=MODEL_DIR,
                hooks=[_LoggerHook(), checkpoint_hook],
                save_checkpoint_secs=None,
                config=tf.ConfigProto(log_device_placement=False)) as sess:
            if checkpoint is not None:
                saver2.restore(sess, checkpoint.model_checkpoint_path)
            elif RESTORE:
                ckpt_path = tf.train.latest_checkpoint(RESTORE_DIR)
                if not ckpt_path:
                    raise ValueError(
                        'Restore dir cant find checkpoint fles: ' +
                        RESTORE_DIR)
                saver.restore(sess, ckpt_path)

            if INIT_DICT is not None:
                sess.run(assign_ops)
                dict_str = ';  '.join(
                    map(
                        lambda scope: scope + '/' + ', '.join(INIT_DICT[scope].
                                                              keys()) + '',
                        INIT_DICT.keys()))
                tf.logging.info('Instantiated tensors with assign ops: ' +
                                dict_str)

            if PRUNING_CKPTS:
                pruning_dir = MODEL_DIR + '/../' + MODEL_NAME + '_prune_ckpts/'
                if not os.path.isdir(pruning_dir):
                    os.mkdir(pruning_dir)

                j = 0
                max_j = len(PRUNING_CKPTS)

            for i in range(TRAIN_STEPS):
                sess.run(train_op)
                sess.run(mask_update_op)

                if (PRUNING_CKPTS and j < max_j and
                        sess.run(pruning_obj._sparsity) > PRUNING_CKPTS[j]):
                    tf.logging.info('Saving pruning ckpt: ' +
                                    str(PRUNING_CKPTS[j]))
                    new_dir = pruning_dir + str(PRUNING_CKPTS[j]) + '/'
                    utils.duplicate_saved(MODEL_DIR, new_dir)
                    j += 1