예제 #1
0
 def testSetSlice_cannotSaveIfTrackableModified(self):
     v1 = resource_variable_ops.ResourceVariable(1.)
     v2 = resource_variable_ops.ResourceVariable(1.)
     l = data_structures.ListWrapper([1, 2, v1, v2])
     l[:] = 2, 8, 9, v2
     self.assertEqual(l, [2, 8, 9, v2])
     self.assertUnableToSave(l, "Unable to save .*__setslice__")
예제 #2
0
    def testFunctionCaching(self):
        @def_function.function
        def f(list_input):
            return list_input[0] + constant_op.constant(1.)

        first_trace = f.get_concrete_function([constant_op.constant(2.)])
        second_trace = f.get_concrete_function(
            data_structures.ListWrapper([constant_op.constant(3.)]))
        self.assertIs(first_trace, second_trace)
예제 #3
0
 def testSort(self):
     l = data_structures.ListWrapper([[1], [2], [3], [4]])
     l.sort()
     self.assertAllEqual(l, [[1], [2], [3], [4]])
     # Regardless of being a no-op for the input list, we still refuse to save.
     # This is intentional since otherwise we would end up with a hard to debug
     # case for users (e.g. sometimes sort on a ListWrapper is trackable and
     # other times it is not).
     self.assertUnableToSave(l, "Unable to save .*sort")
예제 #4
0
 def testIMulPositive(self):
     v = variables.Variable(1.)
     l = data_structures.ListWrapper([1, 2, 3, 4, v])
     self.assertDictEqual({"4": v}, l._trackable_children())
     root = util.Checkpoint(l=l)
     prefix = os.path.join(self.get_temp_dir(), "ckpt")
     path = root.save(prefix)
     v.assign(5.)
     l *= 2
     self.assertEqual(l, [1, 2, 3, 4, v, 1, 2, 3, 4, v])
     self.assertDictEqual({"4": v, "9": v}, l._trackable_children())
     root.restore(path)
     self.assertAllClose(1., v.numpy())
예제 #5
0
 def testIMulNegative(self):
     l = data_structures.ListWrapper([1, 2, 3, [4]])
     l *= -1
     self.assertEqual(l, [1, 2, 3, [4]] * -1)
     self.assertUnableToSave(l, "Unable to save")
예제 #6
0
 def testSetSlice_extend(self):
     l = data_structures.ListWrapper([1, 2, 3, 4])
     l[2:] = 1, 2, 3, 4
     self.assertEqual(l, [1, 2, 1, 2, 3, 4])
예제 #7
0
 def testSetSlice_truncate(self):
     l = data_structures.ListWrapper([1, 2, 3, 4])
     l[:] = []
     self.assertEqual(l, [])
예제 #8
0
 def testSetSlice_canSaveForNonTrackableItems(self):
     l = data_structures.ListWrapper([1, 2, 3, 4])
     l[:] = 2, 8, 9, 0
     self.assertEqual(l, [2, 8, 9, 0])
     l._maybe_initialize_trackable()  # pylint: disable=protected-access
     self.assertEqual(len(l._trackable_children()), 0)  # pylint: disable=protected-access
예제 #9
0
 def testSameStructure(self):
     l = [1]
     nest.assert_same_structure(l,
                                data_structures.ListWrapper(copy.copy(l)))
예제 #10
0
 def testDelItem(self):
     l = data_structures.ListWrapper([1, 2, 3, [4]])
     del l[0]
     self.assertEqual(l, [2, 3, [4]])
     self.assertUnableToSave(l, "Unable to save .*__delitem__")
예제 #11
0
 def testNotHashable(self):
     with self.assertRaises(TypeError):
         hash(data_structures.ListWrapper())  # pylint: disable=no-value-for-parameter
예제 #12
0
 def testListChangesWrapper(self):
     l = []
     l_wrapper = data_structures.ListWrapper(l)
     l.append(1)
     self.assertEqual([1], l_wrapper)
예제 #13
0
 def testWrapperChangesList(self):
     l = []
     l_wrapper = data_structures.ListWrapper(l)
     l_wrapper.append(1)
     self.assertEqual([1], l)
예제 #14
0
 def testAcceptsNonTrackableContent(self):
     l = data_structures.ListWrapper([1, 2, 3])
     self.assertEqual(l, [1, 2, 3])
예제 #15
0
 def testListWrapperBasic(self):
     # ListWrapper, unlike List, compares like the built-in list type (since it
     # is used to automatically replace lists).
     a = tracking.AutoTrackable()
     b = tracking.AutoTrackable()
     self.assertEqual([a, a], [a, a])
     self.assertEqual(data_structures.ListWrapper([a, a]),
                      data_structures.ListWrapper([a, a]))
     self.assertEqual([a, a], data_structures.ListWrapper([a, a]))
     self.assertEqual(data_structures.ListWrapper([a, a]), [a, a])
     self.assertNotEqual([a, a], [b, a])
     self.assertNotEqual(data_structures.ListWrapper([a, a]),
                         data_structures.ListWrapper([b, a]))
     self.assertNotEqual([a, a], data_structures.ListWrapper([b, a]))
     self.assertLess([a], [a, b])
     self.assertLess(data_structures.ListWrapper([a]),
                     data_structures.ListWrapper([a, b]))
     self.assertLessEqual([a], [a, b])
     self.assertLessEqual(data_structures.ListWrapper([a]),
                          data_structures.ListWrapper([a, b]))
     self.assertGreater([a, b], [a])
     self.assertGreater(data_structures.ListWrapper([a, b]),
                        data_structures.ListWrapper([a]))
     self.assertGreaterEqual([a, b], [a])
     self.assertGreaterEqual(data_structures.ListWrapper([a, b]),
                             data_structures.ListWrapper([a]))
     self.assertEqual([a], data_structures.ListWrapper([a]))
     self.assertEqual([a], list(data_structures.List([a])))
     self.assertEqual([a, a], data_structures.ListWrapper([a]) + [a])
     self.assertEqual([a, a], [a] + data_structures.ListWrapper([a]))
     self.assertIsInstance(data_structures.ListWrapper([a]), list)
     self.assertEqual(
         tensor_shape.TensorShape([None, 2]).as_list(),
         (data_structures.ListWrapper([None]) +
          tensor_shape.TensorShape([2])).as_list())
예제 #16
0
 def testDelSlice(self):
     l = data_structures.ListWrapper([1, 2, 3, [4]])
     del l[2:3]
     self.assertEqual(l, [1, 2, [4]])
     self.assertUnableToSave(l, "Unable to save .*__delslice__")
예제 #17
0
 def testListAddOrder(self):
     self.assertEqual([1., 2.],
                      data_structures.ListWrapper([1.]) +
                      data_structures.ListWrapper([2.]))
     self.assertEqual([1., 2.], data_structures.ListWrapper([1.]) + [2.])
     self.assertEqual([1., 2.], [1.] + data_structures.ListWrapper([2.]))
예제 #18
0
 def testPickle(self):
     original = data_structures.ListWrapper([1, 2])
     serialized = pickle.dumps(original)
     del original
     deserialized = pickle.loads(serialized)
     self.assertEqual([1, 2], deserialized)