class Memory(object): def __init__(self, batch_size, max_size, beta): self.batch_size = batch_size # mini batch大小 self.max_size = 2**math.floor(math.log2(max_size)) # 保证 sum tree 为完全二叉树 self.beta = beta self._sum_tree = SumTree(max_size) def store_transition(self, s, a, r, s_, done): self._sum_tree.add((s, a, r, s_, done)) def get_mini_batches(self): n_sample = self.batch_size if self._sum_tree.size >= self.batch_size else self._sum_tree.size total = self._sum_tree.get_total() step = total // n_sample points_transitions_probs = [] for i in range(n_sample): v = np.random.uniform(i * step, (i + 1) * step - 1) t = self._sum_tree.sample(v) points_transitions_probs.append(t) points, transitions, probs = zip(*points_transitions_probs) # 计算重要性比率 max_impmortance_ratio = (n_sample * self._sum_tree.get_min())**-self.beta importance_ratio = [(n_sample * probs[i])**-self.beta / max_impmortance_ratio for i in range(len(probs))] return points, tuple(np.array(e) for e in zip(*transitions)), importance_ratio def update(self, points, td_error): for i in range(len(points)): self._sum_tree.update(points[i], td_error[i])
class Memory(object): def __init__(self, capacity, batch_size): self.capacity = capacity self.batch_size = batch_size self.tree = SumTree(capacity=capacity) self.alpha = 0.6 self.beta = 0.4 self.p_epsilon = 1e-4 self.batch_size = 50 def _get_priority(self, priorities): priorities += self.p_epsilon priorities = np.minimum(priorities, 1.0) priorities = np.power(priorities, self.alpha) return priorities def store(self, transition): max_p = np.max(self.tree.tree[-self.capacity:]) if max_p == 0: max_p = 1.0 self.tree.add(transition, max_p) def sample(self): avg_p = self.tree.total_p() / self.batch_size batch_tree_idx, batch_p, batch_data = [], [], [] for i in range(self.batch_size): a, b = avg_p * i, avg_p * (i + 1) s = np.random.uniform(a, b) tree_idx, p, data = self.tree.sample(s) batch_tree_idx.append(tree_idx) batch_p.append(p) batch_data.append(data) batch_p /= self.tree.total_p() batch_weight = np.power(batch_p * self.capacity, -self.beta) batch_weight = batch_weight / max(batch_weight) batch_tree_idx, batch_data, batch_weight = map( np.array, [batch_tree_idx, batch_data, batch_weight]) return batch_tree_idx, batch_data, batch_weight def update(self, tree_idx, priorities): priorities = self._get_priority(priorities) for index, p in zip(tree_idx, priorities): self.tree.update(index, p)
class PrioritizeReplayBuffer(ReplayBuffer): """Prioritize experience replay.""" def __init__( self, buffer_size, batch_size, seed, beta_start=0.4, delta_beta=1e-5, alpha=0.6, eps=1e-8, ): """Initialize PER. Args: buffer_size (int): Size of replay buffer. The actual size will be the first power of 2 greater than buffer_size. batch_size (int): Size of batches to draw. seed (float): Seed. beta_start (float): Initial value for beta (importance sampling exponent) delta_beta (float): Beta increment at each time step. alpha (float): Priority exponent. eps (float): Small positive number to avoid unsampling 0 prioritized examples. """ # Depth of sum tree depth = int(math.log2(buffer_size)) + 1 super(PrioritizeReplayBuffer, self).__init__(2**depth, batch_size, seed) # Initialize sum tree to keep track of the sum of priorities self.priorities = SumTree(depth) # Current max priority self.max_p = 1.0 # PER Parameters self.alpha = alpha self.eps = eps self.beta = beta_start self.delta_beta = delta_beta def add(self, state, action, reward, next_state, done): """Add transition inside the Replay buffer.""" # Add in the sum tree with current max priority self.priorities.add(self.max_p, self.index) super().add(state, action, reward, next_state, done) def sample(self): """Get sample.""" # Get indices to sample from sum tree # Store these indices to compute importance sampling later self.last_indices = self.priorities.sample(self.batch_size) # Return transitions corresponding to this indices return [self.data[i] for i in self.last_indices] def update_priorities(self, td_error): """Update priorities.""" # Compute new priorites new_priorities = (abs(td_error) + self.eps)**self.alpha # Update sum tree self.priorities.update(self.last_indices, new_priorities) # Update the current max priority self.max_p = max(self.max_p, max(new_priorities)) def importance_sampling(self): """Compute importance sampling weights of last sample.""" # Get probabilities probs = self.priorities.get( self.last_indices) / self.priorities.total_sum # Compute weights weights = (len(self) * probs)**(-self.beta) weights /= max(weights) # Update beta self.beta = min(self.beta + self.delta_beta, 1) # Return weights return weights
class PrioritizedReplayBuffer: """ Memory buffer responsible for Prioritized Experience Replay. This buffer stores up to memory_size experiences in a circular array-like data structure. Each experience is also associated with a probability weight. Batches may be sampled (with replacement) from this implied probability distribution in batches. The provided weights should be non-negative, but are not required to add up to 1. """ def __init__(self, device, memory_size, update_every=4, seed=0): """ Initializes the data structure :param device: (torch.device) Object representing the device where to allocate tensors :param memory_size: (int) Maximum capacity of memory buffer :param update_every: (int) Number of steps between update operations :param seed: (int) Seed used for PRNG """ self.device = device self.probability_weights = SumTree(capacity=memory_size, seed=seed) self.elements = deque(maxlen=memory_size) self.update_every = update_every self.step = 0 self.experience = namedtuple("Experience", field_names=["state", "action", "reward", "next_state", "done"]) def add(self, state, action, reward, next_state, done): """ Adds a experience tuple (s, a, r, s', done) to memory :param state: (array-like) State value from experience tuple :param action: (int) Action value from experience tuple :param reward: (float) Reward value from experience tuple :param next_state: (array-like) Next state value from experience tuple :param done: (bool) Done flag from experience tuple """ e = self.experience(state, action, reward, next_state, done) self.elements.append(e) self.step += 1 # Add batch of experiences to memory, with max initial weight if self.step >= self.update_every: self.probability_weights.add(self.step) self.step = 0 def sample(self, batch_size, alpha, beta): """ Samples a batch of examples with replacement from the buffer. :param batch_size: (int) Number of samples to sample :param alpha: (float) PER probability hyperparameter :param beta: (float) PER probability hyperparameter :return: states: (list) States from sampled experiences actions: (list) Actions from sampled experiences rewards: (list) Rewards from sampled experiences next_states: (list) Next states from sampled experiences dones: (list) Done flags from sampled experiences indexes: (list) Indexes of sampled experiences """ indexes = self.probability_weights.sample(batch_size=batch_size, alpha=alpha, beta=beta) experiences = [self.elements[i] for i in indexes] # Copy experience tensors to device states = torch.from_numpy(np.vstack([e.state for e in experiences])).float().to(self.device) actions = torch.from_numpy(np.vstack([e.action for e in experiences])).long().to(self.device) rewards = torch.from_numpy(np.vstack([e.reward for e in experiences])).float().to(self.device) next_states = torch.from_numpy(np.vstack([e.next_state for e in experiences])).float().to(self.device) dones = torch.from_numpy(np.vstack([e.done for e in experiences]).astype(np.uint8)).float().to(self.device) return states, actions, rewards, next_states, dones, indexes def update(self, indexes, weights): """ Updates the probability weights associated with the provided indexes. :param indexes: (array indexes) Indexes to have weights updated :param weights: (list) New weights for the provided indexes """ self.probability_weights.update(indexes, weights) def __len__(self): """Return the current size of internal memory.""" return len(self.probability_weights)
class PrioritisedReplayBuffer(): """A prioritised replay buffer. Creates a sum tree and uses it to stores a fixed number of experience tuples. When sampled experiences are returned with greater priority given to those with the highest absolute TD-error. """ def __init__(self, buffer_size, alpha, beta_zero, beta_increment_size=0.001, epsilon=0.1, max_priority=1., seed=None): """Priority replay buffer initialiser. Args: buffer_size (int): capacity of the replay buffer. alpha (float): priority scaling hyperparameter. beta_zero (float): importance sampling scaling hyperparameter. beta_increment_size (float): beta annealing rate. epsilon (float): base priority to ensure non-zero sampling probability. max_priority (float): initial maximum priority. seed (int): seed for random number generator """ random.seed(seed) self.sum_tree = SumTree(buffer_size) self.memory = {} self.experience = namedtuple( "experience", ["state", "action", "reward", "next_state", "done"]) self.buffer_size = buffer_size self.beta_increment_size = beta_increment_size self.max_priority = max_priority**alpha self.min_priority = max_priority**alpha self.last_min_update = 0 self.alpha = alpha self.beta = beta_zero self.epsilon = epsilon def add(self, state, action, reward, next_state, done): """Creates experience tuple and adds it to the replay buffer.""" experience = self.experience(state, action, reward, next_state, done) current_tree_idx = self.sum_tree.input_pointer self.memory[current_tree_idx] = experience self.sum_tree.add(self.max_priority) def sample(self, batch_size): """Returns a batch of experiences sampled according to their priority.""" idx_list = [] weights = [] states = [] actions = [] rewards = [] next_states = [] done_list = [] segment = self.sum_tree.total() / batch_size sample_list = [ random.uniform(segment * i, segment * (i + 1)) for i in range(batch_size) ] max_weight = self.min_priority**(-self.beta) for s in sample_list: idx, priority = self.sum_tree.sample(s) idx_list.append(idx) weight = priority**(-self.beta) / max_weight weights.append(weight) sample = self.memory[idx] state, action, reward, next_state, done = sample states.append(state) actions.append(action) rewards.append(reward) next_states.append(next_state) done_list.append(done) return states, actions, rewards, next_states, done_list, idx_list, weights def update(self, idx_list, td_error): """Updates a specifics experience's priority.""" priority_list = (td_error + self.epsilon)**self.alpha self.max_priority = max(self.max_priority, priority_list.max()) list_min_priority = priority_list.min() if list_min_priority <= self.min_priority: self.min_priority = list_min_priority self.last_min_update = 0 else: self.last_min_update += 1 if self.last_min_update >= self.buffer_size: self.min_priority = np.array([ node.val for node in self.sum_tree.tree_array[-self.buffer_size:] ]).min() self.last_min_update = 0 for i, idx in enumerate(idx_list): priority = min(self.max_priority, priority_list[i]) self.sum_tree.update(idx, priority) self.beta = min(1, self.beta + self.beta_increment_size) def __len__(self, ): """Return number of experiences in the replay buffer.""" return len(self.memory)