def testAssertConsumedWithUnusedPythonState(self): has_config = base.CheckpointableBase() has_config.get_config = lambda: {} saved = util.Checkpoint(obj=has_config) save_path = saved.save(os.path.join(self.get_temp_dir(), "ckpt")) restored = util.Checkpoint(obj=base.CheckpointableBase()) restored.restore(save_path).assert_consumed()
def testOverwrite(self): root = base.CheckpointableBase() leaf = base.CheckpointableBase() root._track_checkpointable(leaf, name="leaf") (current_name, current_dependency), = root._checkpoint_dependencies self.assertIs(leaf, current_dependency) self.assertEqual("leaf", current_name) duplicate_name_dep = base.CheckpointableBase() with self.assertRaises(ValueError): root._track_checkpointable(duplicate_name_dep, name="leaf") root._track_checkpointable(duplicate_name_dep, name="leaf", overwrite=True) (current_name, current_dependency), = root._checkpoint_dependencies self.assertIs(duplicate_name_dep, current_dependency) self.assertEqual("leaf", current_name)
def testAssertConsumedFailsWithUsedPythonState(self): has_config = base.CheckpointableBase() attributes = { "foo_attr": functools.partial(base.PythonStringStateSaveable, state_callback=lambda: "", restore_callback=lambda x: None) } has_config._gather_saveables_for_checkpoint = lambda: attributes saved = util.Checkpoint(obj=has_config) save_path = saved.save(os.path.join(self.get_temp_dir(), "ckpt")) restored = util.Checkpoint(obj=base.CheckpointableBase()) status = restored.restore(save_path) with self.assertRaisesRegexp(AssertionError, "foo_attr"): status.assert_consumed()
def testAddVariableOverwrite(self): root = base.CheckpointableBase() a = root._add_variable_with_custom_getter( name="v", shape=[], getter=variable_scope.get_variable) self.assertEqual([root, a], util.list_objects(root)) with ops.Graph().as_default(): b = root._add_variable_with_custom_getter( name="v", shape=[], overwrite=True, getter=variable_scope.get_variable) self.assertEqual([root, b], util.list_objects(root)) with ops.Graph().as_default(): with self.assertRaisesRegexp( ValueError, "already declared as a dependency"): root._add_variable_with_custom_getter( name="v", shape=[], overwrite=False, getter=variable_scope.get_variable)