def testHashing(self): has_sequences = set([data_structures.List(), data_structures.List()]) self.assertEqual(2, len(has_sequences)) self.assertNotIn(data_structures.List(), has_sequences) with self.assertRaises(TypeError): has_sequences.add(data_structures._ListWrapper([]))
def testNesting(self): with context.graph_mode(): inner = data_structures.List() outer = data_structures.List([inner]) inner.append(non_keras_core.Dense(1)) inner[0](array_ops.ones([2, 3])) self.assertEqual(2, len(outer.variables)) self.assertIsInstance(outer.variables[0], resource_variable_ops.ResourceVariable)
def __init__(self): super(HasMapping, self).__init__() self.layer_dict = data_structures.Mapping(output=core.Dense(7)) self.layer_dict["norm"] = data_structures.List() self.layer_dict["dense"] = data_structures.List() self.layer_dict["dense"].extend([ core.Dense(5), core.Dense(6, kernel_regularizer=math_ops.reduce_sum) ]) self.layer_dict["norm"].append(normalization.BatchNormalization()) self.layer_dict["norm"].append(normalization.BatchNormalization())
def testNoOverwrite(self): mapping = data_structures.Mapping() original = data_structures.List() mapping["a"] = original with self.assertRaises(ValueError): mapping["a"] = data_structures.List() self.assertIs(original, mapping["a"]) with self.assertRaises(AttributeError): del mapping["a"] mapping.update(b=data_structures.Mapping()) with self.assertRaises(ValueError): mapping.update({"b": data_structures.Mapping()})
def testListWrapperBasic(self): # _ListWrapper, unlike List, compares like the built-in list type (since it # is used to automatically replace lists). a = tracking.AutoCheckpointable() b = tracking.AutoCheckpointable() self.assertEqual([a, a], [a, a]) self.assertEqual(data_structures._ListWrapper([a, a]), data_structures._ListWrapper([a, a])) self.assertEqual([a, a], data_structures._ListWrapper([a, a])) self.assertEqual(data_structures._ListWrapper([a, a]), [a, a]) self.assertNotEqual([a, a], [b, a]) self.assertNotEqual(data_structures._ListWrapper([a, a]), data_structures._ListWrapper([b, a])) self.assertNotEqual([a, a], data_structures._ListWrapper([b, a])) self.assertLess([a], [a, b]) self.assertLess(data_structures._ListWrapper([a]), data_structures._ListWrapper([a, b])) self.assertLessEqual([a], [a, b]) self.assertLessEqual(data_structures._ListWrapper([a]), data_structures._ListWrapper([a, b])) self.assertGreater([a, b], [a]) self.assertGreater(data_structures._ListWrapper([a, b]), data_structures._ListWrapper([a])) self.assertGreaterEqual([a, b], [a]) self.assertGreaterEqual(data_structures._ListWrapper([a, b]), data_structures._ListWrapper([a])) self.assertEqual([a], data_structures._ListWrapper([a])) self.assertEqual([a], list(data_structures.List([a]))) self.assertEqual([a, a], data_structures._ListWrapper([a]) + [a]) self.assertEqual([a, a], [a] + data_structures._ListWrapper([a])) self.assertIsInstance(data_structures._ListWrapper([a]), list)
def testDictWrapperBadKeys(self): a = tracking.Checkpointable() a.d = {} a.d[1] = data_structures.List() model = training.Model() model.sub = a save_path = os.path.join(self.get_temp_dir(), "ckpt") with self.assertRaisesRegexp(ValueError, "non-string key"): model.save_weights(save_path)
def testSlicing(self): v1 = resource_variable_ops.ResourceVariable(1.) v2 = resource_variable_ops.ResourceVariable(1.) v3 = resource_variable_ops.ResourceVariable(1.) v4 = resource_variable_ops.ResourceVariable(1.) l = data_structures.List([v1, v2, v3, v4]) self.assertEqual(l[1:], [v2, v3, v4]) self.assertEqual(l[1:-1], [v2, v3]) self.assertEqual(l[:-1], [v1, v2, v3])
def __init__(self): super(HasList, self).__init__() self.layer_list = data_structures.List([core.Dense(3)]) self.layer_list.append(core.Dense(4)) self.layer_list.extend( [core.Dense(5), core.Dense(6, kernel_regularizer=math_ops.reduce_sum)]) self.layer_list += [ core.Dense(7, bias_regularizer=math_ops.reduce_sum), core.Dense(8) ] self.layer_list += ( data_structures.List([core.Dense(9)]) + data_structures.List( [core.Dense(10)])) self.layer_list.extend( data_structures.List( list(sequence=[core.Dense(11)]) + [core.Dense(12)])) self.layers_with_updates = data_structures.List( sequence=(normalization.BatchNormalization(),))
def testCopy(self): v1 = resource_variable_ops.ResourceVariable(1.) v2 = resource_variable_ops.ResourceVariable(1.) v3 = resource_variable_ops.ResourceVariable(1.) l1 = data_structures.List([v1, v2]) l2 = l1.copy() l2.append(v3) self.assertEqual(list(l1), [v1, v2]) self.assertEqual(list(l2), [v1, v2, v3])
def testNonLayerVariables(self): v = resource_variable_ops.ResourceVariable([1.]) l = data_structures.List([v]) self.assertTrue(l.trainable) self.assertEqual([], l.layers) self.assertEqual([v], l.variables) self.assertEqual([v], l.trainable_weights) self.assertEqual([], l.non_trainable_variables) l.trainable = False self.assertEqual([v], l.variables) self.assertEqual([], l.trainable_variables) self.assertEqual([v], l.non_trainable_variables) l.trainable = True v2 = resource_variable_ops.ResourceVariable(1., trainable=False) l.append(v2) self.assertEqual([v, v2], l.weights) self.assertEqual([v], l.trainable_weights) self.assertEqual([v2], l.non_trainable_weights)
def testNotCheckpointable(self): class NotCheckpointable(object): pass with self.assertRaises(ValueError): data_structures.List([NotCheckpointable()])
def testHash(self): has_sequences = set([data_structures.List(), data_structures.List()]) self.assertEqual(2, len(has_sequences)) self.assertNotIn(data_structures.List(), has_sequences)
def testNonStringKeys(self): mapping = data_structures.Mapping() with self.assertRaises(TypeError): mapping[1] = data_structures.List()
def testIMul_zero(self): l = data_structures.List([]) with self.assertRaisesRegexp(ValueError, "List only supports append"): l *= 0
def testIMul(self): v = resource_variable_ops.ResourceVariable(1.) l = data_structures.List([v]) l *= 2 self.assertEqual(list(l), [v] * 2)
def testNoPop(self): with self.assertRaises(AttributeError): data_structures.List().pop()
def testCallNotImplemented(self): with self.assertRaisesRegexp(TypeError, "not callable"): data_structures.List()(1.)
def testRMul(self): v = resource_variable_ops.ResourceVariable(1.) l = data_structures.List([v, v, v]) self.assertEqual(list(2 * l), [v, v, v] * 2)