def testAssertConsumedWithUnusedPythonState(self): has_config = base.Trackable() has_config.get_config = lambda: {} saved = util.Checkpoint(obj=has_config) save_path = saved.save(os.path.join(self.get_temp_dir(), "ckpt")) restored = util.Checkpoint(obj=base.Trackable()) restored.restore(save_path).assert_consumed()
def test_registration(self): registration.register_checkpoint_saver( package="Testing", name="test_predicate", predicate=lambda x: hasattr(x, "check_attr"), save_fn=lambda: "save", restore_fn=lambda: "restore") x = base.Trackable() self.assertIsNone(registration.get_registered_saver_name(x)) x.check_attr = 1 saver_name = registration.get_registered_saver_name(x) self.assertEqual(saver_name, "Testing.test_predicate") self.assertEqual(registration.get_save_function(saver_name)(), "save") self.assertEqual( registration.get_restore_function(saver_name)(), "restore") registration.validate_restore_function(x, "Testing.test_predicate") with self.assertRaisesRegex(ValueError, "saver cannot be found"): registration.validate_restore_function(x, "Invalid.name") x2 = base.Trackable() with self.assertRaisesRegex(ValueError, "saver cannot be used"): registration.validate_restore_function(x2, "Testing.test_predicate")
def test_descendants(self): root = base.Trackable() leaf = base.Trackable() root._track_trackable(leaf, name="leaf") descendants = trackable_view.TrackableView(root).descendants() self.assertIs(2, len(descendants)) self.assertIs(root, descendants[0]) self.assertIs(leaf, descendants[1])
def test_children(self): root = base.Trackable() leaf = base.Trackable() root._track_trackable(leaf, name="leaf") (current_name, current_dependency ), = trackable_view.TrackableView.children(root).items() self.assertIs(leaf, current_dependency) self.assertEqual("leaf", current_name)
def test_all_nodes(self): root = base.Trackable() leaf = base.Trackable() root._track_trackable(leaf, name="leaf") all_nodes = trackable_view.TrackableView(root).all_nodes() self.assertIs(2, len(all_nodes)) self.assertIs(root, all_nodes[0]) self.assertIs(leaf, all_nodes[1])
def test_standard_saveable_name(self): self.assertEqual( "object_path/.ATTRIBUTES/", checkpoint_util.extract_saveable_name( base.Trackable(), "object_path/.ATTRIBUTES/123")) self.assertEqual( "object/path/ATTRIBUTES/.ATTRIBUTES/", checkpoint_util.extract_saveable_name( base.Trackable(), "object/path/ATTRIBUTES/.ATTRIBUTES/"))
def test_all_nodes(self): root = base.Trackable() leaf = base.Trackable() root._track_trackable(leaf, name="leaf") root_ckpt = trackable_utils.Checkpoint(root=root) root_save_path = root_ckpt.save( os.path.join(self.get_temp_dir(), "root_ckpt")) all_nodes = checkpoint_view.CheckpointView(root_save_path).descendants() self.assertEqual(1, all_nodes[0]) self.assertEqual(0, all_nodes[1])
def test_children(self): root = base.Trackable() leaf = base.Trackable() root._track_trackable(leaf, name="leaf") root_ckpt = trackable_utils.Checkpoint(root=root) root_save_path = root_ckpt.save( os.path.join(self.get_temp_dir(), "root_ckpt")) current_name, node_id = next( iter( checkpoint_view.CheckpointView(root_save_path).children(0).items())) self.assertEqual("leaf", current_name) self.assertEqual(1, node_id)
def testAssertConsumedFailsWithUsedPythonState(self): has_config = base.Trackable() attributes = { "foo_attr": functools.partial( base.PythonStringStateSaveable, state_callback=lambda: "", restore_callback=lambda x: None)} has_config._gather_saveables_for_checkpoint = lambda: attributes saved = util.Checkpoint(obj=has_config) save_path = saved.save(os.path.join(self.get_temp_dir(), "ckpt")) restored = util.Checkpoint(obj=base.Trackable()) status = restored.restore(save_path) with self.assertRaisesRegex(AssertionError, "foo_attr"): status.assert_consumed()
def testOverwrite(self): root = base.Trackable() leaf = base.Trackable() root._track_trackable(leaf, name="leaf") (current_name, current_dependency), = root._trackable_children().items() self.assertIs(leaf, current_dependency) self.assertEqual("leaf", current_name) duplicate_name_dep = base.Trackable() with self.assertRaises(ValueError): root._track_trackable(duplicate_name_dep, name="leaf") root._track_trackable(duplicate_name_dep, name="leaf", overwrite=True) (current_name, current_dependency), = root._trackable_children().items() self.assertIs(duplicate_name_dep, current_dependency) self.assertEqual("leaf", current_name)
def test_convert_no_saveable(self): t = base.Trackable() converter = saveable_object_util.SaveableCompatibilityConverter(t) self.assertEmpty(converter._serialize_to_tensors()) converter._restore_from_tensors({}) with self.assertRaisesRegex(ValueError, "Could not restore object"): converter._restore_from_tensors({"": 0})
def testAddVariableOverwrite(self): root = base.Trackable() a = root._add_variable_with_custom_getter( name="v", shape=[], getter=variable_scope.get_variable) self.assertEqual([root, a], util.list_objects(root)) with ops.Graph().as_default(): b = root._add_variable_with_custom_getter( name="v", shape=[], overwrite=True, getter=variable_scope.get_variable) self.assertEqual([root, b], util.list_objects(root)) with ops.Graph().as_default(): with self.assertRaisesRegex(ValueError, "already declared as a dependency"): root._add_variable_with_custom_getter( name="v", shape=[], overwrite=False, getter=variable_scope.get_variable)