def testNesting(self): with context.graph_mode(): inner = data_structures.List() outer = data_structures.List([inner]) inner.append(non_keras_core.Dense(1)) inner[0](array_ops.ones([2, 3])) self.assertEqual(2, len(outer.variables)) self.assertIsInstance(outer.variables[0], resource_variable_ops.ResourceVariable)
def testNoOverwrite(self): mapping = data_structures.Mapping() original = data_structures.List() mapping["a"] = original with self.assertRaises(ValueError): mapping["a"] = data_structures.List() self.assertIs(original, mapping["a"]) with self.assertRaises(AttributeError): del mapping["a"] # pylint: disable=unsupported-delete-operation mapping.update(b=data_structures.Mapping()) with self.assertRaises(ValueError): mapping.update({"b": data_structures.Mapping()})
def testListWrapperBasic(self): # ListWrapper, unlike List, compares like the built-in list type (since it # is used to automatically replace lists). a = tracking.AutoTrackable() b = tracking.AutoTrackable() self.assertEqual([a, a], [a, a]) self.assertEqual(data_structures.ListWrapper([a, a]), data_structures.ListWrapper([a, a])) self.assertEqual([a, a], data_structures.ListWrapper([a, a])) self.assertEqual(data_structures.ListWrapper([a, a]), [a, a]) self.assertNotEqual([a, a], [b, a]) self.assertNotEqual(data_structures.ListWrapper([a, a]), data_structures.ListWrapper([b, a])) self.assertNotEqual([a, a], data_structures.ListWrapper([b, a])) self.assertLess([a], [a, b]) self.assertLess(data_structures.ListWrapper([a]), data_structures.ListWrapper([a, b])) self.assertLessEqual([a], [a, b]) self.assertLessEqual(data_structures.ListWrapper([a]), data_structures.ListWrapper([a, b])) self.assertGreater([a, b], [a]) self.assertGreater(data_structures.ListWrapper([a, b]), data_structures.ListWrapper([a])) self.assertGreaterEqual([a, b], [a]) self.assertGreaterEqual(data_structures.ListWrapper([a, b]), data_structures.ListWrapper([a])) self.assertEqual([a], data_structures.ListWrapper([a])) self.assertEqual([a], list(data_structures.List([a]))) self.assertEqual([a, a], data_structures.ListWrapper([a]) + [a]) self.assertEqual([a, a], [a] + data_structures.ListWrapper([a])) self.assertIsInstance(data_structures.ListWrapper([a]), list) self.assertEqual( tensor_shape.TensorShape([None, 2]).as_list(), (data_structures.ListWrapper([None]) + tensor_shape.TensorShape([2])).as_list())
def testCopy(self): v1 = resource_variable_ops.ResourceVariable(1.) v2 = resource_variable_ops.ResourceVariable(1.) v3 = resource_variable_ops.ResourceVariable(1.) l1 = data_structures.List([v1, v2]) l2 = l1.copy() l2.append(v3) self.assertEqual(list(l1), [v1, v2]) self.assertEqual(list(l2), [v1, v2, v3])
def testSlicing(self): v1 = resource_variable_ops.ResourceVariable(1.) v2 = resource_variable_ops.ResourceVariable(1.) v3 = resource_variable_ops.ResourceVariable(1.) v4 = resource_variable_ops.ResourceVariable(1.) l = data_structures.List([v1, v2, v3, v4]) self.assertEqual(l[1:], [v2, v3, v4]) self.assertEqual(l[1:-1], [v2, v3]) self.assertEqual(l[:-1], [v1, v2, v3])
def testNonLayerVariables(self): v = resource_variable_ops.ResourceVariable([1.]) l = data_structures.List([v]) self.assertTrue(l.trainable) self.assertEqual([], l.layers) self.assertEqual([v], l.variables) self.assertEqual([v], l.trainable_weights) self.assertEqual([], l.non_trainable_variables) l.trainable = False self.assertEqual([v], l.variables) self.assertEqual([], l.trainable_variables) self.assertEqual([v], l.non_trainable_variables) l.trainable = True v2 = resource_variable_ops.ResourceVariable(1., trainable=False) l.append(v2) self.assertEqual([v, v2], l.weights) self.assertEqual([v], l.trainable_weights) self.assertEqual([v2], l.non_trainable_weights)
def testNoPop(self): with self.assertRaises(AttributeError): data_structures.List().pop()
def testCallNotImplemented(self): with self.assertRaisesRegex(TypeError, "not callable"): data_structures.List()(1.) # pylint: disable=not-callable
def testNotTrackable(self): class NotTrackable(object): pass with self.assertRaises(ValueError): data_structures.List([NotTrackable()])
def testNonStringKeys(self): mapping = data_structures.Mapping() with self.assertRaises(TypeError): mapping[1] = data_structures.List()
def testRMul(self): v = resource_variable_ops.ResourceVariable(1.) l = data_structures.List([v, v, v]) self.assertEqual(list(2 * l), [v, v, v] * 2)
def testIMul(self): v = resource_variable_ops.ResourceVariable(1.) l = data_structures.List([v]) l *= 2 self.assertEqual(list(l), [v] * 2)
def testIMul_zero(self): l = data_structures.List([]) with self.assertRaisesRegex(ValueError, "List only supports append"): l *= 0
def testHash(self): has_sequences = {data_structures.List(), data_structures.List()} self.assertEqual(2, len(has_sequences)) self.assertNotIn(data_structures.List(), has_sequences)