def test_cache_thread(): def cache_with_purge_count(cache): cache.purge_count = 0 purge = cache.purge_expired_entries def purge_then_increment(): purge() cache.purge_count += 1 cache.purge_expired_entries = purge_then_increment return cache mockdatetime = MockDatetime(datetime(2019, 4, 1, 10)) cache = cache_with_purge_count( ReadthroughTTLCache(timedelta(hours=2), lambda x: "payload")) mocksleep = MockSleep(mockdatetime) with patch("datetime.datetime", mockdatetime): with patch("time.sleep", mocksleep.sleep): cache.get("key_a", force_store=True) cache.start_ttl_thread() # after one hour mockdatetime.set_now(datetime(2019, 4, 1, 11)) assert "key_a" in cache assert mocksleep.wakeups_count == 0 # after two hours one minute before_timechange_purge_count = cache.purge_count mockdatetime.set_now(datetime(2019, 4, 1, 12, 1)) while cache.purge_count == before_timechange_purge_count: time.sleep(0) assert "key_a" not in cache
def test_cache_purges_after_ttl(): mockdatetime = MockDatetime(datetime(2019, 4, 1, 10)) cache = ReadthroughTTLCache(timedelta(hours=2), lambda x: "payload") with patch("datetime.datetime", mockdatetime): cache.get("key_a", force_store=True) # after one hour mockdatetime.set_now(datetime(2019, 4, 1, 11)) cache.purge_expired_entries() assert "key_a" in cache # after two hours one minute mockdatetime.set_now(datetime(2019, 4, 1, 12, 1)) cache.purge_expired_entries() assert "key_a" not in cache
def test_doesnt_cache_unless_accessed_within_ttl(): mockdatetime = MockDatetime(datetime(2019, 4, 1, 10)) cache = ReadthroughTTLCache(timedelta(hours=4), lambda x: "payload") with patch("datetime.datetime", mockdatetime): cache.get("key_a") # after one hour mockdatetime.set_now(datetime(2019, 4, 1, 11)) assert "key_a" not in cache # after two hours mockdatetime.set_now(datetime(2019, 4, 1, 12)) cache.get("key_a") # after three hours mockdatetime.set_now(datetime(2019, 4, 1, 13)) assert "key_a" in cache
def test_force_store(): def with_spied_storage(cache): cache.storage_access_count = 0 cache_getitem = cache.items_storage.__getitem__ def spy_getitem(key): cache.storage_access_count += 1 return cache_getitem(key) mock_items_storage = MagicMock() mock_items_storage.__getitem__.side_effect = spy_getitem mock_items_storage.__setitem__.side_effect = cache.items_storage.__setitem__ mock_items_storage.__contains__.side_effect = cache.items_storage.__contains__ cache.items_storage = mock_items_storage return cache cache = with_spied_storage( ReadthroughTTLCache(timedelta(hours=2), lambda x: "payload")) cache.get("key_a", force_store=True) assert "key_a" in cache assert cache.get("key_a") == "payload" assert cache.storage_access_count == 1
MODELS_NAMES = [ "defectenhancementtask", "component", "regression", "stepstoreproduce", "spambug", "testlabelselect", "testgroupselect", ] DEFAULT_EXPIRATION_TTL = 7 * 24 * 3600 # A week redis = Redis.from_url(os.environ.get("REDIS_URL", "redis://localhost/0")) MODEL_CACHE: ReadthroughTTLCache[str, Model] = ReadthroughTTLCache( timedelta(hours=1), load_model ) MODEL_CACHE.start_ttl_thread() cctx = zstandard.ZstdCompressor(level=10) def setkey(key: str, value: bytes, compress: bool = False) -> None: LOGGER.debug(f"Storing data at {key}: {value!r}") if compress: value = cctx.compress(value) redis.set(key, value) redis.expire(key, DEFAULT_EXPIRATION_TTL) def classify_bug(model_name: str, bug_ids: Collection[int], bugzilla_token: str) -> str:
MODELS_NAMES = [ "defectenhancementtask", "component", "regression", "stepstoreproduce", "spambug", "testlabelselect", "testgroupselect", ] DEFAULT_EXPIRATION_TTL = 7 * 24 * 3600 # A week redis = Redis.from_url(os.environ.get("REDIS_URL", "redis://localhost/0")) MODEL_CACHE: ReadthroughTTLCache[str, Model] = ReadthroughTTLCache( timedelta(hours=1), lambda m: Model.load(f"{m}model") ) MODEL_CACHE.start_ttl_thread() cctx = zstandard.ZstdCompressor(level=10) def setkey(key: str, value: bytes, compress: bool = False) -> None: LOGGER.debug(f"Storing data at {key}: {value!r}") if compress: value = cctx.compress(value) redis.set(key, value) redis.expire(key, DEFAULT_EXPIRATION_TTL) def classify_bug(model_name: str, bug_ids: Sequence[int], bugzilla_token: str) -> str: