def __init__(self, size, e=0.01, a=0.6, beta=0.4, beta_increment_per_sampling=0.001): self.size = size self.tree = SumTree(size) self.e = e self.a = a self.beta = beta self.beta_increment_per_sampling = beta_increment_per_sampling
class PERBuffer: """ref: https://github.com/rlcode/per/blob/master/prioritized_memory.py""" def __init__(self, size, e=0.01, a=0.6, beta=0.4, beta_increment_per_sampling=0.001): self.size = size self.tree = SumTree(size) self.e = e self.a = a self.beta = beta self.beta_increment_per_sampling = beta_increment_per_sampling def _get_priority(self, error): return (error + self.e)**self.a def add(self, error, sample): p = self._get_priority(error) self.tree.add(p, sample) def update(self, idx, error): p = self._get_priority(error) self.tree.update(idx, p) def sample(self, n): batch = [] idxs = [] segment = self.tree.total() / n priorities = [] self.beta = np.min([1., self.beta + self.beta_increment_per_sampling]) for i in range(n): a = segment * i b = segment * (i + 1) s = random.uniform(a, b) (idx, p, data) = self.tree.get(s) priorities.append(p) batch.append(data) idxs.append(idx) sampling_probabilities = priorities / self.tree.total() is_weight = np.power(self.tree.n_entries * sampling_probabilities, -self.beta) is_weight /= is_weight.max() return batch, idxs, is_weight def __len__(self): pass
class Memory: # stored as ( s, a, r, s_ ) in SumTree e = 0.01 a = 0.6 beta = 0.4 beta_increment_per_sampling = 0.001 def __init__(self, capacity): self.tree = SumTree(capacity) self.capacity = capacity def _get_priority(self, error): return (np.abs(error) + self.e)**self.a def add(self, error, sample): p = self._get_priority(error) self.tree.add(p, sample) def sample(self, n): batch = [] idxs = [] segment = self.tree.total() / n priorities = [] self.beta = np.min([1., self.beta + self.beta_increment_per_sampling]) for i in range(n): a = segment * i b = segment * (i + 1) s = random.uniform(a, b) (idx, p, data) = self.tree.get(s) priorities.append(p) batch.append(data) idxs.append(idx) sampling_probabilities = priorities / self.tree.total() is_weight = np.power(self.tree.n_entries * sampling_probabilities, -self.beta) is_weight /= is_weight.max() return batch, idxs, is_weight def update(self, idx, error): p = self._get_priority(error) self.tree.update(idx, p)
def __init__(self, capacity): self.tree = SumTree(capacity) self.capacity = capacity
class PrioritizedReplayBuffer(): alpha = 0.6 beta = 0.4 epsilon = 0.01 beta_increment_per_sampling = 0.0001 abs_err_upper = 1.0 # clipped abs error def __init__(self, buffer_size, batch_size): self.batch_size = batch_size self.experience = namedtuple( "Experience", field_names=["state", "action", "reward", "next_state", "done"]) self.memory = SumTree(buffer_size) def get_priority(self, reward): return (np.abs(reward) + self.epsilon)**self.alpha def add(self, state, action, reward, next_state, done): """Add a new experience to memory""" # priority = self.get_priority(reward) e = self.experience(state, action, reward, next_state, done) max_priority = np.max(self.memory.tree_data[-self.memory.buffer_size:]) if (max_priority == 0): max_priority = self.abs_err_upper self.memory.add(max_priority, e) def update(self, idx, error): error += self.epsilon clipped_errors = np.minimum(error, self.abs_err_upper) priority = clipped_errors**self.alpha self.memory.update(idx, priority) def sample(self): experiences = [] priorities = [] indexs = [] priority_segment = self.memory.total / self.batch_size self.beta = min(1.0, self.beta + self.beta_increment_per_sampling) for i in range(self.batch_size): low_bound, high_bound = i * priority_segment, ( i + 1) * priority_segment value = np.random.uniform(low_bound, high_bound) leaf_index, priority, data = self.memory.get_leaf(value) indexs.append(leaf_index) priorities.append(priority) experiences.append(data) priorities = priorities / self.memory.total weights = np.power(priorities * self.memory.N, -self.beta) weights /= weights.max() states = torch.from_numpy( np.vstack([e.state for e in experiences if e is not None])).float().to(device) actions = torch.from_numpy( np.vstack([e.action for e in experiences if e is not None])).long().to(device) rewards = torch.from_numpy( np.vstack([e.reward for e in experiences if e is not None])).float().to(device) next_states = torch.from_numpy( np.vstack([e.next_state for e in experiences if e is not None])).float().to(device) dones = torch.from_numpy( np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device) return (states, actions, rewards, next_states, dones, indexs, weights) def __len__(self): return self.memory.N
def __init__(self, buffer_size, batch_size): self.batch_size = batch_size self.experience = namedtuple( "Experience", field_names=["state", "action", "reward", "next_state", "done"]) self.memory = SumTree(buffer_size)