Пример #1
0
class FIFOEvictionStrategy(EvictionStrategy):
    def __init__(self, config: Dict[str, any], result_dir: str,
                 cache_stats: CacheInformation):
        super().__init__(config, result_dir, cache_stats)
        self.fifo = OrderedDict()
        self.logger = logging.getLogger(__name__)
        self.renewable_ops = {ObservationType.Hit, ObservationType.Write}
        name = 'fifo_eviction_strategy'
        self.performance_logger = create_file_logger(
            name=f'{name}_performance_logger', result_dir=result_dir)

        self._incomplete_experiences = TTLCache(InMemoryStorage())
        self._incomplete_experiences.expired_entry_callback(
            self._observe_expired_incomplete_experience)

    def observe(self, key: str, observation_type: ObservationType,
                info: Dict[str, any]):
        if observation_type == ObservationType.Write:
            ttl = info['ttl']
            observation_time = time.time()
            self.fifo[key] = {'ttl': ttl, 'observation_time': observation_time}
        elif observation_type in {
                ObservationType.Expiration, ObservationType.Invalidate
        }:
            self.logger.debug(f"Key {key} expired")
            if key in self.fifo:
                del self.fifo[key]

        action_taken = self._incomplete_experiences.get(key)
        if action_taken is not None:
            if observation_type == ObservationType.Invalidate:
                # eviction followed by invalidation.
                self.performance_logger.info(f'{self.episode_num},TrueEvict')
            elif observation_type == ObservationType.Miss:
                self.performance_logger.info(f'{self.episode_num},FalseEvict')
                # Miss after making an eviction decision
            self._incomplete_experiences.delete(key)

    def _observe_expired_incomplete_experience(
            self, key: str, observation_type: ObservationType,
            info: Dict[str, any]):
        self.performance_logger.info(f'{self.episode_num},TrueEvict')

    def trim_cache(self, cache: TTLCache) -> List[str]:
        while True:
            # TTLCache might expire and cause a race condition

            eviction_item = self.fifo.popitem(last=False)
            eviction_key = eviction_item[0]
            eviction_value = eviction_item[1]
            if cache.contains(eviction_key):
                decision_time = time.time()
                ttl_left = (eviction_value['observation_time'] +
                            eviction_value['ttl']) - decision_time
                self._incomplete_experiences.set(eviction_key, 'evict',
                                                 ttl_left)
                cache.delete(eviction_key)
                return [eviction_key]
Пример #2
0
class RLTtlStrategy(TtlStrategy):
    """RL driven TTL estimation strategy."""
    def __init__(self, config: Dict[str, any], result_dir: str,
                 cache_stats: CacheInformation):
        super().__init__(config, result_dir, cache_stats)
        self.observation_seen = 0
        self.cum_reward = 0
        self.checkpoint_steps = config['checkpoint_steps']

        self._incomplete_experiences = TTLCache(InMemoryStorage())
        self._incomplete_experiences.expired_entry_callback(
            self._observe_expiry_eviction)
        self.non_terminal_observations = {
            ObservationType.EvictionPolicy, ObservationType.Expiration
        }
        agent_config = config['agent_config']
        self.maximum_ttl = config['max_ttl']
        self.experimental_reward = config.get('experimental_reward', False)
        fields_in_state = len(TTLAgentSystemState.__slots__)
        self.agent = Agent.from_spec(
            agent_config,
            state_space=FloatBox(shape=(fields_in_state, )),
            action_space=FloatBox(low=0, high=self.maximum_ttl, shape=(1, )))

        # TODO refactor into common RL interface for all strategies
        self.logger = logging.getLogger(__name__)
        name = 'rl_ttl_strategy'
        self.reward_logger = create_file_logger(name=f'{name}_reward_logger',
                                                result_dir=self.result_dir)
        self.loss_logger = create_file_logger(name=f'{name}_loss_logger',
                                              result_dir=self.result_dir)
        self.ttl_logger = create_file_logger(name=f'{name}_ttl_logger',
                                             result_dir=self.result_dir)
        self.observation_logger = create_file_logger(
            name=f'{name}_observation_logger', result_dir=self.result_dir)
        self.key_vocab = Vocabulary()
        self.errors = create_file_logger(name=f'{name}_error_logger',
                                         result_dir=self.result_dir)

    def estimate_ttl(self, key: str, values: Dict[str, any],
                     operation_type: OperationType) -> float:
        observation_time = time.time()
        encoded_key = self.key_vocab.add_or_get_id(key)
        cache_utility = self.cache_stats.cache_utility

        state = TTLAgentSystemState(encoded_key=encoded_key,
                                    hit_count=0,
                                    step_code=0,
                                    cache_utility=cache_utility,
                                    operation_type=operation_type.value)

        state_as_numpy = state.to_numpy()
        agent_action = self.agent.get_action(state_as_numpy)
        action = agent_action.item()

        incomplete_experience = TTLAgentObservedExperience(
            state=state,
            agent_action=agent_action,
            starting_state=state.copy(),
            observation_time=observation_time)
        self._incomplete_experiences.set(key, incomplete_experience,
                                         self.maximum_ttl)

        return action

    def observe(self, key: str, observation_type: ObservationType,
                info: Dict[str, any]):
        observed_experience = self._incomplete_experiences.get(key)

        if observed_experience is None or observation_type == ObservationType.Expiration:
            return  # haven't had to make a decision on it

        current_time = time.time()
        stored_state = observed_experience.state
        if observation_type == ObservationType.Hit:
            stored_state.hit_count += 1
        # elif observation_type in self.non_terminal_observations:
        #     # it was evicted by another policy don't attempt to learn stuff from this
        #     pass

        estimated_ttl = observed_experience.agent_action.item()
        first_observation_time = observed_experience.observation_time
        real_ttl = current_time - first_observation_time
        stored_state.step_code = observation_type.value
        stored_state.cache_utility = self.cache_stats.cache_utility
        self.reward_agent(observation_type, observed_experience, real_ttl)

        if observation_type != ObservationType.Hit:
            self.ttl_logger.info(
                f'{self.episode_num},{observation_type.name},{key},{estimated_ttl},{real_ttl},{stored_state.hit_count}'
            )
            self._incomplete_experiences.delete(key)

        self.observation_seen += 1
        if self.observation_seen % self.checkpoint_steps == 0:
            self.logger.info(
                f'Observation seen so far: {self.observation_seen}, reward so far: {self.cum_reward}'
            )
        if observation_type not in self.non_terminal_observations:
            self.observation_logger.info(f'{key},{observation_type}')

    def _observe_expiry_eviction(self, key: str,
                                 observation_type: ObservationType,
                                 info: Dict[str, any]):
        """Observe decisions taken that hasn't been observed by main cache. e.g. don't cache -> ttl up -> no miss"""
        self.observation_logger.info(
            f'{self.episode_num},{key},{observation_type}')
        experience = info['value']  # type: TTLAgentObservedExperience
        self.ttl_logger.info(
            f'{self.episode_num},{observation_type.name},{key},{experience.agent_action.item()},'
            f'{experience.agent_action.item()},{experience.state.hit_count}')
        experience.state.step_code = observation_type.value

        self.reward_agent(observation_type, experience,
                          experience.agent_action.item())

    def reward_agent(self, observation_type: ObservationType,
                     experience: TTLAgentObservedExperience,
                     real_ttl: time) -> int:
        # reward more utilisation of the cache capacity given more hits
        final_state = experience.state

        difference_in_ttl = -abs((experience.agent_action.item() + 1) /
                                 max(min(real_ttl, self.maximum_ttl), 1))
        # reward = final_state.hit_count - abs(difference_in_ttl * self.cache_stats.cache_utility)

        if observation_type == ObservationType.Hit:
            reward = 1
            terminal = False
        # elif observation_type == ObservationType.EvictionPolicy:
        #     reward = 0
        #     terminal = True
        else:
            reward = difference_in_ttl
            if abs(difference_in_ttl) < 10:
                reward = 10
            terminal = True
            self.logger.debug(
                f'Hits: {final_state.hit_count}, ttl diff: {difference_in_ttl}, Reward: {reward}'
            )

        self.agent.observe(
            preprocessed_states=experience.starting_state.to_numpy(),
            actions=experience.agent_action,
            internals=[],
            rewards=reward,
            next_states=final_state.to_numpy(),
            terminals=terminal)

        self.cum_reward += reward
        self.reward_logger.info(f'{self.episode_num},{reward}')
        loss = self.agent.update()
        if loss is not None:
            self.loss_logger.info(f'{self.episode_num},{loss[0]}')

        return reward

    def close(self):
        super().close()
        for (k, v) in list(self._incomplete_experiences.items()):
            self.ttl_logger.info(
                f'{self.episode_num},{ObservationType.EndOfEpisode.name},{k},{v.agent_action.item()},'
                f'{v.agent_action.item()},{v.state.hit_count}')

        self._incomplete_experiences.clear()
        try:
            self.agent.reset()
        except Exception as e:
            self.errors.info(e)
Пример #3
0
class RLMultiTasksStrategy(BaseStrategy):
    """RL driven multi task strategy - Caching, eviction, and ttl estimation."""
    def __init__(self, config: Dict[str, any], result_dir: str,
                 cache_stats: CacheInformation):
        super().__init__(config, result_dir, cache_stats)
        self.supported_observations = {
            ObservationType.Hit, ObservationType.Miss,
            ObservationType.Invalidate
        }

        # evaluation specific variables
        self.observation_seen = 0
        self.cum_reward = 0
        self.checkpoint_steps = config['checkpoint_steps']

        self._incomplete_experiences = TTLCache(InMemoryStorage())
        self._incomplete_experiences.expired_entry_callback(
            self._observe_expiry_eviction)
        self.non_terminal_observations = {
            ObservationType.EvictionPolicy, ObservationType.Expiration
        }

        agent_config = config['agent_config']
        self.maximum_ttl = config['max_ttl']

        fields_in_state = len(MultiTaskAgentSystemState.__slots__)

        action_space = RLDict({
            'ttl': IntBox(low=0, high=self.maximum_ttl),
            'eviction': IntBox(low=0, high=2)
        })

        self.agent = Agent.from_spec(
            agent_config,
            state_space=FloatBox(shape=(fields_in_state, )),
            action_space=action_space)

        # TODO refactor into common RL interface for all strategies
        self.logger = logging.getLogger(__name__)
        name = 'rl_multi_strategy'
        self.reward_logger = create_file_logger(name=f'{name}_reward_logger',
                                                result_dir=self.result_dir)
        self.loss_logger = create_file_logger(name=f'{name}_loss_logger',
                                              result_dir=self.result_dir)
        self.ttl_logger = create_file_logger(name=f'{name}_ttl_logger',
                                             result_dir=self.result_dir)
        self.observation_logger = create_file_logger(
            name=f'{name}_observation_logger', result_dir=self.result_dir)
        self.performance_logger = create_file_logger(
            name=f'{name}_performance_logger', result_dir=self.result_dir)
        self.key_vocab = Vocabulary()

    def observe(self, key: str, observation_type: ObservationType,
                info: Dict[str, any]):
        observed_experience = self._incomplete_experiences.get(key)

        if observed_experience is None:
            return  # haven't had to make a decision on it

        current_time = time.time()
        stored_state = observed_experience.state

        stored_state.step_code = observation_type.value
        stored_state.cache_utility = self.cache_stats.cache_utility

        if observation_type == ObservationType.Hit:
            stored_state.hit_count += 1
        else:
            # Include eviction, invalidation, and miss
            estimated_ttl = observed_experience.agent_action['ttl'].item()
            first_observation_time = observed_experience.observation_time
            real_ttl = current_time - first_observation_time
            # log the difference between the estimated ttl and real ttl
            self.ttl_logger.info(
                f'{self.episode_num},{observation_type.name},{key},{estimated_ttl},{real_ttl},{stored_state.hit_count}'
            )
            self._incomplete_experiences.delete(key)

        self.reward_agent(observation_type, observed_experience)

        self.observation_seen += 1
        if self.observation_seen % self.checkpoint_steps == 0:
            self.logger.info(
                f'Observation seen so far: {self.observation_seen}, reward so far: {self.cum_reward}'
            )
        if observation_type not in self.non_terminal_observations:
            self.observation_logger.info(f'{key},{observation_type}')

    def trim_cache(self, cache: TTLCache):
        # trim cache isn't called often so the operation is ok to be expensive
        # produce an action on the whole cache
        keys_to_evict = []

        for (key,
             stored_experience) in list(self._incomplete_experiences.items()):
            action = self.agent.get_action(
                stored_experience.state.to_numpy())['eviction']
            evict = (action.flatten() == 1).item()
            if evict:
                cache.delete(key)
                keys_to_evict.append(key)
            # update stored value for eviction action
            stored_experience.agent_action['eviction'] = action
            stored_experience.manual_eviction = True

        if len(keys_to_evict) == 0:
            self.logger.error('trim_cache No keys were evicted.')

        return keys_to_evict

    def should_cache(self, key: str, values: Dict[str, str], ttl: int,
                     operation_type: OperationType) -> bool:
        # cache objects that have TTL more than 1 second (maybe make this configurable?)
        return ttl > 10

    def estimate_ttl(self, key: str, values: Dict[str, any],
                     operation_type: OperationType) -> float:
        # TODO check if it is in the observed queue

        observation_time = time.time()
        encoded_key = self.key_vocab.add_or_get_id(key)
        cache_utility = self.cache_stats.cache_utility

        state = MultiTaskAgentSystemState(encoded_key=encoded_key,
                                          hit_count=0,
                                          ttl=0,
                                          step_code=0,
                                          cache_utility=cache_utility,
                                          operation_type=operation_type.value)

        state_as_numpy = state.to_numpy()
        agent_action = self.agent.get_action(state_as_numpy)
        action = agent_action['ttl'].item()

        incomplete_experience = MultiTaskAgentObservedExperience(
            state=state,
            agent_action=agent_action,
            starting_state=state.copy(),
            observation_time=observation_time,
            manual_eviction=False)
        self._incomplete_experiences.set(key, incomplete_experience,
                                         self.maximum_ttl)

        return action

    def reward_agent(self, observation_type: ObservationType,
                     experience: MultiTaskAgentObservedExperience) -> int:
        # reward more utilisation of the cache capacity given more hits
        final_state = experience.state

        # difference_in_ttl = -abs((experience.agent_action.item() + 1) / min(real_ttl, self.maximum_ttl))
        reward = 0
        terminal = False

        if observation_type == observation_type.Invalidate and (
                experience.agent_action['ttl'] < 10 or
            (experience.agent_action['eviction'].flatten() == 1).item()):
            # if evicted or not cached, followed by an invalidate
            reward = 10
            terminal = True
        elif observation_type == observation_type.Invalidate or observation_type == observation_type.Miss:
            reward = -10
            terminal = True
        elif observation_type == ObservationType.Hit:
            terminal = False
            reward = 1

        if experience.manual_eviction:
            if observation_type == observation_type.Expiration:
                reward = -10
                terminal = True
            if observation_type == observation_type.Hit:
                reward = 2

            self.performance_metric_for_eviction(experience, observation_type)

        self.agent.observe(
            preprocessed_states=experience.starting_state.to_numpy(),
            actions=experience.agent_action,
            internals=[],
            rewards=terminal,
            next_states=final_state.to_numpy(),
            terminals=True)

        self.cum_reward += reward
        self.reward_logger.info(f'{self.episode_num},{reward}')
        # TODO use self.agent.update_schedule to decide when to call update
        loss = self.agent.update()
        if loss is not None:
            self.loss_logger.info(f'{self.episode_num},{loss[0]}')

        return reward

    def _observe_expiry_eviction(self, key: str,
                                 observation_type: ObservationType,
                                 info: Dict[str, any]):
        """Observe decisions taken that hasn't been observed by main cache. e.g. don't cache -> ttl up -> no miss"""
        self.observation_logger.info(
            f'{self.episode_num},{key},{observation_type}')
        experience = info['value']  # type: MultiTaskAgentObservedExperience
        self.ttl_logger.info(
            f'{self.episode_num},{observation_type.name},{key},{experience.agent_action["ttl"].item()},'
            f'{experience.agent_action["ttl"].item()},{experience.state.hit_count}'
        )
        experience.state.step_code = observation_type.value

        self.reward_agent(observation_type, experience)

    def performance_metric_for_eviction(
            self, stored_experience: MultiTaskAgentObservedExperience,
            observation_type: ObservationType) -> int:
        should_evict = (
            stored_experience.agent_action['eviction'].flatten() == 1).item()

        if observation_type == ObservationType.Expiration:
            if should_evict:
                # reward if should evict didn't observe any follow up miss
                self.performance_logger.info(f'{self.episode_num},TrueEvict')
            # else didn't evict
            else:
                # reward for not evicting a key that received more hits.
                # or 0 if it didn't evict but also didn't get any hits
                gain_for_not_evicting = stored_experience.state.hit_count - stored_experience.starting_state.hit_count
                if gain_for_not_evicting > 0:
                    self.performance_logger.info(
                        f'{self.episode_num},TrueMiss')
                else:
                    self.performance_logger.info(
                        f'{self.episode_num},MissEvict')

                return gain_for_not_evicting

        elif observation_type == ObservationType.Invalidate:
            # Set/Delete, remove entry from the cache.
            # reward an eviction followed by invalidation.
            if should_evict:
                self.performance_logger.info(f'{self.episode_num},TrueEvict')
            else:
                # punish not evicting a key that got invalidated after.
                self.performance_logger.info(f'{self.episode_num},MissEvict')

        elif observation_type == ObservationType.Miss:
            if should_evict:
                self.performance_logger.info(f'{self.episode_num},FalseEvict')
            # Miss after making an eviction decision
            # Punish, a read after an eviction decision

    def close(self):
        for (k, v) in list(self._incomplete_experiences.items()):
            self.ttl_logger.info(
                f'{self.episode_num},{ObservationType.EndOfEpisode.name},{k},{v.agent_action["ttl"].item()},'
                f'{v.agent_action["ttl"].item()},{v.state.hit_count}')
            self.performance_logger.info(f'{self.episode_num},TrueMiss')
        super().close()
        self._incomplete_experiences.clear()
        try:
            self.agent.reset()
        except Exception as e:
            pass
Пример #4
0
class RLCachingStrategy(CachingStrategy):

    def __init__(self, config: Dict[str, any], result_dir: str, cache_stats: CacheInformation):
        super().__init__(config, result_dir, cache_stats)
        # evaluation specific variables
        self.observation_seen = 0
        self.episode_reward = 0
        self.checkpoint_steps = config['checkpoint_steps']

        self._incomplete_experiences = TTLCache(InMemoryStorage())
        self._incomplete_experiences.expired_entry_callback(self._observe_expired_incomplete_experience)

        self.experimental_reward = config.get('experimental_reward', False)
        agent_config = config['agent_config']
        self.converter = CachingStrategyRLConverter()
        # action space: should cache: true or false
        # state space: [capacity (1), query key(1), query result set(num_indexes)]
        fields_in_state = len(CachingAgentSystemState.__slots__)
        self.agent = Agent.from_spec(agent_config,
                                     state_space=FloatBox(shape=(fields_in_state,)),
                                     action_space=IntBox(2))

        self.logger = logging.getLogger(__name__)
        name = 'rl_caching_strategy'
        self.reward_logger = create_file_logger(name=f'{name}_reward_logger', result_dir=self.result_dir)
        self.loss_logger = create_file_logger(name=f'{name}_loss_logger', result_dir=self.result_dir)
        self.observation_logger = create_file_logger(name=f'{name}_observation_logger', result_dir=self.result_dir)
        self.entry_hits_logger = create_file_logger(name=f'{name}_entry_hits_logger', result_dir=self.result_dir)

        self.key_vocab = Vocabulary()

    def should_cache(self, key: str, values: Dict[str, str], ttl: int, operation_type: OperationType) -> bool:
        # TODO what about the case of a cache key that exist already in the incomplete exp?
        assert self._incomplete_experiences.get(key) is None, \
            "should_cache is assumed to be first call and key shouldn't be in the cache"
        observation_time = time.time()

        encoded_key = self.key_vocab.add_or_get_id(key)
        state = CachingAgentSystemState(encoded_key=encoded_key,
                                        ttl=ttl,
                                        hit_count=0,
                                        step_code=0,
                                        operation_type=operation_type.value)

        agent_action = self.agent.get_action(state.to_numpy())
        incomplete_experience_entry = CachingAgentIncompleteExperienceEntry(state=state,
                                                                            agent_action=agent_action,
                                                                            starting_state=state.copy(),
                                                                            observation_time=observation_time)

        action = self.converter.agent_to_system_action(agent_action)
        self._incomplete_experiences.set(key, incomplete_experience_entry, ttl)

        return action

    def observe(self, key: str, observation_type: ObservationType, info: Dict[str, any]):
        # TODO include stats/capacity information in the info dict
        experience = self._incomplete_experiences.get(key)  # type: CachingAgentIncompleteExperienceEntry
        if experience is None:
            return  # if I haven't had to make a decision on this, ignore it.

        self.observation_logger.info(f'{self.episode_num},{key},{observation_type.name}')
        if observation_type == ObservationType.Hit:
            experience.state.hit_count += 1

        else:
            self._reward_experience(key, experience, observation_type)

        self.observation_seen += 1
        if self.observation_seen % self.checkpoint_steps == 0:
            self.logger.info(f'Observation seen so far: {self.observation_seen}, reward so far: {self.episode_reward}')

    def _observe_expired_incomplete_experience(self, key: str, observation_type: ObservationType, info: Dict[str, any]):
        """Observe decisions taken that hasn't been observed by main cache. e.g. don't cache -> ttl up -> no miss"""
        assert observation_type == ObservationType.Expiration
        experience = info['value']
        self._reward_experience(key, experience, observation_type)

    def _reward_experience(self,
                           key: str,
                           experience: CachingAgentIncompleteExperienceEntry,
                           observation_type: ObservationType):
        state = experience.state
        state.step_code = observation_type.value

        self._incomplete_experiences.delete(key)

        self.entry_hits_logger.info(f'{self.episode_num},{key},{experience.state.hit_count}')
        reward = self.converter.system_to_agent_reward(experience)
        if self.experimental_reward:
            # TODO add cache utility to state and reward
            pass

        self.agent.observe(preprocessed_states=experience.starting_state.to_numpy(),
                           actions=experience.agent_action,
                           internals=[],
                           rewards=reward,
                           next_states=experience.state.to_numpy(),
                           terminals=False)

        self.episode_reward += reward
        self.reward_logger.info(f'{self.episode_num},{reward}')
        self.logger.debug(f'Key: {key} is in terminal state because: {str(observation_type)}')
        loss = self.agent.update()
        if loss is not None:
            self.loss_logger.info(f'{self.episode_num},{loss[0]}')

    def close(self):
        super().close()
        self.agent.reset()
        self._incomplete_experiences.clear()
Пример #5
0
class LRUEvictionStrategy(EvictionStrategy):
    def __init__(self, config: Dict[str, any], result_dir: str,
                 cache_stats: CacheInformation):
        super().__init__(config, result_dir, cache_stats)
        self.lru = OrderedDict()
        self.logger = logging.getLogger(__name__)
        name = 'lru_eviction_strategy'
        self.performance_logger = create_file_logger(
            name=f'{name}_performance_logger', result_dir=result_dir)

        self._incomplete_experiences = TTLCache(InMemoryStorage())
        self._incomplete_experiences.expired_entry_callback(
            self._observe_expired_incomplete_experience)

    def observe(self, key: str, observation_type: ObservationType,
                info: Dict[str, any]):
        try:
            stored_values = self.lru.pop(key)
        except KeyError:
            self.logger.debug(
                f"Key: {key} not in LRU monitor. Current LRU size: {len(self.lru)}"
            )
            stored_values = None
            pass  # item not observed in cache before.
        # add/refresh lru if hit
        if observation_type == ObservationType.Write:
            ttl = info['ttl']
            observation_time = time.time()
            self.lru[key] = {'ttl': ttl, 'observation_time': observation_time}

        elif observation_type == ObservationType.Hit:
            if stored_values is not None:
                observation_time = time.time()
                self.lru[key] = {
                    'ttl': stored_values['ttl'],
                    'observation_time': observation_time
                }

        elif observation_type in {
                ObservationType.Expiration, ObservationType.Invalidate
        }:
            self.logger.debug(f"Key {key} expired")
            assert key not in self.lru, "Expired key should have been deleted."

        action_taken = self._incomplete_experiences.get(key)
        if action_taken is not None:
            if observation_type == ObservationType.Invalidate:
                # eviction followed by invalidation.
                self.performance_logger.info(f'{self.episode_num},TrueEvict')
            elif observation_type == ObservationType.Miss:
                self.performance_logger.info(f'{self.episode_num},FalseEvict')
                # Miss after making an eviction decision
            self._incomplete_experiences.delete(key)

    def _observe_expired_incomplete_experience(
            self, key: str, observation_type: ObservationType,
            info: Dict[str, any]):
        self.performance_logger.info(f'{self.episode_num},TrueEvict')

    def trim_cache(self, cache: TTLCache) -> List[str]:
        while True:
            eviction_item = self.lru.popitem(last=False)
            eviction_key = eviction_item[0]
            eviction_value = eviction_item[1]

            if cache.contains(eviction_key):
                # TTLCache might expire and cause a race condition
                decision_time = time.time()
                ttl_left = (eviction_value['observation_time'] +
                            eviction_value['ttl']) - decision_time
                self._incomplete_experiences.set(eviction_key, 'evict',
                                                 ttl_left)
                cache.delete(eviction_key)
                return [eviction_key]
Пример #6
0
class RLEvictionStrategy(EvictionStrategy):
    def __init__(self, config: Dict[str, any], result_dir: str,
                 cache_stats: CacheInformation):
        super().__init__(config, result_dir, cache_stats)
        # evaluation specific variables
        self.observation_seen = 0
        self.episode_reward = 0
        self.checkpoint_steps = config['checkpoint_steps']

        self._incomplete_experiences = TTLCache(InMemoryStorage())
        self._incomplete_experiences.expired_entry_callback(
            self._observe_expired_incomplete_experience)
        self.view_of_the_cache = {}  # type: Dict[str, Dict[str, any]]
        self._end_episode_observation = {
            ObservationType.Invalidate, ObservationType.Miss,
            ObservationType.Expiration
        }

        # TODO refactor into common RL interface for all strategies
        # Agent configuration (can be shared with others)
        agent_config = config['agent_config']
        fields_in_state = len(EvictionAgentSystemState.__slots__)
        self.converter = EvictionStrategyRLConverter(self.result_dir)

        # State: fields to observe in question
        # Action: to evict or not that key
        self.agent = Agent.from_spec(
            agent_config,
            state_space=FloatBox(shape=(fields_in_state, )),
            action_space=IntBox(low=0, high=2))

        self.logger = logging.getLogger(__name__)
        name = 'rl_eviction_strategy'
        self.reward_logger = create_file_logger(name=f'{name}_reward_logger',
                                                result_dir=self.result_dir)
        self.loss_logger = create_file_logger(name=f'{name}_loss_logger',
                                              result_dir=self.result_dir)
        self.observation_logger = create_file_logger(
            name=f'{name}_observation_logger', result_dir=self.result_dir)
        self.key_vocab = Vocabulary()

    def trim_cache(self, cache: TTLCache) -> List[str]:
        # trim cache isn't called often so the operation is ok to be expensive
        # produce an action on the whole cache
        keys_to_evict = []
        keys_to_not_evict = []

        for (key, cached_key) in list(self.view_of_the_cache.items()):
            agent_system_state = cached_key['state']

            agent_action = self.agent.get_action(agent_system_state.to_numpy())
            should_evict = self.converter.agent_to_system_action(agent_action)

            decision_time = time.time()
            incomplete_experience = EvictionAgentIncompleteExperienceEntry(
                agent_system_state, agent_action, agent_system_state.copy(),
                decision_time)

            # observe the key for only the ttl period that is left for this key
            ttl_left = (cached_key['observation_time'] +
                        agent_system_state.ttl) - decision_time
            self._incomplete_experiences.set(key=key,
                                             values=incomplete_experience,
                                             ttl=ttl_left)
            if should_evict:
                del self.view_of_the_cache[key]
                keys_to_evict.append(key)
                if not cache.contains(key, clean_expire=False):
                    # race condition, clean up and move on
                    self._incomplete_experiences.delete(key)
                cache.delete(key)
            else:
                keys_to_not_evict.append(key)

        return keys_to_evict

    def observe(self, key: str, observation_type: ObservationType,
                info: Dict[str, any]):
        self.observation_logger.info(
            f'{self.episode_num},{key},{observation_type}')

        observed_key = self.converter.vocabulary.add_or_get_id(key)
        stored_experience = self._incomplete_experiences.get(key)
        if observation_type == ObservationType.Write:
            if stored_experience is not None:
                # race condition
                reward = self.converter.system_to_agent_reward(
                    stored_experience, ObservationType.Miss, self.episode_num)
                state = stored_experience.state
                action = stored_experience.agent_action
                new_state = state.copy()
                new_state.step_code = ObservationType.Miss.value

                self._reward_agent(state.to_numpy(), new_state.to_numpy(),
                                   action, reward)
                self._incomplete_experiences.delete(key)

            # New item to write into cache view and observe.
            ttl = info['ttl']
            observation_time = time.time()
            self.view_of_the_cache[key] = {
                'state':
                EvictionAgentSystemState(encoded_key=observed_key,
                                         ttl=ttl,
                                         hit_count=0,
                                         step_code=observation_type.value),
                'observation_time':
                observation_time
            }

        elif observation_type == ObservationType.Hit:
            # Cache hit, update the hit record of this key in the cache
            stored_view = self.view_of_the_cache[key]['state']
            stored_view.hit_count += 1

        elif observation_type in self._end_episode_observation:
            if stored_experience:
                reward = self.converter.system_to_agent_reward(
                    stored_experience, observation_type, self.episode_num)
                state = stored_experience.state
                action = stored_experience.agent_action
                new_state = state.copy()
                new_state.step_code = observation_type.value

                self._reward_agent(state.to_numpy(), new_state.to_numpy(),
                                   action, reward)
                self._incomplete_experiences.delete(key)

            if key in self.view_of_the_cache:
                del self.view_of_the_cache[key]

        self.observation_seen += 1
        if self.observation_seen % self.checkpoint_steps == 0:
            self.logger.info(
                f'Observation seen so far: {self.observation_seen}, reward so far: {self.episode_reward}'
            )

    def _observe_expired_incomplete_experience(
            self, key: str, observation_type: ObservationType,
            info: Dict[str, any]):
        """Observe decisions taken that hasn't been observed by main cache. e.g. don't cache -> ttl up -> no miss"""
        assert observation_type == ObservationType.Expiration
        self.observation_logger.info(
            f'{self.episode_num},{key},{observation_type}')

        experience = info[
            'value']  # type: EvictionAgentIncompleteExperienceEntry
        reward = self.converter.system_to_agent_reward(experience,
                                                       observation_type,
                                                       self.episode_num)
        starting_state = experience.starting_state
        action = experience.agent_action
        new_state = experience.state.copy()
        new_state.step_code = observation_type.value

        self._reward_agent(starting_state.to_numpy(), new_state.to_numpy(),
                           action, reward)

    def _reward_agent(self, state: np.ndarray, new_state: np.ndarray,
                      agent_action: np.ndarray, reward: int):
        self.agent.observe(preprocessed_states=state,
                           actions=agent_action,
                           internals=[],
                           rewards=reward,
                           next_states=new_state,
                           terminals=False)
        self.reward_logger.info(f'{self.episode_num},{reward}')

        loss = self.agent.update()
        if loss is not None:
            self.loss_logger.info(f'{self.episode_num},{loss[0]}')

    def close(self):
        super().close()
        self._incomplete_experiences.clear()
        self.agent.reset()