Exemple #1
0
 def testNesting(self):
     with context.graph_mode():
         inner = data_structures.List()
         outer = data_structures.List([inner])
         inner.append(non_keras_core.Dense(1))
         inner[0](array_ops.ones([2, 3]))
         self.assertEqual(2, len(outer.variables))
         self.assertIsInstance(outer.variables[0],
                               resource_variable_ops.ResourceVariable)
Exemple #2
0
 def __init__(self):
     super(HasMapping, self).__init__()
     self.layer_dict = data_structures.Mapping(output=core.Dense(7))
     self.layer_dict["norm"] = data_structures.List()
     self.layer_dict["dense"] = data_structures.List()
     self.layer_dict["dense"].extend(
         [core.Dense(5),
          core.Dense(6, kernel_regularizer=tf.reduce_sum)])
     self.layer_dict["norm"].append(normalization.BatchNormalization())
     self.layer_dict["norm"].append(normalization.BatchNormalization())
Exemple #3
0
 def testNoOverwrite(self):
     mapping = data_structures.Mapping()
     original = data_structures.List()
     mapping["a"] = original
     with self.assertRaises(ValueError):
         mapping["a"] = data_structures.List()
     self.assertIs(original, mapping["a"])
     with self.assertRaises(AttributeError):
         del mapping["a"]
     mapping.update(b=data_structures.Mapping())
     with self.assertRaises(ValueError):
         mapping.update({"b": data_structures.Mapping()})
 def testNoOverwrite(self):
   mapping = data_structures.Mapping()
   original = data_structures.List()
   mapping["a"] = original
   with self.assertRaises(ValueError):
     mapping["a"] = data_structures.List()
   self.assertIs(original, mapping["a"])
   with self.assertRaises(AttributeError):
     del mapping["a"]  # pylint: disable=unsupported-delete-operation
   mapping.update(b=data_structures.Mapping())
   with self.assertRaises(ValueError):
     mapping.update({"b": data_structures.Mapping()})
Exemple #5
0
 def testListWrapperBasic(self):
     # _ListWrapper, unlike List, compares like the built-in list type (since it
     # is used to automatically replace lists).
     a = tracking.AutoTrackable()
     b = tracking.AutoTrackable()
     self.assertEqual([a, a], [a, a])
     self.assertEqual(data_structures._ListWrapper([a, a]),
                      data_structures._ListWrapper([a, a]))
     self.assertEqual([a, a], data_structures._ListWrapper([a, a]))
     self.assertEqual(data_structures._ListWrapper([a, a]), [a, a])
     self.assertNotEqual([a, a], [b, a])
     self.assertNotEqual(data_structures._ListWrapper([a, a]),
                         data_structures._ListWrapper([b, a]))
     self.assertNotEqual([a, a], data_structures._ListWrapper([b, a]))
     self.assertLess([a], [a, b])
     self.assertLess(data_structures._ListWrapper([a]),
                     data_structures._ListWrapper([a, b]))
     self.assertLessEqual([a], [a, b])
     self.assertLessEqual(data_structures._ListWrapper([a]),
                          data_structures._ListWrapper([a, b]))
     self.assertGreater([a, b], [a])
     self.assertGreater(data_structures._ListWrapper([a, b]),
                        data_structures._ListWrapper([a]))
     self.assertGreaterEqual([a, b], [a])
     self.assertGreaterEqual(data_structures._ListWrapper([a, b]),
                             data_structures._ListWrapper([a]))
     self.assertEqual([a], data_structures._ListWrapper([a]))
     self.assertEqual([a], list(data_structures.List([a])))
     self.assertEqual([a, a], data_structures._ListWrapper([a]) + [a])
     self.assertEqual([a, a], [a] + data_structures._ListWrapper([a]))
     self.assertIsInstance(data_structures._ListWrapper([a]), list)
Exemple #6
0
 def __init__(self):
     super(HasList, self).__init__()
     self.layer_list = data_structures.List([core.Dense(3)])
     self.layer_list.append(core.Dense(4))
     self.layer_list.extend(
         [core.Dense(5),
          core.Dense(6, kernel_regularizer=tf.reduce_sum)])
     self.layer_list += [
         core.Dense(7, bias_regularizer=tf.reduce_sum),
         core.Dense(8)
     ]
     self.layer_list += (data_structures.List([core.Dense(9)]) +
                         data_structures.List([core.Dense(10)]))
     self.layer_list.extend(
         data_structures.List(list([core.Dense(11)]) + [core.Dense(12)]))
     self.layers_with_updates = data_structures.List(
         (normalization.BatchNormalization(), ))
Exemple #7
0
 def testDictWrapperBadKeys(self):
     a = tracking.AutoTrackable()
     a.d = {}
     a.d[1] = data_structures.List()
     model = training.Model()
     model.sub = a
     save_path = os.path.join(self.get_temp_dir(), "ckpt")
     with self.assertRaisesRegexp(ValueError, "non-string key"):
         model.save_weights(save_path)
Exemple #8
0
    def testCopy(self):
        v1 = resource_variable_ops.ResourceVariable(1.)
        v2 = resource_variable_ops.ResourceVariable(1.)
        v3 = resource_variable_ops.ResourceVariable(1.)

        l1 = data_structures.List([v1, v2])
        l2 = l1.copy()
        l2.append(v3)
        self.assertEqual(list(l1), [v1, v2])
        self.assertEqual(list(l2), [v1, v2, v3])
Exemple #9
0
    def testSlicing(self):
        v1 = resource_variable_ops.ResourceVariable(1.)
        v2 = resource_variable_ops.ResourceVariable(1.)
        v3 = resource_variable_ops.ResourceVariable(1.)
        v4 = resource_variable_ops.ResourceVariable(1.)

        l = data_structures.List([v1, v2, v3, v4])
        self.assertEqual(l[1:], [v2, v3, v4])
        self.assertEqual(l[1:-1], [v2, v3])
        self.assertEqual(l[:-1], [v1, v2, v3])
Exemple #10
0
 def testNonLayerVariables(self):
     v = resource_variable_ops.ResourceVariable([1.])
     l = data_structures.List([v])
     self.assertTrue(l.trainable)
     self.assertEqual([], l.layers)
     self.assertEqual([v], l.variables)
     self.assertEqual([v], l.trainable_weights)
     self.assertEqual([], l.non_trainable_variables)
     l.trainable = False
     self.assertEqual([v], l.variables)
     self.assertEqual([], l.trainable_variables)
     self.assertEqual([v], l.non_trainable_variables)
     l.trainable = True
     v2 = resource_variable_ops.ResourceVariable(1., trainable=False)
     l.append(v2)
     self.assertEqual([v, v2], l.weights)
     self.assertEqual([v], l.trainable_weights)
     self.assertEqual([v2], l.non_trainable_weights)
Exemple #11
0
 def testNonStringKeys(self):
     mapping = data_structures.Mapping()
     with self.assertRaises(TypeError):
         mapping[1] = data_structures.List()
Exemple #12
0
    def testNotTrackable(self):
        class NotTrackable(object):
            pass

        with self.assertRaises(ValueError):
            data_structures.List([NotTrackable()])
Exemple #13
0
 def testCallNotImplemented(self):
     with self.assertRaisesRegexp(TypeError, "not callable"):
         data_structures.List()(1.)
Exemple #14
0
 def testRMul(self):
     v = resource_variable_ops.ResourceVariable(1.)
     l = data_structures.List([v, v, v])
     self.assertEqual(list(2 * l), [v, v, v] * 2)
Exemple #15
0
 def testIMul(self):
     v = resource_variable_ops.ResourceVariable(1.)
     l = data_structures.List([v])
     l *= 2
     self.assertEqual(list(l), [v] * 2)
Exemple #16
0
 def testIMul_zero(self):
     l = data_structures.List([])
     with self.assertRaisesRegexp(ValueError, "List only supports append"):
         l *= 0
Exemple #17
0
 def testHash(self):
     has_sequences = set([data_structures.List(), data_structures.List()])
     self.assertEqual(2, len(has_sequences))
     self.assertNotIn(data_structures.List(), has_sequences)
Exemple #18
0
 def testNoPop(self):
     with self.assertRaises(AttributeError):
         data_structures.List().pop()