def __init__(self, train_env): """Constructs around the training trace. Args: train_env (CacheReplacementEnv): environment wrapping the training trace. The configs of the environment should be the same as those used at test time. """ # set_id --> list of (state, ranked lines) self._train_accesses = collections.defaultdict(list) # Optimization: (set_id, address) --> list index in self._train_accesses # where self._train_accesses[set_id][index] = address self._address2index = collections.defaultdict(list) state = train_env.reset() with tqdm.tqdm(desc='"Training"') as pbar: while True: reuse_distances = { line: train_env.next_access_time(line) for line in state.cache_lines } # Ranked highest reuse distance to lowest reuse distance ranked_lines = sorted(state.cache_lines, key=lambda line: reuse_distances[line], reverse=True) self._address2index[(state.set_id, state.access)].append( len(self._train_accesses[state.set_id])) self._train_accesses[state.set_id].append( (state, ranked_lines)) line_to_evict = ranked_lines[0] action = state.cache_lines.index( line_to_evict) if state.evict else -1 state, _, done, _ = train_env.step(action) pbar.update(1) if done: break # set_id --> history of State self._test_access_history = collections.defaultdict(list) # Dict for memoizing: (set_id, index, index) --> matching suffix length self._suffix_cache = {} # Fall-back policy self._lru = policy.LRU()
import s4lru import tqdm if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("policy_type", help="type of replacement policy to use") args = parser.parse_args() trace_path = "traces/sample_trace.csv" config = cfg.Config.from_files_and_bindings(["spec_llc.json"], []) env = environment.CacheReplacementEnv(config, trace_path, 0) if args.policy_type == "belady": replacement_policy = belady.BeladyPolicy(env) elif args.policy_type == "lru": replacement_policy = policy.LRU() elif args.policy_type == "s4lru": replacement_policy = s4lru.S4LRU(config.get("associativity")) elif args.policy_type == "belady_nearest_neighbors": train_env = environment.CacheReplacementEnv(config, trace_path, 0) replacement_policy = belady.BeladyNearestNeighborsPolicy(train_env) elif args.policy_type == "random": replacement_policy = policy.RandomPolicy(np.random.RandomState(0)) else: raise ValueError(f"Unsupported policy type: {args.policy_type}") state = env.reset() total_reward = 0 steps = 0 with tqdm.tqdm() as pbar: while True: