def testPythonMapImpl(self): t = data_structures._TupleWrapper( (1, data_structures._TupleWrapper((2, )))) self.assertEqual((4, (5, )), nest.map_structure_up_to((None, (None, )), lambda x: x + 3, t, check_types=True)) nest.assert_shallow_structure((None, None), t)
def testNonLayerVariables(self): v = resource_variable_ops.ResourceVariable([1.]) l = data_structures._TupleWrapper((v, )) self.assertEqual([], l.layers) self.assertEqual([v], l.variables) self.assertEqual([v], l.trainable_weights) self.assertEqual([], l.non_trainable_variables)
def testDatasetMapNamed(self): nt_type = collections.namedtuple("A", ["x"]) dataset = dataset_ops.Dataset.from_tensor_slices( constant_op.constant([1, 2, 3])) dataset = dataset.map( lambda x: data_structures._TupleWrapper(nt_type(x))) for index, element in enumerate(dataset): self.assertEqual((index + 1, ), self.evaluate(element))
def testFlatten(self): t = data_structures._TupleWrapper( (1, data_structures._TupleWrapper((2, )))) self.assertEqual([1, 2], nest.flatten(t)) self.assertEqual(nest.flatten_with_tuple_paths((1, (2, ))), nest.flatten_with_tuple_paths(t)) self.assertEqual((3, (4, )), nest.pack_sequence_as(t, [3, 4])) nt_type = collections.namedtuple("nt", ["x", "y"]) nt = nt_type(1., 2.) wrapped_nt = data_structures._TupleWrapper(nt) self.assertEqual(nest.flatten_with_tuple_paths(nt), nest.flatten_with_tuple_paths(wrapped_nt)) self.assertEqual(( 3, 4, ), nest.pack_sequence_as(wrapped_nt, [3, 4])) self.assertEqual(3, nest.pack_sequence_as(wrapped_nt, [3, 4]).x)
def testIAdd(self): v = resource_variable_ops.ResourceVariable(1.) l = data_structures._TupleWrapper((v, )) original = l l += (1, ) self.assertEqual(l, (v, 1)) self.assertNotEqual(original, (v, 1)) self.assertEqual(original, (v, ))
def testFunctionCaching(self): @def_function.function def f(tuple_input): return tuple_input[0] + constant_op.constant(1.) first_trace = f.get_concrete_function((constant_op.constant(2.), )) second_trace = f.get_concrete_function( data_structures._TupleWrapper((constant_op.constant(3.), ))) self.assertIs(first_trace, second_trace)
def testIMul(self): # Note: tuple behavior differs from list behavior. Lists are mutated by # imul/iadd, tuples assign a new object to the left hand side of the # expression. v = resource_variable_ops.ResourceVariable(1.) l = data_structures._TupleWrapper((v, )) original = l l *= 2 self.assertEqual(l, (v, ) * 2) self.assertNotEqual(original, (v, ) * 2)
def testSlicing(self): v1 = resource_variable_ops.ResourceVariable(1.) v2 = resource_variable_ops.ResourceVariable(1.) v3 = resource_variable_ops.ResourceVariable(1.) v4 = resource_variable_ops.ResourceVariable(1.) l = data_structures._TupleWrapper((v1, v2, v3, v4)) self.assertEqual(l[1:], (v2, v3, v4)) self.assertEqual(l[1:-1], (v2, v3)) self.assertEqual(l[:-1], (v1, v2, v3))
def testCopy(self): v1 = resource_variable_ops.ResourceVariable(1.) v2 = resource_variable_ops.ResourceVariable(1.) l1 = data_structures._TupleWrapper((v1, v2)) l2 = copy.copy(l1) self.assertEqual(l1, (v1, v2)) self.assertEqual(l2, (v1, v2)) self.assertIs(l1[0], l2[0]) l2_deep = copy.deepcopy(l1) self.assertIsNot(l1[0], l2_deep[0]) with self.assertRaises(AttributeError): l2.append(v1)
def testDatasetMap(self): dataset = dataset_ops.Dataset.from_tensor_slices( constant_op.constant([1, 2, 3])) dataset = dataset.map(lambda x: data_structures._TupleWrapper((x, ))) for index, element in enumerate(dataset): self.assertEqual((index + 1, ), self.evaluate(element))
def testPickle(self): original = data_structures._TupleWrapper((1, 2)) serialized = pickle.dumps(original) del original deserialized = pickle.loads(serialized) self.assertEqual((1, 2), deserialized)
def testRMul(self): v = resource_variable_ops.ResourceVariable(1.) l = data_structures._TupleWrapper((v, v, v)) self.assertEqual(2 * l, (v, v, v) * 2)
def testIMul_zero(self): l = data_structures._TupleWrapper((1, )) l *= 0 self.assertEqual((), l)
def testHash(self): has_sequences = set( [data_structures._TupleWrapper(), data_structures._TupleWrapper()]) self.assertLen(has_sequences, 1) self.assertIn(data_structures._TupleWrapper(), has_sequences)