コード例 #1
0
ファイル: base_test.py プロジェクト: shen-zc/tensorflow
 def testAssertConsumedWithUnusedPythonState(self):
     has_config = base.Trackable()
     has_config.get_config = lambda: {}
     saved = util.Checkpoint(obj=has_config)
     save_path = saved.save(os.path.join(self.get_temp_dir(), "ckpt"))
     restored = util.Checkpoint(obj=base.Trackable())
     restored.restore(save_path).assert_consumed()
コード例 #2
0
    def test_registration(self):
        registration.register_checkpoint_saver(
            package="Testing",
            name="test_predicate",
            predicate=lambda x: hasattr(x, "check_attr"),
            save_fn=lambda: "save",
            restore_fn=lambda: "restore")
        x = base.Trackable()
        self.assertIsNone(registration.get_registered_saver_name(x))

        x.check_attr = 1
        saver_name = registration.get_registered_saver_name(x)
        self.assertEqual(saver_name, "Testing.test_predicate")

        self.assertEqual(registration.get_save_function(saver_name)(), "save")
        self.assertEqual(
            registration.get_restore_function(saver_name)(), "restore")

        registration.validate_restore_function(x, "Testing.test_predicate")
        with self.assertRaisesRegex(ValueError, "saver cannot be found"):
            registration.validate_restore_function(x, "Invalid.name")
        x2 = base.Trackable()
        with self.assertRaisesRegex(ValueError, "saver cannot be used"):
            registration.validate_restore_function(x2,
                                                   "Testing.test_predicate")
コード例 #3
0
 def test_descendants(self):
   root = base.Trackable()
   leaf = base.Trackable()
   root._track_trackable(leaf, name="leaf")
   descendants = trackable_view.TrackableView(root).descendants()
   self.assertIs(2, len(descendants))
   self.assertIs(root, descendants[0])
   self.assertIs(leaf, descendants[1])
コード例 #4
0
 def test_children(self):
     root = base.Trackable()
     leaf = base.Trackable()
     root._track_trackable(leaf, name="leaf")
     (current_name, current_dependency
      ), = trackable_view.TrackableView.children(root).items()
     self.assertIs(leaf, current_dependency)
     self.assertEqual("leaf", current_name)
コード例 #5
0
 def test_all_nodes(self):
     root = base.Trackable()
     leaf = base.Trackable()
     root._track_trackable(leaf, name="leaf")
     all_nodes = trackable_view.TrackableView(root).all_nodes()
     self.assertIs(2, len(all_nodes))
     self.assertIs(root, all_nodes[0])
     self.assertIs(leaf, all_nodes[1])
コード例 #6
0
 def test_standard_saveable_name(self):
     self.assertEqual(
         "object_path/.ATTRIBUTES/",
         checkpoint_util.extract_saveable_name(
             base.Trackable(), "object_path/.ATTRIBUTES/123"))
     self.assertEqual(
         "object/path/ATTRIBUTES/.ATTRIBUTES/",
         checkpoint_util.extract_saveable_name(
             base.Trackable(), "object/path/ATTRIBUTES/.ATTRIBUTES/"))
コード例 #7
0
 def test_all_nodes(self):
   root = base.Trackable()
   leaf = base.Trackable()
   root._track_trackable(leaf, name="leaf")
   root_ckpt = trackable_utils.Checkpoint(root=root)
   root_save_path = root_ckpt.save(
       os.path.join(self.get_temp_dir(), "root_ckpt"))
   all_nodes = checkpoint_view.CheckpointView(root_save_path).descendants()
   self.assertEqual(1, all_nodes[0])
   self.assertEqual(0, all_nodes[1])
コード例 #8
0
 def test_children(self):
   root = base.Trackable()
   leaf = base.Trackable()
   root._track_trackable(leaf, name="leaf")
   root_ckpt = trackable_utils.Checkpoint(root=root)
   root_save_path = root_ckpt.save(
       os.path.join(self.get_temp_dir(), "root_ckpt"))
   current_name, node_id = next(
       iter(
           checkpoint_view.CheckpointView(root_save_path).children(0).items()))
   self.assertEqual("leaf", current_name)
   self.assertEqual(1, node_id)
コード例 #9
0
ファイル: base_test.py プロジェクト: whoozle/tensorflow
 def testAssertConsumedFailsWithUsedPythonState(self):
   has_config = base.Trackable()
   attributes = {
       "foo_attr": functools.partial(
           base.PythonStringStateSaveable,
           state_callback=lambda: "",
           restore_callback=lambda x: None)}
   has_config._gather_saveables_for_checkpoint = lambda: attributes
   saved = util.Checkpoint(obj=has_config)
   save_path = saved.save(os.path.join(self.get_temp_dir(), "ckpt"))
   restored = util.Checkpoint(obj=base.Trackable())
   status = restored.restore(save_path)
   with self.assertRaisesRegex(AssertionError, "foo_attr"):
     status.assert_consumed()
コード例 #10
0
ファイル: base_test.py プロジェクト: whoozle/tensorflow
 def testOverwrite(self):
   root = base.Trackable()
   leaf = base.Trackable()
   root._track_trackable(leaf, name="leaf")
   (current_name, current_dependency), = root._trackable_children().items()
   self.assertIs(leaf, current_dependency)
   self.assertEqual("leaf", current_name)
   duplicate_name_dep = base.Trackable()
   with self.assertRaises(ValueError):
     root._track_trackable(duplicate_name_dep, name="leaf")
   root._track_trackable(duplicate_name_dep, name="leaf", overwrite=True)
   (current_name, current_dependency), = root._trackable_children().items()
   self.assertIs(duplicate_name_dep, current_dependency)
   self.assertEqual("leaf", current_name)
コード例 #11
0
    def test_convert_no_saveable(self):
        t = base.Trackable()
        converter = saveable_object_util.SaveableCompatibilityConverter(t)
        self.assertEmpty(converter._serialize_to_tensors())
        converter._restore_from_tensors({})

        with self.assertRaisesRegex(ValueError, "Could not restore object"):
            converter._restore_from_tensors({"": 0})
コード例 #12
0
ファイル: base_test.py プロジェクト: whoozle/tensorflow
 def testAddVariableOverwrite(self):
   root = base.Trackable()
   a = root._add_variable_with_custom_getter(
       name="v", shape=[], getter=variable_scope.get_variable)
   self.assertEqual([root, a], util.list_objects(root))
   with ops.Graph().as_default():
     b = root._add_variable_with_custom_getter(
         name="v", shape=[], overwrite=True,
         getter=variable_scope.get_variable)
     self.assertEqual([root, b], util.list_objects(root))
   with ops.Graph().as_default():
     with self.assertRaisesRegex(ValueError,
                                 "already declared as a dependency"):
       root._add_variable_with_custom_getter(
           name="v", shape=[], overwrite=False,
           getter=variable_scope.get_variable)