class FIFOEvictionStrategy(EvictionStrategy): def __init__(self, config: Dict[str, any], result_dir: str, cache_stats: CacheInformation): super().__init__(config, result_dir, cache_stats) self.fifo = OrderedDict() self.logger = logging.getLogger(__name__) self.renewable_ops = {ObservationType.Hit, ObservationType.Write} name = 'fifo_eviction_strategy' self.performance_logger = create_file_logger( name=f'{name}_performance_logger', result_dir=result_dir) self._incomplete_experiences = TTLCache(InMemoryStorage()) self._incomplete_experiences.expired_entry_callback( self._observe_expired_incomplete_experience) def observe(self, key: str, observation_type: ObservationType, info: Dict[str, any]): if observation_type == ObservationType.Write: ttl = info['ttl'] observation_time = time.time() self.fifo[key] = {'ttl': ttl, 'observation_time': observation_time} elif observation_type in { ObservationType.Expiration, ObservationType.Invalidate }: self.logger.debug(f"Key {key} expired") if key in self.fifo: del self.fifo[key] action_taken = self._incomplete_experiences.get(key) if action_taken is not None: if observation_type == ObservationType.Invalidate: # eviction followed by invalidation. self.performance_logger.info(f'{self.episode_num},TrueEvict') elif observation_type == ObservationType.Miss: self.performance_logger.info(f'{self.episode_num},FalseEvict') # Miss after making an eviction decision self._incomplete_experiences.delete(key) def _observe_expired_incomplete_experience( self, key: str, observation_type: ObservationType, info: Dict[str, any]): self.performance_logger.info(f'{self.episode_num},TrueEvict') def trim_cache(self, cache: TTLCache) -> List[str]: while True: # TTLCache might expire and cause a race condition eviction_item = self.fifo.popitem(last=False) eviction_key = eviction_item[0] eviction_value = eviction_item[1] if cache.contains(eviction_key): decision_time = time.time() ttl_left = (eviction_value['observation_time'] + eviction_value['ttl']) - decision_time self._incomplete_experiences.set(eviction_key, 'evict', ttl_left) cache.delete(eviction_key) return [eviction_key]
class RLTtlStrategy(TtlStrategy): """RL driven TTL estimation strategy.""" def __init__(self, config: Dict[str, any], result_dir: str, cache_stats: CacheInformation): super().__init__(config, result_dir, cache_stats) self.observation_seen = 0 self.cum_reward = 0 self.checkpoint_steps = config['checkpoint_steps'] self._incomplete_experiences = TTLCache(InMemoryStorage()) self._incomplete_experiences.expired_entry_callback( self._observe_expiry_eviction) self.non_terminal_observations = { ObservationType.EvictionPolicy, ObservationType.Expiration } agent_config = config['agent_config'] self.maximum_ttl = config['max_ttl'] self.experimental_reward = config.get('experimental_reward', False) fields_in_state = len(TTLAgentSystemState.__slots__) self.agent = Agent.from_spec( agent_config, state_space=FloatBox(shape=(fields_in_state, )), action_space=FloatBox(low=0, high=self.maximum_ttl, shape=(1, ))) # TODO refactor into common RL interface for all strategies self.logger = logging.getLogger(__name__) name = 'rl_ttl_strategy' self.reward_logger = create_file_logger(name=f'{name}_reward_logger', result_dir=self.result_dir) self.loss_logger = create_file_logger(name=f'{name}_loss_logger', result_dir=self.result_dir) self.ttl_logger = create_file_logger(name=f'{name}_ttl_logger', result_dir=self.result_dir) self.observation_logger = create_file_logger( name=f'{name}_observation_logger', result_dir=self.result_dir) self.key_vocab = Vocabulary() self.errors = create_file_logger(name=f'{name}_error_logger', result_dir=self.result_dir) def estimate_ttl(self, key: str, values: Dict[str, any], operation_type: OperationType) -> float: observation_time = time.time() encoded_key = self.key_vocab.add_or_get_id(key) cache_utility = self.cache_stats.cache_utility state = TTLAgentSystemState(encoded_key=encoded_key, hit_count=0, step_code=0, cache_utility=cache_utility, operation_type=operation_type.value) state_as_numpy = state.to_numpy() agent_action = self.agent.get_action(state_as_numpy) action = agent_action.item() incomplete_experience = TTLAgentObservedExperience( state=state, agent_action=agent_action, starting_state=state.copy(), observation_time=observation_time) self._incomplete_experiences.set(key, incomplete_experience, self.maximum_ttl) return action def observe(self, key: str, observation_type: ObservationType, info: Dict[str, any]): observed_experience = self._incomplete_experiences.get(key) if observed_experience is None or observation_type == ObservationType.Expiration: return # haven't had to make a decision on it current_time = time.time() stored_state = observed_experience.state if observation_type == ObservationType.Hit: stored_state.hit_count += 1 # elif observation_type in self.non_terminal_observations: # # it was evicted by another policy don't attempt to learn stuff from this # pass estimated_ttl = observed_experience.agent_action.item() first_observation_time = observed_experience.observation_time real_ttl = current_time - first_observation_time stored_state.step_code = observation_type.value stored_state.cache_utility = self.cache_stats.cache_utility self.reward_agent(observation_type, observed_experience, real_ttl) if observation_type != ObservationType.Hit: self.ttl_logger.info( f'{self.episode_num},{observation_type.name},{key},{estimated_ttl},{real_ttl},{stored_state.hit_count}' ) self._incomplete_experiences.delete(key) self.observation_seen += 1 if self.observation_seen % self.checkpoint_steps == 0: self.logger.info( f'Observation seen so far: {self.observation_seen}, reward so far: {self.cum_reward}' ) if observation_type not in self.non_terminal_observations: self.observation_logger.info(f'{key},{observation_type}') def _observe_expiry_eviction(self, key: str, observation_type: ObservationType, info: Dict[str, any]): """Observe decisions taken that hasn't been observed by main cache. e.g. don't cache -> ttl up -> no miss""" self.observation_logger.info( f'{self.episode_num},{key},{observation_type}') experience = info['value'] # type: TTLAgentObservedExperience self.ttl_logger.info( f'{self.episode_num},{observation_type.name},{key},{experience.agent_action.item()},' f'{experience.agent_action.item()},{experience.state.hit_count}') experience.state.step_code = observation_type.value self.reward_agent(observation_type, experience, experience.agent_action.item()) def reward_agent(self, observation_type: ObservationType, experience: TTLAgentObservedExperience, real_ttl: time) -> int: # reward more utilisation of the cache capacity given more hits final_state = experience.state difference_in_ttl = -abs((experience.agent_action.item() + 1) / max(min(real_ttl, self.maximum_ttl), 1)) # reward = final_state.hit_count - abs(difference_in_ttl * self.cache_stats.cache_utility) if observation_type == ObservationType.Hit: reward = 1 terminal = False # elif observation_type == ObservationType.EvictionPolicy: # reward = 0 # terminal = True else: reward = difference_in_ttl if abs(difference_in_ttl) < 10: reward = 10 terminal = True self.logger.debug( f'Hits: {final_state.hit_count}, ttl diff: {difference_in_ttl}, Reward: {reward}' ) self.agent.observe( preprocessed_states=experience.starting_state.to_numpy(), actions=experience.agent_action, internals=[], rewards=reward, next_states=final_state.to_numpy(), terminals=terminal) self.cum_reward += reward self.reward_logger.info(f'{self.episode_num},{reward}') loss = self.agent.update() if loss is not None: self.loss_logger.info(f'{self.episode_num},{loss[0]}') return reward def close(self): super().close() for (k, v) in list(self._incomplete_experiences.items()): self.ttl_logger.info( f'{self.episode_num},{ObservationType.EndOfEpisode.name},{k},{v.agent_action.item()},' f'{v.agent_action.item()},{v.state.hit_count}') self._incomplete_experiences.clear() try: self.agent.reset() except Exception as e: self.errors.info(e)
class RLMultiTasksStrategy(BaseStrategy): """RL driven multi task strategy - Caching, eviction, and ttl estimation.""" def __init__(self, config: Dict[str, any], result_dir: str, cache_stats: CacheInformation): super().__init__(config, result_dir, cache_stats) self.supported_observations = { ObservationType.Hit, ObservationType.Miss, ObservationType.Invalidate } # evaluation specific variables self.observation_seen = 0 self.cum_reward = 0 self.checkpoint_steps = config['checkpoint_steps'] self._incomplete_experiences = TTLCache(InMemoryStorage()) self._incomplete_experiences.expired_entry_callback( self._observe_expiry_eviction) self.non_terminal_observations = { ObservationType.EvictionPolicy, ObservationType.Expiration } agent_config = config['agent_config'] self.maximum_ttl = config['max_ttl'] fields_in_state = len(MultiTaskAgentSystemState.__slots__) action_space = RLDict({ 'ttl': IntBox(low=0, high=self.maximum_ttl), 'eviction': IntBox(low=0, high=2) }) self.agent = Agent.from_spec( agent_config, state_space=FloatBox(shape=(fields_in_state, )), action_space=action_space) # TODO refactor into common RL interface for all strategies self.logger = logging.getLogger(__name__) name = 'rl_multi_strategy' self.reward_logger = create_file_logger(name=f'{name}_reward_logger', result_dir=self.result_dir) self.loss_logger = create_file_logger(name=f'{name}_loss_logger', result_dir=self.result_dir) self.ttl_logger = create_file_logger(name=f'{name}_ttl_logger', result_dir=self.result_dir) self.observation_logger = create_file_logger( name=f'{name}_observation_logger', result_dir=self.result_dir) self.performance_logger = create_file_logger( name=f'{name}_performance_logger', result_dir=self.result_dir) self.key_vocab = Vocabulary() def observe(self, key: str, observation_type: ObservationType, info: Dict[str, any]): observed_experience = self._incomplete_experiences.get(key) if observed_experience is None: return # haven't had to make a decision on it current_time = time.time() stored_state = observed_experience.state stored_state.step_code = observation_type.value stored_state.cache_utility = self.cache_stats.cache_utility if observation_type == ObservationType.Hit: stored_state.hit_count += 1 else: # Include eviction, invalidation, and miss estimated_ttl = observed_experience.agent_action['ttl'].item() first_observation_time = observed_experience.observation_time real_ttl = current_time - first_observation_time # log the difference between the estimated ttl and real ttl self.ttl_logger.info( f'{self.episode_num},{observation_type.name},{key},{estimated_ttl},{real_ttl},{stored_state.hit_count}' ) self._incomplete_experiences.delete(key) self.reward_agent(observation_type, observed_experience) self.observation_seen += 1 if self.observation_seen % self.checkpoint_steps == 0: self.logger.info( f'Observation seen so far: {self.observation_seen}, reward so far: {self.cum_reward}' ) if observation_type not in self.non_terminal_observations: self.observation_logger.info(f'{key},{observation_type}') def trim_cache(self, cache: TTLCache): # trim cache isn't called often so the operation is ok to be expensive # produce an action on the whole cache keys_to_evict = [] for (key, stored_experience) in list(self._incomplete_experiences.items()): action = self.agent.get_action( stored_experience.state.to_numpy())['eviction'] evict = (action.flatten() == 1).item() if evict: cache.delete(key) keys_to_evict.append(key) # update stored value for eviction action stored_experience.agent_action['eviction'] = action stored_experience.manual_eviction = True if len(keys_to_evict) == 0: self.logger.error('trim_cache No keys were evicted.') return keys_to_evict def should_cache(self, key: str, values: Dict[str, str], ttl: int, operation_type: OperationType) -> bool: # cache objects that have TTL more than 1 second (maybe make this configurable?) return ttl > 10 def estimate_ttl(self, key: str, values: Dict[str, any], operation_type: OperationType) -> float: # TODO check if it is in the observed queue observation_time = time.time() encoded_key = self.key_vocab.add_or_get_id(key) cache_utility = self.cache_stats.cache_utility state = MultiTaskAgentSystemState(encoded_key=encoded_key, hit_count=0, ttl=0, step_code=0, cache_utility=cache_utility, operation_type=operation_type.value) state_as_numpy = state.to_numpy() agent_action = self.agent.get_action(state_as_numpy) action = agent_action['ttl'].item() incomplete_experience = MultiTaskAgentObservedExperience( state=state, agent_action=agent_action, starting_state=state.copy(), observation_time=observation_time, manual_eviction=False) self._incomplete_experiences.set(key, incomplete_experience, self.maximum_ttl) return action def reward_agent(self, observation_type: ObservationType, experience: MultiTaskAgentObservedExperience) -> int: # reward more utilisation of the cache capacity given more hits final_state = experience.state # difference_in_ttl = -abs((experience.agent_action.item() + 1) / min(real_ttl, self.maximum_ttl)) reward = 0 terminal = False if observation_type == observation_type.Invalidate and ( experience.agent_action['ttl'] < 10 or (experience.agent_action['eviction'].flatten() == 1).item()): # if evicted or not cached, followed by an invalidate reward = 10 terminal = True elif observation_type == observation_type.Invalidate or observation_type == observation_type.Miss: reward = -10 terminal = True elif observation_type == ObservationType.Hit: terminal = False reward = 1 if experience.manual_eviction: if observation_type == observation_type.Expiration: reward = -10 terminal = True if observation_type == observation_type.Hit: reward = 2 self.performance_metric_for_eviction(experience, observation_type) self.agent.observe( preprocessed_states=experience.starting_state.to_numpy(), actions=experience.agent_action, internals=[], rewards=terminal, next_states=final_state.to_numpy(), terminals=True) self.cum_reward += reward self.reward_logger.info(f'{self.episode_num},{reward}') # TODO use self.agent.update_schedule to decide when to call update loss = self.agent.update() if loss is not None: self.loss_logger.info(f'{self.episode_num},{loss[0]}') return reward def _observe_expiry_eviction(self, key: str, observation_type: ObservationType, info: Dict[str, any]): """Observe decisions taken that hasn't been observed by main cache. e.g. don't cache -> ttl up -> no miss""" self.observation_logger.info( f'{self.episode_num},{key},{observation_type}') experience = info['value'] # type: MultiTaskAgentObservedExperience self.ttl_logger.info( f'{self.episode_num},{observation_type.name},{key},{experience.agent_action["ttl"].item()},' f'{experience.agent_action["ttl"].item()},{experience.state.hit_count}' ) experience.state.step_code = observation_type.value self.reward_agent(observation_type, experience) def performance_metric_for_eviction( self, stored_experience: MultiTaskAgentObservedExperience, observation_type: ObservationType) -> int: should_evict = ( stored_experience.agent_action['eviction'].flatten() == 1).item() if observation_type == ObservationType.Expiration: if should_evict: # reward if should evict didn't observe any follow up miss self.performance_logger.info(f'{self.episode_num},TrueEvict') # else didn't evict else: # reward for not evicting a key that received more hits. # or 0 if it didn't evict but also didn't get any hits gain_for_not_evicting = stored_experience.state.hit_count - stored_experience.starting_state.hit_count if gain_for_not_evicting > 0: self.performance_logger.info( f'{self.episode_num},TrueMiss') else: self.performance_logger.info( f'{self.episode_num},MissEvict') return gain_for_not_evicting elif observation_type == ObservationType.Invalidate: # Set/Delete, remove entry from the cache. # reward an eviction followed by invalidation. if should_evict: self.performance_logger.info(f'{self.episode_num},TrueEvict') else: # punish not evicting a key that got invalidated after. self.performance_logger.info(f'{self.episode_num},MissEvict') elif observation_type == ObservationType.Miss: if should_evict: self.performance_logger.info(f'{self.episode_num},FalseEvict') # Miss after making an eviction decision # Punish, a read after an eviction decision def close(self): for (k, v) in list(self._incomplete_experiences.items()): self.ttl_logger.info( f'{self.episode_num},{ObservationType.EndOfEpisode.name},{k},{v.agent_action["ttl"].item()},' f'{v.agent_action["ttl"].item()},{v.state.hit_count}') self.performance_logger.info(f'{self.episode_num},TrueMiss') super().close() self._incomplete_experiences.clear() try: self.agent.reset() except Exception as e: pass
class RLCachingStrategy(CachingStrategy): def __init__(self, config: Dict[str, any], result_dir: str, cache_stats: CacheInformation): super().__init__(config, result_dir, cache_stats) # evaluation specific variables self.observation_seen = 0 self.episode_reward = 0 self.checkpoint_steps = config['checkpoint_steps'] self._incomplete_experiences = TTLCache(InMemoryStorage()) self._incomplete_experiences.expired_entry_callback(self._observe_expired_incomplete_experience) self.experimental_reward = config.get('experimental_reward', False) agent_config = config['agent_config'] self.converter = CachingStrategyRLConverter() # action space: should cache: true or false # state space: [capacity (1), query key(1), query result set(num_indexes)] fields_in_state = len(CachingAgentSystemState.__slots__) self.agent = Agent.from_spec(agent_config, state_space=FloatBox(shape=(fields_in_state,)), action_space=IntBox(2)) self.logger = logging.getLogger(__name__) name = 'rl_caching_strategy' self.reward_logger = create_file_logger(name=f'{name}_reward_logger', result_dir=self.result_dir) self.loss_logger = create_file_logger(name=f'{name}_loss_logger', result_dir=self.result_dir) self.observation_logger = create_file_logger(name=f'{name}_observation_logger', result_dir=self.result_dir) self.entry_hits_logger = create_file_logger(name=f'{name}_entry_hits_logger', result_dir=self.result_dir) self.key_vocab = Vocabulary() def should_cache(self, key: str, values: Dict[str, str], ttl: int, operation_type: OperationType) -> bool: # TODO what about the case of a cache key that exist already in the incomplete exp? assert self._incomplete_experiences.get(key) is None, \ "should_cache is assumed to be first call and key shouldn't be in the cache" observation_time = time.time() encoded_key = self.key_vocab.add_or_get_id(key) state = CachingAgentSystemState(encoded_key=encoded_key, ttl=ttl, hit_count=0, step_code=0, operation_type=operation_type.value) agent_action = self.agent.get_action(state.to_numpy()) incomplete_experience_entry = CachingAgentIncompleteExperienceEntry(state=state, agent_action=agent_action, starting_state=state.copy(), observation_time=observation_time) action = self.converter.agent_to_system_action(agent_action) self._incomplete_experiences.set(key, incomplete_experience_entry, ttl) return action def observe(self, key: str, observation_type: ObservationType, info: Dict[str, any]): # TODO include stats/capacity information in the info dict experience = self._incomplete_experiences.get(key) # type: CachingAgentIncompleteExperienceEntry if experience is None: return # if I haven't had to make a decision on this, ignore it. self.observation_logger.info(f'{self.episode_num},{key},{observation_type.name}') if observation_type == ObservationType.Hit: experience.state.hit_count += 1 else: self._reward_experience(key, experience, observation_type) self.observation_seen += 1 if self.observation_seen % self.checkpoint_steps == 0: self.logger.info(f'Observation seen so far: {self.observation_seen}, reward so far: {self.episode_reward}') def _observe_expired_incomplete_experience(self, key: str, observation_type: ObservationType, info: Dict[str, any]): """Observe decisions taken that hasn't been observed by main cache. e.g. don't cache -> ttl up -> no miss""" assert observation_type == ObservationType.Expiration experience = info['value'] self._reward_experience(key, experience, observation_type) def _reward_experience(self, key: str, experience: CachingAgentIncompleteExperienceEntry, observation_type: ObservationType): state = experience.state state.step_code = observation_type.value self._incomplete_experiences.delete(key) self.entry_hits_logger.info(f'{self.episode_num},{key},{experience.state.hit_count}') reward = self.converter.system_to_agent_reward(experience) if self.experimental_reward: # TODO add cache utility to state and reward pass self.agent.observe(preprocessed_states=experience.starting_state.to_numpy(), actions=experience.agent_action, internals=[], rewards=reward, next_states=experience.state.to_numpy(), terminals=False) self.episode_reward += reward self.reward_logger.info(f'{self.episode_num},{reward}') self.logger.debug(f'Key: {key} is in terminal state because: {str(observation_type)}') loss = self.agent.update() if loss is not None: self.loss_logger.info(f'{self.episode_num},{loss[0]}') def close(self): super().close() self.agent.reset() self._incomplete_experiences.clear()
class LRUEvictionStrategy(EvictionStrategy): def __init__(self, config: Dict[str, any], result_dir: str, cache_stats: CacheInformation): super().__init__(config, result_dir, cache_stats) self.lru = OrderedDict() self.logger = logging.getLogger(__name__) name = 'lru_eviction_strategy' self.performance_logger = create_file_logger( name=f'{name}_performance_logger', result_dir=result_dir) self._incomplete_experiences = TTLCache(InMemoryStorage()) self._incomplete_experiences.expired_entry_callback( self._observe_expired_incomplete_experience) def observe(self, key: str, observation_type: ObservationType, info: Dict[str, any]): try: stored_values = self.lru.pop(key) except KeyError: self.logger.debug( f"Key: {key} not in LRU monitor. Current LRU size: {len(self.lru)}" ) stored_values = None pass # item not observed in cache before. # add/refresh lru if hit if observation_type == ObservationType.Write: ttl = info['ttl'] observation_time = time.time() self.lru[key] = {'ttl': ttl, 'observation_time': observation_time} elif observation_type == ObservationType.Hit: if stored_values is not None: observation_time = time.time() self.lru[key] = { 'ttl': stored_values['ttl'], 'observation_time': observation_time } elif observation_type in { ObservationType.Expiration, ObservationType.Invalidate }: self.logger.debug(f"Key {key} expired") assert key not in self.lru, "Expired key should have been deleted." action_taken = self._incomplete_experiences.get(key) if action_taken is not None: if observation_type == ObservationType.Invalidate: # eviction followed by invalidation. self.performance_logger.info(f'{self.episode_num},TrueEvict') elif observation_type == ObservationType.Miss: self.performance_logger.info(f'{self.episode_num},FalseEvict') # Miss after making an eviction decision self._incomplete_experiences.delete(key) def _observe_expired_incomplete_experience( self, key: str, observation_type: ObservationType, info: Dict[str, any]): self.performance_logger.info(f'{self.episode_num},TrueEvict') def trim_cache(self, cache: TTLCache) -> List[str]: while True: eviction_item = self.lru.popitem(last=False) eviction_key = eviction_item[0] eviction_value = eviction_item[1] if cache.contains(eviction_key): # TTLCache might expire and cause a race condition decision_time = time.time() ttl_left = (eviction_value['observation_time'] + eviction_value['ttl']) - decision_time self._incomplete_experiences.set(eviction_key, 'evict', ttl_left) cache.delete(eviction_key) return [eviction_key]
class RLEvictionStrategy(EvictionStrategy): def __init__(self, config: Dict[str, any], result_dir: str, cache_stats: CacheInformation): super().__init__(config, result_dir, cache_stats) # evaluation specific variables self.observation_seen = 0 self.episode_reward = 0 self.checkpoint_steps = config['checkpoint_steps'] self._incomplete_experiences = TTLCache(InMemoryStorage()) self._incomplete_experiences.expired_entry_callback( self._observe_expired_incomplete_experience) self.view_of_the_cache = {} # type: Dict[str, Dict[str, any]] self._end_episode_observation = { ObservationType.Invalidate, ObservationType.Miss, ObservationType.Expiration } # TODO refactor into common RL interface for all strategies # Agent configuration (can be shared with others) agent_config = config['agent_config'] fields_in_state = len(EvictionAgentSystemState.__slots__) self.converter = EvictionStrategyRLConverter(self.result_dir) # State: fields to observe in question # Action: to evict or not that key self.agent = Agent.from_spec( agent_config, state_space=FloatBox(shape=(fields_in_state, )), action_space=IntBox(low=0, high=2)) self.logger = logging.getLogger(__name__) name = 'rl_eviction_strategy' self.reward_logger = create_file_logger(name=f'{name}_reward_logger', result_dir=self.result_dir) self.loss_logger = create_file_logger(name=f'{name}_loss_logger', result_dir=self.result_dir) self.observation_logger = create_file_logger( name=f'{name}_observation_logger', result_dir=self.result_dir) self.key_vocab = Vocabulary() def trim_cache(self, cache: TTLCache) -> List[str]: # trim cache isn't called often so the operation is ok to be expensive # produce an action on the whole cache keys_to_evict = [] keys_to_not_evict = [] for (key, cached_key) in list(self.view_of_the_cache.items()): agent_system_state = cached_key['state'] agent_action = self.agent.get_action(agent_system_state.to_numpy()) should_evict = self.converter.agent_to_system_action(agent_action) decision_time = time.time() incomplete_experience = EvictionAgentIncompleteExperienceEntry( agent_system_state, agent_action, agent_system_state.copy(), decision_time) # observe the key for only the ttl period that is left for this key ttl_left = (cached_key['observation_time'] + agent_system_state.ttl) - decision_time self._incomplete_experiences.set(key=key, values=incomplete_experience, ttl=ttl_left) if should_evict: del self.view_of_the_cache[key] keys_to_evict.append(key) if not cache.contains(key, clean_expire=False): # race condition, clean up and move on self._incomplete_experiences.delete(key) cache.delete(key) else: keys_to_not_evict.append(key) return keys_to_evict def observe(self, key: str, observation_type: ObservationType, info: Dict[str, any]): self.observation_logger.info( f'{self.episode_num},{key},{observation_type}') observed_key = self.converter.vocabulary.add_or_get_id(key) stored_experience = self._incomplete_experiences.get(key) if observation_type == ObservationType.Write: if stored_experience is not None: # race condition reward = self.converter.system_to_agent_reward( stored_experience, ObservationType.Miss, self.episode_num) state = stored_experience.state action = stored_experience.agent_action new_state = state.copy() new_state.step_code = ObservationType.Miss.value self._reward_agent(state.to_numpy(), new_state.to_numpy(), action, reward) self._incomplete_experiences.delete(key) # New item to write into cache view and observe. ttl = info['ttl'] observation_time = time.time() self.view_of_the_cache[key] = { 'state': EvictionAgentSystemState(encoded_key=observed_key, ttl=ttl, hit_count=0, step_code=observation_type.value), 'observation_time': observation_time } elif observation_type == ObservationType.Hit: # Cache hit, update the hit record of this key in the cache stored_view = self.view_of_the_cache[key]['state'] stored_view.hit_count += 1 elif observation_type in self._end_episode_observation: if stored_experience: reward = self.converter.system_to_agent_reward( stored_experience, observation_type, self.episode_num) state = stored_experience.state action = stored_experience.agent_action new_state = state.copy() new_state.step_code = observation_type.value self._reward_agent(state.to_numpy(), new_state.to_numpy(), action, reward) self._incomplete_experiences.delete(key) if key in self.view_of_the_cache: del self.view_of_the_cache[key] self.observation_seen += 1 if self.observation_seen % self.checkpoint_steps == 0: self.logger.info( f'Observation seen so far: {self.observation_seen}, reward so far: {self.episode_reward}' ) def _observe_expired_incomplete_experience( self, key: str, observation_type: ObservationType, info: Dict[str, any]): """Observe decisions taken that hasn't been observed by main cache. e.g. don't cache -> ttl up -> no miss""" assert observation_type == ObservationType.Expiration self.observation_logger.info( f'{self.episode_num},{key},{observation_type}') experience = info[ 'value'] # type: EvictionAgentIncompleteExperienceEntry reward = self.converter.system_to_agent_reward(experience, observation_type, self.episode_num) starting_state = experience.starting_state action = experience.agent_action new_state = experience.state.copy() new_state.step_code = observation_type.value self._reward_agent(starting_state.to_numpy(), new_state.to_numpy(), action, reward) def _reward_agent(self, state: np.ndarray, new_state: np.ndarray, agent_action: np.ndarray, reward: int): self.agent.observe(preprocessed_states=state, actions=agent_action, internals=[], rewards=reward, next_states=new_state, terminals=False) self.reward_logger.info(f'{self.episode_num},{reward}') loss = self.agent.update() if loss is not None: self.loss_logger.info(f'{self.episode_num},{loss[0]}') def close(self): super().close() self._incomplete_experiences.clear() self.agent.reset()