def test_get_prioritized_experience(p_exp_handler: PrioritizedExperienceHandler): assert len(list(p_exp_handler.term_states)) > 0 # assert there is a terminal state states, _, _, _, _, _ = p_exp_handler.get_prioritized_experience(999) # experience handler should return nothing when num_requested > size assert states is None states, actions, rewards, states_tp1, terminal, inds = p_exp_handler.get_prioritized_experience(2) assert states is not None # we should've got a terminal which would be the first element assert np.sum(terminal) == 1 assert np.all(states_tp1[0] == np.zeros(states_tp1[0].shape)) # states should be 2 then 1 assert np.all(states[0] == 2*np.ones(states[0].shape)) assert np.all(states[1] == np.ones(states[0].shape)) # make sure we got correct actions/rewards assert np.sum(actions) == 3 # 2+1 assert np.sum(rewards) == 3 # 2+1 # make sure all state_tp1s are correct for ind in range(states.shape[0]): if not terminal[ind]: assert np.all(states[ind]+1 == states_tp1[ind]) # reinsert for ind in inds: p_exp_handler.tree.insert(ind, ind) # right should be nothing because 0 has value inf assert p_exp_handler.tree.root.right is None
def train_prioritized(self, exp_handler: PrioritizedExperienceHandler, discount, cnn): # generate minibatch data states, actions, rewards, state_tp1s, terminal, mb_inds_popped = exp_handler.get_prioritized_experience(self.mini_batch) if states is not None: r_tp1 = cnn.get_output(state_tp1s) max_tp1 = np.max(r_tp1, axis=1) rewards += (1-terminal) * discount * max_tp1 rewVals = np.zeros((self.mini_batch, self.num_actions), dtype=self.dtype) arange = np.arange(self.mini_batch) rewVals[arange, actions] = rewards mask = np.zeros((self.mini_batch, self.num_actions), dtype=self.dtype) nonZero = np.where(rewVals != 0) mask[nonZero[0], nonZero[1]] = 1 cost, output_states = cnn.train(states, rewVals, mask) self.costList.append(cost) # update prioritized exp handler with new td_error max_states = np.max(output_states, axis=1) td_errors = np.abs(max_tp1 - max_states) exp_handler.set_new_td_errors(td_errors, mb_inds_popped)