def test_trim(exp_handler: ExperienceHandler): old_state_size = len(exp_handler.states) exp_handler.trim() assert exp_handler.size < old_state_size assert len(exp_handler.states) < old_state_size assert exp_handler.size <= exp_handler.max_len assert len(exp_handler.states) <= exp_handler.max_len assert len(exp_handler.states) == exp_handler.size assert exp_handler.term_states.pop() == len(exp_handler.rewards) - 2 #minus 2 because we added a state after terminal # check term_states is still a set assert isinstance(exp_handler.term_states, set) # Test that terminal states that have been deleted are removed # add new terminal exp_handler.add_terminal() # add new states and trim for add in range(10): state = np.ones((2, 10, 10)) * add action = add reward = add exp_handler.add_experience(state, action, reward) exp_handler.trim() assert len(list(exp_handler.term_states)) == 0
def test_get_random_experience(exp_handler: ExperienceHandler): assert len(list(exp_handler.term_states)) > 0 # assert there is a terminal state states, _, _, _, _, _ = exp_handler.get_random_experience(1) # experience handler should return nothing when size < replay_start_size assert states is None add_experience(exp_handler, 9) states, actions, rewards, states_tp1, terminal, inds = exp_handler.get_random_experience(9) assert states is not None # if we got a terminal state states_tp1 should be zeros assert np.all(states_tp1[np.where(terminal)[0]] == np.zeros(states_tp1[0].shape)) # make sure we got correct actions/rewards assert np.sum(actions) == np.sum(inds) assert np.sum(rewards) == np.sum(inds)
def test_add_terminal(exp_handler: ExperienceHandler): state_size = len(exp_handler.rewards) # terminal ind is based on action/reward ind exp_handler.add_terminal() assert state_size-1 in exp_handler.term_states
def add_experience(exp_handler: ExperienceHandler, data:int): state = np.ones((2, 10, 10)) * data action = data reward = data exp_handler.add_experience(state, action, reward)