def testIsVariableInitialized(self): for use_gpu in [True, False]: with self.test_session(use_gpu=use_gpu): v0 = state_ops.variable_op([1, 2], dtypes.float32) self.assertEqual(False, variables.is_variable_initialized(v0).eval()) state_ops.assign(v0, [[2.0, 3.0]]).eval() self.assertEqual(True, variables.is_variable_initialized(v0).eval())
def testPrepareSessionWithReadyForLocalInitOp(self): with ops.Graph().as_default(): v = variables.Variable(1, name="v") w = variables.Variable( v, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="w") with self.test_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) self.assertEqual(False, variables.is_variable_initialized(w).eval()) sm2 = session_manager.SessionManager( ready_op=variables.report_uninitialized_variables(), ready_for_local_init_op=variables.report_uninitialized_variables( variables.global_variables()), local_init_op=w.initializer) sess = sm2.prepare_session("", init_op=v.initializer) self.assertEqual( True, variables.is_variable_initialized( sess.graph.get_tensor_by_name("v:0")).eval(session=sess)) self.assertEqual( True, variables.is_variable_initialized( sess.graph.get_tensor_by_name("w:0")).eval(session=sess)) self.assertEquals(1, sess.run(v)) self.assertEquals(1, sess.run(w))
def testWaitForSessionLocalInit(self): server = server_lib.Server.create_local_server() with ops.Graph().as_default() as graph: v = variables.Variable(1, name="v") w = variables.Variable(v, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="w") sm = session_manager.SessionManager( graph=graph, ready_op=variables.report_uninitialized_variables(), ready_for_local_init_op=variables. report_uninitialized_variables(variables.global_variables()), local_init_op=w.initializer) # Initialize v but not w s = session_lib.Session(server.target, graph=graph) s.run(v.initializer) sess = sm.wait_for_session(server.target, max_wait_secs=3) self.assertEqual( True, variables.is_variable_initialized( sess.graph.get_tensor_by_name("v:0")).eval(session=sess)) self.assertEqual( True, variables.is_variable_initialized( sess.graph.get_tensor_by_name("w:0")).eval(session=sess)) self.assertEquals(1, sess.run(v)) self.assertEquals(1, sess.run(w))
def testPrepareSessionFails(self): checkpoint_dir = os.path.join(self.get_temp_dir(), "prepare_session") checkpoint_dir2 = os.path.join(self.get_temp_dir(), "prepare_session2") try: gfile.DeleteRecursively(checkpoint_dir) gfile.DeleteRecursively(checkpoint_dir2) except errors.OpError: pass # Ignore gfile.MakeDirs(checkpoint_dir) with ops.Graph().as_default(): v = variables.Variable([1.0, 2.0, 3.0], name="v") sm = session_manager.SessionManager( ready_op=variables.report_uninitialized_variables()) saver = saver_lib.Saver({"v": v}) sess = sm.prepare_session( "", init_op=variables.global_variables_initializer(), saver=saver, checkpoint_dir=checkpoint_dir) self.assertAllClose([1.0, 2.0, 3.0], sess.run(v)) checkpoint_filename = os.path.join(checkpoint_dir, "prepare_session_checkpoint") saver.save(sess, checkpoint_filename) # Create a new Graph and SessionManager and recover. with ops.Graph().as_default(): # Renames the checkpoint directory. os.rename(checkpoint_dir, checkpoint_dir2) gfile.MakeDirs(checkpoint_dir) v = variables.Variable([6.0, 7.0, 8.0], name="v") with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) session_manager.SessionManager( ready_op=variables.report_uninitialized_variables()) saver = saver_lib.Saver({"v": v}) # This should fail as there's no checkpoint within 2 seconds. with self.assertRaisesRegexp( RuntimeError, "no init_op or init_fn or local_init_op was given"): sess = sm.prepare_session( "", init_op=None, saver=saver, checkpoint_dir=checkpoint_dir, wait_for_checkpoint=True, max_wait_secs=2) # Rename the checkpoint directory back. gfile.DeleteRecursively(checkpoint_dir) os.rename(checkpoint_dir2, checkpoint_dir) # This should succeed as there's checkpoint. sess = sm.prepare_session( "", init_op=None, saver=saver, checkpoint_dir=checkpoint_dir, wait_for_checkpoint=True, max_wait_secs=2) self.assertEqual( True, variables.is_variable_initialized( sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
def testWaitForSessionLocalInit(self): server = server_lib.Server.create_local_server() with ops.Graph().as_default() as graph: v = variables.Variable(1, name="v") w = variables.Variable( v, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="w") sm = session_manager.SessionManager( graph=graph, ready_op=variables.report_uninitialized_variables(), ready_for_local_init_op=variables.report_uninitialized_variables( variables.global_variables()), local_init_op=w.initializer) # Initialize v but not w s = session_lib.Session(server.target, graph=graph) s.run(v.initializer) sess = sm.wait_for_session(server.target, max_wait_secs=3) self.assertEqual( True, variables.is_variable_initialized( sess.graph.get_tensor_by_name("v:0")).eval(session=sess)) self.assertEqual( True, variables.is_variable_initialized( sess.graph.get_tensor_by_name("w:0")).eval(session=sess)) self.assertEquals(1, sess.run(v)) self.assertEquals(1, sess.run(w))
def testRecoverSessionNoChkptStillRunsLocalInitOp(self): # This test checks for backwards compatibility. # In particular, we continue to ensure that recover_session will execute # local_init_op exactly once, regardless of whether the session was # successfully recovered. with ops.Graph().as_default(): w = variables.Variable(1, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="w") with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(w).eval()) sm2 = session_manager.SessionManager( ready_op=variables.report_uninitialized_variables(), ready_for_local_init_op=None, local_init_op=w.initializer) # Try to recover session from None sess, initialized = sm2.recover_session("", saver=None, checkpoint_dir=None) # Succeeds because recover_session still run local_init_op self.assertFalse(initialized) self.assertEqual( True, variables.is_variable_initialized( sess.graph.get_tensor_by_name("w:0")).eval(session=sess)) self.assertEquals(1, sess.run(w))
def testPrepareSessionFails(self): checkpoint_dir = os.path.join(self.get_temp_dir(), "prepare_session") checkpoint_dir2 = os.path.join(self.get_temp_dir(), "prepare_session2") try: gfile.DeleteRecursively(checkpoint_dir) gfile.DeleteRecursively(checkpoint_dir2) except errors.OpError: pass # Ignore gfile.MakeDirs(checkpoint_dir) with ops.Graph().as_default(): v = variables.Variable([1.0, 2.0, 3.0], name="v") sm = session_manager.SessionManager( ready_op=variables.assert_variables_initialized()) saver = saver_lib.Saver({"v": v}) sess = sm.prepare_session( "", init_op=variables.global_variables_initializer(), saver=saver, checkpoint_dir=checkpoint_dir) self.assertAllClose([1.0, 2.0, 3.0], sess.run(v)) checkpoint_filename = os.path.join(checkpoint_dir, "prepare_session_checkpoint") saver.save(sess, checkpoint_filename) # Create a new Graph and SessionManager and recover. with ops.Graph().as_default(): # Renames the checkpoint directory. os.rename(checkpoint_dir, checkpoint_dir2) gfile.MakeDirs(checkpoint_dir) v = variables.Variable([6.0, 7.0, 8.0], name="v") with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) session_manager.SessionManager( ready_op=variables.assert_variables_initialized()) saver = saver_lib.Saver({"v": v}) # This should fail as there's no checkpoint within 2 seconds. with self.assertRaisesRegexp( RuntimeError, "no init_op or init_fn or local_init_op was given"): sess = sm.prepare_session("", init_op=None, saver=saver, checkpoint_dir=checkpoint_dir, wait_for_checkpoint=True, max_wait_secs=2) # Rename the checkpoint directory back. gfile.DeleteRecursively(checkpoint_dir) os.rename(checkpoint_dir2, checkpoint_dir) # This should succeed as there's checkpoint. sess = sm.prepare_session("", init_op=None, saver=saver, checkpoint_dir=checkpoint_dir, wait_for_checkpoint=True, max_wait_secs=2) self.assertEqual( True, variables.is_variable_initialized( sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
def testRecoverSessionNoChkptStillRunsLocalInitOp(self): # This test checks for backwards compatibility. # In particular, we continue to ensure that recover_session will execute # local_init_op exactly once, regardless of whether the session was # successfully recovered. with ops.Graph().as_default(): w = variables.Variable( 1, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="w") with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(w).eval()) sm2 = session_manager.SessionManager( ready_op=variables.report_uninitialized_variables(), ready_for_local_init_op=None, local_init_op=w.initializer) # Try to recover session from None sess, initialized = sm2.recover_session( "", saver=None, checkpoint_dir=None) # Succeeds because recover_session still run local_init_op self.assertFalse(initialized) self.assertEqual( True, variables.is_variable_initialized( sess.graph.get_tensor_by_name("w:0")).eval(session=sess)) self.assertEquals(1, sess.run(w))
def testPrepareSessionWithReadyForLocalInitOp(self): with ops.Graph().as_default(): v = variables.Variable(1, name="v") w = variables.Variable(v, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="w") with self.test_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) self.assertEqual(False, variables.is_variable_initialized(w).eval()) sm2 = session_manager.SessionManager( ready_op=variables.report_uninitialized_variables(), ready_for_local_init_op=variables. report_uninitialized_variables(variables.global_variables()), local_init_op=w.initializer) sess = sm2.prepare_session("", init_op=v.initializer) self.assertEqual( True, variables.is_variable_initialized( sess.graph.get_tensor_by_name("v:0")).eval(session=sess)) self.assertEqual( True, variables.is_variable_initialized( sess.graph.get_tensor_by_name("w:0")).eval(session=sess)) self.assertEquals(1, sess.run(v)) self.assertEquals(1, sess.run(w))
def testRecoverSessionWithReadyForLocalInitOpFailsToReadyLocal(self): # We use ready_for_local_init_op=report_uninitialized_variables(), # which causes recover_session to not run local_init_op, and to return # initialized=False # Create a checkpoint. checkpoint_dir = os.path.join( self.get_temp_dir(), "recover_session_ready_for_local_init_fails_to_ready_local") try: gfile.DeleteRecursively(checkpoint_dir) except errors.OpError: pass # Ignore gfile.MakeDirs(checkpoint_dir) with ops.Graph().as_default(): v = variables.VariableV1(1, name="v") sm = session_manager.SessionManager( ready_op=variables.report_uninitialized_variables()) saver = saver_lib.Saver({"v": v}) sess, initialized = sm.recover_session( "", saver=saver, checkpoint_dir=checkpoint_dir) self.assertFalse(initialized) sess.run(v.initializer) self.assertEqual(1, sess.run(v)) saver.save( sess, os.path.join(checkpoint_dir, "recover_session_checkpoint")) # Create a new Graph and SessionManager and recover. with ops.Graph().as_default(): v = variables.VariableV1(2, name="v") w = variables.VariableV1( v, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="w") with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) self.assertEqual(False, variables.is_variable_initialized(w).eval()) sm2 = session_manager.SessionManager( ready_op=variables.report_uninitialized_variables(), ready_for_local_init_op=variables. report_uninitialized_variables(), local_init_op=w.initializer) saver = saver_lib.Saver({"v": v}) sess, initialized = sm2.recover_session( "", saver=saver, checkpoint_dir=checkpoint_dir) self.assertFalse(initialized) self.assertEqual( True, variables.is_variable_initialized( sess.graph.get_tensor_by_name("v:0")).eval(session=sess)) self.assertEqual( False, variables.is_variable_initialized( sess.graph.get_tensor_by_name("w:0")).eval(session=sess)) self.assertEqual(1, sess.run(v))
def testRecoverSessionWithReadyForLocalInitOpFailsToReadyLocal(self): # We use ready_for_local_init_op=tf.report_uninitialized_variables(), # which causes recover_session to not run local_init_op, and to return # initialized=False # Create a checkpoint. checkpoint_dir = os.path.join( self.get_temp_dir(), "recover_session_ready_for_local_init_fails_to_ready_local") try: gfile.DeleteRecursively(checkpoint_dir) except errors.OpError: pass # Ignore gfile.MakeDirs(checkpoint_dir) with ops.Graph().as_default(): v = variables.Variable(1, name="v") sm = session_manager.SessionManager( ready_op=variables.report_uninitialized_variables()) saver = saver_lib.Saver({"v": v}) sess, initialized = sm.recover_session( "", saver=saver, checkpoint_dir=checkpoint_dir) self.assertFalse(initialized) sess.run(v.initializer) self.assertEquals(1, sess.run(v)) saver.save(sess, os.path.join(checkpoint_dir, "recover_session_checkpoint")) # Create a new Graph and SessionManager and recover. with ops.Graph().as_default(): v = variables.Variable(2, name="v") w = variables.Variable( v, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="w") with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) self.assertEqual(False, variables.is_variable_initialized(w).eval()) sm2 = session_manager.SessionManager( ready_op=variables.report_uninitialized_variables(), ready_for_local_init_op=variables.report_uninitialized_variables(), local_init_op=w.initializer) saver = saver_lib.Saver({"v": v}) sess, initialized = sm2.recover_session( "", saver=saver, checkpoint_dir=checkpoint_dir) self.assertFalse(initialized) self.assertEqual( True, variables.is_variable_initialized( sess.graph.get_tensor_by_name("v:0")).eval(session=sess)) self.assertEqual( False, variables.is_variable_initialized( sess.graph.get_tensor_by_name("w:0")).eval(session=sess)) self.assertEquals(1, sess.run(v))
def testPrepareSessionDidNotInitLocalVariableList(self): with ops.Graph().as_default(): v = variables.VariableV1(1, name="v") w = variables.VariableV1( v, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="w") with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) self.assertEqual(False, variables.is_variable_initialized(w).eval()) sm2 = session_manager.SessionManager( ready_op=variables.report_uninitialized_variables()) with self.assertRaisesRegex(RuntimeError, "Init operations did not make model ready"): sm2.prepare_session("", init_op=[v.initializer])
def testPrepareSessionDidNotInitLocalVariableList(self): with ops.Graph().as_default(): v = variables.Variable(1, name="v") w = variables.Variable( v, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="w") with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) self.assertEqual(False, variables.is_variable_initialized(w).eval()) sm2 = session_manager.SessionManager( ready_op=variables.report_uninitialized_variables()) with self.assertRaisesRegexp(RuntimeError, "Init operations did not make model ready"): sm2.prepare_session("", init_op=[v.initializer])
def testPrepareSessionWithCyclicInitializer(self): # Regression test. Previously Variable._build_initializer_expr would enter # into an infinite recursion when the variable's initial_value involved # cyclic dependencies. with ops.Graph().as_default(): i = control_flow_ops.while_loop(lambda i: i < 1, lambda i: i + 1, [0]) v = variables.Variable(array_ops.identity(i), name="v") with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) sm = session_manager.SessionManager( ready_op=variables.report_uninitialized_variables()) sess = sm.prepare_session("", init_op=v.initializer) self.assertEqual(1, sess.run(v)) self.assertEqual( True, variables.is_variable_initialized( sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
def testPrepareSessionWithCyclicInitializer(self): # Regression test. Previously Variable._build_initializer_expr would enter # into an infinite recursion when the variable's initial_value involved # cyclic dependencies. with ops.Graph().as_default(): i = control_flow_ops.while_loop(lambda i: i < 1, lambda i: i + 1, [0]) v = variables.VariableV1(array_ops.identity(i), name="v") with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) sm = session_manager.SessionManager( ready_op=variables.report_uninitialized_variables()) sess = sm.prepare_session("", init_op=v.initializer) self.assertEqual(1, sess.run(v)) self.assertEqual( True, variables.is_variable_initialized( sess.graph.get_tensor_by_name("v:0")).eval(session=sess))
def testPrepareSessionWithInsufficientReadyForLocalInitCheck(self): with ops.Graph().as_default(): v = variables.Variable(1, name="v") w = variables.Variable( v, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="w") with self.test_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) self.assertEqual(False, variables.is_variable_initialized(w).eval()) sm2 = session_manager.SessionManager( ready_op=variables.report_uninitialized_variables(), ready_for_local_init_op=None, local_init_op=w.initializer) with self.assertRaisesRegexp(errors_impl.FailedPreconditionError, "Attempting to use uninitialized value v"): sm2.prepare_session("", init_op=None)
def testPrepareSessionWithInsufficientReadyForLocalInitCheck(self): with ops.Graph().as_default(): v = variables.Variable(1, name="v") w = variables.Variable( v, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="w") with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) self.assertEqual(False, variables.is_variable_initialized(w).eval()) sm2 = session_manager.SessionManager( ready_op=variables.report_uninitialized_variables(), ready_for_local_init_op=None, local_init_op=w.initializer) with self.assertRaisesRegexp(RuntimeError, "Init operations did not make model ready.*"): sm2.prepare_session("", init_op=None)
def testPrepareSessionWithInsufficientReadyForLocalInitCheck(self): with ops.Graph().as_default(): v = variables.VariableV1(1, name="v") w = variables.VariableV1( v, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="w") with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) self.assertEqual(False, variables.is_variable_initialized(w).eval()) sm2 = session_manager.SessionManager( ready_op=variables.report_uninitialized_variables(), ready_for_local_init_op=None, local_init_op=w.initializer) with self.assertRaisesRegex(RuntimeError, "Init operations did not make model ready.*"): sm2.prepare_session("", init_op=None)
def testRecoverSessionFailsStillRunsLocalInitOp(self): # Create a checkpoint. checkpoint_dir = os.path.join( self.get_temp_dir(), "recover_session_ready_for_local_init_fails_stil_run") try: gfile.DeleteRecursively(checkpoint_dir) except errors.OpError: pass # Ignore gfile.MakeDirs(checkpoint_dir) # Create a new Graph and SessionManager and recover. with ops.Graph().as_default(): v = variables.VariableV1(2, name="v") w = variables.VariableV1( 1, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="w") with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) self.assertEqual(False, variables.is_variable_initialized(w).eval()) sm2 = session_manager.SessionManager( ready_op=variables.report_uninitialized_variables(), ready_for_local_init_op=None, local_init_op=w.initializer) saver = saver_lib.Saver({"v": v}) sess, initialized = sm2.recover_session( "", saver=saver, checkpoint_dir=checkpoint_dir, wait_for_checkpoint=False) self.assertFalse(initialized) self.assertEqual( False, variables.is_variable_initialized( sess.graph.get_tensor_by_name("v:0")).eval(session=sess)) self.assertEqual( True, variables.is_variable_initialized( sess.graph.get_tensor_by_name("w:0")).eval(session=sess)) self.assertEqual(1, sess.run(w))
def testRecoverSessionFailsStillRunsLocalInitOp(self): # Create a checkpoint. checkpoint_dir = os.path.join( self.get_temp_dir(), "recover_session_ready_for_local_init_fails_stil_run") try: gfile.DeleteRecursively(checkpoint_dir) except errors.OpError: pass # Ignore gfile.MakeDirs(checkpoint_dir) # Create a new Graph and SessionManager and recover. with ops.Graph().as_default(): v = variables.Variable(2, name="v") w = variables.Variable( 1, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="w") with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) self.assertEqual(False, variables.is_variable_initialized(w).eval()) sm2 = session_manager.SessionManager( ready_op=variables.report_uninitialized_variables(), ready_for_local_init_op=None, local_init_op=w.initializer) saver = saver_lib.Saver({"v": v}) sess, initialized = sm2.recover_session( "", saver=saver, checkpoint_dir=checkpoint_dir, wait_for_checkpoint=False) self.assertFalse(initialized) self.assertEqual( False, variables.is_variable_initialized( sess.graph.get_tensor_by_name("v:0")).eval(session=sess)) self.assertEqual( True, variables.is_variable_initialized( sess.graph.get_tensor_by_name("w:0")).eval(session=sess)) self.assertEquals(1, sess.run(w))
def testRecoverSession(self): # Create a checkpoint. checkpoint_dir = os.path.join(self.get_temp_dir(), "recover_session") try: gfile.DeleteRecursively(checkpoint_dir) except errors.OpError: pass # Ignore gfile.MakeDirs(checkpoint_dir) with ops.Graph().as_default(): v = variables.Variable(1, name="v") sm = session_manager.SessionManager( ready_op=variables.assert_variables_initialized()) saver = saver_lib.Saver({"v": v}) sess, initialized = sm.recover_session( "", saver=saver, checkpoint_dir=checkpoint_dir) self.assertFalse(initialized) sess.run(v.initializer) self.assertEquals(1, sess.run(v)) saver.save( sess, os.path.join(checkpoint_dir, "recover_session_checkpoint")) # Create a new Graph and SessionManager and recover. with ops.Graph().as_default(): v = variables.Variable(2, name="v") with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) sm2 = session_manager.SessionManager( ready_op=variables.assert_variables_initialized()) saver = saver_lib.Saver({"v": v}) sess, initialized = sm2.recover_session( "", saver=saver, checkpoint_dir=checkpoint_dir) self.assertTrue(initialized) self.assertEqual( True, variables.is_variable_initialized( sess.graph.get_tensor_by_name("v:0")).eval(session=sess)) self.assertEquals(1, sess.run(v))
def _test_recovered_variable(self, checkpoint_dir=None, checkpoint_filename_with_path=None): # Create a new Graph and SessionManager and recover from a checkpoint. with ops.Graph().as_default(): v = variables.Variable(2, name="v") with session_lib.Session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) sm2 = session_manager.SessionManager( ready_op=variables.report_uninitialized_variables()) saver = saver_lib.Saver({"v": v}) sess, initialized = sm2.recover_session( "", saver=saver, checkpoint_dir=checkpoint_dir, checkpoint_filename_with_path=checkpoint_filename_with_path) self.assertTrue(initialized) self.assertEqual( True, variables.is_variable_initialized( sess.graph.get_tensor_by_name("v:0")).eval(session=sess)) self.assertEquals(1, sess.run(v))
def _test_recovered_variable(self, checkpoint_dir=None, checkpoint_filename_with_path=None): # Create a new Graph and SessionManager and recover from a checkpoint. with ops.Graph().as_default(): v = variables.VariableV1(2, name="v") with session_lib.Session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) sm2 = session_manager.SessionManager( ready_op=variables.report_uninitialized_variables()) saver = saver_lib.Saver({"v": v}) sess, initialized = sm2.recover_session( "", saver=saver, checkpoint_dir=checkpoint_dir, checkpoint_filename_with_path=checkpoint_filename_with_path) self.assertTrue(initialized) self.assertEqual( True, variables.is_variable_initialized( sess.graph.get_tensor_by_name("v:0")).eval(session=sess)) self.assertEqual(1, sess.run(v))
def testRecoverSession(self): # Create a checkpoint. checkpoint_dir = os.path.join(self.get_temp_dir(), "recover_session") try: gfile.DeleteRecursively(checkpoint_dir) except errors.OpError: pass # Ignore gfile.MakeDirs(checkpoint_dir) with ops.Graph().as_default(): v = variables.Variable(1, name="v") sm = session_manager.SessionManager( ready_op=variables.assert_variables_initialized()) saver = saver_lib.Saver({"v": v}) sess, initialized = sm.recover_session( "", saver=saver, checkpoint_dir=checkpoint_dir) self.assertFalse(initialized) sess.run(v.initializer) self.assertEquals(1, sess.run(v)) saver.save(sess, os.path.join(checkpoint_dir, "recover_session_checkpoint")) # Create a new Graph and SessionManager and recover. with ops.Graph().as_default(): v = variables.Variable(2, name="v") with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) sm2 = session_manager.SessionManager( ready_op=variables.assert_variables_initialized()) saver = saver_lib.Saver({"v": v}) sess, initialized = sm2.recover_session( "", saver=saver, checkpoint_dir=checkpoint_dir) self.assertTrue(initialized) self.assertEqual( True, variables.is_variable_initialized( sess.graph.get_tensor_by_name("v:0")).eval(session=sess)) self.assertEquals(1, sess.run(v))
def _wait_for_variable_initialization(session): """Utility to wait for variables to be initialized.""" all_variables = K._get_variables(K.get_graph()) # pylint: disable=protected-access candidate_vars = [] for v in all_variables: if not getattr(v, '_keras_initialized', False): candidate_vars.append(v) if not candidate_vars: return while True: is_initialized = session.run( [variables.is_variable_initialized(v) for v in candidate_vars]) uninitialized_vars = [] for flag, v in zip(is_initialized, candidate_vars): if not flag: uninitialized_vars.append(v) v._keras_initialized = True # pylint: disable=protected-access if not uninitialized_vars: break
def get_session(config=None, checkpoint_dir=None, checkpoint_path=None): global _SESSION if _SESSION is None: from tensorlib.training.sessions.core import BasicSessionCreator session_creator = BasicSessionCreator( config=config or config_pb2.ConfigProto(allow_soft_placement=True), checkpoint_dir=checkpoint_dir, checkpoint_path=checkpoint_path) _SESSION = session_creator.create_session() session = _SESSION else: if config is not None: import warnings warnings.warn("Session has already been created without specific" " session_config: %s, you should invoke method" " `get_session` manually in the beginning of your" " program." % str(config)) session = _SESSION if checkpoint_dir is None and checkpoint_path is None: with session.graph.as_default(): all_vars = variables.global_variables() candidate_vars = [] for var in all_vars: if not getattr(var, "_initialized", False): candidate_vars.append(var) if candidate_vars: is_initialized = session.run([ variables.is_variable_initialized(v) for v in candidate_vars ]) uninitialized_vars = [] for flag, v in zip(is_initialized, candidate_vars): if not flag: uninitialized_vars.append(v) v._initialized = True if uninitialized_vars: session.run( variables.variables_initializer(uninitialized_vars)) return session
def testPrepareSessionWithPartialInitOp(self): with ops.Graph().as_default(): v = variables.Variable(1, name="v") w = variables.Variable( v, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="w") x = variables.Variable( 3 * v, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="x") # TODO(b/70206927): Use ResourceVariables once they are handled properly. v_res = variables.Variable(1, name="v_res") w_res = variables.Variable( v_res, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="w_res") x_res = variables.Variable( 3 * v_res, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="x_res") with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) self.assertEqual(False, variables.is_variable_initialized(w).eval()) self.assertEqual(False, variables.is_variable_initialized(x).eval()) self.assertEqual(False, variables.is_variable_initialized(v_res).eval()) self.assertEqual(False, variables.is_variable_initialized(w_res).eval()) self.assertEqual(False, variables.is_variable_initialized(x_res).eval()) sm2 = session_manager.SessionManager(local_init_op=[ w.initializer, x.initializer, w_res.initializer, x_res.initializer ]) sess = sm2.prepare_session("", init_op=None) self.assertEqual( False, variables.is_variable_initialized( sess.graph.get_tensor_by_name("v:0")).eval(session=sess)) self.assertEqual( True, variables.is_variable_initialized( sess.graph.get_tensor_by_name("w:0")).eval(session=sess)) self.assertEqual( True, variables.is_variable_initialized( sess.graph.get_tensor_by_name("x:0")).eval(session=sess)) self.assertEquals(1, sess.run(w)) self.assertEquals(3, sess.run(x)) self.assertEqual( False, variables.is_variable_initialized( sess.graph.get_tensor_by_name("v_res:0")).eval(session=sess)) self.assertEqual( True, variables.is_variable_initialized( sess.graph.get_tensor_by_name("w_res:0")).eval(session=sess)) self.assertEqual( True, variables.is_variable_initialized( sess.graph.get_tensor_by_name("x_res:0")).eval(session=sess)) self.assertEquals(1, sess.run(w_res)) self.assertEquals(3, sess.run(x_res))
def testPrepareSessionWithPartialInitOp(self): with ops.Graph().as_default(): v = variables.Variable(1, name="v") w = variables.Variable(v, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="w") x = variables.Variable(3 * v, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="x") # TODO(b/70206927): Use ResourceVariables once they are handled properly. v_res = variables.Variable(1, name="v_res") w_res = variables.Variable( v_res, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="w_res") x_res = variables.Variable( 3 * v_res, trainable=False, collections=[ops.GraphKeys.LOCAL_VARIABLES], name="x_res") with self.cached_session(): self.assertEqual(False, variables.is_variable_initialized(v).eval()) self.assertEqual(False, variables.is_variable_initialized(w).eval()) self.assertEqual(False, variables.is_variable_initialized(x).eval()) self.assertEqual( False, variables.is_variable_initialized(v_res).eval()) self.assertEqual( False, variables.is_variable_initialized(w_res).eval()) self.assertEqual( False, variables.is_variable_initialized(x_res).eval()) sm2 = session_manager.SessionManager(local_init_op=[ w.initializer, x.initializer, w_res.initializer, x_res.initializer ]) sess = sm2.prepare_session("", init_op=None) self.assertEqual( False, variables.is_variable_initialized( sess.graph.get_tensor_by_name("v:0")).eval(session=sess)) self.assertEqual( True, variables.is_variable_initialized( sess.graph.get_tensor_by_name("w:0")).eval(session=sess)) self.assertEqual( True, variables.is_variable_initialized( sess.graph.get_tensor_by_name("x:0")).eval(session=sess)) self.assertEquals(1, sess.run(w)) self.assertEquals(3, sess.run(x)) self.assertEqual( False, variables.is_variable_initialized( sess.graph.get_tensor_by_name("v_res:0")).eval( session=sess)) self.assertEqual( True, variables.is_variable_initialized( sess.graph.get_tensor_by_name("w_res:0")).eval( session=sess)) self.assertEqual( True, variables.is_variable_initialized( sess.graph.get_tensor_by_name("x_res:0")).eval( session=sess)) self.assertEquals(1, sess.run(w_res)) self.assertEquals(3, sess.run(x_res))