def testHashing(self): has_sequences = set([data_structures.List(), data_structures.List()]) self.assertEqual(2, len(has_sequences)) self.assertNotIn(data_structures.List(), has_sequences) with self.assertRaises(TypeError): has_sequences.add(data_structures._ListWrapper([]))
def testHashing(self): has_sequences = set([data_structures.List(), data_structures.List()]) self.assertEqual(2, len(has_sequences)) self.assertNotIn(data_structures.List(), has_sequences) with self.assertRaises(TypeError): has_sequences.add(data_structures._ListWrapper([]))
def testSetSlice_cannotSaveIfCheckpointableModified(self): v1 = resource_variable_ops.ResourceVariable(1.) v2 = resource_variable_ops.ResourceVariable(1.) l = data_structures._ListWrapper([1, 2, v1, v2]) l[:] = 2, 8, 9, v2 self.assertEqual(l, [2, 8, 9, v2]) self.assertUnableToSave(l, "Unable to save .*__setslice__")
def testSetSlice_cannotSaveIfCheckpointableModified(self): v1 = resource_variable_ops.ResourceVariable(1.) v2 = resource_variable_ops.ResourceVariable(1.) l = data_structures._ListWrapper([1, 2, v1, v2]) l[:] = 2, 8, 9, v2 self.assertEqual(l, [2, 8, 9, v2]) self.assertUnableToSave(l, "Unable to save .*__setslice__")
def testSort(self): l = data_structures._ListWrapper([1, 2, 3, 4]) l.sort() self.assertEqual(l, [1, 2, 3, 4]) # Regardless of being a no-op for the input list, we still refuse to save. # This is intentional since otherwise we would end up with a hard to debug # case for users (e.g. sometimes sort on a ListWrapper is checkpointable and # other times it is not). self.assertUnableToSave(l, "Unable to save .*sort")
def testSort(self): l = data_structures._ListWrapper([1, 2, 3, 4]) l.sort() self.assertEqual(l, [1, 2, 3, 4]) # Regardless of being a no-op for the input list, we still refuse to save. # This is intentional since otherwise we would end up with a hard to debug # case for users (e.g. sometimes sort on a ListWrapper is checkpointable and # other times it is not). self.assertUnableToSave(l, "Unable to save .*sort")
def testSetSlice_canSaveForNonCheckpointableItems(self): l = data_structures._ListWrapper([1, 2, 3, 4]) l[:] = 2, 8, 9, 0 self.assertEqual(l, [2, 8, 9, 0]) l._maybe_initialize_checkpointable() # pylint: disable=protected-access self.assertEqual(len(l._checkpoint_dependencies), 0) # pylint: disable=protected-access
def testDelSlice(self): l = data_structures._ListWrapper([1, 2, 3, 4]) del l[2:3] self.assertEqual(l, [1, 2, 4]) self.assertUnableToSave(l, "Unable to save .*__delslice__")
def testNotHashable(self): with self.assertRaises(TypeError): hash(data_structures._ListWrapper())
def testDelSlice(self): l = data_structures._ListWrapper([1, 2, 3, 4]) del l[2:3] self.assertEqual(l, [1, 2, 4]) self.assertUnableToSave(l, "Unable to save .*__delslice__")
def testListChangesWrapper(self): l = [] l_wrapper = data_structures._ListWrapper(l) l.append(1) self.assertEqual([1], l_wrapper)
def testListWrapperBasic(self): # _ListWrapper, unlike List, compares like the built-in list type (since it # is used to automatically replace lists). a = tracking.Checkpointable() b = tracking.Checkpointable() self.assertEqual([a, a], [a, a]) self.assertEqual(data_structures._ListWrapper([a, a]), data_structures._ListWrapper([a, a])) self.assertEqual([a, a], data_structures._ListWrapper([a, a])) self.assertEqual(data_structures._ListWrapper([a, a]), [a, a]) self.assertNotEqual([a, a], [b, a]) self.assertNotEqual(data_structures._ListWrapper([a, a]), data_structures._ListWrapper([b, a])) self.assertNotEqual([a, a], data_structures._ListWrapper([b, a])) self.assertLess([a], [a, b]) self.assertLess(data_structures._ListWrapper([a]), data_structures._ListWrapper([a, b])) self.assertLessEqual([a], [a, b]) self.assertLessEqual(data_structures._ListWrapper([a]), data_structures._ListWrapper([a, b])) self.assertGreater([a, b], [a]) self.assertGreater(data_structures._ListWrapper([a, b]), data_structures._ListWrapper([a])) self.assertGreaterEqual([a, b], [a]) self.assertGreaterEqual(data_structures._ListWrapper([a, b]), data_structures._ListWrapper([a])) self.assertEqual([a], data_structures._ListWrapper([a])) self.assertEqual([a], list(data_structures.List([a]))) self.assertEqual([a, a], data_structures._ListWrapper([a]) + [a]) self.assertEqual([a, a], [a] + data_structures._ListWrapper([a])) self.assertIsInstance(data_structures._ListWrapper([a]), list)
def testSetSlice_extend(self): l = data_structures._ListWrapper([1, 2, 3, 4]) l[2:] = 1, 2, 3, 4 self.assertEqual(l, [1, 2, 1, 2, 3, 4])
def testLayerCollectionWithExternalMutation(self): l = [] l_wrapper = data_structures._ListWrapper(l) layer = core.Dense(1) l.append(layer) self.assertEqual([layer], l_wrapper.layers)
def testSetSlice_extend(self): l = data_structures._ListWrapper([1, 2, 3, 4]) l[2:] = 1, 2, 3, 4 self.assertEqual(l, [1, 2, 1, 2, 3, 4])
def testSetSlice_truncate(self): l = data_structures._ListWrapper([1, 2, 3, 4]) l[:] = [] self.assertEqual(l, [])
def testSetSlice_canSaveForNonCheckpointableItems(self): l = data_structures._ListWrapper([1, 2, 3, 4]) l[:] = 2, 8, 9, 0 self.assertEqual(l, [2, 8, 9, 0]) l._maybe_initialize_checkpointable() # pylint: disable=protected-access self.assertEqual(len(l._checkpoint_dependencies), 0) # pylint: disable=protected-access
def testSetSlice_truncate(self): l = data_structures._ListWrapper([1, 2, 3, 4]) l[:] = [] self.assertEqual(l, [])
def testSlicing(self): l = data_structures._ListWrapper([1, 2, 3, 4]) self.assertEqual(l[1:], [2, 3, 4]) self.assertEqual(l[1:-1], [2, 3]) self.assertEqual(l[:-1], [1, 2, 3])
def testAcceptsNonCheckpointableContent(self): l = data_structures._ListWrapper([1, 2, 3]) self.assertEqual(l, [1, 2, 3])
def testListWrapperBasic(self): # _ListWrapper, unlike List, compares like the built-in list type (since it # is used to automatically replace lists). a = tracking.Checkpointable() b = tracking.Checkpointable() self.assertEqual([a, a], [a, a]) self.assertEqual(data_structures._ListWrapper([a, a]), data_structures._ListWrapper([a, a])) self.assertEqual([a, a], data_structures._ListWrapper([a, a])) self.assertEqual(data_structures._ListWrapper([a, a]), [a, a]) self.assertNotEqual([a, a], [b, a]) self.assertNotEqual(data_structures._ListWrapper([a, a]), data_structures._ListWrapper([b, a])) self.assertNotEqual([a, a], data_structures._ListWrapper([b, a])) self.assertLess([a], [a, b]) self.assertLess(data_structures._ListWrapper([a]), data_structures._ListWrapper([a, b])) self.assertLessEqual([a], [a, b]) self.assertLessEqual(data_structures._ListWrapper([a]), data_structures._ListWrapper([a, b])) self.assertGreater([a, b], [a]) self.assertGreater(data_structures._ListWrapper([a, b]), data_structures._ListWrapper([a])) self.assertGreaterEqual([a, b], [a]) self.assertGreaterEqual(data_structures._ListWrapper([a, b]), data_structures._ListWrapper([a])) self.assertEqual([a], data_structures._ListWrapper([a])) self.assertEqual([a], list(data_structures.List([a]))) self.assertEqual([a, a], data_structures._ListWrapper([a]) + [a]) self.assertEqual([a, a], [a] + data_structures._ListWrapper([a])) self.assertIsInstance(data_structures._ListWrapper([a]), list)
def testWrapperChangesList(self): l = [] l_wrapper = data_structures._ListWrapper(l) l_wrapper.append(1) self.assertEqual([1], l)
def testListChangesWrapper(self): l = [] l_wrapper = data_structures._ListWrapper(l) l.append(1) self.assertEqual([1], l_wrapper)
def testLayerCollectionWithExternalMutation(self): l = [] l_wrapper = data_structures._ListWrapper(l) layer = core.Dense(1) l.append(layer) self.assertEqual([layer], l_wrapper.layers)
def testDelItem(self): l = data_structures._ListWrapper([1, 2, 3, 4]) del l[0] self.assertEqual(l, [2, 3, 4]) self.assertUnableToSave(l, "Unable to save .*__delitem__")
def testWrapperChangesList(self): l = [] l_wrapper = data_structures._ListWrapper(l) l_wrapper.append(1) self.assertEqual([1], l)
def testNotHashable(self): with self.assertRaises(TypeError): hash(data_structures._ListWrapper())
def testDelItem(self): l = data_structures._ListWrapper([1, 2, 3, 4]) del l[0] self.assertEqual(l, [2, 3, 4]) self.assertUnableToSave(l, "Unable to save .*__delitem__")
def testAcceptsNonCheckpointableContent(self): l = data_structures._ListWrapper([1, 2, 3]) self.assertEqual(l, [1, 2, 3])