def testListAddOrder(self):
   self.assertEqual([1., 2.],
                    data_structures._ListWrapper([1.])
                    + data_structures._ListWrapper([2.]))
   self.assertEqual([1., 2.],
                    data_structures._ListWrapper([1.])
                    + [2.])
   self.assertEqual([1., 2.],
                    [1.]
                    + data_structures._ListWrapper([2.]))
 def testSetSlice_cannotSaveIfTrackableModified(self):
   v1 = resource_variable_ops.ResourceVariable(1.)
   v2 = resource_variable_ops.ResourceVariable(1.)
   l = data_structures._ListWrapper([1, 2, v1, v2])
   l[:] = 2, 8, 9, v2
   self.assertEqual(l, [2, 8, 9, v2])
   self.assertUnableToSave(l, "Unable to save .*__setslice__")
  def testFunctionCaching(self):
    @def_function.function
    def f(list_input):
      return list_input[0] + constant_op.constant(1.)

    first_trace = f.get_concrete_function([constant_op.constant(2.)])
    second_trace = f.get_concrete_function(
        data_structures._ListWrapper([constant_op.constant(3.)]))
    self.assertIs(first_trace, second_trace)
 def testSort(self):
   l = data_structures._ListWrapper([1, 2, 3, 4])
   l.sort()
   self.assertEqual(l, [1, 2, 3, 4])
   # Regardless of being a no-op for the input list, we still refuse to save.
   # This is intentional since otherwise we would end up with a hard to debug
   # case for users (e.g. sometimes sort on a ListWrapper is trackable and
   # other times it is not).
   self.assertUnableToSave(l, "Unable to save .*sort")
 def testIMulPositive(self):
   v = variables.Variable(1.)
   l = data_structures._ListWrapper([1, 2, 3, 4, v])
   self.assertEqual([("4", v)], l._checkpoint_dependencies)
   root = util.Checkpoint(l=l)
   prefix = os.path.join(self.get_temp_dir(), "ckpt")
   path = root.save(prefix)
   v.assign(5.)
   l *= 2
   self.assertEqual(l, [1, 2, 3, 4, v, 1, 2, 3, 4, v])
   self.assertEqual([("4", v), ("9", v)], l._checkpoint_dependencies)
   root.restore(path)
   self.assertAllClose(1., v.numpy())
示例#6
0
 def testDelItem(self):
     l = data_structures._ListWrapper([1, 2, 3, 4])
     del l[0]
     self.assertEqual(l, [2, 3, 4])
     self.assertUnableToSave(l, "Unable to save .*__delitem__")
 def testLayerCollectionWithExternalMutation(self):
   l = []
   l_wrapper = data_structures._ListWrapper(l)
   layer = core.Dense(1)
   l.append(layer)
   self.assertEqual([layer], l_wrapper.layers)
示例#8
0
 def testSameStructure(self):
     l = [1]
     nest.assert_same_structure(l,
                                data_structures._ListWrapper(copy.copy(l)))
示例#9
0
 def testListWrapperBasic(self):
     # _ListWrapper, unlike List, compares like the built-in list type (since it
     # is used to automatically replace lists).
     a = tracking.AutoTrackable()
     b = tracking.AutoTrackable()
     self.assertEqual([a, a], [a, a])
     self.assertEqual(data_structures._ListWrapper([a, a]),
                      data_structures._ListWrapper([a, a]))
     self.assertEqual([a, a], data_structures._ListWrapper([a, a]))
     self.assertEqual(data_structures._ListWrapper([a, a]), [a, a])
     self.assertNotEqual([a, a], [b, a])
     self.assertNotEqual(data_structures._ListWrapper([a, a]),
                         data_structures._ListWrapper([b, a]))
     self.assertNotEqual([a, a], data_structures._ListWrapper([b, a]))
     self.assertLess([a], [a, b])
     self.assertLess(data_structures._ListWrapper([a]),
                     data_structures._ListWrapper([a, b]))
     self.assertLessEqual([a], [a, b])
     self.assertLessEqual(data_structures._ListWrapper([a]),
                          data_structures._ListWrapper([a, b]))
     self.assertGreater([a, b], [a])
     self.assertGreater(data_structures._ListWrapper([a, b]),
                        data_structures._ListWrapper([a]))
     self.assertGreaterEqual([a, b], [a])
     self.assertGreaterEqual(data_structures._ListWrapper([a, b]),
                             data_structures._ListWrapper([a]))
     self.assertEqual([a], data_structures._ListWrapper([a]))
     self.assertEqual([a], list(data_structures.List([a])))
     self.assertEqual([a, a], data_structures._ListWrapper([a]) + [a])
     self.assertEqual([a, a], [a] + data_structures._ListWrapper([a]))
     self.assertIsInstance(data_structures._ListWrapper([a]), list)
示例#10
0
 def testWrapperChangesList(self):
     l = []
     l_wrapper = data_structures._ListWrapper(l)
     l_wrapper.append(1)
     self.assertEqual([1], l)
示例#11
0
 def testSetSlice_truncate(self):
     l = data_structures._ListWrapper([1, 2, 3, 4])
     l[:] = []
     self.assertEqual(l, [])
示例#12
0
 def testIMulNegative(self):
     l = data_structures._ListWrapper([1, 2, 3, 4])
     l *= -1
     self.assertEqual(l, [1, 2, 3, 4] * -1)
     self.assertUnableToSave(l, "Unable to save")
 def testSetSlice_truncate(self):
   l = data_structures._ListWrapper([1, 2, 3, 4])
   l[:] = []
   self.assertEqual(l, [])
 def testSetSlice_extend(self):
   l = data_structures._ListWrapper([1, 2, 3, 4])
   l[2:] = 1, 2, 3, 4
   self.assertEqual(l, [1, 2, 1, 2, 3, 4])
 def testSetSlice_canSaveForNonTrackableItems(self):
   l = data_structures._ListWrapper([1, 2, 3, 4])
   l[:] = 2, 8, 9, 0
   self.assertEqual(l, [2, 8, 9, 0])
   l._maybe_initialize_trackable()  # pylint: disable=protected-access
   self.assertEqual(len(l._checkpoint_dependencies), 0)  # pylint: disable=protected-access
 def testWrapperChangesList(self):
   l = []
   l_wrapper = data_structures._ListWrapper(l)
   l_wrapper.append(1)
   self.assertEqual([1], l)
 def testDelSlice(self):
   l = data_structures._ListWrapper([1, 2, 3, 4])
   del l[2:3]
   self.assertEqual(l, [1, 2, 4])
   self.assertUnableToSave(l, "Unable to save .*__delslice__")
 def testDelItem(self):
   l = data_structures._ListWrapper([1, 2, 3, 4])
   del l[0]
   self.assertEqual(l, [2, 3, 4])
   self.assertUnableToSave(l, "Unable to save .*__delitem__")
 def testNotHashable(self):
   with self.assertRaises(TypeError):
     hash(data_structures._ListWrapper())
示例#20
0
 def testDelSlice(self):
     l = data_structures._ListWrapper([1, 2, 3, 4])
     del l[2:3]
     self.assertEqual(l, [1, 2, 4])
     self.assertUnableToSave(l, "Unable to save .*__delslice__")
 def testIMulNegative(self):
   l = data_structures._ListWrapper([1, 2, 3, 4])
   l *= -1
   self.assertEqual(l, [1, 2, 3, 4] * -1)
   self.assertUnableToSave(l, "Unable to save")
示例#22
0
 def testSetSlice_canSaveForNonTrackableItems(self):
     l = data_structures._ListWrapper([1, 2, 3, 4])
     l[:] = 2, 8, 9, 0
     self.assertEqual(l, [2, 8, 9, 0])
     l._maybe_initialize_trackable()  # pylint: disable=protected-access
     self.assertEqual(len(l._checkpoint_dependencies), 0)  # pylint: disable=protected-access
示例#23
0
 def testLayerCollectionWithExternalMutation(self):
     l = []
     l_wrapper = data_structures._ListWrapper(l)
     layer = core.Dense(1)
     l.append(layer)
     self.assertEqual([layer], l_wrapper.layers)
示例#24
0
 def testSetSlice_extend(self):
     l = data_structures._ListWrapper([1, 2, 3, 4])
     l[2:] = 1, 2, 3, 4
     self.assertEqual(l, [1, 2, 1, 2, 3, 4])
 def testAcceptsNonTrackableContent(self):
   l = data_structures._ListWrapper([1, 2, 3])
   self.assertEqual(l, [1, 2, 3])
示例#26
0
 def testListAddOrder(self):
     self.assertEqual([1., 2.],
                      data_structures._ListWrapper([1.]) +
                      data_structures._ListWrapper([2.]))
     self.assertEqual([1., 2.], data_structures._ListWrapper([1.]) + [2.])
     self.assertEqual([1., 2.], [1.] + data_structures._ListWrapper([2.]))
 def testListWrapperBasic(self):
   # _ListWrapper, unlike List, compares like the built-in list type (since it
   # is used to automatically replace lists).
   a = tracking.AutoTrackable()
   b = tracking.AutoTrackable()
   self.assertEqual([a, a],
                    [a, a])
   self.assertEqual(data_structures._ListWrapper([a, a]),
                    data_structures._ListWrapper([a, a]))
   self.assertEqual([a, a],
                    data_structures._ListWrapper([a, a]))
   self.assertEqual(data_structures._ListWrapper([a, a]),
                    [a, a])
   self.assertNotEqual([a, a],
                       [b, a])
   self.assertNotEqual(data_structures._ListWrapper([a, a]),
                       data_structures._ListWrapper([b, a]))
   self.assertNotEqual([a, a],
                       data_structures._ListWrapper([b, a]))
   self.assertLess([a], [a, b])
   self.assertLess(data_structures._ListWrapper([a]),
                   data_structures._ListWrapper([a, b]))
   self.assertLessEqual([a], [a, b])
   self.assertLessEqual(data_structures._ListWrapper([a]),
                        data_structures._ListWrapper([a, b]))
   self.assertGreater([a, b], [a])
   self.assertGreater(data_structures._ListWrapper([a, b]),
                      data_structures._ListWrapper([a]))
   self.assertGreaterEqual([a, b], [a])
   self.assertGreaterEqual(data_structures._ListWrapper([a, b]),
                           data_structures._ListWrapper([a]))
   self.assertEqual([a], data_structures._ListWrapper([a]))
   self.assertEqual([a], list(data_structures.List([a])))
   self.assertEqual([a, a], data_structures._ListWrapper([a]) + [a])
   self.assertEqual([a, a], [a] + data_structures._ListWrapper([a]))
   self.assertIsInstance(data_structures._ListWrapper([a]), list)
示例#28
0
 def testAcceptsNonTrackableContent(self):
     l = data_structures._ListWrapper([1, 2, 3])
     self.assertEqual(l, [1, 2, 3])
示例#29
0
 def testPickle(self):
     original = data_structures._ListWrapper([1, 2])
     serialized = pickle.dumps(original)
     del original
     deserialized = pickle.loads(serialized)
     self.assertEqual([1, 2], deserialized)
示例#30
0
 def testListChangesWrapper(self):
     l = []
     l_wrapper = data_structures._ListWrapper(l)
     l.append(1)
     self.assertEqual([1], l_wrapper)
 def testListChangesWrapper(self):
   l = []
   l_wrapper = data_structures._ListWrapper(l)
   l.append(1)
   self.assertEqual([1], l_wrapper)
示例#32
0
 def testNotHashable(self):
     with self.assertRaises(TypeError):
         hash(data_structures._ListWrapper())
 def testSameStructure(self):
   l = [1]
   nest.assert_same_structure(l, data_structures._ListWrapper(copy.copy(l)))