def testListAddOrder(self): self.assertEqual([1., 2.], data_structures._ListWrapper([1.]) + data_structures._ListWrapper([2.])) self.assertEqual([1., 2.], data_structures._ListWrapper([1.]) + [2.]) self.assertEqual([1., 2.], [1.] + data_structures._ListWrapper([2.]))
def testSetSlice_cannotSaveIfTrackableModified(self): v1 = resource_variable_ops.ResourceVariable(1.) v2 = resource_variable_ops.ResourceVariable(1.) l = data_structures._ListWrapper([1, 2, v1, v2]) l[:] = 2, 8, 9, v2 self.assertEqual(l, [2, 8, 9, v2]) self.assertUnableToSave(l, "Unable to save .*__setslice__")
def testFunctionCaching(self): @def_function.function def f(list_input): return list_input[0] + constant_op.constant(1.) first_trace = f.get_concrete_function([constant_op.constant(2.)]) second_trace = f.get_concrete_function( data_structures._ListWrapper([constant_op.constant(3.)])) self.assertIs(first_trace, second_trace)
def testSort(self): l = data_structures._ListWrapper([1, 2, 3, 4]) l.sort() self.assertEqual(l, [1, 2, 3, 4]) # Regardless of being a no-op for the input list, we still refuse to save. # This is intentional since otherwise we would end up with a hard to debug # case for users (e.g. sometimes sort on a ListWrapper is trackable and # other times it is not). self.assertUnableToSave(l, "Unable to save .*sort")
def testIMulPositive(self): v = variables.Variable(1.) l = data_structures._ListWrapper([1, 2, 3, 4, v]) self.assertEqual([("4", v)], l._checkpoint_dependencies) root = util.Checkpoint(l=l) prefix = os.path.join(self.get_temp_dir(), "ckpt") path = root.save(prefix) v.assign(5.) l *= 2 self.assertEqual(l, [1, 2, 3, 4, v, 1, 2, 3, 4, v]) self.assertEqual([("4", v), ("9", v)], l._checkpoint_dependencies) root.restore(path) self.assertAllClose(1., v.numpy())
def testDelItem(self): l = data_structures._ListWrapper([1, 2, 3, 4]) del l[0] self.assertEqual(l, [2, 3, 4]) self.assertUnableToSave(l, "Unable to save .*__delitem__")
def testLayerCollectionWithExternalMutation(self): l = [] l_wrapper = data_structures._ListWrapper(l) layer = core.Dense(1) l.append(layer) self.assertEqual([layer], l_wrapper.layers)
def testSameStructure(self): l = [1] nest.assert_same_structure(l, data_structures._ListWrapper(copy.copy(l)))
def testListWrapperBasic(self): # _ListWrapper, unlike List, compares like the built-in list type (since it # is used to automatically replace lists). a = tracking.AutoTrackable() b = tracking.AutoTrackable() self.assertEqual([a, a], [a, a]) self.assertEqual(data_structures._ListWrapper([a, a]), data_structures._ListWrapper([a, a])) self.assertEqual([a, a], data_structures._ListWrapper([a, a])) self.assertEqual(data_structures._ListWrapper([a, a]), [a, a]) self.assertNotEqual([a, a], [b, a]) self.assertNotEqual(data_structures._ListWrapper([a, a]), data_structures._ListWrapper([b, a])) self.assertNotEqual([a, a], data_structures._ListWrapper([b, a])) self.assertLess([a], [a, b]) self.assertLess(data_structures._ListWrapper([a]), data_structures._ListWrapper([a, b])) self.assertLessEqual([a], [a, b]) self.assertLessEqual(data_structures._ListWrapper([a]), data_structures._ListWrapper([a, b])) self.assertGreater([a, b], [a]) self.assertGreater(data_structures._ListWrapper([a, b]), data_structures._ListWrapper([a])) self.assertGreaterEqual([a, b], [a]) self.assertGreaterEqual(data_structures._ListWrapper([a, b]), data_structures._ListWrapper([a])) self.assertEqual([a], data_structures._ListWrapper([a])) self.assertEqual([a], list(data_structures.List([a]))) self.assertEqual([a, a], data_structures._ListWrapper([a]) + [a]) self.assertEqual([a, a], [a] + data_structures._ListWrapper([a])) self.assertIsInstance(data_structures._ListWrapper([a]), list)
def testWrapperChangesList(self): l = [] l_wrapper = data_structures._ListWrapper(l) l_wrapper.append(1) self.assertEqual([1], l)
def testSetSlice_truncate(self): l = data_structures._ListWrapper([1, 2, 3, 4]) l[:] = [] self.assertEqual(l, [])
def testIMulNegative(self): l = data_structures._ListWrapper([1, 2, 3, 4]) l *= -1 self.assertEqual(l, [1, 2, 3, 4] * -1) self.assertUnableToSave(l, "Unable to save")
def testSetSlice_extend(self): l = data_structures._ListWrapper([1, 2, 3, 4]) l[2:] = 1, 2, 3, 4 self.assertEqual(l, [1, 2, 1, 2, 3, 4])
def testSetSlice_canSaveForNonTrackableItems(self): l = data_structures._ListWrapper([1, 2, 3, 4]) l[:] = 2, 8, 9, 0 self.assertEqual(l, [2, 8, 9, 0]) l._maybe_initialize_trackable() # pylint: disable=protected-access self.assertEqual(len(l._checkpoint_dependencies), 0) # pylint: disable=protected-access
def testDelSlice(self): l = data_structures._ListWrapper([1, 2, 3, 4]) del l[2:3] self.assertEqual(l, [1, 2, 4]) self.assertUnableToSave(l, "Unable to save .*__delslice__")
def testNotHashable(self): with self.assertRaises(TypeError): hash(data_structures._ListWrapper())
def testAcceptsNonTrackableContent(self): l = data_structures._ListWrapper([1, 2, 3]) self.assertEqual(l, [1, 2, 3])
def testPickle(self): original = data_structures._ListWrapper([1, 2]) serialized = pickle.dumps(original) del original deserialized = pickle.loads(serialized) self.assertEqual([1, 2], deserialized)
def testListChangesWrapper(self): l = [] l_wrapper = data_structures._ListWrapper(l) l.append(1) self.assertEqual([1], l_wrapper)