def _sample_next_states(self, batch_idxs): '''Method to sample next_states from states, with proper guard for last idx (out of bound)''' # idxs for next state is state idxs + 1 ns_batch_idxs = batch_idxs + 1 # find the locations to be replaced with latest_next_state latest_ns_locs = np.argwhere(ns_batch_idxs == self.true_size).flatten() to_replace = latest_ns_locs.size != 0 # set to 0, a safe sentinel for ns_batch_idxs due to the +1 above # then sample safely from self.states, and replace at locs with latest_next_state if to_replace: ns_batch_idxs[latest_ns_locs] = 0 next_states = util.cond_multiget(self.states, ns_batch_idxs) if to_replace: next_states[latest_ns_locs] = self.latest_next_state return next_states
def sample(self): ''' Returns a batch of batch_size samples. Batch is stored as a dict. Keys are the names of the different elements of an experience. Values are an array of the corresponding sampled elements e.g. batch = { 'states' : states, 'actions' : actions, 'rewards' : rewards, 'next_states': next_states, 'dones' : dones} ''' self.batch_idxs = self.sample_idxs(self.batch_size) batch = {} for k in self.data_keys: if k == 'next_states': batch[k] = self._sample_next_states(self.batch_idxs) else: batch[k] = util.cond_multiget(getattr(self, k), self.batch_idxs) return batch