Пример #1
0
    def testRecoverSession(self):
        # Create a checkpoint.
        checkpoint_dir = os.path.join(self.get_temp_dir(), "recover_session")
        try:
            gfile.DeleteRecursively(checkpoint_dir)
        except errors.OpError:
            pass  # Ignore
        gfile.MakeDirs(checkpoint_dir)

        with ops.Graph().as_default():
            v = variables.VariableV1(1, name="v")
            sm = session_manager.SessionManager(
                ready_op=variables.report_uninitialized_variables())
            saver = saver_lib.Saver({"v": v})
            sess, initialized = sm.recover_session(
                "", saver=saver, checkpoint_dir=checkpoint_dir)
            self.assertFalse(initialized)
            sess.run(v.initializer)
            self.assertEquals(1, sess.run(v))
            saver.save(
                sess, os.path.join(checkpoint_dir,
                                   "recover_session_checkpoint"))
        self._test_recovered_variable(checkpoint_dir=checkpoint_dir)
        self._test_recovered_variable(
            checkpoint_filename_with_path=checkpoint_management.
            latest_checkpoint(checkpoint_dir))
        # Cannot set both checkpoint_dir and checkpoint_filename_with_path.
        with self.assertRaises(ValueError):
            self._test_recovered_variable(
                checkpoint_dir=checkpoint_dir,
                checkpoint_filename_with_path=checkpoint_management.
                latest_checkpoint(checkpoint_dir))
Пример #2
0
 def __init__(self,
              master,
              is_chief=True,
              checkpoint_dir=None,
              monitors=None,
              scaffold=None,
              config=None):
     self._graph = ops.get_default_graph()
     self._master = master
     self._checkpoint_dir = checkpoint_dir
     self._is_chief = is_chief
     self._config = config
     self._monitors = monitors or []
     self._scaffold = scaffold or Scaffold()
     # Finalize and write the graph.
     self._graph.finalize()
     # Create the session.
     self._session_manager = sm.SessionManager(
         local_init_op=self._scaffold.local_init_op,
         ready_op=self._scaffold.ready_op,
         graph=ops.get_default_graph())
     self._sess = recoverable_session.RecoverableSession(
         self._create_session)
     # Call the begin() method of monitors.
     self._init_step = self._tf_sess.run(self._scaffold.global_step_tensor)
     for monitor in self._monitors:
         monitor.begin(max_steps=None)
     # Write the graph out, note: this uses self._init_step.
     self.write_graph()
Пример #3
0
 def testRecoverSessionNoChkptStillRunsLocalInitOp(self):
     # This test checks for backwards compatibility.
     # In particular, we continue to ensure that recover_session will execute
     # local_init_op exactly once, regardless of whether the session was
     # successfully recovered.
     with ops.Graph().as_default():
         w = variables.VariableV1(
             1,
             trainable=False,
             collections=[ops.GraphKeys.LOCAL_VARIABLES],
             name="w")
         with self.cached_session():
             self.assertEqual(False,
                              variables.is_variable_initialized(w).eval())
         sm2 = session_manager.SessionManager(
             ready_op=variables.report_uninitialized_variables(),
             ready_for_local_init_op=None,
             local_init_op=w.initializer)
         # Try to recover session from None
         sess, initialized = sm2.recover_session("",
                                                 saver=None,
                                                 checkpoint_dir=None)
         # Succeeds because recover_session still run local_init_op
         self.assertFalse(initialized)
         self.assertEqual(
             True,
             variables.is_variable_initialized(
                 sess.graph.get_tensor_by_name("w:0")).eval(session=sess))
         self.assertEquals(1, sess.run(w))
Пример #4
0
    def testWaitForSessionLocalInit(self):
        server = server_lib.Server.create_local_server()
        with ops.Graph().as_default() as graph:
            v = variables.VariableV1(1, name="v")
            w = variables.VariableV1(
                v,
                trainable=False,
                collections=[ops.GraphKeys.LOCAL_VARIABLES],
                name="w")
            sm = session_manager.SessionManager(
                graph=graph,
                ready_op=variables.report_uninitialized_variables(),
                ready_for_local_init_op=variables.
                report_uninitialized_variables(variables.global_variables()),
                local_init_op=w.initializer)

            # Initialize v but not w
            s = session_lib.Session(server.target, graph=graph)
            s.run(v.initializer)

            sess = sm.wait_for_session(server.target, max_wait_secs=3)
            self.assertEqual(
                True,
                variables.is_variable_initialized(
                    sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
            self.assertEqual(
                True,
                variables.is_variable_initialized(
                    sess.graph.get_tensor_by_name("w:0")).eval(session=sess))
            self.assertEqual(1, sess.run(v))
            self.assertEqual(1, sess.run(w))
Пример #5
0
 def testUseSessionManager(self):
   with ops.Graph().as_default():
     variables.VariableV1([1.0, 2.0, 3.0])
     sm = session_manager_lib.SessionManager()
     # Pass in session_manager. The additional init_op is ignored.
     sv = supervisor.Supervisor(logdir="", session_manager=sm)
     sv.prepare_or_wait_for_session("")
 def testPrepareSessionWithReadyForLocalInitOp(self):
   with ops.Graph().as_default():
     v = variables.Variable(1, name="v")
     w = variables.Variable(
         v,
         trainable=False,
         collections=[ops.GraphKeys.LOCAL_VARIABLES],
         name="w")
     with self.test_session():
       self.assertEqual(False, variables.is_variable_initialized(v).eval())
       self.assertEqual(False, variables.is_variable_initialized(w).eval())
     sm2 = session_manager.SessionManager(
         ready_op=variables.report_uninitialized_variables(),
         ready_for_local_init_op=variables.report_uninitialized_variables(
             variables.global_variables()),
         local_init_op=w.initializer)
     sess = sm2.prepare_session("", init_op=v.initializer)
     self.assertEqual(
         True,
         variables.is_variable_initialized(
             sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
     self.assertEqual(
         True,
         variables.is_variable_initialized(
             sess.graph.get_tensor_by_name("w:0")).eval(session=sess))
     self.assertEquals(1, sess.run(v))
     self.assertEquals(1, sess.run(w))
Пример #7
0
 def testPrepareSessionSucceedsWithInitFn(self):
     with ops.Graph().as_default():
         v = variables.Variable([125], name="v")
         sm = session_manager.SessionManager(
             ready_op=variables.assert_variables_initialized())
         sess = sm.prepare_session(
             "", init_fn=lambda sess: sess.run(v.initializer))
         self.assertAllClose([125], sess.run(v))
Пример #8
0
 def _init_session_manager(self, session_manager=None):
   if session_manager is None:
     self._session_manager = session_manager_mod.SessionManager(
         local_init_op=self._local_init_op,
         ready_op=self._ready_op, graph=self._graph,
         recovery_wait_secs=self._recovery_wait_secs)
   else:
     self._session_manager = session_manager
Пример #9
0
 def testPrepareSessionSucceeds(self):
     with ops.Graph().as_default():
         v = variables.Variable([1.0, 2.0, 3.0], name="v")
         sm = session_manager.SessionManager(
             ready_op=variables.assert_variables_initialized())
         sess = sm.prepare_session(
             "", init_op=variables.global_variables_initializer())
         self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
Пример #10
0
    def _get_session_manager(self):
        if self._session_manager:
            return self._session_manager

        self._session_manager = sm.SessionManager(
            local_init_op=self._scaffold.local_init_op,
            ready_op=self._scaffold.ready_op,
            graph=ops.get_default_graph())
        return self._session_manager
Пример #11
0
 def testInitWithNoneLocalInitOpError(self):
     # Creating a SessionManager with a None local_init_op but
     # non-None ready_for_local_init_op raises ValueError
     with self.assertRaisesRegexp(
             ValueError, "If you pass a ready_for_local_init_op "
             "you must also pass a local_init_op "):
         session_manager.SessionManager(ready_for_local_init_op=variables.
                                        report_uninitialized_variables(
                                            variables.global_variables()),
                                        local_init_op=None)
Пример #12
0
    def testWaitForSessionReturnsNoneAfterTimeout(self):
        with ops.Graph().as_default():
            variables.Variable(1, name="v")
            sm = session_manager.SessionManager(
                ready_op=variables.assert_variables_initialized(),
                recovery_wait_secs=1)

            # Set max_wait_secs to allow us to try a few times.
            with self.assertRaises(errors.DeadlineExceededError):
                sm.wait_for_session(master="", max_wait_secs=3)
Пример #13
0
 def testPrepareSessionSucceedsWithLocalInitFeedDict(self):
     with ops.Graph().as_default():
         p = array_ops.placeholder(dtypes.float32, shape=(3, ))
         v = variables.VariableV1(
             p, name="v", collections=[ops.GraphKeys.LOCAL_VARIABLES])
         sm = session_manager.SessionManager(
             local_init_op=v.initializer,
             local_init_feed_dict={p: [1.0, 2.0, 3.0]},
             ready_op=variables.report_uninitialized_variables())
         sess = sm.prepare_session("")
         self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
Пример #14
0
 def testPrepareSessionSucceedsWithInitFeedDict(self):
     with ops.Graph().as_default():
         p = array_ops.placeholder(dtypes.float32, shape=(3, ))
         v = variables.Variable(p, name="v")
         sm = session_manager.SessionManager(
             ready_op=variables.assert_variables_initialized())
         sess = sm.prepare_session(
             "",
             init_op=variables.global_variables_initializer(),
             init_feed_dict={p: [1.0, 2.0, 3.0]})
         self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
Пример #15
0
    def testRecoverSession(self):
        # Create a checkpoint.
        checkpoint_dir = os.path.join(self.get_temp_dir(), "recover_session")
        try:
            gfile.DeleteRecursively(checkpoint_dir)
        except errors.OpError:
            pass  # Ignore
        gfile.MakeDirs(checkpoint_dir)

        with ops.Graph().as_default():
            v = variables.Variable(1, name="v")
            sm = session_manager.SessionManager(
                ready_op=variables.assert_variables_initialized())
            saver = saver_lib.Saver({"v": v})
            sess, initialized = sm.recover_session(
                "", saver=saver, checkpoint_dir=checkpoint_dir)
            self.assertFalse(initialized)
            sess.run(v.initializer)
            self.assertEquals(1, sess.run(v))
            saver.save(
                sess, os.path.join(checkpoint_dir,
                                   "recover_session_checkpoint"))
        # Create a new Graph and SessionManager and recover.
        with ops.Graph().as_default():
            v = variables.Variable(2, name="v")
            with self.cached_session():
                self.assertEqual(False,
                                 variables.is_variable_initialized(v).eval())
            sm2 = session_manager.SessionManager(
                ready_op=variables.assert_variables_initialized())
            saver = saver_lib.Saver({"v": v})
            sess, initialized = sm2.recover_session(
                "", saver=saver, checkpoint_dir=checkpoint_dir)
            self.assertTrue(initialized)
            self.assertEqual(
                True,
                variables.is_variable_initialized(
                    sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
            self.assertEquals(1, sess.run(v))
Пример #16
0
 def testWaitForSessionInsufficientReadyForLocalInitCheck(self):
     with ops.Graph().as_default() as graph:
         v = variables.Variable(1, name="v")
         w = variables.Variable(v,
                                trainable=False,
                                collections=[ops.GraphKeys.LOCAL_VARIABLES],
                                name="w")
         sm = session_manager.SessionManager(
             graph=graph,
             ready_op=variables.report_uninitialized_variables(),
             ready_for_local_init_op=None,
             local_init_op=w.initializer)
     with self.assertRaisesRegexp(errors_impl.DeadlineExceededError,
                                  "Session was not ready after waiting.*"):
         sm.wait_for_session("", max_wait_secs=3)
Пример #17
0
 def testPrepareSessionDidNotInitLocalVariableList(self):
   with ops.Graph().as_default():
     v = variables.VariableV1(1, name="v")
     w = variables.VariableV1(
         v,
         trainable=False,
         collections=[ops.GraphKeys.LOCAL_VARIABLES],
         name="w")
     with self.cached_session():
       self.assertEqual(False, variables.is_variable_initialized(v).eval())
       self.assertEqual(False, variables.is_variable_initialized(w).eval())
     sm2 = session_manager.SessionManager(
         ready_op=variables.report_uninitialized_variables())
     with self.assertRaisesRegex(RuntimeError,
                                 "Init operations did not make model ready"):
       sm2.prepare_session("", init_op=[v.initializer])
Пример #18
0
 def testPrepareSessionWithCyclicInitializer(self):
   # Regression test. Previously Variable._build_initializer_expr would enter
   # into an infinite recursion when the variable's initial_value involved
   # cyclic dependencies.
   with ops.Graph().as_default():
     i = control_flow_ops.while_loop(lambda i: i < 1, lambda i: i + 1, [0])
     v = variables.VariableV1(array_ops.identity(i), name="v")
     with self.cached_session():
       self.assertEqual(False, variables.is_variable_initialized(v).eval())
     sm = session_manager.SessionManager(
         ready_op=variables.report_uninitialized_variables())
     sess = sm.prepare_session("", init_op=v.initializer)
     self.assertEqual(1, sess.run(v))
     self.assertEqual(
         True,
         variables.is_variable_initialized(
             sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
 def testPrepareSessionWithInsufficientReadyForLocalInitCheck(self):
   with ops.Graph().as_default():
     v = variables.Variable(1, name="v")
     w = variables.Variable(
         v,
         trainable=False,
         collections=[ops.GraphKeys.LOCAL_VARIABLES],
         name="w")
     with self.test_session():
       self.assertEqual(False, variables.is_variable_initialized(v).eval())
       self.assertEqual(False, variables.is_variable_initialized(w).eval())
     sm2 = session_manager.SessionManager(
         ready_op=variables.report_uninitialized_variables(),
         ready_for_local_init_op=None,
         local_init_op=w.initializer)
   with self.assertRaisesRegexp(errors_impl.FailedPreconditionError,
                                "Attempting to use uninitialized value v"):
     sm2.prepare_session("", init_op=None)
Пример #20
0
    def testWaitForSessionWithReadyForLocalInitOpFailsToReadyLocal(self):
        with ops.Graph().as_default() as graph:
            v = variables.Variable(1, name="v")
            w = variables.Variable(v,
                                   trainable=False,
                                   collections=[ops.GraphKeys.LOCAL_VARIABLES],
                                   name="w")
            sm = session_manager.SessionManager(
                graph=graph,
                ready_op=variables.report_uninitialized_variables(),
                ready_for_local_init_op=variables.
                report_uninitialized_variables(),
                local_init_op=w.initializer)

            with self.assertRaises(errors_impl.DeadlineExceededError):
                # Time-out because w fails to be initialized,
                # because of overly restrictive ready_for_local_init_op
                sm.wait_for_session("", max_wait_secs=3)
Пример #21
0
 def testPrepareSessionWithInsufficientReadyForLocalInitCheck(self):
   with ops.Graph().as_default():
     v = variables.VariableV1(1, name="v")
     w = variables.VariableV1(
         v,
         trainable=False,
         collections=[ops.GraphKeys.LOCAL_VARIABLES],
         name="w")
     with self.cached_session():
       self.assertEqual(False, variables.is_variable_initialized(v).eval())
       self.assertEqual(False, variables.is_variable_initialized(w).eval())
     sm2 = session_manager.SessionManager(
         ready_op=variables.report_uninitialized_variables(),
         ready_for_local_init_op=None,
         local_init_op=w.initializer)
   with self.assertRaisesRegex(RuntimeError,
                               "Init operations did not make model ready.*"):
     sm2.prepare_session("", init_op=None)
Пример #22
0
    def testRecoverSessionFailsStillRunsLocalInitOp(self):
        # Create a checkpoint.
        checkpoint_dir = os.path.join(
            self.get_temp_dir(),
            "recover_session_ready_for_local_init_fails_stil_run")
        try:
            gfile.DeleteRecursively(checkpoint_dir)
        except errors.OpError:
            pass  # Ignore
        gfile.MakeDirs(checkpoint_dir)

        # Create a new Graph and SessionManager and recover.
        with ops.Graph().as_default():
            v = variables.VariableV1(2, name="v")
            w = variables.VariableV1(
                1,
                trainable=False,
                collections=[ops.GraphKeys.LOCAL_VARIABLES],
                name="w")
            with self.cached_session():
                self.assertEqual(False,
                                 variables.is_variable_initialized(v).eval())
                self.assertEqual(False,
                                 variables.is_variable_initialized(w).eval())
            sm2 = session_manager.SessionManager(
                ready_op=variables.report_uninitialized_variables(),
                ready_for_local_init_op=None,
                local_init_op=w.initializer)
            saver = saver_lib.Saver({"v": v})
            sess, initialized = sm2.recover_session(
                "",
                saver=saver,
                checkpoint_dir=checkpoint_dir,
                wait_for_checkpoint=False)
            self.assertFalse(initialized)
            self.assertEqual(
                False,
                variables.is_variable_initialized(
                    sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
            self.assertEqual(
                True,
                variables.is_variable_initialized(
                    sess.graph.get_tensor_by_name("w:0")).eval(session=sess))
            self.assertEqual(1, sess.run(w))
Пример #23
0
  def __init__(self,
               master='',
               is_chief=True,
               checkpoint_dir=None,
               hooks=None,
               scaffold=None,
               config=None):
    """Creates a MonitoredSession.

    Args:
      master: `String` representation of the TensorFlow master to use.
      is_chief: If True, it will take care of initialization and recovery the
        underlying TensorFlow session. If False, it will wait on a chief to
        initialize or recover the TensorFlow session.
      checkpoint_dir: A string.  Optional path to a directory where to restore
        variables.
      hooks: An iterable of `SessionRunHook' objects.
      scaffold: A `Scaffold` used for gathering or building supportive ops. If
        not specified a default one is created. It's used to finalize the graph.
      config: `ConfigProto` proto used to configure the session.
    """
    self._graph = ops.get_default_graph()
    self._master = master
    self._checkpoint_dir = checkpoint_dir
    self._is_chief = is_chief
    self._config = config
    self._hooks = hooks or []
    self._scaffold = scaffold or Scaffold()
    self._coord = None
    for h in self._hooks:
      h.begin()
    # Create the session.
    self._scaffold.finalize()
    self._session_manager = sm.SessionManager(
        local_init_op=self._scaffold.local_init_op,
        ready_op=self._scaffold.ready_op,
        graph=ops.get_default_graph())
    self._sess = _RecoverableSession(self._create_session)
    self.write_graph()
Пример #24
0
 def _test_recovered_variable(self,
                              checkpoint_dir=None,
                              checkpoint_filename_with_path=None):
   # Create a new Graph and SessionManager and recover from a checkpoint.
   with ops.Graph().as_default():
     v = variables.VariableV1(2, name="v")
     with session_lib.Session():
       self.assertEqual(False, variables.is_variable_initialized(v).eval())
     sm2 = session_manager.SessionManager(
         ready_op=variables.report_uninitialized_variables())
     saver = saver_lib.Saver({"v": v})
     sess, initialized = sm2.recover_session(
         "",
         saver=saver,
         checkpoint_dir=checkpoint_dir,
         checkpoint_filename_with_path=checkpoint_filename_with_path)
     self.assertTrue(initialized)
     self.assertEqual(
         True,
         variables.is_variable_initialized(
             sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
     self.assertEqual(1, sess.run(v))
Пример #25
0
 def __init__(self,
              master,
              is_chief=True,
              checkpoint_dir=None,
              hooks=None,
              scaffold=None,
              config=None):
     self._graph = ops.get_default_graph()
     self._master = master
     self._checkpoint_dir = checkpoint_dir
     self._is_chief = is_chief
     self._config = config
     self._hooks = hooks or []
     self._scaffold = scaffold or Scaffold()
     for h in self._hooks:
         h.begin()
     # Create the session.
     self._scaffold.finalize()
     self._session_manager = sm.SessionManager(
         local_init_op=self._scaffold.local_init_op,
         ready_op=self._scaffold.ready_op,
         graph=ops.get_default_graph())
     self._sess = _RecoverableSession(self._create_session)
     self.write_graph()
Пример #26
0
def evaluate(graph,
             output_dir,
             checkpoint_path,
             eval_dict,
             update_op=None,
             global_step_tensor=None,
             supervisor_master='',
             log_every_steps=10,
             feed_fn=None,
             max_steps=None):
    """Evaluate a model loaded from a checkpoint.

  Given `graph`, a directory to write summaries to (`output_dir`), a checkpoint
  to restore variables from, and a `dict` of `Tensor`s to evaluate, run an eval
  loop for `max_steps` steps.

  In each step of evaluation, all tensors in the `eval_dict` are evaluated, and
  every `log_every_steps` steps, they are logged. At the very end of evaluation,
  a summary is evaluated (finding the summary ops using `Supervisor`'s logic)
  and written to `output_dir`.

  Args:
    graph: A `Graph` to train. It is expected that this graph is not in use
      elsewhere.
    output_dir: A string containing the directory to write a summary to.
    checkpoint_path: A string containing the path to a checkpoint to restore.
      Can be `None` if the graph doesn't require loading any variables.
    eval_dict: A `dict` mapping string names to tensors to evaluate. It is
      evaluated in every logging step. The result of the final evaluation is
      returned. If update_op is None, then it's evaluated in every step.
    update_op: A `Tensor` which is run in every step.
    global_step_tensor: A `Variable` containing the global step. If `None`,
      one is extracted from the graph using the same logic as in `Supervisor`.
      Used to place eval summaries on training curves.
    supervisor_master: The master string to use when preparing the session.
    log_every_steps: Integer. Output logs every `log_every_steps` evaluation
      steps. The logs contain the `eval_dict` and timing information.
    feed_fn: A function that is called every iteration to produce a `feed_dict`
      passed to `session.run` calls. Optional.
    max_steps: Integer. Evaluate `eval_dict` this many times.

  Returns:
    A tuple `(eval_results, global_step)`:
    eval_results: A `dict` mapping `string` to numeric values (`int`, `float`)
      that are the result of running eval_dict in the last step. `None` if no
      eval steps were run.
    global_step: The global step this evaluation corresponds to.
  """
    global_step_tensor = contrib_variables.assert_or_get_global_step(
        graph, global_step_tensor)

    # Add scalar summaries for every tensor in evaluation dict if there is not
    # one existing already or it's a string.
    existing_tags = [
        tensor_util.constant_value(summary.op.inputs[0])
        for summary in ops.get_collection(ops.GraphKeys.SUMMARIES)
    ]
    for key, value in eval_dict.items():
        if key in existing_tags:
            continue
        if isinstance(value, ops.Tensor):
            summaries.summarize_tensor(value, tag=key)

    # Create or get summary op, global_step and saver.
    summary_op = logging_ops.get_summary_op()
    saver = _get_saver()
    local_init_op = _get_local_init_op()
    ready_op = _get_ready_op()

    session_manager = session_manager_lib.SessionManager(
        local_init_op=local_init_op, ready_op=ready_op)
    session, initialized = session_manager.recover_session(
        master=supervisor_master, saver=saver, checkpoint_dir=checkpoint_path)

    # Start queue runners.
    coord = coordinator.Coordinator()
    threads = _start_queue_runners(session, coord)

    with session:
        if not initialized:
            logging.warning('Failed to initialize from %s.', checkpoint_path)
            # TODO(ipolosukhin): This should be failing, but old code relies on that.
            session.run(variables.initialize_all_variables())
            if checkpoint_path:
                _restore_from_checkpoint(session, graph, checkpoint_path,
                                         saver)

        current_global_step = session.run(global_step_tensor)
        eval_results = None
        # TODO(amodei): Fix this to run through the eval set exactly once.
        step = 0
        logging.info('Eval steps [%d,%s) for training step %d.', step,
                     'inf' if max_steps is None else str(max_steps),
                     current_global_step)
        try:
            try:
                while (max_steps is None) or (step < max_steps):
                    start_time = time.time()
                    feed_dict = feed_fn() if feed_fn is not None else None
                    eval_results = None
                    if update_op is not None:
                        session.run(update_op, feed_dict=feed_dict)
                    else:
                        eval_results = _run_dict(session,
                                                 eval_dict,
                                                 feed_dict=feed_dict)

                    # TODO(wicke): We should assert that the global step hasn't changed.
                    step += 1
                    if step % log_every_steps == 0:
                        if eval_results is None:
                            eval_results = _run_dict(session,
                                                     eval_dict,
                                                     feed_dict=feed_dict)
                        duration = time.time() - start_time
                        logging.info(
                            'Results after %d steps (%.3f sec/batch): %s.',
                            step, float(duration),
                            ', '.join('%s = %s' % (k, v)
                                      for k, v in eval_results.items()))
            finally:
                if eval_results is None:
                    eval_results = _run_dict(session,
                                             eval_dict,
                                             feed_dict=feed_dict)
                # Stop queue runners.
                coord.request_stop()
                coord.join(threads, stop_grace_period_secs=120)

                # Make our own summary writer and write a summary to the eval dir.
                # Only is feed_fn is not provided.
                # TODO(ipolosukhin): Convert evaluation to use streaming_metrics,
                # then we can save for non feed_fn as well.
                if summary_op is not None and feed_fn is None:
                    summary_writer = None
                    try:
                        summary_writer = get_summary_writer(output_dir)
                        summary_str = session.run(summary_op)
                        if summary_str:
                            summary_writer.add_summary(summary_str,
                                                       current_global_step)
                    finally:
                        if summary_writer:
                            summary_writer.close()
        # catch OutOfRangeError which is thrown when queue is out of data (and for
        # other reasons as well).
        except errors.OutOfRangeError as e:
            if max_steps is None:
                logging.info('Input queue is exhausted.')
            else:
                logging.warn('Input queue is exhausted: %s.', e)
        # catch StopIteration which is thrown is DataReader is out of data.
        except StopIteration as e:
            if max_steps is None:
                logging.info('Input iterator is exhausted.')
            else:
                logging.warn('Input iterator is exhausted: %s.', e)

    return eval_results, current_global_step
Пример #27
0
def evaluate(graph,
             output_dir,
             checkpoint_path,
             eval_dict,
             update_op=None,
             global_step_tensor=None,
             supervisor_master='',
             log_every_steps=10,
             feed_fn=None,
             max_steps=None):
  """Evaluate a model loaded from a checkpoint.

  Given `graph`, a directory to write summaries to (`output_dir`), a checkpoint
  to restore variables from, and a `dict` of `Tensor`s to evaluate, run an eval
  loop for `max_steps` steps, or until an exception (generally, an
  end-of-input signal from a reader operation) is raised from running
  `eval_dict`.

  In each step of evaluation, all tensors in the `eval_dict` are evaluated, and
  every `log_every_steps` steps, they are logged. At the very end of evaluation,
  a summary is evaluated (finding the summary ops using `Supervisor`'s logic)
  and written to `output_dir`.

  Args:
    graph: A `Graph` to train. It is expected that this graph is not in use
      elsewhere.
    output_dir: A string containing the directory to write a summary to.
    checkpoint_path: A string containing the path to a checkpoint to restore.
      Can be `None` if the graph doesn't require loading any variables.
    eval_dict: A `dict` mapping string names to tensors to evaluate. It is
      evaluated in every logging step. The result of the final evaluation is
      returned. If `update_op` is None, then it's evaluated in every step. If
      `max_steps` is `None`, this should depend on a reader that will raise an
      end-of-input exception when the inputs are exhausted.
    update_op: A `Tensor` which is run in every step.
    global_step_tensor: A `Variable` containing the global step. If `None`,
      one is extracted from the graph using the same logic as in `Supervisor`.
      Used to place eval summaries on training curves.
    supervisor_master: The master string to use when preparing the session.
    log_every_steps: Integer. Output logs every `log_every_steps` evaluation
      steps. The logs contain the `eval_dict` and timing information.
    feed_fn: A function that is called every iteration to produce a `feed_dict`
      passed to `session.run` calls. Optional.
    max_steps: Integer. Evaluate `eval_dict` this many times.

  Returns:
    A tuple `(eval_results, global_step)`:
    eval_results: A `dict` mapping `string` to numeric values (`int`, `float`)
      that are the result of running eval_dict in the last step. `None` if no
      eval steps were run.
    global_step: The global step this evaluation corresponds to.

  Raises:
    ValueError: if `output_dir` is empty.
  """
  if not output_dir:
    raise ValueError('Output directory should be non-empty %s.' % output_dir)
  with graph.as_default():
    global_step_tensor = contrib_variables.assert_or_get_global_step(
        graph, global_step_tensor)

    # Create or get summary op, global_step and saver.
    saver = _get_saver()
    local_init_op = _get_local_init_op()
    ready_op = _get_ready_op()

    session_manager = session_manager_lib.SessionManager(
        local_init_op=local_init_op,
        ready_op=ready_op)
    session, initialized = session_manager.recover_session(
        master=supervisor_master,
        saver=saver,
        checkpoint_dir=checkpoint_path)

    # Start queue runners.
    coord = coordinator.Coordinator()
    threads = queue_runner.start_queue_runners(session, coord)

  with session:
    if not initialized:
      logging.warning('Failed to initialize from %s.', checkpoint_path)
      # TODO(ipolosukhin): This should be failing, but old code relies on that.
      session.run(variables.global_variables_initializer())
      if checkpoint_path:
        _restore_from_checkpoint(session, graph, checkpoint_path, saver)

    current_global_step = session.run(global_step_tensor)
    eval_results = None
    # TODO(amodei): Fix this to run through the eval set exactly once.
    step = 0
    eval_step = None
    feed_dict = None
    logging.info('Eval steps [%d,%s) for training step %d.', step,
                 'inf' if max_steps is None
                 else str(max_steps), current_global_step)
    try:
      try:
        while (max_steps is None) or (step < max_steps):
          step += 1
          start_time = time.time()
          feed_dict = feed_fn() if feed_fn is not None else None
          if update_op is not None:
            session.run(update_op, feed_dict=feed_dict)
          else:
            eval_results = session.run(eval_dict, feed_dict=feed_dict)
            eval_step = step

          # TODO(wicke): We should assert that the global step hasn't changed.
          if step % log_every_steps == 0:
            if eval_step is None or step != eval_step:
              eval_results = session.run(eval_dict, feed_dict=feed_dict)
              eval_step = step
            duration = time.time() - start_time
            logging.info('Results after %d steps (%.3f sec/batch): %s.',
                         step, float(duration),
                         _eval_results_to_str(eval_results))
      finally:
        if eval_results is None or step != eval_step:
          eval_results = session.run(eval_dict, feed_dict=feed_dict)
          eval_step = step
        # Stop session first, before queue runners.
        session.close()

        # Stop queue runners.
        try:
          coord.request_stop()
          coord.join(threads, stop_grace_period_secs=120)
        except (RuntimeError, errors.CancelledError) as e:
          logging.warning('Coordinator didn\'t stop cleanly: %s', e)

    # catch OutOfRangeError which is thrown when queue is out of data (and for
    # other reasons as well).
    except errors.OutOfRangeError as e:
      if max_steps is None:
        logging.info('Input queue is exhausted.')
      else:
        logging.warn('Input queue is exhausted: %s.', e)
    # catch StopIteration which is thrown is DataReader is out of data.
    except StopIteration as e:
      if max_steps is None:
        logging.info('Input iterator is exhausted.')
      else:
        logging.warn('Input iterator is exhausted: %s.', e)

  # Save summaries for this evaluation.
  _write_summary_results(output_dir, eval_results, current_global_step)

  return eval_results, current_global_step
Пример #28
0
    def testPrepareSessionWithPartialInitOp(self):
        with ops.Graph().as_default():
            v = variables.Variable(1, name="v")
            w = variables.Variable(v,
                                   trainable=False,
                                   collections=[ops.GraphKeys.LOCAL_VARIABLES],
                                   name="w")
            x = variables.Variable(3 * v,
                                   trainable=False,
                                   collections=[ops.GraphKeys.LOCAL_VARIABLES],
                                   name="x")
            # TODO(b/70206927): Use ResourceVariables once they are handled properly.
            v_res = variables.Variable(1, name="v_res")
            w_res = variables.Variable(
                v_res,
                trainable=False,
                collections=[ops.GraphKeys.LOCAL_VARIABLES],
                name="w_res")
            x_res = variables.Variable(
                3 * v_res,
                trainable=False,
                collections=[ops.GraphKeys.LOCAL_VARIABLES],
                name="x_res")

            with self.cached_session():
                self.assertEqual(False,
                                 variables.is_variable_initialized(v).eval())
                self.assertEqual(False,
                                 variables.is_variable_initialized(w).eval())
                self.assertEqual(False,
                                 variables.is_variable_initialized(x).eval())
                self.assertEqual(
                    False,
                    variables.is_variable_initialized(v_res).eval())
                self.assertEqual(
                    False,
                    variables.is_variable_initialized(w_res).eval())
                self.assertEqual(
                    False,
                    variables.is_variable_initialized(x_res).eval())
            sm2 = session_manager.SessionManager(local_init_op=[
                w.initializer, x.initializer, w_res.initializer,
                x_res.initializer
            ])
            sess = sm2.prepare_session("", init_op=None)
            self.assertEqual(
                False,
                variables.is_variable_initialized(
                    sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
            self.assertEqual(
                True,
                variables.is_variable_initialized(
                    sess.graph.get_tensor_by_name("w:0")).eval(session=sess))
            self.assertEqual(
                True,
                variables.is_variable_initialized(
                    sess.graph.get_tensor_by_name("x:0")).eval(session=sess))
            self.assertEquals(1, sess.run(w))
            self.assertEquals(3, sess.run(x))
            self.assertEqual(
                False,
                variables.is_variable_initialized(
                    sess.graph.get_tensor_by_name("v_res:0")).eval(
                        session=sess))
            self.assertEqual(
                True,
                variables.is_variable_initialized(
                    sess.graph.get_tensor_by_name("w_res:0")).eval(
                        session=sess))
            self.assertEqual(
                True,
                variables.is_variable_initialized(
                    sess.graph.get_tensor_by_name("x_res:0")).eval(
                        session=sess))
            self.assertEquals(1, sess.run(w_res))
            self.assertEquals(3, sess.run(x_res))