def testWrapperNotEqualToWrapped(self): class SettableHash(object): def __init__(self): self.hash_value = 8675309 def __hash__(self): return self.hash_value o = SettableHash() wrap1 = object_identity._ObjectIdentityWrapper(o) wrap2 = object_identity._ObjectIdentityWrapper(o) self.assertEqual(wrap1, wrap1) self.assertEqual(wrap1, wrap2) self.assertEqual(o, wrap1.unwrapped) self.assertEqual(o, wrap2.unwrapped) with self.assertRaises(TypeError): bool(o == wrap1) with self.assertRaises(TypeError): bool(wrap1 != o) self.assertNotIn(o, set([wrap1])) o.hash_value = id(o) # Since there is now a hash collision we raise an exception with self.assertRaises(TypeError): bool(o in set([wrap1]))
def testNestMapStructure(self): k = object_identity._ObjectIdentityWrapper('k') v1 = object_identity._ObjectIdentityWrapper('v1') v2 = object_identity._ObjectIdentityWrapper('v2') struct = tf.nest.map_structure(lambda a, b: (a, b), {k: v1}, {k: v2}) self.assertEqual(struct, {k: (v1, v2)})
def testNestFlatten(self): a = object_identity._ObjectIdentityWrapper('a') b = object_identity._ObjectIdentityWrapper('b') c = object_identity._ObjectIdentityWrapper('c') flat = tf.nest.flatten([[[(a, b)]], c]) self.assertEqual(flat, [a, b, c])