def test_get_prioritized_experience(p_exp_handler: PrioritizedExperienceHandler):
    assert len(list(p_exp_handler.term_states)) > 0  # assert there is a terminal state

    states, _, _, _, _, _ = p_exp_handler.get_prioritized_experience(999)
    # experience handler should return nothing when num_requested > size
    assert states is None

    states, actions, rewards, states_tp1, terminal, inds = p_exp_handler.get_prioritized_experience(2)
    assert states is not None
    # we should've got a terminal which would be the first element
    assert np.sum(terminal) == 1
    assert np.all(states_tp1[0] == np.zeros(states_tp1[0].shape))

    # states should be 2 then 1
    assert np.all(states[0] == 2*np.ones(states[0].shape))
    assert np.all(states[1] == np.ones(states[0].shape))

    # make sure we got correct actions/rewards
    assert np.sum(actions) == 3  # 2+1
    assert np.sum(rewards) == 3  # 2+1

    # make sure all state_tp1s are correct
    for ind in range(states.shape[0]):
        if not terminal[ind]:
            assert np.all(states[ind]+1 == states_tp1[ind])

    # reinsert
    for ind in inds:
        p_exp_handler.tree.insert(ind, ind)

    # right should be nothing because 0 has value inf
    assert p_exp_handler.tree.root.right is None
def test_add_experiennce(p_exp_handler: PrioritizedExperienceHandler):
    for add in range(3):
        state = np.ones((2, 10, 10)) * add
        action = add
        reward = add
        p_exp_handler.add_experience(state, action, reward)

    assert p_exp_handler.tree.root is not None
    assert p_exp_handler.tree.root.value == np.inf
    assert p_exp_handler.tree.root.extra_vals == 0
    def __init__(self, skip_frame, num_actions, load=None):
        super().__init__()

        rand_vals = (1, 0.1, 10000 / skip_frame)  # starting at 1 anneal eGreedy policy to 0.1 over 1,000,000/skip_frame
        self.action_handler = ActionHandler(ActionPolicy.eGreedy, rand_vals)

        self.exp_handler = PrioritizedExperienceHandler(1000000 / skip_frame)
        self.train_handler = TrainHandler(32, num_actions)
        self.cnn = CNN((None, skip_frame, 86, 80), num_actions, 0.1)

        self.discount = 0.99

        if load is not None:
            self.cnn.load(load)
    def train_prioritized(self, exp_handler: PrioritizedExperienceHandler, discount, cnn):
        # generate minibatch data
        states, actions, rewards, state_tp1s, terminal, mb_inds_popped = exp_handler.get_prioritized_experience(self.mini_batch)

        if states is not None:
            r_tp1 = cnn.get_output(state_tp1s)
            max_tp1 = np.max(r_tp1, axis=1)
            rewards += (1-terminal) * discount * max_tp1

            rewVals = np.zeros((self.mini_batch, self.num_actions), dtype=self.dtype)
            arange = np.arange(self.mini_batch)
            rewVals[arange, actions] = rewards

            mask = np.zeros((self.mini_batch, self.num_actions), dtype=self.dtype)
            nonZero = np.where(rewVals != 0)
            mask[nonZero[0], nonZero[1]] = 1
            cost, output_states = cnn.train(states, rewVals, mask)
            self.costList.append(cost)

            # update prioritized exp handler with new td_error
            max_states = np.max(output_states, axis=1)
            td_errors = np.abs(max_tp1 - max_states)

            exp_handler.set_new_td_errors(td_errors, mb_inds_popped)
class PrioritizedExperienceLearner(learner):
    def __init__(self, skip_frame, num_actions, load=None):
        super().__init__()

        rand_vals = (1, 0.1, 10000 / skip_frame)  # starting at 1 anneal eGreedy policy to 0.1 over 1,000,000/skip_frame
        self.action_handler = ActionHandler(ActionPolicy.eGreedy, rand_vals)

        self.exp_handler = PrioritizedExperienceHandler(1000000 / skip_frame)
        self.train_handler = TrainHandler(32, num_actions)
        self.cnn = CNN((None, skip_frame, 86, 80), num_actions, 0.1)

        self.discount = 0.99

        if load is not None:
            self.cnn.load(load)

    def frames_processed(self, frames, action_performed, reward):
        self.exp_handler.add_experience(frames, self.action_handler.game_action_to_action_ind(action_performed), reward)
        self.train_handler.train_prioritized(self.exp_handler, 0.99, self.cnn)
        self.action_handler.anneal()

    def plot_tree(self):
        self.exp_handler.tree.plot()

    def get_action(self, game_input):
        return self.cnn.get_output(game_input)[0]

    def game_over(self):
        self.exp_handler.trim()  # trim experience replay of learner
        self.exp_handler.add_terminal()  # adds a terminal

    def get_game_action(self, game_input):
        return self.action_handler.action_vect_to_game_action(self.get_action(game_input))

    def set_legal_actions(self, legal_actions):
        self.action_handler.set_legal_actions(legal_actions)

    def save(self, file):
        self.cnn.save(file)

    def get_cost_list(self):
        return self.train_handler.costList
def test_trim(p_exp_handler: PrioritizedExperienceHandler):
    old_state_size = len(p_exp_handler.states)
    p_exp_handler.trim()

    assert p_exp_handler.size < old_state_size
    assert len(p_exp_handler.states) < old_state_size
    assert p_exp_handler.size == p_exp_handler.max_len
    assert len(p_exp_handler.states) == p_exp_handler.max_len
    assert len(p_exp_handler.states) == p_exp_handler.size
    assert list(p_exp_handler.term_states)[0] == p_exp_handler.size - 1

    # check tree extra_vals have been updated
    assert p_exp_handler.tree.get_size() == p_exp_handler.size
    # 0 should still be there because it will still be as inf
    assert p_exp_handler.tree.root.value == np.inf
    assert p_exp_handler.tree.root.extra_vals == 0
    # 2 should be the only one left and it's ind should now be 1
    assert p_exp_handler.tree.root.left.value == 2
    assert p_exp_handler.tree.root.left.extra_vals == 1

    # check term_states is still a set
    assert isinstance(p_exp_handler.term_states, set)

    # Test that terminal states that have been deleted are removed
    # add new states and trim
    for add in range(1):
        state = np.ones((2, 10, 10)) * add
        action = add
        reward = add
        p_exp_handler.add_experience(state, action, reward)
    p_exp_handler.trim()

    # the left value should be deleted
    assert p_exp_handler.tree.root.left is None
    # right value should be 1 because we deleted state 1 (and it used to be 2)
    assert p_exp_handler.tree.root.right.extra_vals == 1
    assert len(list(p_exp_handler.term_states)) == 0
def test_add_terminal(p_exp_handler: PrioritizedExperienceHandler):
    state_size = len(p_exp_handler.states)
    p_exp_handler.add_terminal()
    assert p_exp_handler.tree.root.right.value == np.inf
    assert p_exp_handler.tree.root.right.right.extra_vals == state_size-1