def testWeakStructRespectsContainerTypes(self): a = tf.constant(1., name='alice') b = tf.constant(2., name='bob') # Different lists are equal. self.assertEqual(cache_util.WeakStructRef([a, b]), cache_util.WeakStructRef([a, b])) # List and tuple with same contents are not equal. self.assertNotEqual(cache_util.WeakStructRef([a, b]), cache_util.WeakStructRef((a, b)))
def testStructRefCallbackFiresOnce(self): tensor_struct = { 'a': tf.constant(1., name='alice'), 'b': tf.constant(2., name='bob'), 'c': tf.constant(3., name='carol') } callback_keys = [] def callback(key): callback_keys.append(key) struct_ref = cache_util.WeakStructRef(tensor_struct, callback=callback) if tf.executing_eagerly(): self.assertTrue(struct_ref.alive) self.assertEqual(callback_keys, []) del tensor_struct['a'] # Goodbye, Alice! self.assertFalse(struct_ref.alive) self.assertEqual(callback_keys, [struct_ref]) del tensor_struct['b'] # Goodbye, Bob! self.assertEqual(callback_keys, [struct_ref]) del tensor_struct # Goodbye, everybody! self.assertEqual(callback_keys, [struct_ref]) else: self.assertTrue(struct_ref.alive) del tensor_struct self.assertTrue(struct_ref.alive)
def testInputsAreCached(self, steps=4): struct = self.test_arg() weak_start = cache_util.WeakStructRef(struct) # Call forward a few times. for _ in range(steps): struct = self.forward(struct) print(type(struct), struct) self.assertLen(self.forward_keys, steps) self.assertEqual(self.bijector.forward_call_count, steps) self.assertLen(self.cache.items(direction='forward'), steps) self.assertLen(self.inverse_keys, steps) self.assertEqual(self.bijector.inverse_call_count, 0) self.assertLen(self.cache.items(direction='inverse'), steps) # Now invert our calls for _ in range(steps): struct = self.inverse(struct) self.assertLen(self.forward_keys, 1) # Has cached attrs. self.assertEqual(self.bijector.forward_call_count, steps) # No new calls. self.assertLen(self.inverse_keys, 0) # Refs are all gone. self.assertEqual(self.bijector.inverse_call_count, 0) # All cache hits. # Original is recoverable. Contents are referentially equal. self.assertTrue(weak_start.alive) tf.nest.map_structure(self.assertIs, struct, weak_start()) struct = None self.assertFalse(weak_start.alive) self.assertLen(self.forward_keys, 0) self.assertLen(self.inverse_keys, 0)
def testStructRefIsWeak(self): tensor_struct = { 'a': tf.constant(1., name='alice'), 'b': tf.constant(2., name='bob'), 'c': tf.constant(3., name='carol')} weak_ref = cache_util.WeakStructRef(tensor_struct) another_weak_ref = cache_util.WeakStructRef(tensor_struct) self.assertTrue(weak_ref.alive) del tensor_struct # In eager mode, references get cleaned up. if tf.executing_eagerly(): self.assertFalse(weak_ref.alive) self.assertFalse(another_weak_ref.alive) # In graph mode, references stay alive. else: self.assertTrue(weak_ref.alive) self.assertTrue(another_weak_ref.alive)
def testWeakStructCopiesContainers(self): a = tf.constant(1., name='alice') b = tf.constant(2., name='bob') c = tf.constant(3., name='carol') tensor_struct = [a, {'x': b, 'y': c}] # Reference the original struct. weak_ref = cache_util.WeakStructRef(tensor_struct) # Copy the structure, then mutate the original inplace. struct_copy = tf.nest.map_structure(lambda x: x, tensor_struct) tensor_struct[:] = tensor_struct[::-1] self.assertEqual(weak_ref(), struct_copy) with self.assertRaises(ValueError): # pylint: disable=g-error-prone-assert-raises tf.nest.assert_same_structure(weak_ref(), tensor_struct)
def testStructRefChecksSubkey(self): tensor = tf.constant([1., 2., 3.], dtype=tf.float32, name='alice') ref1 = cache_util.WeakStructRef(tensor, subkey='a') ref2 = cache_util.WeakStructRef(tensor, subkey='b') self.assertNotEqual(ref1, ref2) self.assertNotEqual(hash(ref1), hash(ref2))