コード例 #1
0
    def __init__(self, train_env):
        """Constructs around the training trace.

    Args:
      train_env (CacheReplacementEnv): environment wrapping the training trace.
        The configs of the environment should be the same as those used at test
        time.
    """
        # set_id --> list of (state, ranked lines)
        self._train_accesses = collections.defaultdict(list)
        # Optimization: (set_id, address) --> list index in self._train_accesses
        # where self._train_accesses[set_id][index] = address
        self._address2index = collections.defaultdict(list)
        state = train_env.reset()
        with tqdm.tqdm(desc='"Training"') as pbar:
            while True:
                reuse_distances = {
                    line: train_env.next_access_time(line)
                    for line in state.cache_lines
                }

                # Ranked highest reuse distance to lowest reuse distance
                ranked_lines = sorted(state.cache_lines,
                                      key=lambda line: reuse_distances[line],
                                      reverse=True)
                self._address2index[(state.set_id, state.access)].append(
                    len(self._train_accesses[state.set_id]))
                self._train_accesses[state.set_id].append(
                    (state, ranked_lines))
                line_to_evict = ranked_lines[0]
                action = state.cache_lines.index(
                    line_to_evict) if state.evict else -1

                state, _, done, _ = train_env.step(action)
                pbar.update(1)
                if done:
                    break

        # set_id --> history of State
        self._test_access_history = collections.defaultdict(list)
        # Dict for memoizing: (set_id, index, index) --> matching suffix length
        self._suffix_cache = {}
        # Fall-back policy
        self._lru = policy.LRU()
コード例 #2
0
import s4lru
import tqdm

if __name__ == "__main__":
  parser = argparse.ArgumentParser()
  parser.add_argument("policy_type", help="type of replacement policy to use")
  args = parser.parse_args()

  trace_path = "traces/sample_trace.csv"
  config = cfg.Config.from_files_and_bindings(["spec_llc.json"], [])
  env = environment.CacheReplacementEnv(config, trace_path, 0)

  if args.policy_type == "belady":
    replacement_policy = belady.BeladyPolicy(env)
  elif args.policy_type == "lru":
    replacement_policy = policy.LRU()
  elif args.policy_type == "s4lru":
    replacement_policy = s4lru.S4LRU(config.get("associativity"))
  elif args.policy_type == "belady_nearest_neighbors":
    train_env = environment.CacheReplacementEnv(config, trace_path, 0)
    replacement_policy = belady.BeladyNearestNeighborsPolicy(train_env)
  elif args.policy_type == "random":
    replacement_policy = policy.RandomPolicy(np.random.RandomState(0))
  else:
    raise ValueError(f"Unsupported policy type: {args.policy_type}")

  state = env.reset()
  total_reward = 0
  steps = 0
  with tqdm.tqdm() as pbar:
    while True: