Пример #1
0
 def testListAddOrder(self):
   self.assertEqual([1., 2.],
                    data_structures._ListWrapper([1.])
                    + data_structures._ListWrapper([2.]))
   self.assertEqual([1., 2.],
                    data_structures._ListWrapper([1.])
                    + [2.])
   self.assertEqual([1., 2.],
                    [1.]
                    + data_structures._ListWrapper([2.]))
Пример #2
0
 def testSetSlice_cannotSaveIfTrackableModified(self):
   v1 = resource_variable_ops.ResourceVariable(1.)
   v2 = resource_variable_ops.ResourceVariable(1.)
   l = data_structures._ListWrapper([1, 2, v1, v2])
   l[:] = 2, 8, 9, v2
   self.assertEqual(l, [2, 8, 9, v2])
   self.assertUnableToSave(l, "Unable to save .*__setslice__")
Пример #3
0
  def testFunctionCaching(self):
    @def_function.function
    def f(list_input):
      return list_input[0] + constant_op.constant(1.)

    first_trace = f.get_concrete_function([constant_op.constant(2.)])
    second_trace = f.get_concrete_function(
        data_structures._ListWrapper([constant_op.constant(3.)]))
    self.assertIs(first_trace, second_trace)
Пример #4
0
 def testSort(self):
   l = data_structures._ListWrapper([1, 2, 3, 4])
   l.sort()
   self.assertEqual(l, [1, 2, 3, 4])
   # Regardless of being a no-op for the input list, we still refuse to save.
   # This is intentional since otherwise we would end up with a hard to debug
   # case for users (e.g. sometimes sort on a ListWrapper is trackable and
   # other times it is not).
   self.assertUnableToSave(l, "Unable to save .*sort")
Пример #5
0
 def testIMulPositive(self):
   v = variables.Variable(1.)
   l = data_structures._ListWrapper([1, 2, 3, 4, v])
   self.assertEqual([("4", v)], l._checkpoint_dependencies)
   root = util.Checkpoint(l=l)
   prefix = os.path.join(self.get_temp_dir(), "ckpt")
   path = root.save(prefix)
   v.assign(5.)
   l *= 2
   self.assertEqual(l, [1, 2, 3, 4, v, 1, 2, 3, 4, v])
   self.assertEqual([("4", v), ("9", v)], l._checkpoint_dependencies)
   root.restore(path)
   self.assertAllClose(1., v.numpy())
Пример #6
0
 def testDelItem(self):
     l = data_structures._ListWrapper([1, 2, 3, 4])
     del l[0]
     self.assertEqual(l, [2, 3, 4])
     self.assertUnableToSave(l, "Unable to save .*__delitem__")
Пример #7
0
 def testLayerCollectionWithExternalMutation(self):
   l = []
   l_wrapper = data_structures._ListWrapper(l)
   layer = core.Dense(1)
   l.append(layer)
   self.assertEqual([layer], l_wrapper.layers)
Пример #8
0
 def testSameStructure(self):
     l = [1]
     nest.assert_same_structure(l,
                                data_structures._ListWrapper(copy.copy(l)))
Пример #9
0
 def testListWrapperBasic(self):
     # _ListWrapper, unlike List, compares like the built-in list type (since it
     # is used to automatically replace lists).
     a = tracking.AutoTrackable()
     b = tracking.AutoTrackable()
     self.assertEqual([a, a], [a, a])
     self.assertEqual(data_structures._ListWrapper([a, a]),
                      data_structures._ListWrapper([a, a]))
     self.assertEqual([a, a], data_structures._ListWrapper([a, a]))
     self.assertEqual(data_structures._ListWrapper([a, a]), [a, a])
     self.assertNotEqual([a, a], [b, a])
     self.assertNotEqual(data_structures._ListWrapper([a, a]),
                         data_structures._ListWrapper([b, a]))
     self.assertNotEqual([a, a], data_structures._ListWrapper([b, a]))
     self.assertLess([a], [a, b])
     self.assertLess(data_structures._ListWrapper([a]),
                     data_structures._ListWrapper([a, b]))
     self.assertLessEqual([a], [a, b])
     self.assertLessEqual(data_structures._ListWrapper([a]),
                          data_structures._ListWrapper([a, b]))
     self.assertGreater([a, b], [a])
     self.assertGreater(data_structures._ListWrapper([a, b]),
                        data_structures._ListWrapper([a]))
     self.assertGreaterEqual([a, b], [a])
     self.assertGreaterEqual(data_structures._ListWrapper([a, b]),
                             data_structures._ListWrapper([a]))
     self.assertEqual([a], data_structures._ListWrapper([a]))
     self.assertEqual([a], list(data_structures.List([a])))
     self.assertEqual([a, a], data_structures._ListWrapper([a]) + [a])
     self.assertEqual([a, a], [a] + data_structures._ListWrapper([a]))
     self.assertIsInstance(data_structures._ListWrapper([a]), list)
Пример #10
0
 def testWrapperChangesList(self):
     l = []
     l_wrapper = data_structures._ListWrapper(l)
     l_wrapper.append(1)
     self.assertEqual([1], l)
Пример #11
0
 def testSetSlice_truncate(self):
     l = data_structures._ListWrapper([1, 2, 3, 4])
     l[:] = []
     self.assertEqual(l, [])
Пример #12
0
 def testIMulNegative(self):
     l = data_structures._ListWrapper([1, 2, 3, 4])
     l *= -1
     self.assertEqual(l, [1, 2, 3, 4] * -1)
     self.assertUnableToSave(l, "Unable to save")
Пример #13
0
 def testSetSlice_truncate(self):
   l = data_structures._ListWrapper([1, 2, 3, 4])
   l[:] = []
   self.assertEqual(l, [])
Пример #14
0
 def testSetSlice_extend(self):
   l = data_structures._ListWrapper([1, 2, 3, 4])
   l[2:] = 1, 2, 3, 4
   self.assertEqual(l, [1, 2, 1, 2, 3, 4])
Пример #15
0
 def testSetSlice_canSaveForNonTrackableItems(self):
   l = data_structures._ListWrapper([1, 2, 3, 4])
   l[:] = 2, 8, 9, 0
   self.assertEqual(l, [2, 8, 9, 0])
   l._maybe_initialize_trackable()  # pylint: disable=protected-access
   self.assertEqual(len(l._checkpoint_dependencies), 0)  # pylint: disable=protected-access
Пример #16
0
 def testWrapperChangesList(self):
   l = []
   l_wrapper = data_structures._ListWrapper(l)
   l_wrapper.append(1)
   self.assertEqual([1], l)
Пример #17
0
 def testDelSlice(self):
   l = data_structures._ListWrapper([1, 2, 3, 4])
   del l[2:3]
   self.assertEqual(l, [1, 2, 4])
   self.assertUnableToSave(l, "Unable to save .*__delslice__")
Пример #18
0
 def testDelItem(self):
   l = data_structures._ListWrapper([1, 2, 3, 4])
   del l[0]
   self.assertEqual(l, [2, 3, 4])
   self.assertUnableToSave(l, "Unable to save .*__delitem__")
Пример #19
0
 def testNotHashable(self):
   with self.assertRaises(TypeError):
     hash(data_structures._ListWrapper())
Пример #20
0
 def testDelSlice(self):
     l = data_structures._ListWrapper([1, 2, 3, 4])
     del l[2:3]
     self.assertEqual(l, [1, 2, 4])
     self.assertUnableToSave(l, "Unable to save .*__delslice__")
Пример #21
0
 def testIMulNegative(self):
   l = data_structures._ListWrapper([1, 2, 3, 4])
   l *= -1
   self.assertEqual(l, [1, 2, 3, 4] * -1)
   self.assertUnableToSave(l, "Unable to save")
Пример #22
0
 def testSetSlice_canSaveForNonTrackableItems(self):
     l = data_structures._ListWrapper([1, 2, 3, 4])
     l[:] = 2, 8, 9, 0
     self.assertEqual(l, [2, 8, 9, 0])
     l._maybe_initialize_trackable()  # pylint: disable=protected-access
     self.assertEqual(len(l._checkpoint_dependencies), 0)  # pylint: disable=protected-access
Пример #23
0
 def testLayerCollectionWithExternalMutation(self):
     l = []
     l_wrapper = data_structures._ListWrapper(l)
     layer = core.Dense(1)
     l.append(layer)
     self.assertEqual([layer], l_wrapper.layers)
Пример #24
0
 def testSetSlice_extend(self):
     l = data_structures._ListWrapper([1, 2, 3, 4])
     l[2:] = 1, 2, 3, 4
     self.assertEqual(l, [1, 2, 1, 2, 3, 4])
Пример #25
0
 def testAcceptsNonTrackableContent(self):
   l = data_structures._ListWrapper([1, 2, 3])
   self.assertEqual(l, [1, 2, 3])
Пример #26
0
 def testListAddOrder(self):
     self.assertEqual([1., 2.],
                      data_structures._ListWrapper([1.]) +
                      data_structures._ListWrapper([2.]))
     self.assertEqual([1., 2.], data_structures._ListWrapper([1.]) + [2.])
     self.assertEqual([1., 2.], [1.] + data_structures._ListWrapper([2.]))
Пример #27
0
 def testListWrapperBasic(self):
   # _ListWrapper, unlike List, compares like the built-in list type (since it
   # is used to automatically replace lists).
   a = tracking.AutoTrackable()
   b = tracking.AutoTrackable()
   self.assertEqual([a, a],
                    [a, a])
   self.assertEqual(data_structures._ListWrapper([a, a]),
                    data_structures._ListWrapper([a, a]))
   self.assertEqual([a, a],
                    data_structures._ListWrapper([a, a]))
   self.assertEqual(data_structures._ListWrapper([a, a]),
                    [a, a])
   self.assertNotEqual([a, a],
                       [b, a])
   self.assertNotEqual(data_structures._ListWrapper([a, a]),
                       data_structures._ListWrapper([b, a]))
   self.assertNotEqual([a, a],
                       data_structures._ListWrapper([b, a]))
   self.assertLess([a], [a, b])
   self.assertLess(data_structures._ListWrapper([a]),
                   data_structures._ListWrapper([a, b]))
   self.assertLessEqual([a], [a, b])
   self.assertLessEqual(data_structures._ListWrapper([a]),
                        data_structures._ListWrapper([a, b]))
   self.assertGreater([a, b], [a])
   self.assertGreater(data_structures._ListWrapper([a, b]),
                      data_structures._ListWrapper([a]))
   self.assertGreaterEqual([a, b], [a])
   self.assertGreaterEqual(data_structures._ListWrapper([a, b]),
                           data_structures._ListWrapper([a]))
   self.assertEqual([a], data_structures._ListWrapper([a]))
   self.assertEqual([a], list(data_structures.List([a])))
   self.assertEqual([a, a], data_structures._ListWrapper([a]) + [a])
   self.assertEqual([a, a], [a] + data_structures._ListWrapper([a]))
   self.assertIsInstance(data_structures._ListWrapper([a]), list)
Пример #28
0
 def testAcceptsNonTrackableContent(self):
     l = data_structures._ListWrapper([1, 2, 3])
     self.assertEqual(l, [1, 2, 3])
Пример #29
0
 def testPickle(self):
     original = data_structures._ListWrapper([1, 2])
     serialized = pickle.dumps(original)
     del original
     deserialized = pickle.loads(serialized)
     self.assertEqual([1, 2], deserialized)
Пример #30
0
 def testListChangesWrapper(self):
     l = []
     l_wrapper = data_structures._ListWrapper(l)
     l.append(1)
     self.assertEqual([1], l_wrapper)
Пример #31
0
 def testListChangesWrapper(self):
   l = []
   l_wrapper = data_structures._ListWrapper(l)
   l.append(1)
   self.assertEqual([1], l_wrapper)
Пример #32
0
 def testNotHashable(self):
     with self.assertRaises(TypeError):
         hash(data_structures._ListWrapper())
Пример #33
0
 def testSameStructure(self):
   l = [1]
   nest.assert_same_structure(l, data_structures._ListWrapper(copy.copy(l)))