def testPrepareSessionFails(self):
    checkpoint_dir = os.path.join(self.get_temp_dir(), "prepare_session")
    checkpoint_dir2 = os.path.join(self.get_temp_dir(), "prepare_session2")
    try:
      gfile.DeleteRecursively(checkpoint_dir)
      gfile.DeleteRecursively(checkpoint_dir2)
    except errors.OpError:
      pass  # Ignore
    gfile.MakeDirs(checkpoint_dir)

    with ops.Graph().as_default():
      v = variables.Variable([1.0, 2.0, 3.0], name="v")
      sm = session_manager.SessionManager(
          ready_op=variables.assert_variables_initialized())
      saver = saver_lib.Saver({"v": v})
      sess = sm.prepare_session(
          "",
          init_op=variables.global_variables_initializer(),
          saver=saver,
          checkpoint_dir=checkpoint_dir)
      self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
      checkpoint_filename = os.path.join(checkpoint_dir,
                                         "prepare_session_checkpoint")
      saver.save(sess, checkpoint_filename)
    # Create a new Graph and SessionManager and recover.
    with ops.Graph().as_default():
      # Renames the checkpoint directory.
      os.rename(checkpoint_dir, checkpoint_dir2)
      gfile.MakeDirs(checkpoint_dir)
      v = variables.Variable([6.0, 7.0, 8.0], name="v")
      with self.cached_session():
        self.assertEqual(False, variables.is_variable_initialized(v).eval())
      session_manager.SessionManager(
          ready_op=variables.assert_variables_initialized())
      saver = saver_lib.Saver({"v": v})
      # This should fail as there's no checkpoint within 2 seconds.
      with self.assertRaisesRegexp(
          RuntimeError, "no init_op or init_fn or local_init_op was given"):
        sess = sm.prepare_session(
            "",
            init_op=None,
            saver=saver,
            checkpoint_dir=checkpoint_dir,
            wait_for_checkpoint=True,
            max_wait_secs=2)
      # Rename the checkpoint directory back.
      gfile.DeleteRecursively(checkpoint_dir)
      os.rename(checkpoint_dir2, checkpoint_dir)
      # This should succeed as there's checkpoint.
      sess = sm.prepare_session(
          "",
          init_op=None,
          saver=saver,
          checkpoint_dir=checkpoint_dir,
          wait_for_checkpoint=True,
          max_wait_secs=2)
      self.assertEqual(
          True,
          variables.is_variable_initialized(
              sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
示例#2
0
    def testPrepareSessionFails(self):
        checkpoint_dir = os.path.join(self.get_temp_dir(), "prepare_session")
        checkpoint_dir2 = os.path.join(self.get_temp_dir(), "prepare_session2")
        try:
            gfile.DeleteRecursively(checkpoint_dir)
            gfile.DeleteRecursively(checkpoint_dir2)
        except errors.OpError:
            pass  # Ignore
        gfile.MakeDirs(checkpoint_dir)

        with ops.Graph().as_default():
            v = variables.Variable([1.0, 2.0, 3.0], name="v")
            sm = session_manager.SessionManager(
                ready_op=variables.assert_variables_initialized())
            saver = saver_lib.Saver({"v": v})
            sess = sm.prepare_session(
                "",
                init_op=variables.global_variables_initializer(),
                saver=saver,
                checkpoint_dir=checkpoint_dir)
            self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
            checkpoint_filename = os.path.join(checkpoint_dir,
                                               "prepare_session_checkpoint")
            saver.save(sess, checkpoint_filename)
        # Create a new Graph and SessionManager and recover.
        with ops.Graph().as_default():
            # Renames the checkpoint directory.
            os.rename(checkpoint_dir, checkpoint_dir2)
            gfile.MakeDirs(checkpoint_dir)
            v = variables.Variable([6.0, 7.0, 8.0], name="v")
            with self.cached_session():
                self.assertEqual(False,
                                 variables.is_variable_initialized(v).eval())
            session_manager.SessionManager(
                ready_op=variables.assert_variables_initialized())
            saver = saver_lib.Saver({"v": v})
            # This should fail as there's no checkpoint within 2 seconds.
            with self.assertRaisesRegexp(
                    RuntimeError,
                    "no init_op or init_fn or local_init_op was given"):
                sess = sm.prepare_session("",
                                          init_op=None,
                                          saver=saver,
                                          checkpoint_dir=checkpoint_dir,
                                          wait_for_checkpoint=True,
                                          max_wait_secs=2)
            # Rename the checkpoint directory back.
            gfile.DeleteRecursively(checkpoint_dir)
            os.rename(checkpoint_dir2, checkpoint_dir)
            # This should succeed as there's checkpoint.
            sess = sm.prepare_session("",
                                      init_op=None,
                                      saver=saver,
                                      checkpoint_dir=checkpoint_dir,
                                      wait_for_checkpoint=True,
                                      max_wait_secs=2)
            self.assertEqual(
                True,
                variables.is_variable_initialized(
                    sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
 def testPrepareSessionSucceeds(self):
   with ops.Graph().as_default():
     v = variables.Variable([1.0, 2.0, 3.0], name="v")
     sm = session_manager.SessionManager(
         ready_op=variables.assert_variables_initialized())
     sess = sm.prepare_session(
         "", init_op=variables.global_variables_initializer())
     self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
示例#4
0
 def testPrepareSessionSucceedsWithInitFn(self):
     with ops.Graph().as_default():
         v = variables.Variable([125], name="v")
         sm = session_manager.SessionManager(
             ready_op=variables.assert_variables_initialized())
         sess = sm.prepare_session(
             "", init_fn=lambda sess: sess.run(v.initializer))
         self.assertAllClose([125], sess.run(v))
示例#5
0
 def testPrepareSessionSucceeds(self):
     with ops.Graph().as_default():
         v = variables.Variable([1.0, 2.0, 3.0], name="v")
         sm = session_manager.SessionManager(
             ready_op=variables.assert_variables_initialized())
         sess = sm.prepare_session(
             "", init_op=variables.global_variables_initializer())
         self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
 def testPrepareSessionSucceedsWithInitFn(self):
   with ops.Graph().as_default():
     v = variables.Variable([125], name="v")
     sm = session_manager.SessionManager(
         ready_op=variables.assert_variables_initialized())
     sess = sm.prepare_session(
         "", init_fn=lambda sess: sess.run(v.initializer))
     self.assertAllClose([125], sess.run(v))
 def testVariables(self):
   with ops.Graph().as_default(), self.cached_session() as sess:
     v = variables.VariableV1([1, 2])
     w = variables.VariableV1([3, 4])
     _ = v, w
     inited = variables.assert_variables_initialized()
     with self.assertRaisesOpError("Attempting to use uninitialized value"):
       self.evaluate(inited)
     variables.global_variables_initializer().run()
     self.evaluate(inited)
  def testWaitForSessionReturnsNoneAfterTimeout(self):
    with ops.Graph().as_default():
      variables.Variable(1, name="v")
      sm = session_manager.SessionManager(
          ready_op=variables.assert_variables_initialized(),
          recovery_wait_secs=1)

      # Set max_wait_secs to allow us to try a few times.
      with self.assertRaises(errors.DeadlineExceededError):
        sm.wait_for_session(master="", max_wait_secs=3)
示例#9
0
 def testVariables(self):
   with ops.Graph().as_default(), self.cached_session() as sess:
     v = variables.VariableV1([1, 2])
     w = variables.VariableV1([3, 4])
     _ = v, w
     inited = variables.assert_variables_initialized()
     with self.assertRaisesOpError("Attempting to use uninitialized value"):
       self.evaluate(inited)
     self.evaluate(variables.global_variables_initializer())
     self.evaluate(inited)
示例#10
0
    def testWaitForSessionReturnsNoneAfterTimeout(self):
        with ops.Graph().as_default():
            variables.Variable(1, name="v")
            sm = session_manager.SessionManager(
                ready_op=variables.assert_variables_initialized(),
                recovery_wait_secs=1)

            # Set max_wait_secs to allow us to try a few times.
            with self.assertRaises(errors.DeadlineExceededError):
                sm.wait_for_session(master="", max_wait_secs=3)
示例#11
0
 def testPrepareSessionSucceedsWithInitFeedDict(self):
     with ops.Graph().as_default():
         p = array_ops.placeholder(dtypes.float32, shape=(3, ))
         v = variables.Variable(p, name="v")
         sm = session_manager.SessionManager(
             ready_op=variables.assert_variables_initialized())
         sess = sm.prepare_session(
             "",
             init_op=variables.global_variables_initializer(),
             init_feed_dict={p: [1.0, 2.0, 3.0]})
         self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
 def testPrepareSessionSucceedsWithInitFeedDict(self):
   with ops.Graph().as_default():
     p = array_ops.placeholder(dtypes.float32, shape=(3,))
     v = variables.Variable(p, name="v")
     sm = session_manager.SessionManager(
         ready_op=variables.assert_variables_initialized())
     sess = sm.prepare_session(
         "",
         init_op=variables.global_variables_initializer(),
         init_feed_dict={p: [1.0, 2.0, 3.0]})
     self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
示例#13
0
 def testVariableList(self):
   with ops.Graph().as_default(), self.cached_session() as sess:
     v = variables.VariableV1([1, 2])
     w = variables.VariableV1([3, 4])
     inited = variables.assert_variables_initialized([v])
     with self.assertRaisesOpError("Attempting to use uninitialized value"):
       inited.op.run()
     self.evaluate(w.initializer)
     with self.assertRaisesOpError("Attempting to use uninitialized value"):
       inited.op.run()
     v.initializer.run()
     inited.op.run()
 def testVariableList(self):
   with ops.Graph().as_default(), self.cached_session() as sess:
     v = variables.VariableV1([1, 2])
     w = variables.VariableV1([3, 4])
     inited = variables.assert_variables_initialized([v])
     with self.assertRaisesOpError("Attempting to use uninitialized value"):
       inited.op.run()
     self.evaluate(w.initializer)
     with self.assertRaisesOpError("Attempting to use uninitialized value"):
       inited.op.run()
     v.initializer.run()
     inited.op.run()
示例#15
0
    def testRecoverSession(self):
        # Create a checkpoint.
        checkpoint_dir = os.path.join(self.get_temp_dir(), "recover_session")
        try:
            gfile.DeleteRecursively(checkpoint_dir)
        except errors.OpError:
            pass  # Ignore
        gfile.MakeDirs(checkpoint_dir)

        with ops.Graph().as_default():
            v = variables.Variable(1, name="v")
            sm = session_manager.SessionManager(
                ready_op=variables.assert_variables_initialized())
            saver = saver_lib.Saver({"v": v})
            sess, initialized = sm.recover_session(
                "", saver=saver, checkpoint_dir=checkpoint_dir)
            self.assertFalse(initialized)
            sess.run(v.initializer)
            self.assertEquals(1, sess.run(v))
            saver.save(
                sess, os.path.join(checkpoint_dir,
                                   "recover_session_checkpoint"))
        # Create a new Graph and SessionManager and recover.
        with ops.Graph().as_default():
            v = variables.Variable(2, name="v")
            with self.cached_session():
                self.assertEqual(False,
                                 variables.is_variable_initialized(v).eval())
            sm2 = session_manager.SessionManager(
                ready_op=variables.assert_variables_initialized())
            saver = saver_lib.Saver({"v": v})
            sess, initialized = sm2.recover_session(
                "", saver=saver, checkpoint_dir=checkpoint_dir)
            self.assertTrue(initialized)
            self.assertEqual(
                True,
                variables.is_variable_initialized(
                    sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
            self.assertEquals(1, sess.run(v))
  def testRecoverSession(self):
    # Create a checkpoint.
    checkpoint_dir = os.path.join(self.get_temp_dir(), "recover_session")
    try:
      gfile.DeleteRecursively(checkpoint_dir)
    except errors.OpError:
      pass  # Ignore
    gfile.MakeDirs(checkpoint_dir)

    with ops.Graph().as_default():
      v = variables.Variable(1, name="v")
      sm = session_manager.SessionManager(
          ready_op=variables.assert_variables_initialized())
      saver = saver_lib.Saver({"v": v})
      sess, initialized = sm.recover_session(
          "", saver=saver, checkpoint_dir=checkpoint_dir)
      self.assertFalse(initialized)
      sess.run(v.initializer)
      self.assertEquals(1, sess.run(v))
      saver.save(sess,
                 os.path.join(checkpoint_dir, "recover_session_checkpoint"))
    # Create a new Graph and SessionManager and recover.
    with ops.Graph().as_default():
      v = variables.Variable(2, name="v")
      with self.cached_session():
        self.assertEqual(False, variables.is_variable_initialized(v).eval())
      sm2 = session_manager.SessionManager(
          ready_op=variables.assert_variables_initialized())
      saver = saver_lib.Saver({"v": v})
      sess, initialized = sm2.recover_session(
          "", saver=saver, checkpoint_dir=checkpoint_dir)
      self.assertTrue(initialized)
      self.assertEqual(
          True,
          variables.is_variable_initialized(
              sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
      self.assertEquals(1, sess.run(v))
示例#17
0
  def _init_ready_op(self, ready_op=USE_DEFAULT):
    """Initializes ready_op.

    Args:
      ready_op: `Operation` to check if the model is initialized.
        If it's set to USE_DEFAULT, creates an op that checks all
        the variables are initialized.
    """
    if ready_op is Supervisor.USE_DEFAULT:
      ready_op = self._get_first_op_from_collection(ops.GraphKeys.READY_OP)
      if ready_op is None:
        ready_op = variables.assert_variables_initialized()
        if ready_op is not None:
          ops.add_to_collection(ops.GraphKeys.READY_OP, ready_op)
    self._ready_op = ready_op
示例#18
0
  def _init_ready_op(self, ready_op=USE_DEFAULT):
    """Initializes ready_op.

    Args:
      ready_op: `Operation` to check if the model is initialized.
        If it's set to USE_DEFAULT, creates an op that checks all
        the variables are initialized.
    """
    if ready_op is Supervisor.USE_DEFAULT:
      ready_op = self._get_first_op_from_collection(ops.GraphKeys.READY_OP)
      if ready_op is None:
        ready_op = variables.assert_variables_initialized()
        if ready_op is not None:
          ops.add_to_collection(ops.GraphKeys.READY_OP, ready_op)
    self._ready_op = ready_op
示例#19
0
 def testNoVars(self):
   with ops.Graph().as_default():
     self.assertEqual(None, variables.assert_variables_initialized())
 def testNoVars(self):
   with ops.Graph().as_default():
     self.assertEqual(None, variables.assert_variables_initialized())