def testIsVariableInitialized(self):
   for use_gpu in [True, False]:
     with self.test_session(use_gpu=use_gpu):
       v0 = state_ops.variable_op([1, 2], dtypes.float32)
       self.assertEqual(False, variables.is_variable_initialized(v0).eval())
       state_ops.assign(v0, [[2.0, 3.0]]).eval()
       self.assertEqual(True, variables.is_variable_initialized(v0).eval())
 def testPrepareSessionWithReadyForLocalInitOp(self):
   with ops.Graph().as_default():
     v = variables.Variable(1, name="v")
     w = variables.Variable(
         v,
         trainable=False,
         collections=[ops.GraphKeys.LOCAL_VARIABLES],
         name="w")
     with self.test_session():
       self.assertEqual(False, variables.is_variable_initialized(v).eval())
       self.assertEqual(False, variables.is_variable_initialized(w).eval())
     sm2 = session_manager.SessionManager(
         ready_op=variables.report_uninitialized_variables(),
         ready_for_local_init_op=variables.report_uninitialized_variables(
             variables.global_variables()),
         local_init_op=w.initializer)
     sess = sm2.prepare_session("", init_op=v.initializer)
     self.assertEqual(
         True,
         variables.is_variable_initialized(
             sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
     self.assertEqual(
         True,
         variables.is_variable_initialized(
             sess.graph.get_tensor_by_name("w:0")).eval(session=sess))
     self.assertEquals(1, sess.run(v))
     self.assertEquals(1, sess.run(w))
 def testIsVariableInitialized(self):
   for use_gpu in [True, False]:
     with self.test_session(use_gpu=use_gpu):
       v0 = state_ops.variable_op([1, 2], dtypes.float32)
       self.assertEqual(False, variables.is_variable_initialized(v0).eval())
       state_ops.assign(v0, [[2.0, 3.0]]).eval()
       self.assertEqual(True, variables.is_variable_initialized(v0).eval())
Beispiel #4
0
    def testWaitForSessionLocalInit(self):
        server = server_lib.Server.create_local_server()
        with ops.Graph().as_default() as graph:
            v = variables.Variable(1, name="v")
            w = variables.Variable(v,
                                   trainable=False,
                                   collections=[ops.GraphKeys.LOCAL_VARIABLES],
                                   name="w")
            sm = session_manager.SessionManager(
                graph=graph,
                ready_op=variables.report_uninitialized_variables(),
                ready_for_local_init_op=variables.
                report_uninitialized_variables(variables.global_variables()),
                local_init_op=w.initializer)

            # Initialize v but not w
            s = session_lib.Session(server.target, graph=graph)
            s.run(v.initializer)

            sess = sm.wait_for_session(server.target, max_wait_secs=3)
            self.assertEqual(
                True,
                variables.is_variable_initialized(
                    sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
            self.assertEqual(
                True,
                variables.is_variable_initialized(
                    sess.graph.get_tensor_by_name("w:0")).eval(session=sess))
            self.assertEquals(1, sess.run(v))
            self.assertEquals(1, sess.run(w))
  def testPrepareSessionFails(self):
    checkpoint_dir = os.path.join(self.get_temp_dir(), "prepare_session")
    checkpoint_dir2 = os.path.join(self.get_temp_dir(), "prepare_session2")
    try:
      gfile.DeleteRecursively(checkpoint_dir)
      gfile.DeleteRecursively(checkpoint_dir2)
    except errors.OpError:
      pass  # Ignore
    gfile.MakeDirs(checkpoint_dir)

    with ops.Graph().as_default():
      v = variables.Variable([1.0, 2.0, 3.0], name="v")
      sm = session_manager.SessionManager(
          ready_op=variables.report_uninitialized_variables())
      saver = saver_lib.Saver({"v": v})
      sess = sm.prepare_session(
          "",
          init_op=variables.global_variables_initializer(),
          saver=saver,
          checkpoint_dir=checkpoint_dir)
      self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
      checkpoint_filename = os.path.join(checkpoint_dir,
                                         "prepare_session_checkpoint")
      saver.save(sess, checkpoint_filename)
    # Create a new Graph and SessionManager and recover.
    with ops.Graph().as_default():
      # Renames the checkpoint directory.
      os.rename(checkpoint_dir, checkpoint_dir2)
      gfile.MakeDirs(checkpoint_dir)
      v = variables.Variable([6.0, 7.0, 8.0], name="v")
      with self.cached_session():
        self.assertEqual(False, variables.is_variable_initialized(v).eval())
      session_manager.SessionManager(
          ready_op=variables.report_uninitialized_variables())
      saver = saver_lib.Saver({"v": v})
      # This should fail as there's no checkpoint within 2 seconds.
      with self.assertRaisesRegexp(
          RuntimeError, "no init_op or init_fn or local_init_op was given"):
        sess = sm.prepare_session(
            "",
            init_op=None,
            saver=saver,
            checkpoint_dir=checkpoint_dir,
            wait_for_checkpoint=True,
            max_wait_secs=2)
      # Rename the checkpoint directory back.
      gfile.DeleteRecursively(checkpoint_dir)
      os.rename(checkpoint_dir2, checkpoint_dir)
      # This should succeed as there's checkpoint.
      sess = sm.prepare_session(
          "",
          init_op=None,
          saver=saver,
          checkpoint_dir=checkpoint_dir,
          wait_for_checkpoint=True,
          max_wait_secs=2)
      self.assertEqual(
          True,
          variables.is_variable_initialized(
              sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
  def testWaitForSessionLocalInit(self):
    server = server_lib.Server.create_local_server()
    with ops.Graph().as_default() as graph:
      v = variables.Variable(1, name="v")
      w = variables.Variable(
          v,
          trainable=False,
          collections=[ops.GraphKeys.LOCAL_VARIABLES],
          name="w")
      sm = session_manager.SessionManager(
          graph=graph,
          ready_op=variables.report_uninitialized_variables(),
          ready_for_local_init_op=variables.report_uninitialized_variables(
              variables.global_variables()),
          local_init_op=w.initializer)

      # Initialize v but not w
      s = session_lib.Session(server.target, graph=graph)
      s.run(v.initializer)

      sess = sm.wait_for_session(server.target, max_wait_secs=3)
      self.assertEqual(
          True,
          variables.is_variable_initialized(
              sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
      self.assertEqual(
          True,
          variables.is_variable_initialized(
              sess.graph.get_tensor_by_name("w:0")).eval(session=sess))
      self.assertEquals(1, sess.run(v))
      self.assertEquals(1, sess.run(w))
Beispiel #7
0
 def testRecoverSessionNoChkptStillRunsLocalInitOp(self):
     # This test checks for backwards compatibility.
     # In particular, we continue to ensure that recover_session will execute
     # local_init_op exactly once, regardless of whether the session was
     # successfully recovered.
     with ops.Graph().as_default():
         w = variables.Variable(1,
                                trainable=False,
                                collections=[ops.GraphKeys.LOCAL_VARIABLES],
                                name="w")
         with self.cached_session():
             self.assertEqual(False,
                              variables.is_variable_initialized(w).eval())
         sm2 = session_manager.SessionManager(
             ready_op=variables.report_uninitialized_variables(),
             ready_for_local_init_op=None,
             local_init_op=w.initializer)
         # Try to recover session from None
         sess, initialized = sm2.recover_session("",
                                                 saver=None,
                                                 checkpoint_dir=None)
         # Succeeds because recover_session still run local_init_op
         self.assertFalse(initialized)
         self.assertEqual(
             True,
             variables.is_variable_initialized(
                 sess.graph.get_tensor_by_name("w:0")).eval(session=sess))
         self.assertEquals(1, sess.run(w))
Beispiel #8
0
    def testPrepareSessionFails(self):
        checkpoint_dir = os.path.join(self.get_temp_dir(), "prepare_session")
        checkpoint_dir2 = os.path.join(self.get_temp_dir(), "prepare_session2")
        try:
            gfile.DeleteRecursively(checkpoint_dir)
            gfile.DeleteRecursively(checkpoint_dir2)
        except errors.OpError:
            pass  # Ignore
        gfile.MakeDirs(checkpoint_dir)

        with ops.Graph().as_default():
            v = variables.Variable([1.0, 2.0, 3.0], name="v")
            sm = session_manager.SessionManager(
                ready_op=variables.assert_variables_initialized())
            saver = saver_lib.Saver({"v": v})
            sess = sm.prepare_session(
                "",
                init_op=variables.global_variables_initializer(),
                saver=saver,
                checkpoint_dir=checkpoint_dir)
            self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
            checkpoint_filename = os.path.join(checkpoint_dir,
                                               "prepare_session_checkpoint")
            saver.save(sess, checkpoint_filename)
        # Create a new Graph and SessionManager and recover.
        with ops.Graph().as_default():
            # Renames the checkpoint directory.
            os.rename(checkpoint_dir, checkpoint_dir2)
            gfile.MakeDirs(checkpoint_dir)
            v = variables.Variable([6.0, 7.0, 8.0], name="v")
            with self.cached_session():
                self.assertEqual(False,
                                 variables.is_variable_initialized(v).eval())
            session_manager.SessionManager(
                ready_op=variables.assert_variables_initialized())
            saver = saver_lib.Saver({"v": v})
            # This should fail as there's no checkpoint within 2 seconds.
            with self.assertRaisesRegexp(
                    RuntimeError,
                    "no init_op or init_fn or local_init_op was given"):
                sess = sm.prepare_session("",
                                          init_op=None,
                                          saver=saver,
                                          checkpoint_dir=checkpoint_dir,
                                          wait_for_checkpoint=True,
                                          max_wait_secs=2)
            # Rename the checkpoint directory back.
            gfile.DeleteRecursively(checkpoint_dir)
            os.rename(checkpoint_dir2, checkpoint_dir)
            # This should succeed as there's checkpoint.
            sess = sm.prepare_session("",
                                      init_op=None,
                                      saver=saver,
                                      checkpoint_dir=checkpoint_dir,
                                      wait_for_checkpoint=True,
                                      max_wait_secs=2)
            self.assertEqual(
                True,
                variables.is_variable_initialized(
                    sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
 def testRecoverSessionNoChkptStillRunsLocalInitOp(self):
   # This test checks for backwards compatibility.
   # In particular, we continue to ensure that recover_session will execute
   # local_init_op exactly once, regardless of whether the session was
   # successfully recovered.
   with ops.Graph().as_default():
     w = variables.Variable(
         1,
         trainable=False,
         collections=[ops.GraphKeys.LOCAL_VARIABLES],
         name="w")
     with self.cached_session():
       self.assertEqual(False, variables.is_variable_initialized(w).eval())
     sm2 = session_manager.SessionManager(
         ready_op=variables.report_uninitialized_variables(),
         ready_for_local_init_op=None,
         local_init_op=w.initializer)
     # Try to recover session from None
     sess, initialized = sm2.recover_session(
         "", saver=None, checkpoint_dir=None)
     # Succeeds because recover_session still run local_init_op
     self.assertFalse(initialized)
     self.assertEqual(
         True,
         variables.is_variable_initialized(
             sess.graph.get_tensor_by_name("w:0")).eval(session=sess))
     self.assertEquals(1, sess.run(w))
 def testPrepareSessionWithReadyForLocalInitOp(self):
     with ops.Graph().as_default():
         v = variables.Variable(1, name="v")
         w = variables.Variable(v,
                                trainable=False,
                                collections=[ops.GraphKeys.LOCAL_VARIABLES],
                                name="w")
         with self.test_session():
             self.assertEqual(False,
                              variables.is_variable_initialized(v).eval())
             self.assertEqual(False,
                              variables.is_variable_initialized(w).eval())
         sm2 = session_manager.SessionManager(
             ready_op=variables.report_uninitialized_variables(),
             ready_for_local_init_op=variables.
             report_uninitialized_variables(variables.global_variables()),
             local_init_op=w.initializer)
         sess = sm2.prepare_session("", init_op=v.initializer)
         self.assertEqual(
             True,
             variables.is_variable_initialized(
                 sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
         self.assertEqual(
             True,
             variables.is_variable_initialized(
                 sess.graph.get_tensor_by_name("w:0")).eval(session=sess))
         self.assertEquals(1, sess.run(v))
         self.assertEquals(1, sess.run(w))
    def testRecoverSessionWithReadyForLocalInitOpFailsToReadyLocal(self):
        # We use ready_for_local_init_op=report_uninitialized_variables(),
        # which causes recover_session to not run local_init_op, and to return
        # initialized=False

        # Create a checkpoint.
        checkpoint_dir = os.path.join(
            self.get_temp_dir(),
            "recover_session_ready_for_local_init_fails_to_ready_local")
        try:
            gfile.DeleteRecursively(checkpoint_dir)
        except errors.OpError:
            pass  # Ignore
        gfile.MakeDirs(checkpoint_dir)

        with ops.Graph().as_default():
            v = variables.VariableV1(1, name="v")
            sm = session_manager.SessionManager(
                ready_op=variables.report_uninitialized_variables())
            saver = saver_lib.Saver({"v": v})
            sess, initialized = sm.recover_session(
                "", saver=saver, checkpoint_dir=checkpoint_dir)
            self.assertFalse(initialized)
            sess.run(v.initializer)
            self.assertEqual(1, sess.run(v))
            saver.save(
                sess, os.path.join(checkpoint_dir,
                                   "recover_session_checkpoint"))
        # Create a new Graph and SessionManager and recover.
        with ops.Graph().as_default():
            v = variables.VariableV1(2, name="v")
            w = variables.VariableV1(
                v,
                trainable=False,
                collections=[ops.GraphKeys.LOCAL_VARIABLES],
                name="w")
            with self.cached_session():
                self.assertEqual(False,
                                 variables.is_variable_initialized(v).eval())
                self.assertEqual(False,
                                 variables.is_variable_initialized(w).eval())
            sm2 = session_manager.SessionManager(
                ready_op=variables.report_uninitialized_variables(),
                ready_for_local_init_op=variables.
                report_uninitialized_variables(),
                local_init_op=w.initializer)
            saver = saver_lib.Saver({"v": v})
            sess, initialized = sm2.recover_session(
                "", saver=saver, checkpoint_dir=checkpoint_dir)
            self.assertFalse(initialized)
            self.assertEqual(
                True,
                variables.is_variable_initialized(
                    sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
            self.assertEqual(
                False,
                variables.is_variable_initialized(
                    sess.graph.get_tensor_by_name("w:0")).eval(session=sess))
            self.assertEqual(1, sess.run(v))
  def testRecoverSessionWithReadyForLocalInitOpFailsToReadyLocal(self):
    # We use ready_for_local_init_op=tf.report_uninitialized_variables(),
    # which causes recover_session to not run local_init_op, and to return
    # initialized=False

    # Create a checkpoint.
    checkpoint_dir = os.path.join(
        self.get_temp_dir(),
        "recover_session_ready_for_local_init_fails_to_ready_local")
    try:
      gfile.DeleteRecursively(checkpoint_dir)
    except errors.OpError:
      pass  # Ignore
    gfile.MakeDirs(checkpoint_dir)

    with ops.Graph().as_default():
      v = variables.Variable(1, name="v")
      sm = session_manager.SessionManager(
          ready_op=variables.report_uninitialized_variables())
      saver = saver_lib.Saver({"v": v})
      sess, initialized = sm.recover_session(
          "", saver=saver, checkpoint_dir=checkpoint_dir)
      self.assertFalse(initialized)
      sess.run(v.initializer)
      self.assertEquals(1, sess.run(v))
      saver.save(sess,
                 os.path.join(checkpoint_dir, "recover_session_checkpoint"))
    # Create a new Graph and SessionManager and recover.
    with ops.Graph().as_default():
      v = variables.Variable(2, name="v")
      w = variables.Variable(
          v,
          trainable=False,
          collections=[ops.GraphKeys.LOCAL_VARIABLES],
          name="w")
      with self.cached_session():
        self.assertEqual(False, variables.is_variable_initialized(v).eval())
        self.assertEqual(False, variables.is_variable_initialized(w).eval())
      sm2 = session_manager.SessionManager(
          ready_op=variables.report_uninitialized_variables(),
          ready_for_local_init_op=variables.report_uninitialized_variables(),
          local_init_op=w.initializer)
      saver = saver_lib.Saver({"v": v})
      sess, initialized = sm2.recover_session(
          "", saver=saver, checkpoint_dir=checkpoint_dir)
      self.assertFalse(initialized)
      self.assertEqual(
          True,
          variables.is_variable_initialized(
              sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
      self.assertEqual(
          False,
          variables.is_variable_initialized(
              sess.graph.get_tensor_by_name("w:0")).eval(session=sess))
      self.assertEquals(1, sess.run(v))
Beispiel #13
0
 def testPrepareSessionDidNotInitLocalVariableList(self):
   with ops.Graph().as_default():
     v = variables.VariableV1(1, name="v")
     w = variables.VariableV1(
         v,
         trainable=False,
         collections=[ops.GraphKeys.LOCAL_VARIABLES],
         name="w")
     with self.cached_session():
       self.assertEqual(False, variables.is_variable_initialized(v).eval())
       self.assertEqual(False, variables.is_variable_initialized(w).eval())
     sm2 = session_manager.SessionManager(
         ready_op=variables.report_uninitialized_variables())
     with self.assertRaisesRegex(RuntimeError,
                                 "Init operations did not make model ready"):
       sm2.prepare_session("", init_op=[v.initializer])
 def testPrepareSessionDidNotInitLocalVariableList(self):
   with ops.Graph().as_default():
     v = variables.Variable(1, name="v")
     w = variables.Variable(
         v,
         trainable=False,
         collections=[ops.GraphKeys.LOCAL_VARIABLES],
         name="w")
     with self.cached_session():
       self.assertEqual(False, variables.is_variable_initialized(v).eval())
       self.assertEqual(False, variables.is_variable_initialized(w).eval())
     sm2 = session_manager.SessionManager(
         ready_op=variables.report_uninitialized_variables())
     with self.assertRaisesRegexp(RuntimeError,
                                  "Init operations did not make model ready"):
       sm2.prepare_session("", init_op=[v.initializer])
 def testPrepareSessionWithCyclicInitializer(self):
   # Regression test. Previously Variable._build_initializer_expr would enter
   # into an infinite recursion when the variable's initial_value involved
   # cyclic dependencies.
   with ops.Graph().as_default():
     i = control_flow_ops.while_loop(lambda i: i < 1, lambda i: i + 1, [0])
     v = variables.Variable(array_ops.identity(i), name="v")
     with self.cached_session():
       self.assertEqual(False, variables.is_variable_initialized(v).eval())
     sm = session_manager.SessionManager(
         ready_op=variables.report_uninitialized_variables())
     sess = sm.prepare_session("", init_op=v.initializer)
     self.assertEqual(1, sess.run(v))
     self.assertEqual(
         True,
         variables.is_variable_initialized(
             sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
Beispiel #16
0
 def testPrepareSessionWithCyclicInitializer(self):
   # Regression test. Previously Variable._build_initializer_expr would enter
   # into an infinite recursion when the variable's initial_value involved
   # cyclic dependencies.
   with ops.Graph().as_default():
     i = control_flow_ops.while_loop(lambda i: i < 1, lambda i: i + 1, [0])
     v = variables.VariableV1(array_ops.identity(i), name="v")
     with self.cached_session():
       self.assertEqual(False, variables.is_variable_initialized(v).eval())
     sm = session_manager.SessionManager(
         ready_op=variables.report_uninitialized_variables())
     sess = sm.prepare_session("", init_op=v.initializer)
     self.assertEqual(1, sess.run(v))
     self.assertEqual(
         True,
         variables.is_variable_initialized(
             sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
 def testPrepareSessionWithInsufficientReadyForLocalInitCheck(self):
   with ops.Graph().as_default():
     v = variables.Variable(1, name="v")
     w = variables.Variable(
         v,
         trainable=False,
         collections=[ops.GraphKeys.LOCAL_VARIABLES],
         name="w")
     with self.test_session():
       self.assertEqual(False, variables.is_variable_initialized(v).eval())
       self.assertEqual(False, variables.is_variable_initialized(w).eval())
     sm2 = session_manager.SessionManager(
         ready_op=variables.report_uninitialized_variables(),
         ready_for_local_init_op=None,
         local_init_op=w.initializer)
   with self.assertRaisesRegexp(errors_impl.FailedPreconditionError,
                                "Attempting to use uninitialized value v"):
     sm2.prepare_session("", init_op=None)
 def testPrepareSessionWithInsufficientReadyForLocalInitCheck(self):
   with ops.Graph().as_default():
     v = variables.Variable(1, name="v")
     w = variables.Variable(
         v,
         trainable=False,
         collections=[ops.GraphKeys.LOCAL_VARIABLES],
         name="w")
     with self.cached_session():
       self.assertEqual(False, variables.is_variable_initialized(v).eval())
       self.assertEqual(False, variables.is_variable_initialized(w).eval())
     sm2 = session_manager.SessionManager(
         ready_op=variables.report_uninitialized_variables(),
         ready_for_local_init_op=None,
         local_init_op=w.initializer)
   with self.assertRaisesRegexp(RuntimeError,
                                "Init operations did not make model ready.*"):
     sm2.prepare_session("", init_op=None)
Beispiel #19
0
 def testPrepareSessionWithInsufficientReadyForLocalInitCheck(self):
   with ops.Graph().as_default():
     v = variables.VariableV1(1, name="v")
     w = variables.VariableV1(
         v,
         trainable=False,
         collections=[ops.GraphKeys.LOCAL_VARIABLES],
         name="w")
     with self.cached_session():
       self.assertEqual(False, variables.is_variable_initialized(v).eval())
       self.assertEqual(False, variables.is_variable_initialized(w).eval())
     sm2 = session_manager.SessionManager(
         ready_op=variables.report_uninitialized_variables(),
         ready_for_local_init_op=None,
         local_init_op=w.initializer)
   with self.assertRaisesRegex(RuntimeError,
                               "Init operations did not make model ready.*"):
     sm2.prepare_session("", init_op=None)
    def testRecoverSessionFailsStillRunsLocalInitOp(self):
        # Create a checkpoint.
        checkpoint_dir = os.path.join(
            self.get_temp_dir(),
            "recover_session_ready_for_local_init_fails_stil_run")
        try:
            gfile.DeleteRecursively(checkpoint_dir)
        except errors.OpError:
            pass  # Ignore
        gfile.MakeDirs(checkpoint_dir)

        # Create a new Graph and SessionManager and recover.
        with ops.Graph().as_default():
            v = variables.VariableV1(2, name="v")
            w = variables.VariableV1(
                1,
                trainable=False,
                collections=[ops.GraphKeys.LOCAL_VARIABLES],
                name="w")
            with self.cached_session():
                self.assertEqual(False,
                                 variables.is_variable_initialized(v).eval())
                self.assertEqual(False,
                                 variables.is_variable_initialized(w).eval())
            sm2 = session_manager.SessionManager(
                ready_op=variables.report_uninitialized_variables(),
                ready_for_local_init_op=None,
                local_init_op=w.initializer)
            saver = saver_lib.Saver({"v": v})
            sess, initialized = sm2.recover_session(
                "",
                saver=saver,
                checkpoint_dir=checkpoint_dir,
                wait_for_checkpoint=False)
            self.assertFalse(initialized)
            self.assertEqual(
                False,
                variables.is_variable_initialized(
                    sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
            self.assertEqual(
                True,
                variables.is_variable_initialized(
                    sess.graph.get_tensor_by_name("w:0")).eval(session=sess))
            self.assertEqual(1, sess.run(w))
  def testRecoverSessionFailsStillRunsLocalInitOp(self):
    # Create a checkpoint.
    checkpoint_dir = os.path.join(
        self.get_temp_dir(),
        "recover_session_ready_for_local_init_fails_stil_run")
    try:
      gfile.DeleteRecursively(checkpoint_dir)
    except errors.OpError:
      pass  # Ignore
    gfile.MakeDirs(checkpoint_dir)

    # Create a new Graph and SessionManager and recover.
    with ops.Graph().as_default():
      v = variables.Variable(2, name="v")
      w = variables.Variable(
          1,
          trainable=False,
          collections=[ops.GraphKeys.LOCAL_VARIABLES],
          name="w")
      with self.cached_session():
        self.assertEqual(False, variables.is_variable_initialized(v).eval())
        self.assertEqual(False, variables.is_variable_initialized(w).eval())
      sm2 = session_manager.SessionManager(
          ready_op=variables.report_uninitialized_variables(),
          ready_for_local_init_op=None,
          local_init_op=w.initializer)
      saver = saver_lib.Saver({"v": v})
      sess, initialized = sm2.recover_session(
          "",
          saver=saver,
          checkpoint_dir=checkpoint_dir,
          wait_for_checkpoint=False)
      self.assertFalse(initialized)
      self.assertEqual(
          False,
          variables.is_variable_initialized(
              sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
      self.assertEqual(
          True,
          variables.is_variable_initialized(
              sess.graph.get_tensor_by_name("w:0")).eval(session=sess))
      self.assertEquals(1, sess.run(w))
Beispiel #22
0
    def testRecoverSession(self):
        # Create a checkpoint.
        checkpoint_dir = os.path.join(self.get_temp_dir(), "recover_session")
        try:
            gfile.DeleteRecursively(checkpoint_dir)
        except errors.OpError:
            pass  # Ignore
        gfile.MakeDirs(checkpoint_dir)

        with ops.Graph().as_default():
            v = variables.Variable(1, name="v")
            sm = session_manager.SessionManager(
                ready_op=variables.assert_variables_initialized())
            saver = saver_lib.Saver({"v": v})
            sess, initialized = sm.recover_session(
                "", saver=saver, checkpoint_dir=checkpoint_dir)
            self.assertFalse(initialized)
            sess.run(v.initializer)
            self.assertEquals(1, sess.run(v))
            saver.save(
                sess, os.path.join(checkpoint_dir,
                                   "recover_session_checkpoint"))
        # Create a new Graph and SessionManager and recover.
        with ops.Graph().as_default():
            v = variables.Variable(2, name="v")
            with self.cached_session():
                self.assertEqual(False,
                                 variables.is_variable_initialized(v).eval())
            sm2 = session_manager.SessionManager(
                ready_op=variables.assert_variables_initialized())
            saver = saver_lib.Saver({"v": v})
            sess, initialized = sm2.recover_session(
                "", saver=saver, checkpoint_dir=checkpoint_dir)
            self.assertTrue(initialized)
            self.assertEqual(
                True,
                variables.is_variable_initialized(
                    sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
            self.assertEquals(1, sess.run(v))
 def _test_recovered_variable(self,
                              checkpoint_dir=None,
                              checkpoint_filename_with_path=None):
   # Create a new Graph and SessionManager and recover from a checkpoint.
   with ops.Graph().as_default():
     v = variables.Variable(2, name="v")
     with session_lib.Session():
       self.assertEqual(False, variables.is_variable_initialized(v).eval())
     sm2 = session_manager.SessionManager(
         ready_op=variables.report_uninitialized_variables())
     saver = saver_lib.Saver({"v": v})
     sess, initialized = sm2.recover_session(
         "",
         saver=saver,
         checkpoint_dir=checkpoint_dir,
         checkpoint_filename_with_path=checkpoint_filename_with_path)
     self.assertTrue(initialized)
     self.assertEqual(
         True,
         variables.is_variable_initialized(
             sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
     self.assertEquals(1, sess.run(v))
Beispiel #24
0
 def _test_recovered_variable(self,
                              checkpoint_dir=None,
                              checkpoint_filename_with_path=None):
   # Create a new Graph and SessionManager and recover from a checkpoint.
   with ops.Graph().as_default():
     v = variables.VariableV1(2, name="v")
     with session_lib.Session():
       self.assertEqual(False, variables.is_variable_initialized(v).eval())
     sm2 = session_manager.SessionManager(
         ready_op=variables.report_uninitialized_variables())
     saver = saver_lib.Saver({"v": v})
     sess, initialized = sm2.recover_session(
         "",
         saver=saver,
         checkpoint_dir=checkpoint_dir,
         checkpoint_filename_with_path=checkpoint_filename_with_path)
     self.assertTrue(initialized)
     self.assertEqual(
         True,
         variables.is_variable_initialized(
             sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
     self.assertEqual(1, sess.run(v))
  def testRecoverSession(self):
    # Create a checkpoint.
    checkpoint_dir = os.path.join(self.get_temp_dir(), "recover_session")
    try:
      gfile.DeleteRecursively(checkpoint_dir)
    except errors.OpError:
      pass  # Ignore
    gfile.MakeDirs(checkpoint_dir)

    with ops.Graph().as_default():
      v = variables.Variable(1, name="v")
      sm = session_manager.SessionManager(
          ready_op=variables.assert_variables_initialized())
      saver = saver_lib.Saver({"v": v})
      sess, initialized = sm.recover_session(
          "", saver=saver, checkpoint_dir=checkpoint_dir)
      self.assertFalse(initialized)
      sess.run(v.initializer)
      self.assertEquals(1, sess.run(v))
      saver.save(sess,
                 os.path.join(checkpoint_dir, "recover_session_checkpoint"))
    # Create a new Graph and SessionManager and recover.
    with ops.Graph().as_default():
      v = variables.Variable(2, name="v")
      with self.cached_session():
        self.assertEqual(False, variables.is_variable_initialized(v).eval())
      sm2 = session_manager.SessionManager(
          ready_op=variables.assert_variables_initialized())
      saver = saver_lib.Saver({"v": v})
      sess, initialized = sm2.recover_session(
          "", saver=saver, checkpoint_dir=checkpoint_dir)
      self.assertTrue(initialized)
      self.assertEqual(
          True,
          variables.is_variable_initialized(
              sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
      self.assertEquals(1, sess.run(v))
Beispiel #26
0
def _wait_for_variable_initialization(session):
    """Utility to wait for variables to be initialized."""
    all_variables = K._get_variables(K.get_graph())  # pylint: disable=protected-access
    candidate_vars = []
    for v in all_variables:
        if not getattr(v, '_keras_initialized', False):
            candidate_vars.append(v)

    if not candidate_vars:
        return

    while True:
        is_initialized = session.run(
            [variables.is_variable_initialized(v) for v in candidate_vars])
        uninitialized_vars = []
        for flag, v in zip(is_initialized, candidate_vars):
            if not flag:
                uninitialized_vars.append(v)
            v._keras_initialized = True  # pylint: disable=protected-access
        if not uninitialized_vars:
            break
def _wait_for_variable_initialization(session):
  """Utility to wait for variables to be initialized."""
  all_variables = K._get_variables(K.get_graph())  # pylint: disable=protected-access
  candidate_vars = []
  for v in all_variables:
    if not getattr(v, '_keras_initialized', False):
      candidate_vars.append(v)

  if not candidate_vars:
    return

  while True:
    is_initialized = session.run(
        [variables.is_variable_initialized(v) for v in candidate_vars])
    uninitialized_vars = []
    for flag, v in zip(is_initialized, candidate_vars):
      if not flag:
        uninitialized_vars.append(v)
      v._keras_initialized = True  # pylint: disable=protected-access
    if not uninitialized_vars:
      break
Beispiel #28
0
def get_session(config=None, checkpoint_dir=None, checkpoint_path=None):
    global _SESSION
    if _SESSION is None:
        from tensorlib.training.sessions.core import BasicSessionCreator
        session_creator = BasicSessionCreator(
            config=config or config_pb2.ConfigProto(allow_soft_placement=True),
            checkpoint_dir=checkpoint_dir,
            checkpoint_path=checkpoint_path)
        _SESSION = session_creator.create_session()
        session = _SESSION
    else:
        if config is not None:
            import warnings
            warnings.warn("Session has already been created without specific"
                          " session_config: %s, you should invoke method"
                          " `get_session` manually in the beginning of your"
                          " program." % str(config))
        session = _SESSION
    if checkpoint_dir is None and checkpoint_path is None:
        with session.graph.as_default():
            all_vars = variables.global_variables()
            candidate_vars = []
            for var in all_vars:
                if not getattr(var, "_initialized", False):
                    candidate_vars.append(var)
            if candidate_vars:
                is_initialized = session.run([
                    variables.is_variable_initialized(v)
                    for v in candidate_vars
                ])
                uninitialized_vars = []
                for flag, v in zip(is_initialized, candidate_vars):
                    if not flag:
                        uninitialized_vars.append(v)
                    v._initialized = True
                if uninitialized_vars:
                    session.run(
                        variables.variables_initializer(uninitialized_vars))
    return session
  def testPrepareSessionWithPartialInitOp(self):
    with ops.Graph().as_default():
      v = variables.Variable(1, name="v")
      w = variables.Variable(
          v,
          trainable=False,
          collections=[ops.GraphKeys.LOCAL_VARIABLES],
          name="w")
      x = variables.Variable(
          3 * v,
          trainable=False,
          collections=[ops.GraphKeys.LOCAL_VARIABLES],
          name="x")
      # TODO(b/70206927): Use ResourceVariables once they are handled properly.
      v_res = variables.Variable(1, name="v_res")
      w_res = variables.Variable(
          v_res,
          trainable=False,
          collections=[ops.GraphKeys.LOCAL_VARIABLES],
          name="w_res")
      x_res = variables.Variable(
          3 * v_res,
          trainable=False,
          collections=[ops.GraphKeys.LOCAL_VARIABLES],
          name="x_res")

      with self.cached_session():
        self.assertEqual(False, variables.is_variable_initialized(v).eval())
        self.assertEqual(False, variables.is_variable_initialized(w).eval())
        self.assertEqual(False, variables.is_variable_initialized(x).eval())
        self.assertEqual(False, variables.is_variable_initialized(v_res).eval())
        self.assertEqual(False, variables.is_variable_initialized(w_res).eval())
        self.assertEqual(False, variables.is_variable_initialized(x_res).eval())
      sm2 = session_manager.SessionManager(local_init_op=[
          w.initializer, x.initializer, w_res.initializer, x_res.initializer
      ])
      sess = sm2.prepare_session("", init_op=None)
      self.assertEqual(
          False,
          variables.is_variable_initialized(
              sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
      self.assertEqual(
          True,
          variables.is_variable_initialized(
              sess.graph.get_tensor_by_name("w:0")).eval(session=sess))
      self.assertEqual(
          True,
          variables.is_variable_initialized(
              sess.graph.get_tensor_by_name("x:0")).eval(session=sess))
      self.assertEquals(1, sess.run(w))
      self.assertEquals(3, sess.run(x))
      self.assertEqual(
          False,
          variables.is_variable_initialized(
              sess.graph.get_tensor_by_name("v_res:0")).eval(session=sess))
      self.assertEqual(
          True,
          variables.is_variable_initialized(
              sess.graph.get_tensor_by_name("w_res:0")).eval(session=sess))
      self.assertEqual(
          True,
          variables.is_variable_initialized(
              sess.graph.get_tensor_by_name("x_res:0")).eval(session=sess))
      self.assertEquals(1, sess.run(w_res))
      self.assertEquals(3, sess.run(x_res))
Beispiel #30
0
    def testPrepareSessionWithPartialInitOp(self):
        with ops.Graph().as_default():
            v = variables.Variable(1, name="v")
            w = variables.Variable(v,
                                   trainable=False,
                                   collections=[ops.GraphKeys.LOCAL_VARIABLES],
                                   name="w")
            x = variables.Variable(3 * v,
                                   trainable=False,
                                   collections=[ops.GraphKeys.LOCAL_VARIABLES],
                                   name="x")
            # TODO(b/70206927): Use ResourceVariables once they are handled properly.
            v_res = variables.Variable(1, name="v_res")
            w_res = variables.Variable(
                v_res,
                trainable=False,
                collections=[ops.GraphKeys.LOCAL_VARIABLES],
                name="w_res")
            x_res = variables.Variable(
                3 * v_res,
                trainable=False,
                collections=[ops.GraphKeys.LOCAL_VARIABLES],
                name="x_res")

            with self.cached_session():
                self.assertEqual(False,
                                 variables.is_variable_initialized(v).eval())
                self.assertEqual(False,
                                 variables.is_variable_initialized(w).eval())
                self.assertEqual(False,
                                 variables.is_variable_initialized(x).eval())
                self.assertEqual(
                    False,
                    variables.is_variable_initialized(v_res).eval())
                self.assertEqual(
                    False,
                    variables.is_variable_initialized(w_res).eval())
                self.assertEqual(
                    False,
                    variables.is_variable_initialized(x_res).eval())
            sm2 = session_manager.SessionManager(local_init_op=[
                w.initializer, x.initializer, w_res.initializer,
                x_res.initializer
            ])
            sess = sm2.prepare_session("", init_op=None)
            self.assertEqual(
                False,
                variables.is_variable_initialized(
                    sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
            self.assertEqual(
                True,
                variables.is_variable_initialized(
                    sess.graph.get_tensor_by_name("w:0")).eval(session=sess))
            self.assertEqual(
                True,
                variables.is_variable_initialized(
                    sess.graph.get_tensor_by_name("x:0")).eval(session=sess))
            self.assertEquals(1, sess.run(w))
            self.assertEquals(3, sess.run(x))
            self.assertEqual(
                False,
                variables.is_variable_initialized(
                    sess.graph.get_tensor_by_name("v_res:0")).eval(
                        session=sess))
            self.assertEqual(
                True,
                variables.is_variable_initialized(
                    sess.graph.get_tensor_by_name("w_res:0")).eval(
                        session=sess))
            self.assertEqual(
                True,
                variables.is_variable_initialized(
                    sess.graph.get_tensor_by_name("x_res:0")).eval(
                        session=sess))
            self.assertEquals(1, sess.run(w_res))
            self.assertEquals(3, sess.run(x_res))