def test_registration(self):
        registration.register_checkpoint_saver(
            package="Testing",
            name="test_predicate",
            predicate=lambda x: hasattr(x, "check_attr"),
            save_fn=lambda: "save",
            restore_fn=lambda: "restore")
        x = base.Trackable()
        self.assertIsNone(registration.get_registered_saver_name(x))

        x.check_attr = 1
        saver_name = registration.get_registered_saver_name(x)
        self.assertEqual(saver_name, "Testing.test_predicate")

        self.assertEqual(registration.get_save_function(saver_name)(), "save")
        self.assertEqual(
            registration.get_restore_function(saver_name)(), "restore")

        registration.validate_restore_function(x, "Testing.test_predicate")
        with self.assertRaisesRegex(ValueError, "saver cannot be found"):
            registration.validate_restore_function(x, "Invalid.name")
        x2 = base.Trackable()
        with self.assertRaisesRegex(ValueError, "saver cannot be used"):
            registration.validate_restore_function(x2,
                                                   "Testing.test_predicate")
Exemple #2
0
 def testAssertConsumedWithUnusedPythonState(self):
     has_config = base.Trackable()
     has_config.get_config = lambda: {}
     saved = util.Checkpoint(obj=has_config)
     save_path = saved.save(os.path.join(self.get_temp_dir(), "ckpt"))
     restored = util.Checkpoint(obj=base.Trackable())
     restored.restore(save_path).assert_consumed()
Exemple #3
0
 def testOverwrite(self):
     root = base.Trackable()
     leaf = base.Trackable()
     root._track_trackable(leaf, name="leaf")
     (current_name, current_dependency), = root._checkpoint_dependencies
     self.assertIs(leaf, current_dependency)
     self.assertEqual("leaf", current_name)
     duplicate_name_dep = base.Trackable()
     with self.assertRaises(ValueError):
         root._track_trackable(duplicate_name_dep, name="leaf")
     root._track_trackable(duplicate_name_dep, name="leaf", overwrite=True)
     (current_name, current_dependency), = root._checkpoint_dependencies
     self.assertIs(duplicate_name_dep, current_dependency)
     self.assertEqual("leaf", current_name)
Exemple #4
0
 def testAssertConsumedFailsWithUsedPythonState(self):
     has_config = base.Trackable()
     attributes = {
         "foo_attr":
         functools.partial(base.PythonStringStateSaveable,
                           state_callback=lambda: "",
                           restore_callback=lambda x: None)
     }
     has_config._gather_saveables_for_checkpoint = lambda: attributes
     saved = util.Checkpoint(obj=has_config)
     save_path = saved.save(os.path.join(self.get_temp_dir(), "ckpt"))
     restored = util.Checkpoint(obj=base.Trackable())
     status = restored.restore(save_path)
     with self.assertRaisesRegex(AssertionError, "foo_attr"):
         status.assert_consumed()
Exemple #5
0
def main(_):
    tr1 = base.Trackable()
    v = tf.Variable(1)
    tr1._track_trackable(v, name='tr1_v')
    for _ in range(3):
        trackable(tr1, v)

    tr2 = tracking.AutoTrackable()
    tracked, untracked = tf.Variable(1000), tf.Variable(0)
    tr2.v = tracked
    with base.no_automatic_dependency_tracking_scope(tr2):
        tr2.untracked = untracked
    for _ in range(2):
        autotrackable(tr2, tracked, untracked)

    listing()

    deleting(tr2)

    tr3 = tracking.AutoTrackable()
    br1 = tracking.AutoTrackable()
    br1.v = tf.Variable(5)
    br2 = tracking.AutoTrackable()
    br2.v = tf.Variable(5)
    tr3.br_list = [br1, br2]
    br3 = tracking.AutoTrackable()
    br3.v = tf.Variable(5)
    tr3.br_dict = {'br3': br3}
    containers(tr3)

    tr3.br_dict = {'br1': br1, 'br2': br2, 'br3': br3}
    sharing(tr3)

    mod1 = Module('m1')
    mod1.sub = Module('m2')
    mod1.sub.sub = Module('m3')
    modules(mod1)

    # @tf.function
    # def tracer1():
    #     return mod1()

    # graph(tracer1)

    ins = [tf.keras.Input(shape=(), dtype=tf.int32)]
    lay = Layer(name='l1', sub=Layer(name='l2', sub=Layer(name='l3')))
    outs = [lay(ins)]
    mod2 = tf.keras.Model(name='m2', inputs=ins, outputs=outs)
    models(mod2, lay)

    @tf.function
    def tracer2():
        return mod2(tf.constant([100, 100]))

    graph(tracer2)
 def testAttributeException(self):
     with context.eager_mode():
         original_root = trackable_utils.Checkpoint(
             v1=variables_lib.Variable(2.), v2=variables_lib.Variable(3.))
         prefix = os.path.join(self.get_temp_dir(), "ckpt")
         save_path = original_root.save(prefix)
         partial_root = trackable_utils.Checkpoint(
             v1=base.Trackable(), v2=variables_lib.Variable(0.))
         status = partial_root.restore(save_path)
         with self.assertRaisesRegex(
                 AssertionError, r"Unused attributes(.|\n)*\(root\).v1"):
             status.assert_consumed()
  def test_object_tracker(self):
    test_case.skip_if_not_tf2('Tensorflow 2.x required')

    trackable_object = base.Trackable()

    @tf.function
    def preprocessing_fn():
      _ = annotators.make_and_track_object(lambda: trackable_object)
      return 1

    object_tracker = annotators.ObjectTracker()
    with annotators.object_tracker_scope(object_tracker):
      _ = preprocessing_fn()

    self.assertLen(object_tracker.trackable_objects, 1)
    self.assertEqual(trackable_object, object_tracker.trackable_objects[0])
Exemple #8
0
 def testAddVariableOverwrite(self):
     root = base.Trackable()
     a = root._add_variable_with_custom_getter(
         name="v", shape=[], getter=variable_scope.get_variable)
     self.assertEqual([root, a], util.list_objects(root))
     with ops.Graph().as_default():
         b = root._add_variable_with_custom_getter(
             name="v",
             shape=[],
             overwrite=True,
             getter=variable_scope.get_variable)
         self.assertEqual([root, b], util.list_objects(root))
     with ops.Graph().as_default():
         with self.assertRaisesRegex(ValueError,
                                     "already declared as a dependency"):
             root._add_variable_with_custom_getter(
                 name="v",
                 shape=[],
                 overwrite=False,
                 getter=variable_scope.get_variable)
 def testPartialRestoreWarningAttribute(self):
     with context.eager_mode():
         original_root = trackable_utils.Checkpoint(
             v1=variables_lib.Variable(2.), v2=variables_lib.Variable(3.))
         prefix = os.path.join(self.get_temp_dir(), "ckpt")
         save_path = original_root.save(prefix)
         partial_root = trackable_utils.Checkpoint(
             v1=base.Trackable(), v2=variables_lib.Variable(0.))
         weak_partial_root = weakref.ref(partial_root)
         with test.mock.patch.object(logging, "warning") as mock_log:
             # Note: Unlike in testPartialRestoreWarningObject, the warning actually
             # prints immediately here, since all of the objects have been created
             # and there's no deferred restoration sitting around.
             partial_root.restore(save_path)
             self.assertEqual(3., partial_root.v2.numpy())
             del partial_root
             self.assertIsNone(weak_partial_root())
             messages = str(mock_log.call_args_list)
         self.assertIn("(root).v1", messages)
         self.assertNotIn("(root).v2", messages)
         self.assertIn("expect_partial()", messages)