Ejemplo n.º 1
0
 def _sample_next_states(self, batch_idxs):
     '''Method to sample next_states from states, with proper guard for last idx (out of bound)'''
     # idxs for next state is state idxs + 1
     ns_batch_idxs = batch_idxs + 1
     # find the locations to be replaced with latest_next_state
     latest_ns_locs = np.argwhere(ns_batch_idxs == self.true_size).flatten()
     to_replace = latest_ns_locs.size != 0
     # set to 0, a safe sentinel for ns_batch_idxs due to the +1 above
     # then sample safely from self.states, and replace at locs with latest_next_state
     if to_replace:
         ns_batch_idxs[latest_ns_locs] = 0
     next_states = util.cond_multiget(self.states, ns_batch_idxs)
     if to_replace:
         next_states[latest_ns_locs] = self.latest_next_state
     return next_states
Ejemplo n.º 2
0
 def sample(self):
     '''
     Returns a batch of batch_size samples. Batch is stored as a dict.
     Keys are the names of the different elements of an experience. Values are an array of the corresponding sampled elements
     e.g.
     batch = {
         'states'     : states,
         'actions'    : actions,
         'rewards'    : rewards,
         'next_states': next_states,
         'dones'      : dones}
     '''
     self.batch_idxs = self.sample_idxs(self.batch_size)
     batch = {}
     for k in self.data_keys:
         if k == 'next_states':
             batch[k] = self._sample_next_states(self.batch_idxs)
         else:
             batch[k] = util.cond_multiget(getattr(self, k), self.batch_idxs)
     return batch