示例#1
0
 def testHashing(self):
   has_sequences = set([data_structures.List(),
                        data_structures.List()])
   self.assertEqual(2, len(has_sequences))
   self.assertNotIn(data_structures.List(), has_sequences)
   with self.assertRaises(TypeError):
     has_sequences.add(data_structures._ListWrapper([]))
 def testNesting(self):
     with context.graph_mode():
         inner = data_structures.List()
         outer = data_structures.List([inner])
         inner.append(non_keras_core.Dense(1))
         inner[0](array_ops.ones([2, 3]))
         self.assertEqual(2, len(outer.variables))
         self.assertIsInstance(outer.variables[0],
                               resource_variable_ops.ResourceVariable)
 def __init__(self):
     super(HasMapping, self).__init__()
     self.layer_dict = data_structures.Mapping(output=core.Dense(7))
     self.layer_dict["norm"] = data_structures.List()
     self.layer_dict["dense"] = data_structures.List()
     self.layer_dict["dense"].extend([
         core.Dense(5),
         core.Dense(6, kernel_regularizer=math_ops.reduce_sum)
     ])
     self.layer_dict["norm"].append(normalization.BatchNormalization())
     self.layer_dict["norm"].append(normalization.BatchNormalization())
示例#4
0
 def testNoOverwrite(self):
   mapping = data_structures.Mapping()
   original = data_structures.List()
   mapping["a"] = original
   with self.assertRaises(ValueError):
     mapping["a"] = data_structures.List()
   self.assertIs(original, mapping["a"])
   with self.assertRaises(AttributeError):
     del mapping["a"]
   mapping.update(b=data_structures.Mapping())
   with self.assertRaises(ValueError):
     mapping.update({"b": data_structures.Mapping()})
 def testListWrapperBasic(self):
     # _ListWrapper, unlike List, compares like the built-in list type (since it
     # is used to automatically replace lists).
     a = tracking.AutoCheckpointable()
     b = tracking.AutoCheckpointable()
     self.assertEqual([a, a], [a, a])
     self.assertEqual(data_structures._ListWrapper([a, a]),
                      data_structures._ListWrapper([a, a]))
     self.assertEqual([a, a], data_structures._ListWrapper([a, a]))
     self.assertEqual(data_structures._ListWrapper([a, a]), [a, a])
     self.assertNotEqual([a, a], [b, a])
     self.assertNotEqual(data_structures._ListWrapper([a, a]),
                         data_structures._ListWrapper([b, a]))
     self.assertNotEqual([a, a], data_structures._ListWrapper([b, a]))
     self.assertLess([a], [a, b])
     self.assertLess(data_structures._ListWrapper([a]),
                     data_structures._ListWrapper([a, b]))
     self.assertLessEqual([a], [a, b])
     self.assertLessEqual(data_structures._ListWrapper([a]),
                          data_structures._ListWrapper([a, b]))
     self.assertGreater([a, b], [a])
     self.assertGreater(data_structures._ListWrapper([a, b]),
                        data_structures._ListWrapper([a]))
     self.assertGreaterEqual([a, b], [a])
     self.assertGreaterEqual(data_structures._ListWrapper([a, b]),
                             data_structures._ListWrapper([a]))
     self.assertEqual([a], data_structures._ListWrapper([a]))
     self.assertEqual([a], list(data_structures.List([a])))
     self.assertEqual([a, a], data_structures._ListWrapper([a]) + [a])
     self.assertEqual([a, a], [a] + data_structures._ListWrapper([a]))
     self.assertIsInstance(data_structures._ListWrapper([a]), list)
示例#6
0
 def testDictWrapperBadKeys(self):
   a = tracking.Checkpointable()
   a.d = {}
   a.d[1] = data_structures.List()
   model = training.Model()
   model.sub = a
   save_path = os.path.join(self.get_temp_dir(), "ckpt")
   with self.assertRaisesRegexp(ValueError, "non-string key"):
     model.save_weights(save_path)
    def testSlicing(self):
        v1 = resource_variable_ops.ResourceVariable(1.)
        v2 = resource_variable_ops.ResourceVariable(1.)
        v3 = resource_variable_ops.ResourceVariable(1.)
        v4 = resource_variable_ops.ResourceVariable(1.)

        l = data_structures.List([v1, v2, v3, v4])
        self.assertEqual(l[1:], [v2, v3, v4])
        self.assertEqual(l[1:-1], [v2, v3])
        self.assertEqual(l[:-1], [v1, v2, v3])
示例#8
0
 def __init__(self):
   super(HasList, self).__init__()
   self.layer_list = data_structures.List([core.Dense(3)])
   self.layer_list.append(core.Dense(4))
   self.layer_list.extend(
       [core.Dense(5),
        core.Dense(6, kernel_regularizer=math_ops.reduce_sum)])
   self.layer_list += [
       core.Dense(7, bias_regularizer=math_ops.reduce_sum),
       core.Dense(8)
   ]
   self.layer_list += (
       data_structures.List([core.Dense(9)]) + data_structures.List(
           [core.Dense(10)]))
   self.layer_list.extend(
       data_structures.List(
           list(sequence=[core.Dense(11)]) + [core.Dense(12)]))
   self.layers_with_updates = data_structures.List(
       sequence=(normalization.BatchNormalization(),))
    def testCopy(self):
        v1 = resource_variable_ops.ResourceVariable(1.)
        v2 = resource_variable_ops.ResourceVariable(1.)
        v3 = resource_variable_ops.ResourceVariable(1.)

        l1 = data_structures.List([v1, v2])
        l2 = l1.copy()
        l2.append(v3)
        self.assertEqual(list(l1), [v1, v2])
        self.assertEqual(list(l2), [v1, v2, v3])
示例#10
0
 def testNonLayerVariables(self):
   v = resource_variable_ops.ResourceVariable([1.])
   l = data_structures.List([v])
   self.assertTrue(l.trainable)
   self.assertEqual([], l.layers)
   self.assertEqual([v], l.variables)
   self.assertEqual([v], l.trainable_weights)
   self.assertEqual([], l.non_trainable_variables)
   l.trainable = False
   self.assertEqual([v], l.variables)
   self.assertEqual([], l.trainable_variables)
   self.assertEqual([v], l.non_trainable_variables)
   l.trainable = True
   v2 = resource_variable_ops.ResourceVariable(1., trainable=False)
   l.append(v2)
   self.assertEqual([v, v2], l.weights)
   self.assertEqual([v], l.trainable_weights)
   self.assertEqual([v2], l.non_trainable_weights)
示例#11
0
  def testNotCheckpointable(self):
    class NotCheckpointable(object):
      pass

    with self.assertRaises(ValueError):
      data_structures.List([NotCheckpointable()])
 def testHash(self):
     has_sequences = set([data_structures.List(), data_structures.List()])
     self.assertEqual(2, len(has_sequences))
     self.assertNotIn(data_structures.List(), has_sequences)
示例#13
0
 def testNonStringKeys(self):
   mapping = data_structures.Mapping()
   with self.assertRaises(TypeError):
     mapping[1] = data_structures.List()
 def testIMul_zero(self):
     l = data_structures.List([])
     with self.assertRaisesRegexp(ValueError, "List only supports append"):
         l *= 0
 def testIMul(self):
     v = resource_variable_ops.ResourceVariable(1.)
     l = data_structures.List([v])
     l *= 2
     self.assertEqual(list(l), [v] * 2)
示例#16
0
 def testNoPop(self):
   with self.assertRaises(AttributeError):
     data_structures.List().pop()
示例#17
0
 def testCallNotImplemented(self):
   with self.assertRaisesRegexp(TypeError, "not callable"):
     data_structures.List()(1.)
 def testRMul(self):
     v = resource_variable_ops.ResourceVariable(1.)
     l = data_structures.List([v, v, v])
     self.assertEqual(list(2 * l), [v, v, v] * 2)