def testMutationDirtiesList(self):
   a = autotrackable.AutoTrackable()
   b = autotrackable.AutoTrackable()
   a.l = [b]
   c = autotrackable.AutoTrackable()
   a.l.insert(0, c)
   checkpoint = util.Checkpoint(a=a)
   with self.assertRaisesRegex(ValueError, "A list element was replaced"):
     checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
 def testOutOfBandEditDirtiesList(self):
   a = autotrackable.AutoTrackable()
   b = autotrackable.AutoTrackable()
   held_reference = [b]
   a.l = held_reference
   c = autotrackable.AutoTrackable()
   held_reference.append(c)
   checkpoint = util.Checkpoint(a=a)
   with self.assertRaisesRegex(ValueError, "The wrapped list was modified"):
     checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
 def testListBasic(self):
   a = autotrackable.AutoTrackable()
   b = autotrackable.AutoTrackable()
   a.l = [b]
   c = autotrackable.AutoTrackable()
   a.l.append(c)
   a_deps = util.list_objects(a)
   self.assertIn(b, a_deps)
   self.assertIn(c, a_deps)
   direct_a_dep, = a._checkpoint_dependencies
   self.assertEqual("l", direct_a_dep.name)
   self.assertIn(b, direct_a_dep.ref)
   self.assertIn(c, direct_a_dep.ref)
Exemplo n.º 4
0
 def testListBasic(self):
     a = autotrackable.AutoTrackable()
     b = autotrackable.AutoTrackable()
     a.l = [b]
     c = autotrackable.AutoTrackable()
     a.l.append(c)
     a_deps = util.list_objects(a)
     self.assertIn(b, a_deps)
     self.assertIn(c, a_deps)
     self.assertIn("l", a._trackable_children())
     direct_a_dep = a._trackable_children()["l"]
     self.assertIn(b, direct_a_dep)
     self.assertIn(c, direct_a_dep)
Exemplo n.º 5
0
 def testMultipleAssignment(self):
     root = autotrackable.AutoTrackable()
     root.leaf = autotrackable.AutoTrackable()
     root.leaf = root.leaf
     duplicate_name_dep = autotrackable.AutoTrackable()
     with self.assertRaisesRegex(ValueError, "already declared"):
         root._track_trackable(duplicate_name_dep, name="leaf")
     # No error; we're overriding __setattr__, so we can't really stop people
     # from doing this while maintaining backward compatibility.
     root.leaf = duplicate_name_dep
     root._track_trackable(duplicate_name_dep, name="leaf", overwrite=True)
     self.assertIs(duplicate_name_dep, root._lookup_dependency("leaf"))
     self.assertIs(duplicate_name_dep, root._trackable_children()["leaf"])
 def testRemoveDependency(self):
   root = autotrackable.AutoTrackable()
   root.a = autotrackable.AutoTrackable()
   self.assertEqual(1, len(root._checkpoint_dependencies))
   self.assertEqual(1, len(root._unconditional_checkpoint_dependencies))
   self.assertIs(root.a, root._checkpoint_dependencies[0].ref)
   del root.a
   self.assertFalse(hasattr(root, "a"))
   self.assertEqual(0, len(root._checkpoint_dependencies))
   self.assertEqual(0, len(root._unconditional_checkpoint_dependencies))
   root.a = autotrackable.AutoTrackable()
   self.assertEqual(1, len(root._checkpoint_dependencies))
   self.assertEqual(1, len(root._unconditional_checkpoint_dependencies))
   self.assertIs(root.a, root._checkpoint_dependencies[0].ref)
 def testAssertions(self):
   a = autotrackable.AutoTrackable()
   a.l = {"k": [np.zeros([2, 2])]}
   self.assertAllEqual(nest.flatten({"k": [np.zeros([2, 2])]}),
                       nest.flatten(a.l))
   self.assertAllClose({"k": [np.zeros([2, 2])]}, a.l)
   nest.map_structure(self.assertAllClose, a.l, {"k": [np.zeros([2, 2])]})
   a.tensors = {"k": [array_ops.ones([2, 2]), array_ops.zeros([3, 3])]}
   self.assertAllClose({"k": [np.ones([2, 2]), np.zeros([3, 3])]},
                       self.evaluate(a.tensors))
 def testNestedLists(self):
   a = autotrackable.AutoTrackable()
   a.l = []
   b = autotrackable.AutoTrackable()
   a.l.append([b])
   c = autotrackable.AutoTrackable()
   a.l[0].append(c)
   a_deps = util.list_objects(a)
   self.assertIn(b, a_deps)
   self.assertIn(c, a_deps)
   a.l[0].append(1)
   d = autotrackable.AutoTrackable()
   a.l[0].append(d)
   a_deps = util.list_objects(a)
   self.assertIn(d, a_deps)
   self.assertIn(b, a_deps)
   self.assertIn(c, a_deps)
   self.assertNotIn(1, a_deps)
   e = autotrackable.AutoTrackable()
   f = autotrackable.AutoTrackable()
   a.l1 = [[], [e]]
   a.l1[0].append(f)
   a_deps = util.list_objects(a)
   self.assertIn(e, a_deps)
   self.assertIn(f, a_deps)
   checkpoint = util.Checkpoint(a=a)
   checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
   a.l[0].append(data_structures.NoDependency([]))
   a.l[0][-1].append(5)
   checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))
   # Dirtying the inner list means the root object is unsaveable.
   a.l[0][1] = 2
   with self.assertRaisesRegex(ValueError, "A list element was replaced"):
     checkpoint.save(os.path.join(self.get_temp_dir(), "ckpt"))