示例#1
0
 def testPythonMapImpl(self):
     t = data_structures._TupleWrapper(
         (1, data_structures._TupleWrapper((2, ))))
     self.assertEqual((4, (5, )),
                      nest.map_structure_up_to((None, (None, )),
                                               lambda x: x + 3,
                                               t,
                                               check_types=True))
     nest.assert_shallow_structure((None, None), t)
示例#2
0
 def testNonLayerVariables(self):
     v = resource_variable_ops.ResourceVariable([1.])
     l = data_structures._TupleWrapper((v, ))
     self.assertEqual([], l.layers)
     self.assertEqual([v], l.variables)
     self.assertEqual([v], l.trainable_weights)
     self.assertEqual([], l.non_trainable_variables)
示例#3
0
 def testDatasetMapNamed(self):
     nt_type = collections.namedtuple("A", ["x"])
     dataset = dataset_ops.Dataset.from_tensor_slices(
         constant_op.constant([1, 2, 3]))
     dataset = dataset.map(
         lambda x: data_structures._TupleWrapper(nt_type(x)))
     for index, element in enumerate(dataset):
         self.assertEqual((index + 1, ), self.evaluate(element))
示例#4
0
 def testFlatten(self):
     t = data_structures._TupleWrapper(
         (1, data_structures._TupleWrapper((2, ))))
     self.assertEqual([1, 2], nest.flatten(t))
     self.assertEqual(nest.flatten_with_tuple_paths((1, (2, ))),
                      nest.flatten_with_tuple_paths(t))
     self.assertEqual((3, (4, )), nest.pack_sequence_as(t, [3, 4]))
     nt_type = collections.namedtuple("nt", ["x", "y"])
     nt = nt_type(1., 2.)
     wrapped_nt = data_structures._TupleWrapper(nt)
     self.assertEqual(nest.flatten_with_tuple_paths(nt),
                      nest.flatten_with_tuple_paths(wrapped_nt))
     self.assertEqual((
         3,
         4,
     ), nest.pack_sequence_as(wrapped_nt, [3, 4]))
     self.assertEqual(3, nest.pack_sequence_as(wrapped_nt, [3, 4]).x)
示例#5
0
 def testIAdd(self):
     v = resource_variable_ops.ResourceVariable(1.)
     l = data_structures._TupleWrapper((v, ))
     original = l
     l += (1, )
     self.assertEqual(l, (v, 1))
     self.assertNotEqual(original, (v, 1))
     self.assertEqual(original, (v, ))
示例#6
0
    def testFunctionCaching(self):
        @def_function.function
        def f(tuple_input):
            return tuple_input[0] + constant_op.constant(1.)

        first_trace = f.get_concrete_function((constant_op.constant(2.), ))
        second_trace = f.get_concrete_function(
            data_structures._TupleWrapper((constant_op.constant(3.), )))
        self.assertIs(first_trace, second_trace)
示例#7
0
 def testIMul(self):
     # Note: tuple behavior differs from list behavior. Lists are mutated by
     # imul/iadd, tuples assign a new object to the left hand side of the
     # expression.
     v = resource_variable_ops.ResourceVariable(1.)
     l = data_structures._TupleWrapper((v, ))
     original = l
     l *= 2
     self.assertEqual(l, (v, ) * 2)
     self.assertNotEqual(original, (v, ) * 2)
示例#8
0
    def testSlicing(self):
        v1 = resource_variable_ops.ResourceVariable(1.)
        v2 = resource_variable_ops.ResourceVariable(1.)
        v3 = resource_variable_ops.ResourceVariable(1.)
        v4 = resource_variable_ops.ResourceVariable(1.)

        l = data_structures._TupleWrapper((v1, v2, v3, v4))
        self.assertEqual(l[1:], (v2, v3, v4))
        self.assertEqual(l[1:-1], (v2, v3))
        self.assertEqual(l[:-1], (v1, v2, v3))
示例#9
0
    def testCopy(self):
        v1 = resource_variable_ops.ResourceVariable(1.)
        v2 = resource_variable_ops.ResourceVariable(1.)

        l1 = data_structures._TupleWrapper((v1, v2))
        l2 = copy.copy(l1)
        self.assertEqual(l1, (v1, v2))
        self.assertEqual(l2, (v1, v2))
        self.assertIs(l1[0], l2[0])
        l2_deep = copy.deepcopy(l1)
        self.assertIsNot(l1[0], l2_deep[0])
        with self.assertRaises(AttributeError):
            l2.append(v1)
示例#10
0
 def testDatasetMap(self):
     dataset = dataset_ops.Dataset.from_tensor_slices(
         constant_op.constant([1, 2, 3]))
     dataset = dataset.map(lambda x: data_structures._TupleWrapper((x, )))
     for index, element in enumerate(dataset):
         self.assertEqual((index + 1, ), self.evaluate(element))
示例#11
0
 def testPickle(self):
     original = data_structures._TupleWrapper((1, 2))
     serialized = pickle.dumps(original)
     del original
     deserialized = pickle.loads(serialized)
     self.assertEqual((1, 2), deserialized)
示例#12
0
 def testRMul(self):
     v = resource_variable_ops.ResourceVariable(1.)
     l = data_structures._TupleWrapper((v, v, v))
     self.assertEqual(2 * l, (v, v, v) * 2)
示例#13
0
 def testIMul_zero(self):
     l = data_structures._TupleWrapper((1, ))
     l *= 0
     self.assertEqual((), l)
示例#14
0
 def testHash(self):
     has_sequences = set(
         [data_structures._TupleWrapper(),
          data_structures._TupleWrapper()])
     self.assertLen(has_sequences, 1)
     self.assertIn(data_structures._TupleWrapper(), has_sequences)