def benchmarkCacheHit50thKeyKnownSubtype(self): # If there are 50 keys and we get a key that has a subtype in cache and # the cache has observed the key before (to memorize the subtype). cache = function_cache.FunctionCache() args_per_call = 5 num_total_checks = 50 keys = [] for i in range(num_total_checks - 1): args = [] for j in range(args_per_call): args.append(array_ops.zeros([i, j])) keys.append(function_cache.make_cache_key(args)) for key in keys: cache.add(*key, "testing") cache.add( function_cache.FunctionCacheKey(MockSubtypeOf2(2), None), function_trace_type.WeakrefDeletionObserver(), "testing") cache.lookup(function_cache.FunctionCacheKey(MockSubtypeOf2(3), None), True) iterations = 10000 lookup_key = function_cache.FunctionCacheKey(MockSubtypeOf2(2), None) subtyping_time = timeit.timeit( lambda: cache.lookup(lookup_key, True), number=iterations) self.report_benchmark( name="cache_hit_50th_key_known_subtype", iters=iterations, wall_time=subtyping_time, metrics=[{ "name": "cache_hit_50th_key_known_subtype_avg_ms", "value": subtyping_time / iterations * 1000 }])
def testFunctionCacheKeyRespectsSupertype(self): ctx = function_cache.ExecutionContext(1, 1, 1, 1, 1, 1) key_a = function_cache.FunctionCacheKey(MockSupertypes2With3(1), ctx) key_b = function_cache.FunctionCacheKey(MockSupertypes2With3(2), ctx) self.assertEqual( key_b.most_specific_common_subtype([key_a]), function_cache.FunctionCacheKey(MockSupertypes2With3(3), ctx)) self.assertIsNone(key_a.most_specific_common_subtype([key_b]))
def testFunctionCacheKeyRespectsSubtype(self): ctx = function_cache.ExecutionContext(1, 1, 1, 1, 1, 1) key_a = function_cache.FunctionCacheKey(MockSubtypeOf2(1), ctx) key_b = function_cache.FunctionCacheKey(MockSubtypeOf2(2), ctx) key_c = function_cache.FunctionCacheKey(MockSubtypeOf2(1), ctx) self.assertTrue(key_b.is_subtype_of(key_a)) self.assertFalse(key_a.is_subtype_of(key_b)) self.assertFalse(key_c.is_subtype_of(key_a))
def testFunctionCacheKeyRespectsEquality(self): ctx = function_cache.ExecutionContext(1, 1, 1, 1, 1, 1) generic = function_trace_type.GenericType key_a = function_cache.FunctionCacheKey(generic(1), ctx) key_b = function_cache.FunctionCacheKey(generic(2), ctx) key_c = function_cache.FunctionCacheKey(generic(1), ctx) self.assertNotEqual(key_a, key_b) self.assertEqual(key_a, key_c) self.assertEqual(hash(key_a), hash(key_c))
def testFirstMostSpecificFunctionCacheKeyIsLookedUp(self): ctx = function_cache.ExecutionContext(1, 1, 1, 1, 1, 1) cache = function_cache.FunctionCache() cache.add(function_cache.FunctionCacheKey(MockShape(1, 2, None), ctx), trace_type.WeakrefDeletionObserver(), "a") cache.add(function_cache.FunctionCacheKey(MockShape(1, None, 3), ctx), trace_type.WeakrefDeletionObserver(), "b") self.assertEqual( cache.lookup( function_cache.FunctionCacheKey(MockShape(1, 2, 3), ctx), True), "a")
def setup(): cache.clear() for key in keys: cache.add(*key, "testing") cache.add( function_cache.FunctionCacheKey(MockSubtypeOf2(3), None), function_trace_type.WeakrefDeletionObserver(), "testing")
def testDeleteRemovesConcreteFunctions(self): cache = function_cache.FunctionCache() key_1, deletion_observer_1 = function_cache.make_cache_key(1) cache.add(key_1, deletion_observer_1, "test_1") self.assertEqual(cache.lookup(key_1, False), "test_1") cache.delete(key_1) self.assertIsNone(cache.lookup(key_1, False)) key_2 = function_cache.FunctionCacheKey(MockSubtypeOf2(2), None) cache.add(key_2, trace_type.WeakrefDeletionObserver(), "test_2") self.assertEqual(cache.lookup(key_2, False), "test_2") key_3 = function_cache.FunctionCacheKey(MockSubtypeOf2(3), None) self.assertEqual(cache.lookup(key_3, True), "test_2") cache.delete(key_2) self.assertIsNone(cache.lookup(key_2, False)) self.assertIsNone(cache.lookup(key_3, True))
def testMostSpecificFunctionCacheKeyIsOrderAgnostic(self): ctx = function_cache.ExecutionContext(1, 1, 1, 1, 1, 1) keys = [(function_cache.FunctionCacheKey(MockShape(1, 1, 1), ctx), "a"), (function_cache.FunctionCacheKey(MockShape(1, None, 1), ctx), "b"), (function_cache.FunctionCacheKey(MockShape(None, None, 1), ctx), "c"), (function_cache.FunctionCacheKey(MockShape(None, None, None), ctx), "d")] for permutation in itertools.permutations(keys): cache = function_cache.FunctionCache() cache.add(permutation[0][0], trace_type.WeakrefDeletionObserver(), permutation[0][1]) cache.add(permutation[1][0], trace_type.WeakrefDeletionObserver(), permutation[1][1]) cache.add(permutation[2][0], trace_type.WeakrefDeletionObserver(), permutation[2][1]) cache.add(permutation[3][0], trace_type.WeakrefDeletionObserver(), permutation[3][1]) self.assertEqual( cache.lookup( function_cache.FunctionCacheKey(MockShape(1, 1, 1), ctx), True), "a") self.assertEqual( cache.lookup( function_cache.FunctionCacheKey(MockShape(1, 2, 1), ctx), True), "b") self.assertEqual( cache.lookup( function_cache.FunctionCacheKey(MockShape(2, 2, 1), ctx), True), "c") self.assertEqual( cache.lookup( function_cache.FunctionCacheKey(MockShape(2, 2, 2), ctx), True), "d")