class ReplayMemory: def __init__(self, memory_size): self.memory_size = memory_size self.memory = SumTree(memory_size) self.epsilon = 0.0001 # small amount to avoid zero priority self.alpha = 0.6 # adj_pri = pri^alpha self.beta = 0.4 # importance-sampling, from initial value increasing to 1 self.beta_max = 1 self.beta_increment_per_sampling = 0.001 self.abs_err_upper = 1. # clipped td error def add(self, row): max_p = np.max( self.memory.tree[-self.memory.capacity:]) # max adj_pri of leaves if max_p == 0: max_p = self.abs_err_upper self.memory.add(max_p, row) # set the max adj_pri for new adj_pri def get_batch(self, batch_size): leaf_idx, batch_memory, ISWeights = np.empty( batch_size, dtype=np.int32), np.empty(batch_size, dtype=object), np.empty(batch_size) pri_seg = self.memory.total_p / batch_size # adj_pri segment self.beta = np.min( [self.beta_max, self.beta + self.beta_increment_per_sampling]) # max = 1 # Pi = Prob(i) = softmax(priority(i)) = adj_pri(i) / ∑_i(adj_pri(i)) # ISWeight = (N*Pj)^(-beta) / max_i[(N*Pi)^(-beta)] = (Pj / min_i[Pi])^(-beta) min_prob = np.min( self.memory.tree[self.memory.capacity - 1:self.memory.capacity - 1 + self.memory.counter]) / self.memory.total_p for i in range(batch_size): # sample from each interval a, b = pri_seg * i, pri_seg * (i + 1) # interval v = np.random.uniform(a, b) idx, p, data = self.memory.get_leaf(v) prob = p / self.memory.total_p ISWeights[i] = np.power(prob / min_prob, -self.beta) leaf_idx[i], batch_memory[i] = idx, data return leaf_idx, batch_memory, ISWeights def update_sum_tree(self, tree_idx, td_errors): for ti, td_error in zip(tree_idx, td_errors): p = self._calculate_priority(td_error) self.memory.update(ti, p) def _calculate_priority(self, td_error): priority = abs(td_error) + self.epsilon clipped_pri = np.minimum(priority, self.abs_err_upper) return np.power(clipped_pri, self.alpha) @property def length(self): return self.memory.counter def load_memory(self, memory): self.memory = memory def get_memory(self): return self.memory
class Memory(object): def __init__(self, capacity, state_size=37, epsilon=0.001, alpha=0.4, beta=0.3, beta_increment_per_sampling=0.001, abs_err_upper=1): self.tree = SumTree(capacity) self.epsilon = epsilon # Avoid 0 priority and hence a do not give a chance for the priority to be selected stochastically self.alpha = alpha # Vary priority vs randomness. alpha = 0 pure uniform randomnes. Alpha = 1, pure priority self.beta = beta # importance-weight-sampling, from small to big to give more importance to corrections done towards the end of the training self.beta_increment_per_sampling = 0.001 self.abs_err_upper = 1 # clipped abs error self.state_size = state_size # Save experience in memory def store(self, state, action, reward, next_state, done): transition = [state, action, reward, next_state, done] max_p = np.max(self.tree.tree[-self.tree.capacity:]) # In case of no priority, we set abs error to 1 if max_p == 0: max_p = self.abs_err_upper self.tree.add(max_p, transition) # set the max p for new p # Sample n amount of experiences using prioritized experience replay def sample(self, n): b_idx = np.empty((n, ), dtype=np.int32) states = np.empty((n, self.state_size)) actions = np.empty((n, )) rewards = np.empty((n, )) next_states = np.empty((n, self.state_size)) dones = np.empty((n, )) ISWeights = np.empty((n, )) # IS -> Importance Sampling pri_seg = self.tree.total_p / n # priority segment self.beta = np.min([ 1., self.beta + self.beta_increment_per_sampling ]) # Increase the importance of the sampling for ISWeights # min_prob = np.min(self.tree.tree[-self.tree.capacity:]) / self.tree.total_p # for later calculate ISweight for i in range(n): a, b = pri_seg * i, pri_seg * (i + 1) v = np.random.uniform(a, b) idx, p, data = self.tree.get_leaf(v) prob = p / self.tree.total_p ISWeights[i] = np.power(prob, -self.beta) b_idx[i] = idx states[i, :] = data[0] actions[i] = data[1] rewards[i] = data[2] next_states[i, :] = data[3] dones[i] = data[4] states = torch.from_numpy(np.vstack(states)).float().to(device) actions = torch.from_numpy(np.vstack(actions)).long().to(device) rewards = torch.from_numpy(np.vstack(rewards)).float().to(device) next_states = torch.from_numpy( np.vstack(next_states)).float().to(device) dones = torch.from_numpy(np.vstack(dones).astype( np.uint8)).float().to(device) ISWeights = torch.from_numpy(np.vstack(ISWeights)).float().to(device) return b_idx, states, actions, rewards, next_states, dones, ISWeights # Update the priorities according to the new errors def batch_update(self, tree_idx, abs_errors): abs_errors += self.epsilon # convert to abs and avoid 0 clipped_errors = np.minimum(abs_errors, self.abs_err_upper) ps = np.power(clipped_errors, self.alpha) for ti, p in zip(tree_idx, ps): self.tree.update(ti, p) def __len__(self): return self.tree.length()
class PriorityMemory(SimpleMemory): PER_e = 0.01 # Hyperparameter that we use to avoid some experiences to have 0 probability of being taken PER_a = 0.6 # Hyperparameter that we use to make a tradeoff between taking only exp with high priority and sampling randomly PER_b = 0.4 # importance-sampling, from initial value increasing to 1 PER_b_increment_per_sampling = 0.001 absolute_error_upper = 1. # clipped abs error def __init__(self, obs_dim, act_dim, size, act_dtype): SimpleMemory.__init__(self, obs_dim, act_dim, size, act_dtype) self.tree = SumTree(size) self.tree_lock = Lock() def store(self, obs, act, rew, next_obs, done): # Find the max priority max_priority = np.max(self.tree.tree[-self.tree.capacity:]) # If the max priority = 0 we can't put priority = 0 since this exp will never have a chance to be selected # So we use a minimum priority if max_priority == 0: max_priority = self.absolute_error_upper insertion_pos = super().store(obs, act, rew, next_obs, done) self.tree_lock.acquire() insertion_pos_tree = self.tree.add( max_priority) # set the max p for new p self.tree_lock.release() assert insertion_pos == insertion_pos_tree def sample_batch(self, batch_size): #idxs = np.random.randint(0, self._size, size=batch_size) #return self.obs1_buf[idxs],self.acts_buf[idxs],self.rews_buf[idxs],self.obs2_buf[idxs],self.done_buf[idxs] mem_idxs, tree_idxs, b_ISWeights =\ np.empty((batch_size,), dtype=np.int32),\ np.empty((batch_size,), dtype=np.int32),\ np.empty((batch_size, 1), dtype=np.float32) # Calculate the priority segment # Here, as explained in the paper, we divide the Range[0, ptotal] into n ranges priority_segment = self.tree.total_priority / batch_size # priority segment # Here we increasing the PER_b each time we sample a new minibatch self.PER_b = np.min( [1., self.PER_b + self.PER_b_increment_per_sampling]) # max = 1 # Calculating the max_weight #print('### pp: {}'.format(-self.tree.capacity)) #print('### pp: {}'.format(self.tree.tree[-self.tree.capacity:])) #print('### pp: {}'.format(np.min(self.tree.tree[-self.tree.capacity:]))) #p_min = np.min(self.tree.tree[-self.tree.capacity:]) / self.tree.total_priority p_min = self.tree.p_min assert p_min > 0 max_weight = (p_min * batch_size)**(-self.PER_b) assert max_weight > 0 for i in range(batch_size): """ A value is uniformly sample from each range """ a, b = priority_segment * i, priority_segment * (i + 1) value = np.random.uniform(a, b) """ Experience that correspond to each value is retrieved """ assert self.tree.data_pointer > 0 self.tree_lock.acquire() index, priority = self.tree.get_leaf(value) self.tree_lock.release() assert priority > 0, "### index {}".format(index) #P(j) sampling_probabilities = priority / self.tree.total_priority # IS = (1/N * 1/P(i))**b /max wi == (N*P(i))**-b /max wi b_ISWeights[i, 0] = batch_size * sampling_probabilities assert b_ISWeights[i, 0] > 0 b_ISWeights[i, 0] = np.power(b_ISWeights[i, 0], -self.PER_b) b_ISWeights[i, 0] = b_ISWeights[i, 0] / max_weight mem_idxs[i] = index - self.max_size + 1 tree_idxs[i] = index #assert b_idx[i] < self.max_size , "{} and {}".format(b_idx[i], self.max_size) return self.obs1_buf[mem_idxs],\ self.acts_buf[mem_idxs],\ self.rews_buf[mem_idxs],\ self.obs2_buf[mem_idxs],\ self.done_buf[mem_idxs],\ tree_idxs,\ b_ISWeights """ Update the priorities on the tree """ def batch_update(self, tree_idx, abs_errors): abs_errors += self.PER_e # convert to abs and avoid 0 clipped_errors = np.minimum(abs_errors, self.absolute_error_upper) ps = np.power(clipped_errors, self.PER_a) self.tree_lock.acquire() for ti, p in zip(tree_idx, ps): self.tree.update(ti, p) self.tree_lock.release()
class Memory(object): """ This SumTree code is modified version and the original code is from: https://github.com/jaara/AI-blog/blob/master/Seaquest-DDQN-PER.py """ beta = MEMORY_BETA def __init__(self): self.limit = MEMORY_CAPACITY self.err_tree = SumTree(MEMORY_CAPACITY) self.action_shape = (0, MEMORY_ACTION_CNT) self.reward_shape = (0, MEMORY_REWARD_CNT) self.terminal_shape = self.action_shape self.observation_shape = (0, MEMORY_CRITIC_FEATURE_NUM) self.store_times = 0 self.Transition = namedtuple( 'Transition', ('state', 'action', 'reward', 'next_state', 'terminal')) def size(self): return self.limit if self.store_times > self.limit else self.store_times def sample(self, batch_size): idxes = np.empty(self.reward_shape, dtype=np.int32) isw = np.empty(self.reward_shape, dtype=np.float32) obs0 = np.empty(self.observation_shape, dtype=np.float32) obs1 = np.empty(self.observation_shape, dtype=np.float32) actions = np.empty(self.action_shape, dtype=np.float32) rewards = np.empty(self.reward_shape, dtype=np.float32) terminals = np.empty(self.terminal_shape, dtype=np.bool) nan_state = np.array([np.nan] * self.observation_shape[1]) self.beta = np.min([1., self.beta + MEMORY_BETA_INC_RATE]) # max = 1 max_td_err = np.max(self.err_tree.tree[-self.err_tree.capacity:]) idx_set = set() for i in range( batch_size * 2 ): # sample maximum batch_size * 2 times to get batch_size different instances v = np.random.uniform(0, self.err_tree.total_p) idx, td_err, trans = self.err_tree.get_leaf(v) if batch_size == len(idx_set): break if idx not in idx_set: idx_set.add(idx) else: continue if (trans.state == 0).all(): continue idxes = np.row_stack((idxes, np.array([idx]))) isw = np.row_stack((isw, np.array([ np.power( self._getPriority(td_err) / max_td_err, -self.beta) ]))) obs0 = np.row_stack((obs0, trans.state)) obs1 = np.row_stack( (obs1, nan_state if trans.terminal.all() else trans.next_state)) actions = np.row_stack((actions, trans.action)) rewards = np.row_stack((rewards, trans.reward)) terminals = np.row_stack((terminals, trans.terminal)) result = { 'obs0': array_min2d(obs0), 'actions': array_min2d(actions), 'rewards': array_min2d(rewards), 'obs1': array_min2d(obs1), 'terminals': array_min2d(terminals), } return idxes, result, isw def _getPriority(self, error): return (error + EPSILON)**MEMORY_ALPHA def append(self, obs0, action, reward, obs1, terminal, err, training=True): if not training: return trans = self.Transition(obs0, action, reward, obs1, terminal) self.err_tree.add(self._getPriority(err), trans) self.store_times += 1 def batch_update(self, tree_idx, errs): errs = np.abs(errs) + EPSILON # convert to abs and avoid 0 ps = np.power(errs, MEMORY_ALPHA) for ti, p in zip(tree_idx, ps): self.err_tree.update(ti, p[0]) @property def nb_entries(self): return self.store_times