def test_indexing_non_existing_item_raises_error(self): tree = SumTree(2) tree.append(0.1, 10) with pytest.raises(IndexError): _ = tree[1]
def test_max_length_can_only_be_power_of_two(self): with pytest.raises(ValueError): _ = SumTree(3) with pytest.raises(ValueError): _ = SumTree(5)
def test_sum_is_correct(self): tree = SumTree(2) tree.append(0.1, 10) tree.append(0.2, 20) assert tree.sum == pytest.approx(0.3, 0.01)
def test_retrieve_correct_at_priority_boundaries(self): tree = SumTree(2) tree.append(0.1, 10) tree.append(0.2, 20) assert tree.retrieve(0.1) == 0 assert tree.retrieve(0.2) == 1
def test_retrieve_returns_object_correctly(self): tree = SumTree(2) tree.append(0.1, 10) tree.append(0.2, 10) assert tree.retrieve(0.05) == 0 assert tree.retrieve(0.15) == 1
def test_append_adds_items(self): tree = SumTree(2) tree.append(0.1, 10) tree.append(0.1, 20) assert tree[0] == 10 assert tree[1] == 20
def test_retrieve_correct_on_two_level_tree(self): tree = SumTree(4) tree.append(0.1, 10) tree.append(0.1, 10) tree.append(0.1, 10) assert tree.retrieve(0.25) == 2
def __init__( self, buffer_size, batch_size, epsilon=0.001, alpha=0.6, beta=0.6, seed=None, ): """Initialize an instance of Prioritized Experience Replay. Raises: ValueError: max_buffer must be a power of two Args: buffer_size (int): maximum size of buffer batch_size (int): size of training batches to sample epsilon (float): value that is added to all priorities alpha (float): exponent which determines how much priorization is used, with alpha == 0 corresponding to the uniform case beta (float): important sampling bias correction exponent, where beta == 1 corresponds to full bias correction seed (int): optional, seed for randomness """ try: self._storage = SumTree(buffer_size) except ValueError: raise self.batch_size = batch_size self.epsilon = epsilon self.alpha = alpha self.beta = beta self.Experience = namedtuple( "Experience", ["state", "action", "reward", "next_state", "done"]) self.highest_priority = 0.1 self.highest_isweight = 0. self._sampled_indices = deque() if seed is not None: self.seed = random.seed(seed)
def test_retrieve_return_last_item_when_priority_gt_sum(self): expected_item = 20 tree = SumTree(4) tree.append(0.1, 10) tree.append(0.1, expected_item) priority = tree.sum + 10.0 index = tree.retrieve(priority) assert tree[index] == expected_item
def test_sum_correct_after_overflow(self): tree = SumTree(2) tree.append(0.1, 10) tree.append(0.2, 20) tree.append(0.3, 20) assert tree.sum == pytest.approx(0.5, 0.01)
def test_sum_of_incomplete_tree_is_correct(self): tree = SumTree(4) tree.append(0.1, 10) tree.append(0.2, 20) tree.append(0.3, 30) assert tree.sum == pytest.approx(0.6, 0.01)
def test_append_overflow_replaces_oldest_item(self): tree = SumTree(2) tree.append(0.1, 10) tree.append(0.1, 20) tree.append(0.1, 30) assert tree[0] == 30 assert tree[1] == 20
def test_sum_of_empty_tree_is_zero(self): tree = SumTree(2) assert tree.sum == 0.0
def test_doctring_example_works_as_described(self): tree = SumTree(2) tree.append(0.1, "item object") assert tree[0] == "item object"
class PrioritizedReplayBuffer: """Prioritized Experience Replay with Proportional Prioritization. In reinforcement learning, prioritizing which transitions are replayed can make experience replay more effective compared to if all transitions are replayed uniformly. Related paper: https://arxiv.org/pdf/1511.05952.pdf """ def __init__( self, buffer_size, batch_size, epsilon=0.001, alpha=0.6, beta=0.6, seed=None, ): """Initialize an instance of Prioritized Experience Replay. Raises: ValueError: max_buffer must be a power of two Args: buffer_size (int): maximum size of buffer batch_size (int): size of training batches to sample epsilon (float): value that is added to all priorities alpha (float): exponent which determines how much priorization is used, with alpha == 0 corresponding to the uniform case beta (float): important sampling bias correction exponent, where beta == 1 corresponds to full bias correction seed (int): optional, seed for randomness """ try: self._storage = SumTree(buffer_size) except ValueError: raise self.batch_size = batch_size self.epsilon = epsilon self.alpha = alpha self.beta = beta self.Experience = namedtuple( "Experience", ["state", "action", "reward", "next_state", "done"]) self.highest_priority = 0.1 self.highest_isweight = 0. self._sampled_indices = deque() if seed is not None: self.seed = random.seed(seed) def add(self, state, action, reward, next_state, done, priority=None): """Adds a new experience into buffer.""" exp = self.Experience(state, action, reward, next_state, done) if priority is not None: priority = pow(abs(priority) + self.epsilon, self.alpha) if priority > self.highest_priority: self.highest_priority = priority else: priority = self.highest_priority self._storage.append(priority, exp) def update_priorities(self, new_priorities): """Updates priorities for previously sampled batch of experience.""" if len(new_priorities) != len(self._sampled_indices): raise ValueError( "sample() should be called before called right before " "calling this method, and length of argument " "'new_priorities' should match batch_size") new_priorities = np.power( np.abs(new_priorities) + self.epsilon, self.alpha) max_priority = np.max(new_priorities) if max_priority > self.highest_priority: self.highest_priority = max_priority for index, new_priority in zip(self._sampled_indices, new_priorities): self._storage.update_priority(index, new_priority.item()) def sample(self): """Randomly samples a batch of experiences from buffer. Raises: IndexError: if not enough items in buffer, len() needs to be at least size of batch_size Returns: list: a randomly sampled list of objects stored in the buffer """ if len(self._storage) < self.batch_size: raise IndexError( "not enough items in buffer to sample(), try again later") self._sampled_indices.clear() range_size = self.sum() / self.batch_size sample_priorities = deque() for range_start, range_end in zip(range(self.batch_size), range(1, self.batch_size + 1)): priority = random.uniform(range_size * range_start, range_size * range_end) index = self._storage.retrieve(priority) self._sampled_indices.append(index) priority = self._storage.get_priority(index) sample_priorities.append(priority) experiences = [self._storage[idx] for idx in self._sampled_indices] importance_sampling_weights = np.power( self.sum() * np.array(sample_priorities), -self.beta) max_weight = np.max(importance_sampling_weights) if max_weight > self.highest_isweight: self.highest_isweight = max_weight importance_sampling_weights /= self.highest_isweight return ( importance_sampling_weights, ) + self._unpack_samples(experiences) def sum(self): """Returns the total sum of priorities within the buffer.""" return self._storage.sum def _unpack_samples(self, samples): """Unpacks a list of Experience samples into tuples.""" states = [exp.state for exp in samples] actions = [exp.action for exp in samples] rewards = [exp.reward for exp in samples] next_states = [exp.next_state for exp in samples] dones = [exp.done for exp in samples] return (states, actions, rewards, next_states, dones) def __getitem__(self, index): """Returns the item from the Replay Buffer at the given `index`. Args: index (int): index of the item Raises: IndexError: if index is out of range Returns: object: the object at location `index` """ try: item = self._storage[index] except IndexError: raise return item def __len__(self): """Return the current size of internal buffer.""" return len(self._storage)