Example #1
0
 def testHashing(self):
   has_sequences = set([data_structures.List(),
                        data_structures.List()])
   self.assertEqual(2, len(has_sequences))
   self.assertNotIn(data_structures.List(), has_sequences)
   with self.assertRaises(TypeError):
     has_sequences.add(data_structures._ListWrapper([]))
 def testHashing(self):
   has_sequences = set([data_structures.List(),
                        data_structures.List()])
   self.assertEqual(2, len(has_sequences))
   self.assertNotIn(data_structures.List(), has_sequences)
   with self.assertRaises(TypeError):
     has_sequences.add(data_structures._ListWrapper([]))
 def testSetSlice_cannotSaveIfCheckpointableModified(self):
   v1 = resource_variable_ops.ResourceVariable(1.)
   v2 = resource_variable_ops.ResourceVariable(1.)
   l = data_structures._ListWrapper([1, 2, v1, v2])
   l[:] = 2, 8, 9, v2
   self.assertEqual(l, [2, 8, 9, v2])
   self.assertUnableToSave(l, "Unable to save .*__setslice__")
 def testSetSlice_cannotSaveIfCheckpointableModified(self):
     v1 = resource_variable_ops.ResourceVariable(1.)
     v2 = resource_variable_ops.ResourceVariable(1.)
     l = data_structures._ListWrapper([1, 2, v1, v2])
     l[:] = 2, 8, 9, v2
     self.assertEqual(l, [2, 8, 9, v2])
     self.assertUnableToSave(l, "Unable to save .*__setslice__")
 def testSort(self):
   l = data_structures._ListWrapper([1, 2, 3, 4])
   l.sort()
   self.assertEqual(l, [1, 2, 3, 4])
   # Regardless of being a no-op for the input list, we still refuse to save.
   # This is intentional since otherwise we would end up with a hard to debug
   # case for users (e.g. sometimes sort on a ListWrapper is checkpointable and
   # other times it is not).
   self.assertUnableToSave(l, "Unable to save .*sort")
 def testSort(self):
     l = data_structures._ListWrapper([1, 2, 3, 4])
     l.sort()
     self.assertEqual(l, [1, 2, 3, 4])
     # Regardless of being a no-op for the input list, we still refuse to save.
     # This is intentional since otherwise we would end up with a hard to debug
     # case for users (e.g. sometimes sort on a ListWrapper is checkpointable and
     # other times it is not).
     self.assertUnableToSave(l, "Unable to save .*sort")
 def testSetSlice_canSaveForNonCheckpointableItems(self):
     l = data_structures._ListWrapper([1, 2, 3, 4])
     l[:] = 2, 8, 9, 0
     self.assertEqual(l, [2, 8, 9, 0])
     l._maybe_initialize_checkpointable()  # pylint: disable=protected-access
     self.assertEqual(len(l._checkpoint_dependencies), 0)  # pylint: disable=protected-access
 def testDelSlice(self):
     l = data_structures._ListWrapper([1, 2, 3, 4])
     del l[2:3]
     self.assertEqual(l, [1, 2, 4])
     self.assertUnableToSave(l, "Unable to save .*__delslice__")
 def testNotHashable(self):
   with self.assertRaises(TypeError):
     hash(data_structures._ListWrapper())
 def testDelSlice(self):
   l = data_structures._ListWrapper([1, 2, 3, 4])
   del l[2:3]
   self.assertEqual(l, [1, 2, 4])
   self.assertUnableToSave(l, "Unable to save .*__delslice__")
Example #11
0
 def testListChangesWrapper(self):
   l = []
   l_wrapper = data_structures._ListWrapper(l)
   l.append(1)
   self.assertEqual([1], l_wrapper)
Example #12
0
 def testListWrapperBasic(self):
   # _ListWrapper, unlike List, compares like the built-in list type (since it
   # is used to automatically replace lists).
   a = tracking.Checkpointable()
   b = tracking.Checkpointable()
   self.assertEqual([a, a],
                    [a, a])
   self.assertEqual(data_structures._ListWrapper([a, a]),
                    data_structures._ListWrapper([a, a]))
   self.assertEqual([a, a],
                    data_structures._ListWrapper([a, a]))
   self.assertEqual(data_structures._ListWrapper([a, a]),
                    [a, a])
   self.assertNotEqual([a, a],
                       [b, a])
   self.assertNotEqual(data_structures._ListWrapper([a, a]),
                       data_structures._ListWrapper([b, a]))
   self.assertNotEqual([a, a],
                       data_structures._ListWrapper([b, a]))
   self.assertLess([a], [a, b])
   self.assertLess(data_structures._ListWrapper([a]),
                   data_structures._ListWrapper([a, b]))
   self.assertLessEqual([a], [a, b])
   self.assertLessEqual(data_structures._ListWrapper([a]),
                        data_structures._ListWrapper([a, b]))
   self.assertGreater([a, b], [a])
   self.assertGreater(data_structures._ListWrapper([a, b]),
                      data_structures._ListWrapper([a]))
   self.assertGreaterEqual([a, b], [a])
   self.assertGreaterEqual(data_structures._ListWrapper([a, b]),
                           data_structures._ListWrapper([a]))
   self.assertEqual([a], data_structures._ListWrapper([a]))
   self.assertEqual([a], list(data_structures.List([a])))
   self.assertEqual([a, a], data_structures._ListWrapper([a]) + [a])
   self.assertEqual([a, a], [a] + data_structures._ListWrapper([a]))
   self.assertIsInstance(data_structures._ListWrapper([a]), list)
 def testSetSlice_extend(self):
     l = data_structures._ListWrapper([1, 2, 3, 4])
     l[2:] = 1, 2, 3, 4
     self.assertEqual(l, [1, 2, 1, 2, 3, 4])
 def testLayerCollectionWithExternalMutation(self):
   l = []
   l_wrapper = data_structures._ListWrapper(l)
   layer = core.Dense(1)
   l.append(layer)
   self.assertEqual([layer], l_wrapper.layers)
 def testSetSlice_extend(self):
   l = data_structures._ListWrapper([1, 2, 3, 4])
   l[2:] = 1, 2, 3, 4
   self.assertEqual(l, [1, 2, 1, 2, 3, 4])
 def testSetSlice_truncate(self):
   l = data_structures._ListWrapper([1, 2, 3, 4])
   l[:] = []
   self.assertEqual(l, [])
 def testSetSlice_canSaveForNonCheckpointableItems(self):
   l = data_structures._ListWrapper([1, 2, 3, 4])
   l[:] = 2, 8, 9, 0
   self.assertEqual(l, [2, 8, 9, 0])
   l._maybe_initialize_checkpointable()  # pylint: disable=protected-access
   self.assertEqual(len(l._checkpoint_dependencies), 0)  # pylint: disable=protected-access
 def testSetSlice_truncate(self):
     l = data_structures._ListWrapper([1, 2, 3, 4])
     l[:] = []
     self.assertEqual(l, [])
 def testSlicing(self):
   l = data_structures._ListWrapper([1, 2, 3, 4])
   self.assertEqual(l[1:], [2, 3, 4])
   self.assertEqual(l[1:-1], [2, 3])
   self.assertEqual(l[:-1], [1, 2, 3])
 def testAcceptsNonCheckpointableContent(self):
     l = data_structures._ListWrapper([1, 2, 3])
     self.assertEqual(l, [1, 2, 3])
 def testListWrapperBasic(self):
   # _ListWrapper, unlike List, compares like the built-in list type (since it
   # is used to automatically replace lists).
   a = tracking.Checkpointable()
   b = tracking.Checkpointable()
   self.assertEqual([a, a],
                    [a, a])
   self.assertEqual(data_structures._ListWrapper([a, a]),
                    data_structures._ListWrapper([a, a]))
   self.assertEqual([a, a],
                    data_structures._ListWrapper([a, a]))
   self.assertEqual(data_structures._ListWrapper([a, a]),
                    [a, a])
   self.assertNotEqual([a, a],
                       [b, a])
   self.assertNotEqual(data_structures._ListWrapper([a, a]),
                       data_structures._ListWrapper([b, a]))
   self.assertNotEqual([a, a],
                       data_structures._ListWrapper([b, a]))
   self.assertLess([a], [a, b])
   self.assertLess(data_structures._ListWrapper([a]),
                   data_structures._ListWrapper([a, b]))
   self.assertLessEqual([a], [a, b])
   self.assertLessEqual(data_structures._ListWrapper([a]),
                        data_structures._ListWrapper([a, b]))
   self.assertGreater([a, b], [a])
   self.assertGreater(data_structures._ListWrapper([a, b]),
                      data_structures._ListWrapper([a]))
   self.assertGreaterEqual([a, b], [a])
   self.assertGreaterEqual(data_structures._ListWrapper([a, b]),
                           data_structures._ListWrapper([a]))
   self.assertEqual([a], data_structures._ListWrapper([a]))
   self.assertEqual([a], list(data_structures.List([a])))
   self.assertEqual([a, a], data_structures._ListWrapper([a]) + [a])
   self.assertEqual([a, a], [a] + data_structures._ListWrapper([a]))
   self.assertIsInstance(data_structures._ListWrapper([a]), list)
Example #22
0
 def testWrapperChangesList(self):
   l = []
   l_wrapper = data_structures._ListWrapper(l)
   l_wrapper.append(1)
   self.assertEqual([1], l)
 def testListChangesWrapper(self):
   l = []
   l_wrapper = data_structures._ListWrapper(l)
   l.append(1)
   self.assertEqual([1], l_wrapper)
 def testLayerCollectionWithExternalMutation(self):
     l = []
     l_wrapper = data_structures._ListWrapper(l)
     layer = core.Dense(1)
     l.append(layer)
     self.assertEqual([layer], l_wrapper.layers)
 def testDelItem(self):
   l = data_structures._ListWrapper([1, 2, 3, 4])
   del l[0]
   self.assertEqual(l, [2, 3, 4])
   self.assertUnableToSave(l, "Unable to save .*__delitem__")
 def testWrapperChangesList(self):
   l = []
   l_wrapper = data_structures._ListWrapper(l)
   l_wrapper.append(1)
   self.assertEqual([1], l)
 def testNotHashable(self):
     with self.assertRaises(TypeError):
         hash(data_structures._ListWrapper())
 def testDelItem(self):
     l = data_structures._ListWrapper([1, 2, 3, 4])
     del l[0]
     self.assertEqual(l, [2, 3, 4])
     self.assertUnableToSave(l, "Unable to save .*__delitem__")
 def testAcceptsNonCheckpointableContent(self):
   l = data_structures._ListWrapper([1, 2, 3])
   self.assertEqual(l, [1, 2, 3])