def test_registration(self): registration.register_checkpoint_saver( package="Testing", name="test_predicate", predicate=lambda x: hasattr(x, "check_attr"), save_fn=lambda: "save", restore_fn=lambda: "restore") x = base.Trackable() self.assertIsNone(registration.get_registered_saver_name(x)) x.check_attr = 1 saver_name = registration.get_registered_saver_name(x) self.assertEqual(saver_name, "Testing.test_predicate") self.assertEqual(registration.get_save_function(saver_name)(), "save") self.assertEqual( registration.get_restore_function(saver_name)(), "restore") registration.validate_restore_function(x, "Testing.test_predicate") with self.assertRaisesRegex(ValueError, "saver cannot be found"): registration.validate_restore_function(x, "Invalid.name") x2 = base.Trackable() with self.assertRaisesRegex(ValueError, "saver cannot be used"): registration.validate_restore_function(x2, "Testing.test_predicate")
def testAssertConsumedWithUnusedPythonState(self): has_config = base.Trackable() has_config.get_config = lambda: {} saved = util.Checkpoint(obj=has_config) save_path = saved.save(os.path.join(self.get_temp_dir(), "ckpt")) restored = util.Checkpoint(obj=base.Trackable()) restored.restore(save_path).assert_consumed()
def testOverwrite(self): root = base.Trackable() leaf = base.Trackable() root._track_trackable(leaf, name="leaf") (current_name, current_dependency), = root._checkpoint_dependencies self.assertIs(leaf, current_dependency) self.assertEqual("leaf", current_name) duplicate_name_dep = base.Trackable() with self.assertRaises(ValueError): root._track_trackable(duplicate_name_dep, name="leaf") root._track_trackable(duplicate_name_dep, name="leaf", overwrite=True) (current_name, current_dependency), = root._checkpoint_dependencies self.assertIs(duplicate_name_dep, current_dependency) self.assertEqual("leaf", current_name)
def testAssertConsumedFailsWithUsedPythonState(self): has_config = base.Trackable() attributes = { "foo_attr": functools.partial(base.PythonStringStateSaveable, state_callback=lambda: "", restore_callback=lambda x: None) } has_config._gather_saveables_for_checkpoint = lambda: attributes saved = util.Checkpoint(obj=has_config) save_path = saved.save(os.path.join(self.get_temp_dir(), "ckpt")) restored = util.Checkpoint(obj=base.Trackable()) status = restored.restore(save_path) with self.assertRaisesRegex(AssertionError, "foo_attr"): status.assert_consumed()
def main(_): tr1 = base.Trackable() v = tf.Variable(1) tr1._track_trackable(v, name='tr1_v') for _ in range(3): trackable(tr1, v) tr2 = tracking.AutoTrackable() tracked, untracked = tf.Variable(1000), tf.Variable(0) tr2.v = tracked with base.no_automatic_dependency_tracking_scope(tr2): tr2.untracked = untracked for _ in range(2): autotrackable(tr2, tracked, untracked) listing() deleting(tr2) tr3 = tracking.AutoTrackable() br1 = tracking.AutoTrackable() br1.v = tf.Variable(5) br2 = tracking.AutoTrackable() br2.v = tf.Variable(5) tr3.br_list = [br1, br2] br3 = tracking.AutoTrackable() br3.v = tf.Variable(5) tr3.br_dict = {'br3': br3} containers(tr3) tr3.br_dict = {'br1': br1, 'br2': br2, 'br3': br3} sharing(tr3) mod1 = Module('m1') mod1.sub = Module('m2') mod1.sub.sub = Module('m3') modules(mod1) # @tf.function # def tracer1(): # return mod1() # graph(tracer1) ins = [tf.keras.Input(shape=(), dtype=tf.int32)] lay = Layer(name='l1', sub=Layer(name='l2', sub=Layer(name='l3'))) outs = [lay(ins)] mod2 = tf.keras.Model(name='m2', inputs=ins, outputs=outs) models(mod2, lay) @tf.function def tracer2(): return mod2(tf.constant([100, 100])) graph(tracer2)
def testAttributeException(self): with context.eager_mode(): original_root = trackable_utils.Checkpoint( v1=variables_lib.Variable(2.), v2=variables_lib.Variable(3.)) prefix = os.path.join(self.get_temp_dir(), "ckpt") save_path = original_root.save(prefix) partial_root = trackable_utils.Checkpoint( v1=base.Trackable(), v2=variables_lib.Variable(0.)) status = partial_root.restore(save_path) with self.assertRaisesRegex( AssertionError, r"Unused attributes(.|\n)*\(root\).v1"): status.assert_consumed()
def test_object_tracker(self): test_case.skip_if_not_tf2('Tensorflow 2.x required') trackable_object = base.Trackable() @tf.function def preprocessing_fn(): _ = annotators.make_and_track_object(lambda: trackable_object) return 1 object_tracker = annotators.ObjectTracker() with annotators.object_tracker_scope(object_tracker): _ = preprocessing_fn() self.assertLen(object_tracker.trackable_objects, 1) self.assertEqual(trackable_object, object_tracker.trackable_objects[0])
def testAddVariableOverwrite(self): root = base.Trackable() a = root._add_variable_with_custom_getter( name="v", shape=[], getter=variable_scope.get_variable) self.assertEqual([root, a], util.list_objects(root)) with ops.Graph().as_default(): b = root._add_variable_with_custom_getter( name="v", shape=[], overwrite=True, getter=variable_scope.get_variable) self.assertEqual([root, b], util.list_objects(root)) with ops.Graph().as_default(): with self.assertRaisesRegex(ValueError, "already declared as a dependency"): root._add_variable_with_custom_getter( name="v", shape=[], overwrite=False, getter=variable_scope.get_variable)
def testPartialRestoreWarningAttribute(self): with context.eager_mode(): original_root = trackable_utils.Checkpoint( v1=variables_lib.Variable(2.), v2=variables_lib.Variable(3.)) prefix = os.path.join(self.get_temp_dir(), "ckpt") save_path = original_root.save(prefix) partial_root = trackable_utils.Checkpoint( v1=base.Trackable(), v2=variables_lib.Variable(0.)) weak_partial_root = weakref.ref(partial_root) with test.mock.patch.object(logging, "warning") as mock_log: # Note: Unlike in testPartialRestoreWarningObject, the warning actually # prints immediately here, since all of the objects have been created # and there's no deferred restoration sitting around. partial_root.restore(save_path) self.assertEqual(3., partial_root.v2.numpy()) del partial_root self.assertIsNone(weak_partial_root()) messages = str(mock_log.call_args_list) self.assertIn("(root).v1", messages) self.assertNotIn("(root).v2", messages) self.assertIn("expect_partial()", messages)