示例#1
0
 def testNesting(self):
     with context.graph_mode():
         inner = data_structures.List()
         outer = data_structures.List([inner])
         inner.append(non_keras_core.Dense(1))
         inner[0](array_ops.ones([2, 3]))
         self.assertEqual(2, len(outer.variables))
         self.assertIsInstance(outer.variables[0],
                               resource_variable_ops.ResourceVariable)
示例#2
0
 def testNoOverwrite(self):
     mapping = data_structures.Mapping()
     original = data_structures.List()
     mapping["a"] = original
     with self.assertRaises(ValueError):
         mapping["a"] = data_structures.List()
     self.assertIs(original, mapping["a"])
     with self.assertRaises(AttributeError):
         del mapping["a"]  # pylint: disable=unsupported-delete-operation
     mapping.update(b=data_structures.Mapping())
     with self.assertRaises(ValueError):
         mapping.update({"b": data_structures.Mapping()})
示例#3
0
 def testListWrapperBasic(self):
     # ListWrapper, unlike List, compares like the built-in list type (since it
     # is used to automatically replace lists).
     a = tracking.AutoTrackable()
     b = tracking.AutoTrackable()
     self.assertEqual([a, a], [a, a])
     self.assertEqual(data_structures.ListWrapper([a, a]),
                      data_structures.ListWrapper([a, a]))
     self.assertEqual([a, a], data_structures.ListWrapper([a, a]))
     self.assertEqual(data_structures.ListWrapper([a, a]), [a, a])
     self.assertNotEqual([a, a], [b, a])
     self.assertNotEqual(data_structures.ListWrapper([a, a]),
                         data_structures.ListWrapper([b, a]))
     self.assertNotEqual([a, a], data_structures.ListWrapper([b, a]))
     self.assertLess([a], [a, b])
     self.assertLess(data_structures.ListWrapper([a]),
                     data_structures.ListWrapper([a, b]))
     self.assertLessEqual([a], [a, b])
     self.assertLessEqual(data_structures.ListWrapper([a]),
                          data_structures.ListWrapper([a, b]))
     self.assertGreater([a, b], [a])
     self.assertGreater(data_structures.ListWrapper([a, b]),
                        data_structures.ListWrapper([a]))
     self.assertGreaterEqual([a, b], [a])
     self.assertGreaterEqual(data_structures.ListWrapper([a, b]),
                             data_structures.ListWrapper([a]))
     self.assertEqual([a], data_structures.ListWrapper([a]))
     self.assertEqual([a], list(data_structures.List([a])))
     self.assertEqual([a, a], data_structures.ListWrapper([a]) + [a])
     self.assertEqual([a, a], [a] + data_structures.ListWrapper([a]))
     self.assertIsInstance(data_structures.ListWrapper([a]), list)
     self.assertEqual(
         tensor_shape.TensorShape([None, 2]).as_list(),
         (data_structures.ListWrapper([None]) +
          tensor_shape.TensorShape([2])).as_list())
示例#4
0
    def testCopy(self):
        v1 = resource_variable_ops.ResourceVariable(1.)
        v2 = resource_variable_ops.ResourceVariable(1.)
        v3 = resource_variable_ops.ResourceVariable(1.)

        l1 = data_structures.List([v1, v2])
        l2 = l1.copy()
        l2.append(v3)
        self.assertEqual(list(l1), [v1, v2])
        self.assertEqual(list(l2), [v1, v2, v3])
示例#5
0
    def testSlicing(self):
        v1 = resource_variable_ops.ResourceVariable(1.)
        v2 = resource_variable_ops.ResourceVariable(1.)
        v3 = resource_variable_ops.ResourceVariable(1.)
        v4 = resource_variable_ops.ResourceVariable(1.)

        l = data_structures.List([v1, v2, v3, v4])
        self.assertEqual(l[1:], [v2, v3, v4])
        self.assertEqual(l[1:-1], [v2, v3])
        self.assertEqual(l[:-1], [v1, v2, v3])
示例#6
0
 def testNonLayerVariables(self):
     v = resource_variable_ops.ResourceVariable([1.])
     l = data_structures.List([v])
     self.assertTrue(l.trainable)
     self.assertEqual([], l.layers)
     self.assertEqual([v], l.variables)
     self.assertEqual([v], l.trainable_weights)
     self.assertEqual([], l.non_trainable_variables)
     l.trainable = False
     self.assertEqual([v], l.variables)
     self.assertEqual([], l.trainable_variables)
     self.assertEqual([v], l.non_trainable_variables)
     l.trainable = True
     v2 = resource_variable_ops.ResourceVariable(1., trainable=False)
     l.append(v2)
     self.assertEqual([v, v2], l.weights)
     self.assertEqual([v], l.trainable_weights)
     self.assertEqual([v2], l.non_trainable_weights)
示例#7
0
 def testNoPop(self):
     with self.assertRaises(AttributeError):
         data_structures.List().pop()
示例#8
0
 def testCallNotImplemented(self):
     with self.assertRaisesRegex(TypeError, "not callable"):
         data_structures.List()(1.)  # pylint: disable=not-callable
示例#9
0
    def testNotTrackable(self):
        class NotTrackable(object):
            pass

        with self.assertRaises(ValueError):
            data_structures.List([NotTrackable()])
示例#10
0
 def testNonStringKeys(self):
     mapping = data_structures.Mapping()
     with self.assertRaises(TypeError):
         mapping[1] = data_structures.List()
示例#11
0
 def testRMul(self):
     v = resource_variable_ops.ResourceVariable(1.)
     l = data_structures.List([v, v, v])
     self.assertEqual(list(2 * l), [v, v, v] * 2)
示例#12
0
 def testIMul(self):
     v = resource_variable_ops.ResourceVariable(1.)
     l = data_structures.List([v])
     l *= 2
     self.assertEqual(list(l), [v] * 2)
示例#13
0
 def testIMul_zero(self):
     l = data_structures.List([])
     with self.assertRaisesRegex(ValueError, "List only supports append"):
         l *= 0
示例#14
0
 def testHash(self):
     has_sequences = {data_structures.List(), data_structures.List()}
     self.assertEqual(2, len(has_sequences))
     self.assertNotIn(data_structures.List(), has_sequences)