Ejemplo n.º 1
0
 def testNestedLists(self):
     a = autotrackable.AutoTrackable()
     a.l = []
     b = autotrackable.AutoTrackable()
     a.l.append([b])
     c = autotrackable.AutoTrackable()
     a.l[0].append(c)
     a_deps = util.list_objects(a)
     self.assertIn(b, a_deps)
     self.assertIn(c, a_deps)
     a.l[0].append(1)
     d = autotrackable.AutoTrackable()
     a.l[0].append(d)
     a_deps = util.list_objects(a)
     self.assertIn(d, a_deps)
     self.assertIn(b, a_deps)
     self.assertIn(c, a_deps)
     self.assertNotIn(1, a_deps)
     e = autotrackable.AutoTrackable()
     f = autotrackable.AutoTrackable()
     a.l1 = [[], [e]]
     a.l1[0].append(f)
     a_deps = util.list_objects(a)
     self.assertIn(e, a_deps)
     self.assertIn(f, a_deps)
     checkpoint = util.Checkpoint(a=a)
     checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
     a.l[0].append(data_structures.NoDependency([]))
     a.l[0][-1].append(5)
     checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
     # Dirtying the inner list means the root object is unsaveable.
     a.l[0][1] = 2
     with self.assertRaisesRegex(ValueError, "A list element was replaced"):
         checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
Ejemplo n.º 2
0
    def test_timedistributed_dense(self):
        model = keras.models.Sequential()
        model.add(
            keras.layers.TimeDistributed(
                keras.layers.Dense(2), input_shape=(3, 4)
            )
        )
        model.compile(optimizer="rmsprop", loss="mse")
        model.fit(
            np.random.random((10, 3, 4)),
            np.random.random((10, 3, 2)),
            epochs=1,
            batch_size=10,
        )

        # test config
        model.get_config()

        # check whether the model variables are present in the
        # trackable list of objects
        checkpointed_object_ids = {
            id(o) for o in trackable_util.list_objects(model)
        }
        for v in model.variables:
            self.assertIn(id(v), checkpointed_object_ids)
Ejemplo n.º 3
0
 def testAddVariableOverwrite(self):
   root = base.Trackable()
   a = root._add_variable_with_custom_getter(
       name="v", shape=[], getter=variable_scope.get_variable)
   self.assertEqual([root, a], util.list_objects(root))
   with ops.Graph().as_default():
     b = root._add_variable_with_custom_getter(
         name="v", shape=[], overwrite=True,
         getter=variable_scope.get_variable)
     self.assertEqual([root, b], util.list_objects(root))
   with ops.Graph().as_default():
     with self.assertRaisesRegex(ValueError,
                                 "already declared as a dependency"):
       root._add_variable_with_custom_getter(
           name="v", shape=[], overwrite=False,
           getter=variable_scope.get_variable)
Ejemplo n.º 4
0
 def testDictWrapperNoDependency(self):
     a = tf.Module()
     a.d = data_structures.NoDependency({})
     a.d[1] = [3]
     self.assertEqual([a], util.list_objects(a))
     model = training.Model()
     model.sub = a
     save_path = os.path.join(self.get_temp_dir(), "ckpt")
     model.save_weights(save_path)
     model.load_weights(save_path)
Ejemplo n.º 5
0
 def testNonStringKeyNotTrackableValue(self):
     a = tf.Module()
     a.d = {}
     a.d["a"] = [3]
     a.d[1] = data_structures.NoDependency([3])
     self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
     model = training.Model()
     model.sub = a
     save_path = os.path.join(self.get_temp_dir(), "ckpt")
     model.save_weights(save_path)
     model.load_weights(save_path)
Ejemplo n.º 6
0
 def testListBasic(self):
     a = autotrackable.AutoTrackable()
     b = autotrackable.AutoTrackable()
     a.l = [b]
     c = autotrackable.AutoTrackable()
     a.l.append(c)
     a_deps = util.list_objects(a)
     self.assertIn(b, a_deps)
     self.assertIn(c, a_deps)
     self.assertIn("l", a._trackable_children())
     direct_a_dep = a._trackable_children()["l"]
     self.assertIn(b, direct_a_dep)
     self.assertIn(c, direct_a_dep)
Ejemplo n.º 7
0
 def testShallowCopyTrackable(self):
     original = autotrackable.AutoTrackable()
     original_sub = autotrackable.AutoTrackable()
     original.a = [[1.]]
     original.b = {"a": original_sub}
     shallow_copied = copy.copy(original)
     self.assertIs(original_sub, shallow_copied.b["a"])
     self.assertIsNot(original, shallow_copied)
     self.assertEqual([[1.]], shallow_copied.a)
     shallow_deps = util.list_objects(shallow_copied)
     self.assertIn(shallow_copied.a, shallow_deps)
     self.assertIn(shallow_copied.b, shallow_deps)
     self.assertIn(shallow_copied.b["a"], shallow_deps)
    def test_trackable_save_restore(self):
        def _templated():
            v = variable_scope.get_variable(
                "v",
                shape=[1],
                initializer=init_ops.zeros_initializer(),
                use_resource=True)
            v2 = variable_scope.get_variable(
                "v2",
                shape=[1],
                initializer=init_ops.zeros_initializer(),
                use_resource=True)
            manual = _ManualScope()
            return v, v + 1., v2, manual, manual()

        save_template = template.make_template("s1", _templated)
        v1_save, _, v2_save, manual_scope, manual_scope_v = save_template()
        six.assertCountEqual(self, [
            id(obj) for obj in
            [v1_save, v2_save, manual_scope, manual_scope_v, save_template]
        ], [id(obj) for obj in trackable_utils.list_objects(save_template)])
        self.assertDictEqual({"in_manual_scope": manual_scope_v},
                             manual_scope._trackable_children())
        optimizer = adam.AdamOptimizer(0.0)
        save_root = trackable_utils.Checkpoint(my_template=save_template,
                                               optimizer=optimizer)
        optimizer.minimize(v1_save.read_value)
        self.evaluate([v.initializer for v in save_template.variables])
        self.evaluate([v.initializer for v in optimizer.variables()])
        self.evaluate(v1_save.assign([12.]))
        self.evaluate(v2_save.assign([14.]))
        checkpoint_directory = self.get_temp_dir()
        checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
        save_path = save_root.save(checkpoint_prefix)

        load_template = template.make_template("s2", _templated)
        load_optimizer = adam.AdamOptimizer(0.0)
        load_root = trackable_utils.Checkpoint(my_template=load_template,
                                               optimizer=load_optimizer)
        status = load_root.restore(save_path)
        var, var_plus_one, var2, _, _ = load_template()
        load_optimizer.minimize(var.read_value)
        self.assertEqual(3, len(load_template._trackable_children()))
        self.assertEqual(set(["v", "v2", "ManualScope"]),
                         load_template._trackable_children().keys())
        status.assert_consumed().run_restore_ops()
        self.assertAllEqual([12.], self.evaluate(var))
        self.assertAllEqual([13.], self.evaluate(var_plus_one))
        self.assertAllEqual([14.], self.evaluate(var2))
Ejemplo n.º 9
0
 def testDeepCopyTrackable(self):
     original = autotrackable.AutoTrackable()
     original_sub = autotrackable.AutoTrackable()
     original.a = [[1.]]
     original.b = {"a": original_sub}
     self.assertIsInstance(original.b, dict)
     deep_copied = copy.deepcopy(original)
     self.assertIsInstance(deep_copied.b, dict)
     self.assertIsNot(original, deep_copied)
     self.assertIsNot(original_sub, deep_copied.b["a"])
     self.assertEqual([[1.]], deep_copied.a)
     self.assertIsInstance(deep_copied.b["a"], autotrackable.AutoTrackable)
     deps = util.list_objects(deep_copied)
     self.assertIn(deep_copied.a, deps)
     self.assertIn(deep_copied.b, deps)
     self.assertIn(deep_copied.b["a"], deps)
     self.assertNotIn(original_sub, deps)
Ejemplo n.º 10
0
 def testDictionariesBasic(self):
     a = training.Model()
     b = training.Model()
     a.attribute = {"b": b}
     c = training.Model()
     a.attribute["c"] = []
     a.attribute["c"].append(c)
     a_deps = util.list_objects(a)
     self.assertIn(b, a_deps)
     self.assertIn(c, a_deps)
     self.assertIs(b, a.attribute["b"])
     self.assertEqual({"b", "c"}, a.attribute._trackable_children().keys())
     self.assertEqual([b, c], a.layers)
     self.assertEqual([b, c], a.attribute.layers)
     self.assertEqual([c], a.attribute["c"].layers)
     checkpoint = tf.train.Checkpoint(a=a)
     save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
     with self.cached_session():
         checkpoint.restore(
             save_path).assert_consumed().initialize_or_restore()
Ejemplo n.º 11
0
    def testNoDependency(self):
        root = tf.Module()
        hasdep = tf.Module()
        root.hasdep = hasdep
        nodep = tf.Module()
        root.nodep = data_structures.NoDependency(nodep)
        self.assertLen(root._trackable_children(), 1)
        self.assertIs(root._trackable_children()["hasdep"], root.hasdep)
        self.assertIs(root.hasdep, hasdep)
        self.assertIs(root.nodep, nodep)

        class NoDependencyModel(training.Model):
            @tf.__internal__.tracking.no_automatic_dependency_tracking
            def __init__(self):
                super().__init__()
                self.a = []
                self.b = tf.Module()

        nodeps = NoDependencyModel()
        self.assertEqual([nodeps], util.list_objects(nodeps))
Ejemplo n.º 12
0
 def testNonAppendNotTrackable(self):
     # Non-append mutations (deleting or overwriting values) are OK when the
     # values aren't tracked.
     a = tf.Module()
     a.d = {}
     a.d["a"] = [3]
     a.d[1] = 3
     a.d[1] = 2
     self.assertEqual(2, a.d[1])
     del a.d[1]
     a.d[2] = data_structures.NoDependency(tf.Module())
     second = tf.Module()
     a.d[2] = data_structures.NoDependency(second)
     self.assertIs(second, a.d[2])
     self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a))
     model = training.Model()
     model.sub = a
     save_path = os.path.join(self.get_temp_dir(), "ckpt")
     model.save_weights(save_path)
     model.load_weights(save_path)
Ejemplo n.º 13
0
    def testDictDeepCopy(self):
        root = autotrackable.AutoTrackable()
        orig_dict = {"a": [1.]}
        root.a = orig_dict
        copied = copy.deepcopy(root.a)
        self.assertAllEqual([1.], copied["a"])
        self.assertIsNot(root.a, copied)
        self.assertIsNot(root.a["a"], copied["a"])

        # Dirtiness should be inherited
        util.list_objects(root.a)
        orig_dict["b"] = []
        with self.assertRaises(ValueError):
            util.list_objects(root.a)
        with self.assertRaises(ValueError):
            util.list_objects(copy.deepcopy(root.a))
Ejemplo n.º 14
0
    def testListDeepCopy(self):
        root = autotrackable.AutoTrackable()
        orig_list = [[1.]]
        root.a = orig_list
        copied = copy.deepcopy(root.a)
        self.assertAllEqual([[1.]], copied)
        self.assertIsNot(root.a, copied)
        self.assertIsNot(root.a[0], copied[0])

        # Dirtiness should be inherited
        util.list_objects(root.a)
        orig_list.append(1.)
        with self.assertRaises(ValueError):
            util.list_objects(root.a)
        with self.assertRaises(ValueError):
            util.list_objects(copy.deepcopy(root.a))
Ejemplo n.º 15
0
    def test_trackable_save_restore(self):
        with self.test_session():

            def _templated():
                v = tf.compat.v1.get_variable(
                    "v",
                    shape=[1],
                    initializer=tf.compat.v1.zeros_initializer(),
                    use_resource=True,
                )
                v2 = tf.compat.v1.get_variable(
                    "v2",
                    shape=[1],
                    initializer=tf.compat.v1.zeros_initializer(),
                    use_resource=True,
                )
                manual = _ManualScope()
                return v, v + 1.0, v2, manual, manual()

            save_template = tf.compat.v1.make_template("s1", _templated)
            v1_save, _, v2_save, manual_scope, manual_scope_v = save_template()
            self.assertEqual(
                set([
                    id(v1_save),
                    id(v2_save),
                    id(manual_scope),
                    id(manual_scope_v),
                    id(save_template),
                ]),
                set(map(id, trackable_utils.list_objects(save_template))),
            )
            self.assertDictEqual(
                {"in_manual_scope": manual_scope_v},
                manual_scope._trackable_children(),
            )
            optimizer = adam.Adam(0.0)
            save_root = tf.train.Checkpoint(my_template=save_template,
                                            optimizer=optimizer)
            optimizer.minimize(v1_save.read_value, var_list=[v1_save])
            self.evaluate([v.initializer for v in save_template.variables])
            optimizer_variables = optimizer.variables() + list(
                optimizer._hyper.values())
            self.evaluate([v.initializer for v in optimizer_variables])
            self.evaluate(v1_save.assign([12.0]))
            self.evaluate(v2_save.assign([14.0]))
            checkpoint_directory = self.get_temp_dir()
            checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
            save_path = save_root.save(checkpoint_prefix)

            load_template = tf.compat.v1.make_template("s2", _templated)
            load_optimizer = adam.Adam(0.0)
            load_root = tf.train.Checkpoint(my_template=load_template,
                                            optimizer=load_optimizer)
            status = load_root.restore(save_path)
            var, var_plus_one, var2, _, _ = load_template()
            load_optimizer.minimize(var.read_value, var_list=[var])

            children = load_template._trackable_children()
            self.assertEqual({"v", "v2", "ManualScope"}, children.keys())
            status.assert_consumed().run_restore_ops()
            self.assertAllEqual([12.0], self.evaluate(var))
            self.assertAllEqual([13.0], self.evaluate(var_plus_one))
            self.assertAllEqual([14.0], self.evaluate(var2))