def testNestedLists(self): a = autotrackable.AutoTrackable() a.l = [] b = autotrackable.AutoTrackable() a.l.append([b]) c = autotrackable.AutoTrackable() a.l[0].append(c) a_deps = util.list_objects(a) self.assertIn(b, a_deps) self.assertIn(c, a_deps) a.l[0].append(1) d = autotrackable.AutoTrackable() a.l[0].append(d) a_deps = util.list_objects(a) self.assertIn(d, a_deps) self.assertIn(b, a_deps) self.assertIn(c, a_deps) self.assertNotIn(1, a_deps) e = autotrackable.AutoTrackable() f = autotrackable.AutoTrackable() a.l1 = [[], [e]] a.l1[0].append(f) a_deps = util.list_objects(a) self.assertIn(e, a_deps) self.assertIn(f, a_deps) checkpoint = util.Checkpoint(a=a) checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) a.l[0].append(data_structures.NoDependency([])) a.l[0][-1].append(5) checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) # Dirtying the inner list means the root object is unsaveable. a.l[0][1] = 2 with self.assertRaisesRegex(ValueError, "A list element was replaced"): checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
def test_timedistributed_dense(self): model = keras.models.Sequential() model.add( keras.layers.TimeDistributed( keras.layers.Dense(2), input_shape=(3, 4) ) ) model.compile(optimizer="rmsprop", loss="mse") model.fit( np.random.random((10, 3, 4)), np.random.random((10, 3, 2)), epochs=1, batch_size=10, ) # test config model.get_config() # check whether the model variables are present in the # trackable list of objects checkpointed_object_ids = { id(o) for o in trackable_util.list_objects(model) } for v in model.variables: self.assertIn(id(v), checkpointed_object_ids)
def testAddVariableOverwrite(self): root = base.Trackable() a = root._add_variable_with_custom_getter( name="v", shape=[], getter=variable_scope.get_variable) self.assertEqual([root, a], util.list_objects(root)) with ops.Graph().as_default(): b = root._add_variable_with_custom_getter( name="v", shape=[], overwrite=True, getter=variable_scope.get_variable) self.assertEqual([root, b], util.list_objects(root)) with ops.Graph().as_default(): with self.assertRaisesRegex(ValueError, "already declared as a dependency"): root._add_variable_with_custom_getter( name="v", shape=[], overwrite=False, getter=variable_scope.get_variable)
def testDictWrapperNoDependency(self): a = tf.Module() a.d = data_structures.NoDependency({}) a.d[1] = [3] self.assertEqual([a], util.list_objects(a)) model = training.Model() model.sub = a save_path = os.path.join(self.get_temp_dir(), "ckpt") model.save_weights(save_path) model.load_weights(save_path)
def testNonStringKeyNotTrackableValue(self): a = tf.Module() a.d = {} a.d["a"] = [3] a.d[1] = data_structures.NoDependency([3]) self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a)) model = training.Model() model.sub = a save_path = os.path.join(self.get_temp_dir(), "ckpt") model.save_weights(save_path) model.load_weights(save_path)
def testListBasic(self): a = autotrackable.AutoTrackable() b = autotrackable.AutoTrackable() a.l = [b] c = autotrackable.AutoTrackable() a.l.append(c) a_deps = util.list_objects(a) self.assertIn(b, a_deps) self.assertIn(c, a_deps) self.assertIn("l", a._trackable_children()) direct_a_dep = a._trackable_children()["l"] self.assertIn(b, direct_a_dep) self.assertIn(c, direct_a_dep)
def testShallowCopyTrackable(self): original = autotrackable.AutoTrackable() original_sub = autotrackable.AutoTrackable() original.a = [[1.]] original.b = {"a": original_sub} shallow_copied = copy.copy(original) self.assertIs(original_sub, shallow_copied.b["a"]) self.assertIsNot(original, shallow_copied) self.assertEqual([[1.]], shallow_copied.a) shallow_deps = util.list_objects(shallow_copied) self.assertIn(shallow_copied.a, shallow_deps) self.assertIn(shallow_copied.b, shallow_deps) self.assertIn(shallow_copied.b["a"], shallow_deps)
def test_trackable_save_restore(self): def _templated(): v = variable_scope.get_variable( "v", shape=[1], initializer=init_ops.zeros_initializer(), use_resource=True) v2 = variable_scope.get_variable( "v2", shape=[1], initializer=init_ops.zeros_initializer(), use_resource=True) manual = _ManualScope() return v, v + 1., v2, manual, manual() save_template = template.make_template("s1", _templated) v1_save, _, v2_save, manual_scope, manual_scope_v = save_template() six.assertCountEqual(self, [ id(obj) for obj in [v1_save, v2_save, manual_scope, manual_scope_v, save_template] ], [id(obj) for obj in trackable_utils.list_objects(save_template)]) self.assertDictEqual({"in_manual_scope": manual_scope_v}, manual_scope._trackable_children()) optimizer = adam.AdamOptimizer(0.0) save_root = trackable_utils.Checkpoint(my_template=save_template, optimizer=optimizer) optimizer.minimize(v1_save.read_value) self.evaluate([v.initializer for v in save_template.variables]) self.evaluate([v.initializer for v in optimizer.variables()]) self.evaluate(v1_save.assign([12.])) self.evaluate(v2_save.assign([14.])) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") save_path = save_root.save(checkpoint_prefix) load_template = template.make_template("s2", _templated) load_optimizer = adam.AdamOptimizer(0.0) load_root = trackable_utils.Checkpoint(my_template=load_template, optimizer=load_optimizer) status = load_root.restore(save_path) var, var_plus_one, var2, _, _ = load_template() load_optimizer.minimize(var.read_value) self.assertEqual(3, len(load_template._trackable_children())) self.assertEqual(set(["v", "v2", "ManualScope"]), load_template._trackable_children().keys()) status.assert_consumed().run_restore_ops() self.assertAllEqual([12.], self.evaluate(var)) self.assertAllEqual([13.], self.evaluate(var_plus_one)) self.assertAllEqual([14.], self.evaluate(var2))
def testDeepCopyTrackable(self): original = autotrackable.AutoTrackable() original_sub = autotrackable.AutoTrackable() original.a = [[1.]] original.b = {"a": original_sub} self.assertIsInstance(original.b, dict) deep_copied = copy.deepcopy(original) self.assertIsInstance(deep_copied.b, dict) self.assertIsNot(original, deep_copied) self.assertIsNot(original_sub, deep_copied.b["a"]) self.assertEqual([[1.]], deep_copied.a) self.assertIsInstance(deep_copied.b["a"], autotrackable.AutoTrackable) deps = util.list_objects(deep_copied) self.assertIn(deep_copied.a, deps) self.assertIn(deep_copied.b, deps) self.assertIn(deep_copied.b["a"], deps) self.assertNotIn(original_sub, deps)
def testDictionariesBasic(self): a = training.Model() b = training.Model() a.attribute = {"b": b} c = training.Model() a.attribute["c"] = [] a.attribute["c"].append(c) a_deps = util.list_objects(a) self.assertIn(b, a_deps) self.assertIn(c, a_deps) self.assertIs(b, a.attribute["b"]) self.assertEqual({"b", "c"}, a.attribute._trackable_children().keys()) self.assertEqual([b, c], a.layers) self.assertEqual([b, c], a.attribute.layers) self.assertEqual([c], a.attribute["c"].layers) checkpoint = tf.train.Checkpoint(a=a) save_path = checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt")) with self.cached_session(): checkpoint.restore( save_path).assert_consumed().initialize_or_restore()
def testNoDependency(self): root = tf.Module() hasdep = tf.Module() root.hasdep = hasdep nodep = tf.Module() root.nodep = data_structures.NoDependency(nodep) self.assertLen(root._trackable_children(), 1) self.assertIs(root._trackable_children()["hasdep"], root.hasdep) self.assertIs(root.hasdep, hasdep) self.assertIs(root.nodep, nodep) class NoDependencyModel(training.Model): @tf.__internal__.tracking.no_automatic_dependency_tracking def __init__(self): super().__init__() self.a = [] self.b = tf.Module() nodeps = NoDependencyModel() self.assertEqual([nodeps], util.list_objects(nodeps))
def testNonAppendNotTrackable(self): # Non-append mutations (deleting or overwriting values) are OK when the # values aren't tracked. a = tf.Module() a.d = {} a.d["a"] = [3] a.d[1] = 3 a.d[1] = 2 self.assertEqual(2, a.d[1]) del a.d[1] a.d[2] = data_structures.NoDependency(tf.Module()) second = tf.Module() a.d[2] = data_structures.NoDependency(second) self.assertIs(second, a.d[2]) self.assertEqual([a, a.d, a.d["a"]], util.list_objects(a)) model = training.Model() model.sub = a save_path = os.path.join(self.get_temp_dir(), "ckpt") model.save_weights(save_path) model.load_weights(save_path)
def testDictDeepCopy(self): root = autotrackable.AutoTrackable() orig_dict = {"a": [1.]} root.a = orig_dict copied = copy.deepcopy(root.a) self.assertAllEqual([1.], copied["a"]) self.assertIsNot(root.a, copied) self.assertIsNot(root.a["a"], copied["a"]) # Dirtiness should be inherited util.list_objects(root.a) orig_dict["b"] = [] with self.assertRaises(ValueError): util.list_objects(root.a) with self.assertRaises(ValueError): util.list_objects(copy.deepcopy(root.a))
def testListDeepCopy(self): root = autotrackable.AutoTrackable() orig_list = [[1.]] root.a = orig_list copied = copy.deepcopy(root.a) self.assertAllEqual([[1.]], copied) self.assertIsNot(root.a, copied) self.assertIsNot(root.a[0], copied[0]) # Dirtiness should be inherited util.list_objects(root.a) orig_list.append(1.) with self.assertRaises(ValueError): util.list_objects(root.a) with self.assertRaises(ValueError): util.list_objects(copy.deepcopy(root.a))
def test_trackable_save_restore(self): with self.test_session(): def _templated(): v = tf.compat.v1.get_variable( "v", shape=[1], initializer=tf.compat.v1.zeros_initializer(), use_resource=True, ) v2 = tf.compat.v1.get_variable( "v2", shape=[1], initializer=tf.compat.v1.zeros_initializer(), use_resource=True, ) manual = _ManualScope() return v, v + 1.0, v2, manual, manual() save_template = tf.compat.v1.make_template("s1", _templated) v1_save, _, v2_save, manual_scope, manual_scope_v = save_template() self.assertEqual( set([ id(v1_save), id(v2_save), id(manual_scope), id(manual_scope_v), id(save_template), ]), set(map(id, trackable_utils.list_objects(save_template))), ) self.assertDictEqual( {"in_manual_scope": manual_scope_v}, manual_scope._trackable_children(), ) optimizer = adam.Adam(0.0) save_root = tf.train.Checkpoint(my_template=save_template, optimizer=optimizer) optimizer.minimize(v1_save.read_value, var_list=[v1_save]) self.evaluate([v.initializer for v in save_template.variables]) optimizer_variables = optimizer.variables() + list( optimizer._hyper.values()) self.evaluate([v.initializer for v in optimizer_variables]) self.evaluate(v1_save.assign([12.0])) self.evaluate(v2_save.assign([14.0])) checkpoint_directory = self.get_temp_dir() checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") save_path = save_root.save(checkpoint_prefix) load_template = tf.compat.v1.make_template("s2", _templated) load_optimizer = adam.Adam(0.0) load_root = tf.train.Checkpoint(my_template=load_template, optimizer=load_optimizer) status = load_root.restore(save_path) var, var_plus_one, var2, _, _ = load_template() load_optimizer.minimize(var.read_value, var_list=[var]) children = load_template._trackable_children() self.assertEqual({"v", "v2", "ManualScope"}, children.keys()) status.assert_consumed().run_restore_ops() self.assertAllEqual([12.0], self.evaluate(var)) self.assertAllEqual([13.0], self.evaluate(var_plus_one)) self.assertAllEqual([14.0], self.evaluate(var2))