class ReplayBuffer: """Fixed-size buffer to store experience tuples.""" def __init__(self, buffer_size, batch_size, td_eps, seed, p_replay_alpha, reward_scale=False, error_clip=False, error_max=1.0, error_init=False, use_tree=False, err_init=1.0): """Initialize a ReplayBuffer object. Params ====== buffer_size (int): maximum size of buffer batch_size (int): size of each training batch td_eps: (float): to avoid zero td_error p_replay_alpha (float): discount factor for priority sampling reward_scale (flag): to scale reward down by 10 error_clip (flag): max error to 1 seed (int): random seed """ self.useTree = use_tree self.memory = deque(maxlen=buffer_size) self.tree = SumTree(buffer_size) #create tree instance self.batch_size = batch_size self.buffer_size = buffer_size self.td_eps = td_eps self.experience = namedtuple("Experience", field_names=[ "state", "action", "reward", "next_state", "done", "td_error" ]) self.seed = random.seed(seed) self.p_replay_alpha = p_replay_alpha self.reward_scale = reward_scale self.error_clip = error_clip self.error_init = error_init self.error_max = error_max self.memory_index = np.zeros([self.buffer_size, 1]) #for quicker calculation self.memory_pointer = 0 def add(self, state, action, reward, next_state, done, td_error): """Add a new experience to memory. td_error: abs value """ #reward clipping if self.reward_scale: reward = reward / 10.0 #scale reward by factor of 10 #error clipping if self.error_clip: #error clipping td_error = np.clip(td_error, -self.error_max, self.error_max) # apply alpha power td_error = (td_error**self.p_replay_alpha) + self.td_eps # make sure experience is at least visit once if self.error_init: td_mad = np.max(self.memory_index) if td_mad == 0: td_error = self.error_max else: td_error = td_mad e = self.experience(np.expand_dims(state, 0), action, reward, np.expand_dims(next_state, 0), done, td_error) if self.useTree: self.tree.add(td_error, e) # update the td score and experience data else: self.memory.append(e) ### memory index ### if self.memory_pointer >= self.buffer_size: #self.memory_pointer = 0 self.memory_index = np.roll(self.memory_index, -1) self.memory_index[-1] = td_error #fifo else: self.memory_index[self.memory_pointer] = td_error self.memory_pointer += 1 def update(self, td_updated, index): """ update the td error values while restoring orders td_updated: abs value; np.array of shape 1,batch_size,1 index: in case of tree, it is the leaf index """ td_updated = td_updated.squeeze() # (batch_size,) #error clipping if self.error_clip: #error clipping td_updated = np.clip(td_updated, -ERROR_MAX, ERROR_MAX) # apply alpha power td_updated = (td_updated**self.p_replay_alpha) + self.td_eps ### checking memory and memory index are sync ### #tmp_memory = copy.deepcopy(self.memory) i = 0 #while loop while i < len(index): if self.useTree: #data_index = index[i] #tree_index = data_index + self.buffer_size - 1 self.tree.update(index[i], td_updated[i]) else: self.memory.rotate( -index[i]) # move the target index to the front e = self.memory.popleft() td_i = td_updated[i].reshape(1, 1) e1 = self.experience(e.state, e.action, e.reward, e.next_state, e.done, td_i) self.memory.appendleft(e1) #append the new update self.memory.rotate(index[i]) #restore the original order ### memory index ### self.memory_index[index[i]] = td_i i += 1 #increment # make sure its updated # assert(self.memory[index[i]].td_error == self.memory_index[index[i]]) ### checking memory and memory index are sync ### #for i in range(len(self.memory)): # assert(self.memory_index[i] == self.memory[i].td_error) # if i in index: # assert(td_updated[list(index).index(i)] == self.memory[i].td_error) # else: # print(self.memory[i].td_error) # assert(tmp_memory[i].td_error == self.memory[i].td_error) def sample(self, p_replay_beta): """Sample a batch of experiences from memory.""" l = len(self.memory) p_dist = (self.memory_index[:l] / np.sum(self.memory_index[:l])).squeeze() assert (np.abs(np.sum(p_dist) - 1) < 1e-5) assert (len(p_dist) == l) # get sample of index from the p distribution sample_ind = np.random.choice(l, self.batch_size, p=p_dist) ### checking: make sure the rotation didnt screw up the memory ### #tmp_memory = copy.deepcopy(self.memory) #checking # get the selected experiences: avoid using mid list indexing es, ea, er, en, ed = [], [], [], [], [] for i in sample_ind: self.memory.rotate(-i) e = copy.deepcopy(self.memory[0]) es.append(e.state) ea.append(e.action) er.append(e.reward) en.append(e.next_state) ed.append(e.done) self.memory.rotate(i) ### checking: make sure the rotation didnt screw up the memory ### #for i in range(len(tmp_memory)): # assert(tmp_memory[i].td_error == self.memory[i].td_error) #checking states = torch.from_numpy(np.vstack(es)).float().to(device) actions = torch.from_numpy(np.vstack(ea)).long().to(device) rewards = torch.from_numpy(np.vstack(er)).float().to(device) next_states = torch.from_numpy(np.vstack(en)).float().to(device) dones = torch.from_numpy(np.vstack(ed).astype( np.uint8)).float().to(device) # for weight update adjustment selected_td_p = p_dist[sample_ind] #the prob of selected e ### checker: the mean of selected TD errors should be greater than ### checking: the mean of selected TD err are higher than memory average if p_replay_beta > 0: if np.mean(self.memory_index[sample_ind]) < np.mean( self.memory_index[:l]): print(np.mean(self.memory_index[sample_ind]), np.mean(self.memory_index[:l])) #weight = (np.array(selected_td_p) * l) ** -p_replay_beta #max_weight = (np.min(selected_td_p) * self.batch_size) ** -p_replay_beta weight = (1 / selected_td_p * 1 / l)**p_replay_beta weight = weight / np.max(weight) #normalizer by max weight = torch.from_numpy(np.array(weight)).float().to( device) #change form assert (weight.requires_grad == False) return (states, actions, rewards, next_states, dones), weight, sample_ind def sample_tree(self, p_replay_beta): # Create a sample array that will contains the minibatch e_s, e_a, e_r, e_n, e_d = [], [], [], [], [] sample_ind = np.empty((self.batch_size, ), dtype=np.int32) sampled_td_score = np.empty((self.batch_size, 1)) weight = np.empty((self.batch_size, 1)) # Calculate the priority segment # Here, as explained in the paper, we divide the Range[0, ptotal] into n ranges td_score_segment = self.tree.total_td_score / self.batch_size # priority segment i = 0 #use while loop while i < self.batch_size: """ A value is uniformly sample from each range """ a, b = td_score_segment * i, td_score_segment * (i + 1) value = np.random.uniform(a, b) """ Experience that correspond to each value is retrieved """ leaf_index, td_score, data = self.tree.get_leaf(value) #P(j) sampling_p = td_score / self.tree.total_td_score sampled_td_score[i, 0] = td_score # IS = (1/N * 1/P(i))**b /max wi == (N*P(i))**-b /max wi weight[i, 0] = (1 / self.buffer_size * 1 / sampling_p)**p_replay_beta sample_ind[i] = leaf_index e_s.append(data.state) e_a.append(data.action) e_r.append(data.reward) e_n.append(data.next_state) e_d.append(data.done) i += 1 # increment # Calculating the max_weight """ p_min = np.min(self.tree.tree[-self.buffer_size:]) / self.tree.total_td_score if p_min == 0: p_min = self.td_eps # avoid div by zero max_weight = (1/p_min * 1/self.buffer_size) ** (p_replay_beta) """ # apply max weigth adjustment max_weight = np.max(weight) weight = weight / max_weight #assert(np.mean(sampled_td_score) >= np.mean(self.tree.tree[-self.buffer_size:])) states = torch.from_numpy(np.vstack(e_s)).float().to(device) actions = torch.from_numpy(np.vstack(e_a)).long().to(device) rewards = torch.from_numpy(np.vstack(e_r)).float().to(device) next_states = torch.from_numpy(np.vstack(e_n)).float().to(device) dones = torch.from_numpy(np.vstack(e_d).astype( np.uint8)).float().to(device) weight = torch.from_numpy(weight).float().to(device) #change form assert (weight.requires_grad == False) return (states, actions, rewards, next_states, dones), weight, sample_ind def __len__(self): """Return the current size of internal memory.""" return len(self.memory)
class ReplayBuffer: def __init__(self, buffer_size, num_agents, state_size, action_size, use_PER=False): self.buffer_size = buffer_size self.use_PER = use_PER self.num_agents = num_agents self.state_size = state_size self.action_size = action_size if use_PER: self.tree = SumTree(buffer_size) #create tree instance else: self.memory = deque(maxlen=buffer_size) self.buffer_size = buffer_size self.leaves_count = 0 def add_tree(self, data, td_default=1.0): """PER function. Add a new experience to memory. td_error: abs value""" td_max = np.max(self.tree.tree[-self.buffer_size:]) if td_max == 0.0: td_max = td_default self.tree.add(td_max, data) #increase chance to be selected self.leaves_count = min(self.leaves_count+1,self.buffer_size) def add(self, data): """add into the buffer""" self.memory.append(data) def sample_tree(self, batch_size, p_replay_beta, td_eps=1e-4): """PER function. Segment piece wise sampling""" s_samp, a_samp, r_samp, d_samp, ns_samp = ([] for l in range(5)) sample_ind = np.empty((batch_size,), dtype=np.int32) weight = np.empty((batch_size, 1)) # create segments according to td score range td_score_segment = self.tree.total_td_score / batch_size for i in range(batch_size): # A value is uniformly sample from each range _start, _end = i * td_score_segment, (i+1) * td_score_segment value = np.random.uniform(_start, _end) # get the experience with the closest value in that segment leaf_index, td_score, data = self.tree.get_leaf(value) # the sampling prob for this sample across all tds sampling_p = td_score / self.tree.total_td_score # apply weight adjustment weight[i,0] = (1/sampling_p * 1/self.leaves_count)**p_replay_beta sample_ind[i] = leaf_index s_samp.append(data.states) a_samp.append(data.actions) r_samp.append(data.rewards) d_samp.append(data.dones) ns_samp.append(data.next_states) # Calculating the max_weight among entire memory #p_max = np.max(self.tree.tree[-self.buffer_size:]) / self.tree.total_td_score #if p_max == 0: p_max = td_eps # avoid div by zero #max_weight_t = (1/p_max * 1/self.leaves_count)**p_replay_beta #max_weight = np.max(weight) weight_n = toTorch(weight) #normalize weight /max_weight return (s_samp, a_samp, r_samp, d_samp, ns_samp, weight_n, sample_ind) def sample(self, batch_size): """sample from the buffer""" sample_ind = np.random.choice(len(self.memory), batch_size) s_samp, a_samp, r_samp, d_samp, ns_samp = ([] for l in range(5)) i = 0 while i < batch_size: #while loop is faster self.memory.rotate(-sample_ind[i]) e = self.memory[0] s_samp.append(e.states) a_samp.append(e.actions) r_samp.append(e.rewards) d_samp.append(e.dones) ns_samp.append(e.next_states) self.memory.rotate(sample_ind[i]) i += 1 # last 2 values for functions compatibility with PER return (s_samp, a_samp, r_samp, d_samp, ns_samp, 1.0, []) def update_tree(self, td_updated, index, p_replay_alpha, td_eps=1e-4): """ PER function. update the td error values while restoring orders td_updated: abs value; np.array of shape 1,batch_size,1 index: in case of tree, it is the leaf index """ # apply alpha power td_updated = (td_updated.squeeze() ** p_replay_alpha) + td_eps for i in range(len(index)): self.tree.update(index[i], td_updated[i]) def __len__(self): if not self.use_PER: return len(self.memory) else: return self.leaves_count