class PrioritizedExperienceLearner(learner): def __init__(self, skip_frame, num_actions, load=None): super().__init__() rand_vals = (1, 0.1, 10000 / skip_frame) # starting at 1 anneal eGreedy policy to 0.1 over 1,000,000/skip_frame self.action_handler = ActionHandler(ActionPolicy.eGreedy, rand_vals) self.exp_handler = PrioritizedExperienceHandler(1000000 / skip_frame) self.train_handler = TrainHandler(32, num_actions) self.cnn = CNN((None, skip_frame, 86, 80), num_actions, 0.1) self.discount = 0.99 if load is not None: self.cnn.load(load) def frames_processed(self, frames, action_performed, reward): self.exp_handler.add_experience(frames, self.action_handler.game_action_to_action_ind(action_performed), reward) self.train_handler.train_prioritized(self.exp_handler, 0.99, self.cnn) self.action_handler.anneal() def plot_tree(self): self.exp_handler.tree.plot() def get_action(self, game_input): return self.cnn.get_output(game_input)[0] def game_over(self): self.exp_handler.trim() # trim experience replay of learner self.exp_handler.add_terminal() # adds a terminal def get_game_action(self, game_input): return self.action_handler.action_vect_to_game_action(self.get_action(game_input)) def set_legal_actions(self, legal_actions): self.action_handler.set_legal_actions(legal_actions) def save(self, file): self.cnn.save(file) def get_cost_list(self): return self.train_handler.costList
def test_trim(p_exp_handler: PrioritizedExperienceHandler): old_state_size = len(p_exp_handler.states) p_exp_handler.trim() assert p_exp_handler.size < old_state_size assert len(p_exp_handler.states) < old_state_size assert p_exp_handler.size == p_exp_handler.max_len assert len(p_exp_handler.states) == p_exp_handler.max_len assert len(p_exp_handler.states) == p_exp_handler.size assert list(p_exp_handler.term_states)[0] == p_exp_handler.size - 1 # check tree extra_vals have been updated assert p_exp_handler.tree.get_size() == p_exp_handler.size # 0 should still be there because it will still be as inf assert p_exp_handler.tree.root.value == np.inf assert p_exp_handler.tree.root.extra_vals == 0 # 2 should be the only one left and it's ind should now be 1 assert p_exp_handler.tree.root.left.value == 2 assert p_exp_handler.tree.root.left.extra_vals == 1 # check term_states is still a set assert isinstance(p_exp_handler.term_states, set) # Test that terminal states that have been deleted are removed # add new states and trim for add in range(1): state = np.ones((2, 10, 10)) * add action = add reward = add p_exp_handler.add_experience(state, action, reward) p_exp_handler.trim() # the left value should be deleted assert p_exp_handler.tree.root.left is None # right value should be 1 because we deleted state 1 (and it used to be 2) assert p_exp_handler.tree.root.right.extra_vals == 1 assert len(list(p_exp_handler.term_states)) == 0