def __init__(self, conf): self.size = conf['size'] # If the transitions should be replaced if the heap is full # A lower priority node is expelled self.replace_flag = conf['replace_old'] if 'replace_old' in conf else True # Represents the max capacity of the heap self.priority_size = conf['priority_size'] if 'priority_size' in conf else self.size # The alpha used in Equation (1) in PER paper self.alpha = conf['alpha'] if 'alpha' in conf else 0.7 # The bias correction term. Usually annealed linearly from # beta_zero to 1. Section 3.4 of the paper self.beta_zero = conf['beta_zero'] if 'beta_zero' in conf else 0.5 self.batch_size = conf['batch_size'] if 'batch_size' in conf else 32 self.learn_start = conf['learn_start'] if 'learn_start' in conf else 32 self.total_steps = conf['steps'] if 'steps' in conf else 100000 # partition number N, split total size to N part self.partition_num = conf['partition_num'] if 'partition_num' in conf else 100 self.partition_size = conf['partition_size'] if 'partition_size' in conf else math.floor(self.size / self.partition_num) if 'partition_size' in conf: self.partition_num = math.floor(self.size / self.partition_size) self.index = 0 self.record_size = 0 self.isFull = False self._experience = {} self.priority_queue = BinaryHeap(self.priority_size) # Added in new code self.distribution = None self.dist_index = 1 self.beta_grad = (1 - self.beta_zero) / (self.total_steps - self.learn_start)
class Experience(object): def __init__(self, conf): self.size = conf['size'] # If the transitions should be replaced if the heap is full # A lower priority node is expelled self.replace_flag = conf[ 'replace_old'] if 'replace_old' in conf else True # Represents the max capacity of the heap self.priority_size = conf[ 'priority_size'] if 'priority_size' in conf else self.size # The alpha used in Equation (1) in PER paper self.alpha = conf['alpha'] if 'alpha' in conf else 0.7 # The bias correction term. Usually annealed linearly from # beta_zero to 1. Section 3.4 of the paper self.beta_zero = conf['beta_zero'] if 'beta_zero' in conf else 0.5 self.batch_size = conf['batch_size'] if 'batch_size' in conf else 32 self.learn_start = conf['learn_start'] if 'learn_start' in conf else 32 self.total_steps = conf['steps'] if 'steps' in conf else 100000 # partition number N, split total size to N part self.partition_num = conf[ 'partition_num'] if 'partition_num' in conf else 100 self.partition_size = conf[ 'partition_size'] if 'partition_size' in conf else math.floor( self.size / self.partition_num) if 'partition_size' in conf: self.partition_num = math.floor(self.size / self.partition_size) self.index = 0 self.record_size = 0 self.isFull = False self._experience = {} self.priority_queue = BinaryHeap(self.priority_size) # Added in new code self.distribution = None self.dist_index = 1 self.beta_grad = (1 - self.beta_zero) / (self.total_steps - self.learn_start) # Debug Code self.debug = {} # Return the correct distribution, build if required def return_distribution(self, dist_index): if (dist_index == self.dist_index) and self.dist_index > 1: return self.distribution elif dist_index < self.dist_index: # print("Dist_index is: "+str(dist_index)) # print("Self.dist_index is: "+str(self.dist_index)) raise Exception( 'Elements have been illegally deleted from the priority_queue in rank_based' ) else: res = {} # Store the current dist_index self.dist_index = dist_index partition_num = dist_index # The procedure being followed here is that given on Page 13 # last line. We divide the whole range into 'k' segments # This has the advantage that the same transition will not be # picked twice in the same batch (Stratified Sampling) n = partition_num * self.partition_size if self.batch_size <= n <= self.priority_size: distribution = {} # P(i) = (rank i) ^ (-alpha) / sum ((rank i) ^ (-alpha)) pdf = list( map(lambda x: math.pow(x, -self.alpha), range(1, n + 1))) pdf_sum = math.fsum(pdf) distribution['pdf'] = list(map(lambda x: x / pdf_sum, pdf)) # split to k segment, and than uniform sample in each k # set k = batch_size, each segment has total probability is 1 / batch_size # strata_ends keep each segment start pos and end pos # The following code creates strata_ends such that # strata_ends[i]-strata_ends[0] = (i-1)*1/batch_size probability cdf = np.cumsum(distribution['pdf']) strata_ends = {1: 0, self.batch_size + 1: n} step = 1 / self.batch_size index = 1 for s in range(2, self.batch_size + 1): while cdf[index] < step: index += 1 strata_ends[s] = index step += 1 / self.batch_size distribution['strata_ends'] = strata_ends # print("The strata is: "+str(distribution['strata_ends'])) self.distribution = distribution return self.distribution def fix_index(self): """ get next insert index :return: index, int """ if self.record_size <= self.size: # self.record_size increases till self.size and stays there after that self.record_size += 1 # self.index is being monotonically increased and thus say self.size = 3 # When self.index = 3, the heap is full and when self.index = 6, three # other elements have been added and hence it is still full # This will happen because replace is True # But self.index is being set to 1 when replace_flag is true. Hence I don't # see why % operator was used if self.index % self.size == 0: # This condition because self.index = 0 initially, so control will always reach here self.isFull = True if len(self._experience) == self.size else False if self.replace_flag: self.index = 1 # Doubt:: Won't the highest priority node be replaced? return self.index else: sys.stderr.write( 'Experience replay buff is full and replace is set to FALSE!\n' ) return -1 else: self.index += 1 return self.index def store(self, experience): """ store experience, suggest that experience is a tuple of (s1, a, r, s2, g, t) so each experience is valid :param experience: maybe a tuple, or list :return: bool, indicate insert status """ # This function should ideally be called only if a new experience needs to # be stored because it is given the highest priority # Get the next position to be inserted in insert_index = self.fix_index() if insert_index > 0: # Remove the previous experience with the same index if insert_index in self._experience: del self._experience[insert_index] # Add the newest experience self._experience[insert_index] = experience ######Debug self._experience[insert_index]['new'] = True ###### # add to priority queue # Add it with max priority so that it gets picked as soon as possible priority = self.priority_queue.get_max_priority() # Update the node where the new experience was inserted self.priority_queue.update(priority, insert_index) # print('The buffer size is: '+str(self.record_size)) return True else: # This happens if replace is set to false and elements # are trying to be added sys.stderr.write('Insert failed\n') return False def retrieve(self, indices): """ get experience from indices :param indices: list of experience id :return: experience replay sample """ # self.debug['old'] = 0 # self.debug['new'] = 0 # #######Debug # for v in indices: # if 'new' in self._experience[v].keys(): # self.debug['new'] += 1 # del self._experience[v]['new'] # else: # self.debug['old'] += 1 # f = open('new_old_ratio.txt', 'a') # f.write("The ratio is: "+str(float(self.debug['new'])/self.debug['old'])+'\n') # ####### # Given a list of Experience_IDs, return the experience # it represents return [self._experience[v] for v in indices] def rebalance(self): """ rebalance priority queue :return: None """ # Balance the Heap from scratch self.priority_queue.balance_tree() # Function called from def update_priority(self, indices, delta): """ update priority according indices and deltas :param indices: list of experience id :param delta: list of delta, order correspond to indices :return: None """ # Update the priority of the node. indices[i] should represent # the experience_ID and delta should represent TD error (priority) --- Check this # Update the priorities of multiple nodes, given their experience_ids # and new priorities for i in range(0, len(indices)): self.priority_queue.update(math.fabs(delta[i]), indices[i]) # if batch_size argument is passed, use that, else use the one at __init__ def sample(self, global_step, uniform_priority=False, batch_size=32): """ sample a mini batch from experience replay :param global_step: now training step :return: experience, list, samples :return: w, list, weights :return: rank_e_id, list, samples id, used for update priority """ if self.record_size < self.learn_start: # Store a minimum of self.learn_start number of experiences before starting # any kind of learning. This is done to ensure there is not learning happening # with very small number of examples, leading to unstable estimates # Recollect: self.record_size increases till it reaches self.size and then stops there sys.stderr.write( 'Record size less than learn start! Sample failed\n') return False, False, False # If the replay buffer is not full, find the right partition to use # If only half the buffer is full, partition number 'self.partition_num/2' # is used because there are only those many ranks assigned # dist_index will always be the last partition after the replay # buffer is full. If it is not full, it will represent some # partition number less than that # print("(In Values are (record_size, size, partition_num)::"+str(self.record_size)+"::"+str(self.size)+"::"+str(self.partition_num)) dist_index = math.floor(self.record_size / self.size * self.partition_num) # dist_index = max(math.floor(self.record_size / self.size * self.partition_num)+1, self.partition_num) # issue 1 by @camigord partition_size = math.floor(self.size / self.partition_num) partition_max = dist_index * partition_size ############################ # distribution = self.distributions[dist_index] # print("Dist Index is: "+str(dist_index)) distribution = self.return_distribution(dist_index) ############################ # print("Dist Index is: "+str(dist_index)) rank_list = [] # sample from k segments if uniform_priority == True: for i in range(1, self.batch_size + 1): index = random.randint( 1, distribution['strata_ends'][self.batch_size] + 1) rank_list.append(index) # The following statement is to ensure the there is no bias correction when uniform sampliing is done w = np.ones(self.batch_size) else: # This is stratified sampling. Each segment represents a probability # of 1/self.batch_size and we sample once from each segment # index represents which rank to choose, (1,n) for n in range(1, self.batch_size + 1): if distribution['strata_ends'][n] + 1 >= distribution[ 'strata_ends'][n + 1]: index = distribution['strata_ends'][n + 1] else: # print("The values are: "+str(distribution['strata_ends'][n] + 1)+"::"+str(distribution['strata_ends'][n + 1])) index = random.randint(distribution['strata_ends'][n] + 1, distribution['strata_ends'][n + 1]) rank_list.append(index) # beta, increase by global_step (the current training step), max 1 # This is linear annealing mentioned in section 3.4 beta = min( self.beta_zero + (global_step - self.learn_start - 1) * self.beta_grad, 1) # find all alpha pow, notice that pdf is a list, start from 0 alpha_pow = [distribution['pdf'][v - 1] for v in rank_list] # w = (N * P(i)) ^ (-beta) / max w # This equation is metioned in equation 3.4 w = np.power(np.array(alpha_pow) * partition_max, -beta) w_max = max(w) w = np.divide(w, w_max) # rank list is priority id # convert to experience id # This gives all the experience_IDs based on the priority (node numbers) # file_obj = open("rank_list","a") # file_obj.write(str(rank_list)+"\n") # file_obj.close() rank_e_id = self.priority_queue.priority_to_experience(rank_list) # get experience id according rank_e_id experience = self.retrieve(rank_e_id) return experience, w, rank_e_id