コード例 #1
0
ファイル: base_test.py プロジェクト: hemphillmc/tensorflow
 def testAssertConsumedWithUnusedPythonState(self):
     has_config = base.CheckpointableBase()
     has_config.get_config = lambda: {}
     saved = util.Checkpoint(obj=has_config)
     save_path = saved.save(os.path.join(self.get_temp_dir(), "ckpt"))
     restored = util.Checkpoint(obj=base.CheckpointableBase())
     restored.restore(save_path).assert_consumed()
コード例 #2
0
 def testOverwrite(self):
   root = base.CheckpointableBase()
   leaf = base.CheckpointableBase()
   root._track_checkpointable(leaf, name="leaf")
   (current_name, current_dependency), = root._checkpoint_dependencies
   self.assertIs(leaf, current_dependency)
   self.assertEqual("leaf", current_name)
   duplicate_name_dep = base.CheckpointableBase()
   with self.assertRaises(ValueError):
     root._track_checkpointable(duplicate_name_dep, name="leaf")
   root._track_checkpointable(duplicate_name_dep, name="leaf", overwrite=True)
   (current_name, current_dependency), = root._checkpoint_dependencies
   self.assertIs(duplicate_name_dep, current_dependency)
   self.assertEqual("leaf", current_name)
コード例 #3
0
ファイル: base_test.py プロジェクト: hemphillmc/tensorflow
 def testAssertConsumedFailsWithUsedPythonState(self):
     has_config = base.CheckpointableBase()
     attributes = {
         "foo_attr":
         functools.partial(base.PythonStringStateSaveable,
                           state_callback=lambda: "",
                           restore_callback=lambda x: None)
     }
     has_config._gather_saveables_for_checkpoint = lambda: attributes
     saved = util.Checkpoint(obj=has_config)
     save_path = saved.save(os.path.join(self.get_temp_dir(), "ckpt"))
     restored = util.Checkpoint(obj=base.CheckpointableBase())
     status = restored.restore(save_path)
     with self.assertRaisesRegexp(AssertionError, "foo_attr"):
         status.assert_consumed()
コード例 #4
0
ファイル: base_test.py プロジェクト: kuo1220/verbose-barnacle
 def testAddVariableOverwrite(self):
   root = base.CheckpointableBase()
   a = root._add_variable_with_custom_getter(
       name="v", shape=[], getter=variable_scope.get_variable)
   self.assertEqual([root, a], util.list_objects(root))
   with ops.Graph().as_default():
     b = root._add_variable_with_custom_getter(
         name="v", shape=[], overwrite=True,
         getter=variable_scope.get_variable)
     self.assertEqual([root, b], util.list_objects(root))
   with ops.Graph().as_default():
     with self.assertRaisesRegexp(
         ValueError, "already declared as a dependency"):
       root._add_variable_with_custom_getter(
           name="v", shape=[], overwrite=False,
           getter=variable_scope.get_variable)