def test_memoization_with_default_arguments(self): for _ in range(5): f1(10) f2(10) f1(20) f2(20) self.assertEqual(exec_times['f1'], 2) self.assertEqual(exec_times['f2'], 2) for info in f1.cache_info(), f2.cache_info(): self.assertIsNone(info.max_size) self.assertEqual(info.algorithm, CachingAlgorithmFlag.LRU) self.assertIsNone(info.ttl) self.assertTrue(info.thread_safe) self.assertEqual(info.hits, 4) self.assertEqual(info.misses, 2) self.assertEqual(info.current_size, 2) for f in f1, f2: keys = make_key((10, ), None), make_key((20, ), None) for key in keys: self.assertIn(key, f._cache) f1.cache_clear() f2.cache_clear() self._check_empty_cache_after_clearing(f1) self._check_empty_cache_after_clearing(f2)
def _general_test(self, tested_function, algorithm, hits, misses, in_cache, not_in_cache): # clear exec_times[tested_function.__name__] = 0 tested_function.cache_clear() for i in range(20): tested_function(i) tested_function(99) self.assertEqual(exec_times[tested_function.__name__], 21) info = tested_function.cache_info() self.assertEqual(info.max_size, 5) self.assertEqual(info.algorithm, algorithm) self.assertIsNone(info.ttl) self.assertIsNotNone(info.thread_safe) self.assertEqual(info.hits, 0) self.assertEqual(info.misses, 21) self.assertEqual(info.current_size, 5) keys = [make_key((x, ), None) for x in (99, 19, 18, 17, 16)] for key in keys: self.assertIn(key, tested_function._cache) # 10 consecutive calls here tested_function(16) tested_function(17) tested_function(18) tested_function(16) tested_function(17) tested_function(18) tested_function(19) tested_function(15) tested_function(100) tested_function(16) info = tested_function.cache_info() self.assertEqual(info.hits, hits) self.assertEqual(info.misses, misses) self.assertEqual(info.current_size, 5) keys = [make_key((x, ), None) for x in in_cache] for key in keys: self.assertIn(key, tested_function._cache) keys = [ make_key((x, ), None) for x in chain(not_in_cache, range(0, 15)) ] for key in keys: self.assertNotIn(key, tested_function._cache)
def wrapper(*args, **kwargs): """ The actual wrapper """ nonlocal hits, misses key = make_key(args, kwargs) cache_expired = False with lock: result = _access_lfu_cache(cache, key, sentinel) if result is not sentinel: if values_toolkit.is_cache_value_valid(result): hits += 1 return values_toolkit.retrieve_result_from_cache_value( result) else: cache_expired = True misses += 1 result = user_function(*args, **kwargs) with lock: if key in cache: if cache_expired: # update cache with new ttl cache[key].value = values_toolkit.make_cache_value( result, ttl) else: # result added to the cache while the lock was released # no need to add again pass else: _insert_into_lfu_cache( cache, key, values_toolkit.make_cache_value(result, ttl), lfu_freq_list_root, max_size) return result
def wrapper(*args, **kwargs): """ The actual wrapper """ nonlocal hits, misses, root, full key = make_key(args, kwargs) cache_expired = False with lock: node = cache.get(key, sentinel) if node is not sentinel: if values_toolkit.is_cache_value_valid(node[_VALUE]): hits += 1 return values_toolkit.retrieve_result_from_cache_value( node[_VALUE]) else: cache_expired = True misses += 1 result = user_function(*args, **kwargs) with lock: if key in cache: if cache_expired: # update cache with new ttl cache[key][_VALUE] = values_toolkit.make_cache_value( result, ttl) else: # result added to the cache while the lock was released # no need to add again pass elif full: # switch root to the oldest element in the cache old_root = root root = root[_NEXT] # keep references of root[_KEY] and root[_VALUE] to prevent arbitrary GC old_key = root[_KEY] old_value = root[_VALUE] # overwrite the content of the old root old_root[_KEY] = key old_root[_VALUE] = values_toolkit.make_cache_value(result, ttl) # clear the content of the new root root[_KEY] = root[_VALUE] = None # delete from cache del cache[old_key] # save the result to the cache cache[key] = old_root else: # add a node to the linked list last = root[_PREV] node = [ last, root, key, values_toolkit.make_cache_value(result, ttl) ] # new node cache[key] = root[_PREV] = last[ _NEXT] = node # save result to the cache # check whether the cache is full full = (cache.__len__() >= max_size) return result
def _general_ttl_test(self, tested_function, arg=1, kwargs=None): # clear exec_times[tested_function.__name__] = 0 tested_function.cache_clear() def call_tested_function(arg, kwargs): if kwargs is None: tested_function(arg) else: tested_function(arg, **kwargs) key = make_key((arg, ), kwargs) call_tested_function(arg, kwargs) time.sleep(0.25) # wait for a short time info = tested_function.cache_info() self.assertEqual(info.hits, 0) self.assertEqual(info.misses, 1) self.assertEqual(info.current_size, 1) self.assertIn(key, tested_function._cache) call_tested_function(arg, kwargs) # this WILL NOT call the tested function info = tested_function.cache_info() self.assertEqual(info.hits, 1) self.assertEqual(info.misses, 1) self.assertEqual(info.current_size, 1) self.assertIn(key, tested_function._cache) self.assertEqual(exec_times[tested_function.__name__], 1) time.sleep(0.35) # wait until the cache expires info = tested_function.cache_info() self.assertEqual(info.current_size, 1) call_tested_function(arg, kwargs) # this WILL call the tested function info = tested_function.cache_info() self.assertEqual(info.hits, 1) self.assertEqual(info.misses, 2) self.assertEqual(info.current_size, 1) self.assertIn(key, tested_function._cache) self.assertEqual(exec_times[tested_function.__name__], 2) # The previous call should have been cached, so it must not call the function again call_tested_function( arg, kwargs) # this SHOULD NOT call the tested function info = tested_function.cache_info() self.assertEqual(info.hits, 2) self.assertEqual(info.misses, 2) self.assertEqual(info.current_size, 1) self.assertIn(key, tested_function._cache) self.assertEqual(exec_times[tested_function.__name__], 2)
def _general_unhashable_arguments_test(self, tested_function): args = ([1, 2, 3], { 'this': 'is unhashable' }, ['yet', ['another', ['complex', { 'type, ': 'isn\'t it?' }]]]) for arg in args: # clear exec_times[tested_function.__name__] = 0 tested_function.cache_clear() key = make_key((arg, ), None) tested_function(arg) self.assertIn(key, tested_function._cache) if isinstance(arg, list): arg.append(0) elif isinstance(arg, dict): arg['foo'] = 'bar' else: raise TypeError key = make_key((arg, ), None) tested_function(arg) self.assertIn(key, tested_function._cache) if isinstance(arg, list): arg.pop() elif isinstance(arg, dict): del arg['foo'] else: raise TypeError key = make_key((arg, ), None) tested_function(arg) self.assertIn(key, tested_function._cache) self.assertEqual(exec_times[tested_function.__name__], 2) info = tested_function.cache_info() self.assertEqual(info.hits, 1) self.assertEqual(info.misses, 2) self.assertEqual(info.current_size, 2)
def wrapper(*args, **kwargs): """ The actual wrapper """ nonlocal hits, misses key = make_key(args, kwargs) value = cache.get(key, sentinel) if value is not sentinel and values_toolkit.is_cache_value_valid( value): with lock: hits += 1 return values_toolkit.retrieve_result_from_cache_value(value) else: with lock: misses += 1 result = user_function(*args, **kwargs) cache[key] = values_toolkit.make_cache_value(result, ttl) return result
def _general_multithreading_test(self, tested_function, algorithm): number_of_keys = 30000 number_of_threads = 4 # clear exec_times[tested_function.__name__] = 0 tested_function.cache_clear() info = tested_function.cache_info() self.assertEqual(info.max_size, 5) self.assertEqual(info.algorithm, algorithm) self.assertIsNone(info.ttl) self.assertTrue(info.thread_safe) self.assertEqual(info.current_size, 0) # Test must-hit def run_must_hit(): keys = list(range(5)) * int(number_of_keys / 5) random.shuffle(keys) for i in keys: tested_function(i) threads = [ Thread(target=run_must_hit) for _ in range(number_of_threads) ] for thread in threads: thread.start() for thread in threads: thread.join() self.assertGreaterEqual(exec_times[tested_function.__name__], 5) info = tested_function.cache_info() self.assertLessEqual(info.hits, number_of_keys * number_of_threads - 5) self.assertGreaterEqual(info.misses, 5) self.assertEqual(info.current_size, 5) for key in [make_key((x, ), None) for x in range(5)]: self.assertIn(key, tested_function._cache) # Test can-miss def run_can_miss(): keys = list(range(20)) * int(number_of_keys / 20) random.shuffle(keys) for i in keys: tested_function(i) threads = [ Thread(target=run_can_miss) for _ in range(number_of_threads) ] for thread in threads: thread.start() for thread in threads: thread.join() executed_times = exec_times[tested_function.__name__] self.assertLessEqual(executed_times, number_of_keys * number_of_threads) self.assertGreaterEqual(executed_times, 20) info = tested_function.cache_info() self.assertGreaterEqual(info.hits, 0) self.assertLessEqual(info.misses, number_of_keys * number_of_threads) self.assertEqual(info.current_size, 5)