def test_segtree(data): size = 100000 tree = SegmentTree(size) tree[np.arange(size)] = np.random.rand(size) for i in np.arange(1e5): scalar = np.random.rand(64) * tree.reduce() tree.get_prefix_sum_idx(scalar)
def test_segtree(): for op, init in zip(['sum', 'max', 'min'], [0., -np.inf, np.inf]): realop = getattr(np, op) # small test actual_len = 8 tree = SegmentTree(actual_len, op) # 1-15. 8-15 are leaf nodes assert len(tree) == actual_len assert np.all([tree[i] == init for i in range(actual_len)]) with pytest.raises(IndexError): tree[actual_len] naive = np.full([actual_len], init) for _ in range(1000): # random choose a place to perform single update index = np.random.randint(actual_len) value = np.random.rand() naive[index] = value tree[index] = value for i in range(actual_len): for j in range(i + 1, actual_len): ref = realop(naive[i:j]) out = tree.reduce(i, j) assert np.allclose(ref, out) assert np.allclose(tree.reduce(start=1), realop(naive[1:])) assert np.allclose(tree.reduce(end=-1), realop(naive[:-1])) # batch setitem for _ in range(1000): index = np.random.choice(actual_len, size=4) value = np.random.rand(4) naive[index] = value tree[index] = value assert np.allclose(realop(naive), tree.reduce()) for i in range(10): left = np.random.randint(actual_len) right = np.random.randint(left + 1, actual_len + 1) assert np.allclose(realop(naive[left:right]), tree.reduce(left, right)) # large test actual_len = 16384 tree = SegmentTree(actual_len, op) naive = np.full([actual_len], init) for _ in range(1000): index = np.random.choice(actual_len, size=64) value = np.random.rand(64) naive[index] = value tree[index] = value assert np.allclose(realop(naive), tree.reduce()) for i in range(10): left = np.random.randint(actual_len) right = np.random.randint(left + 1, actual_len + 1) assert np.allclose(realop(naive[left:right]), tree.reduce(left, right)) # test prefix-sum-idx actual_len = 8 tree = SegmentTree(actual_len) naive = np.random.rand(actual_len) tree[np.arange(actual_len)] = naive for _ in range(1000): scalar = np.random.rand() * naive.sum() index = tree.get_prefix_sum_idx(scalar) assert naive[:index].sum() <= scalar <= naive[:index + 1].sum() # corner case here naive = np.ones(actual_len, np.int) tree[np.arange(actual_len)] = naive for scalar in range(actual_len): index = tree.get_prefix_sum_idx(scalar * 1.) assert naive[:index].sum() <= scalar <= naive[:index + 1].sum() tree = SegmentTree(10) tree[np.arange(3)] = np.array([0.1, 0, 0.1]) assert np.allclose(tree.get_prefix_sum_idx( np.array([0, .1, .1 + 1e-6, .2 - 1e-6])), [0, 0, 2, 2]) with pytest.raises(AssertionError): tree.get_prefix_sum_idx(.2) # test large prefix-sum-idx actual_len = 16384 tree = SegmentTree(actual_len) naive = np.random.rand(actual_len) tree[np.arange(actual_len)] = naive for _ in range(1000): scalar = np.random.rand() * naive.sum() index = tree.get_prefix_sum_idx(scalar) assert naive[:index].sum() <= scalar <= naive[:index + 1].sum() # profile if __name__ == '__main__': size = 100000 bsz = 64 naive = np.random.rand(size) tree = SegmentTree(size) tree[np.arange(size)] = naive def sample_npbuf(): return np.random.choice(size, bsz, p=naive / naive.sum()) def sample_tree(): scalar = np.random.rand(bsz) * tree.reduce() return tree.get_prefix_sum_idx(scalar) print('npbuf', timeit(sample_npbuf, setup=sample_npbuf, number=1000)) print('tree', timeit(sample_tree, setup=sample_tree, number=1000))
class PrioritizedReplayBuffer(ReplayBuffer): """Implementation of Prioritized Experience Replay. arXiv:1511.05952. :param float alpha: the prioritization exponent. :param float beta: the importance sample soft coefficient. .. seealso:: Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. """ def __init__(self, size: int, alpha: float, beta: float, **kwargs: Any) -> None: # will raise KeyError in PrioritizedVectorReplayBuffer # super().__init__(size, **kwargs) ReplayBuffer.__init__(self, size, **kwargs) assert alpha > 0.0 and beta >= 0.0 self._alpha, self._beta = alpha, beta self._max_prio = self._min_prio = 1.0 # save weight directly in this class instead of self._meta self.weight = SegmentTree(size) self.__eps = np.finfo(np.float32).eps.item() self.options.update(alpha=alpha, beta=beta) def init_weight(self, index: Union[int, np.ndarray]) -> None: self.weight[index] = self._max_prio**self._alpha def update(self, buffer: ReplayBuffer) -> np.ndarray: indices = super().update(buffer) self.init_weight(indices) return indices def add( self, batch: Batch, buffer_ids: Optional[Union[np.ndarray, List[int]]] = None ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: ptr, ep_rew, ep_len, ep_idx = super().add(batch, buffer_ids) self.init_weight(ptr) return ptr, ep_rew, ep_len, ep_idx def sample_index(self, batch_size: int) -> np.ndarray: if batch_size > 0 and len(self) > 0: scalar = np.random.rand(batch_size) * self.weight.reduce() return self.weight.get_prefix_sum_idx(scalar) # type: ignore else: return super().sample_index(batch_size) def get_weight(self, index: Union[int, np.ndarray]) -> Union[float, np.ndarray]: """Get the importance sampling weight. The "weight" in the returned Batch is the weight on loss function to de-bias the sampling process (some transition tuples are sampled more often so their losses are weighted less). """ # important sampling weight calculation # original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta)) # simplified formula: (p_j/p_min)**(-beta) return (self.weight[index] / self._min_prio)**(-self._beta) def update_weight(self, index: np.ndarray, new_weight: Union[np.ndarray, torch.Tensor]) -> None: """Update priority weight by index in this buffer. :param np.ndarray index: index you want to update weight. :param np.ndarray new_weight: new priority weight you want to update. """ weight = np.abs(to_numpy(new_weight)) + self.__eps self.weight[index] = weight**self._alpha self._max_prio = max(self._max_prio, weight.max()) self._min_prio = min(self._min_prio, weight.min()) def __getitem__(self, index: Union[slice, int, List[int], np.ndarray]) -> Batch: if isinstance(index, slice): # change slice to np array # buffer[:] will get all available data indice = self.sample_index(0) if index == slice(None) \ else self._indices[:len(self)][index] else: indice = index batch = super().__getitem__(indice) batch.weight = self.get_weight(indice) return batch
class PrioritizedReplayBuffer(ReplayBuffer): """Implementation of Prioritized Experience Replay. arXiv:1511.05952. :param float alpha: the prioritization exponent. :param float beta: the importance sample soft coefficient. .. seealso:: Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed explanation. """ def __init__(self, size: int, alpha: float, beta: float, **kwargs: Any) -> None: super().__init__(size, **kwargs) assert alpha > 0.0 and beta >= 0.0 self._alpha, self._beta = alpha, beta self._max_prio = self._min_prio = 1.0 # save weight directly in this class instead of self._meta self.weight = SegmentTree(size) self.__eps = np.finfo(np.float32).eps.item() def add( self, obs: Any, act: Any, rew: Union[Number, np.number, np.ndarray], done: Union[Number, np.number, np.bool_], obs_next: Any = None, info: Optional[Union[dict, Batch]] = {}, policy: Optional[Union[dict, Batch]] = {}, weight: Optional[Union[Number, np.number]] = None, **kwargs: Any, ) -> Tuple[int, Union[float, np.ndarray]]: if weight is None: weight = self._max_prio else: weight = np.abs(weight) self._max_prio = max(self._max_prio, weight) self._min_prio = min(self._min_prio, weight) self.weight[self._index] = weight**self._alpha return super().add(obs, act, rew, done, obs_next, info, policy, **kwargs) def sample_index(self, batch_size: int) -> np.ndarray: if batch_size > 0 and self._size > 0: scalar = np.random.rand(batch_size) * self.weight.reduce() return self.weight.get_prefix_sum_idx(scalar) else: return super().sample_index(batch_size) def get_weight( self, index: Union[slice, int, np.integer, np.ndarray]) -> np.ndarray: """Get the importance sampling weight. The "weight" in the returned Batch is the weight on loss function to de-bias the sampling process (some transition tuples are sampled more often so their losses are weighted less). """ # important sampling weight calculation # original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta)) # simplified formula: (p_j/p_min)**(-beta) return (self.weight[index] / self._min_prio)**(-self._beta) def update_weight( self, index: np.ndarray, new_weight: Union[np.ndarray, torch.Tensor], ) -> None: """Update priority weight by index in this buffer. :param np.ndarray index: index you want to update weight. :param np.ndarray new_weight: new priority weight you want to update. """ weight = np.abs(to_numpy(new_weight)) + self.__eps self.weight[index] = weight**self._alpha self._max_prio = max(self._max_prio, weight.max()) self._min_prio = min(self._min_prio, weight.min()) def __getitem__(self, index: Union[slice, int, np.integer, np.ndarray]) -> Batch: batch = super().__getitem__(index) batch.weight = self.get_weight(index) return batch
class PrioritizedReplayBuffer(ReplayBuffer): """Implementation of Prioritized Experience Replay. arXiv:1511.05952. :param float alpha: the prioritization exponent. :param float beta: the importance sample soft coefficient. .. seealso:: Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed explanation. """ def __init__( self, size: int, alpha: float, beta: float, **kwargs: Any ) -> None: super().__init__(size, **kwargs) assert alpha > 0.0 and beta >= 0.0 self._alpha, self._beta = alpha, beta self._max_prio = self._min_prio = 1.0 # save weight directly in this class instead of self._meta self.weight = SegmentTree(size) self.__eps = np.finfo(np.float32).eps.item() def add( self, obs: Any, act: Any, rew: Union[Number, np.number, np.ndarray], done: Union[Number, np.number, np.bool_], obs_next: Any = None, info: Optional[Union[dict, Batch]] = {}, policy: Optional[Union[dict, Batch]] = {}, weight: Optional[Union[Number, np.number]] = None, **kwargs: Any, ) -> None: """Add a batch of data into replay buffer.""" if weight is None: weight = self._max_prio else: weight = np.abs(weight) self._max_prio = max(self._max_prio, weight) self._min_prio = min(self._min_prio, weight) self.weight[self._index] = weight ** self._alpha super().add(obs, act, rew, done, obs_next, info, policy, **kwargs) def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]: """Get a random sample from buffer with priority probability. Return all the data in the buffer if batch_size is 0. :return: Sample data and its corresponding index inside the buffer. The "weight" in the returned Batch is the weight on loss function to de-bias the sampling process (some transition tuples are sampled more often so their losses are weighted less). """ assert self._size > 0, "Cannot sample a buffer with 0 size!" if batch_size == 0: indice = np.concatenate([ np.arange(self._index, self._size), np.arange(0, self._index), ]) else: scalar = np.random.rand(batch_size) * self.weight.reduce() indice = self.weight.get_prefix_sum_idx(scalar) batch = self[indice] # important sampling weight calculation # original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta)) # simplified formula: (p_j/p_min)**(-beta) batch.weight = (batch.weight / self._min_prio) ** (-self._beta) return batch, indice def update_weight( self, indice: Union[np.ndarray], new_weight: Union[np.ndarray, torch.Tensor] ) -> None: """Update priority weight by indice in this buffer. :param np.ndarray indice: indice you want to update weight. :param np.ndarray new_weight: new priority weight you want to update. """ weight = np.abs(to_numpy(new_weight)) + self.__eps self.weight[indice] = weight ** self._alpha self._max_prio = max(self._max_prio, weight.max()) self._min_prio = min(self._min_prio, weight.min()) def __getitem__( self, index: Union[slice, int, np.integer, np.ndarray] ) -> Batch: return Batch( obs=self.get(index, "obs"), act=self.act[index], rew=self.rew[index], done=self.done[index], obs_next=self.get(index, "obs_next"), info=self.get(index, "info"), policy=self.get(index, "policy"), weight=self.weight[index], )