コード例 #1
0
def test_segtree(data):
    size = 100000
    tree = SegmentTree(size)
    tree[np.arange(size)] = np.random.rand(size)

    for i in np.arange(1e5):
        scalar = np.random.rand(64) * tree.reduce()
        tree.get_prefix_sum_idx(scalar)
コード例 #2
0
ファイル: buffer.py プロジェクト: ZhangMaoJun/tianshou
 def __init__(self, size: int, alpha: float, beta: float,
              **kwargs: Any) -> None:
     super().__init__(size, **kwargs)
     assert alpha > 0.0 and beta >= 0.0
     self._alpha, self._beta = alpha, beta
     self._max_prio = self._min_prio = 1.0
     # save weight directly in this class instead of self._meta
     self.weight = SegmentTree(size)
     self.__eps = np.finfo(np.float32).eps.item()
コード例 #3
0
 def __init__(self, size: int, alpha: float, beta: float,
              **kwargs: Any) -> None:
     # will raise KeyError in PrioritizedVectorReplayBuffer
     # super().__init__(size, **kwargs)
     ReplayBuffer.__init__(self, size, **kwargs)
     assert alpha > 0.0 and beta >= 0.0
     self._alpha, self._beta = alpha, beta
     self._max_prio = self._min_prio = 1.0
     # save weight directly in this class instead of self._meta
     self.weight = SegmentTree(size)
     self.__eps = np.finfo(np.float32).eps.item()
     self.options.update(alpha=alpha, beta=beta)
コード例 #4
0
 def __init__(self, size: int, alpha: float, beta: float, **kwargs) -> None:
     super().__init__(size, **kwargs)
     assert alpha > 0. and beta >= 0.
     self._alpha, self._beta = alpha, beta
     self._max_prio = 1.
     self._min_prio = 1.
     # bypass the check
     self._weight = SegmentTree(size)
     self.__eps = np.finfo(np.float32).eps.item()
コード例 #5
0
def test_segtree():
    for op, init in zip(['sum', 'max', 'min'], [0., -np.inf, np.inf]):
        realop = getattr(np, op)
        # small test
        actual_len = 8
        tree = SegmentTree(actual_len, op)  # 1-15. 8-15 are leaf nodes
        assert len(tree) == actual_len
        assert np.all([tree[i] == init for i in range(actual_len)])
        with pytest.raises(IndexError):
            tree[actual_len]
        naive = np.full([actual_len], init)
        for _ in range(1000):
            # random choose a place to perform single update
            index = np.random.randint(actual_len)
            value = np.random.rand()
            naive[index] = value
            tree[index] = value
            for i in range(actual_len):
                for j in range(i + 1, actual_len):
                    ref = realop(naive[i:j])
                    out = tree.reduce(i, j)
                    assert np.allclose(ref, out)
        assert np.allclose(tree.reduce(start=1), realop(naive[1:]))
        assert np.allclose(tree.reduce(end=-1), realop(naive[:-1]))
        # batch setitem
        for _ in range(1000):
            index = np.random.choice(actual_len, size=4)
            value = np.random.rand(4)
            naive[index] = value
            tree[index] = value
            assert np.allclose(realop(naive), tree.reduce())
            for i in range(10):
                left = np.random.randint(actual_len)
                right = np.random.randint(left + 1, actual_len + 1)
                assert np.allclose(realop(naive[left:right]),
                                   tree.reduce(left, right))
        # large test
        actual_len = 16384
        tree = SegmentTree(actual_len, op)
        naive = np.full([actual_len], init)
        for _ in range(1000):
            index = np.random.choice(actual_len, size=64)
            value = np.random.rand(64)
            naive[index] = value
            tree[index] = value
            assert np.allclose(realop(naive), tree.reduce())
            for i in range(10):
                left = np.random.randint(actual_len)
                right = np.random.randint(left + 1, actual_len + 1)
                assert np.allclose(realop(naive[left:right]),
                                   tree.reduce(left, right))

    # test prefix-sum-idx
    actual_len = 8
    tree = SegmentTree(actual_len)
    naive = np.random.rand(actual_len)
    tree[np.arange(actual_len)] = naive
    for _ in range(1000):
        scalar = np.random.rand() * naive.sum()
        index = tree.get_prefix_sum_idx(scalar)
        assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
    # corner case here
    naive = np.ones(actual_len, np.int)
    tree[np.arange(actual_len)] = naive
    for scalar in range(actual_len):
        index = tree.get_prefix_sum_idx(scalar * 1.)
        assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()
    tree = SegmentTree(10)
    tree[np.arange(3)] = np.array([0.1, 0, 0.1])
    assert np.allclose(tree.get_prefix_sum_idx(
        np.array([0, .1, .1 + 1e-6, .2 - 1e-6])), [0, 0, 2, 2])
    with pytest.raises(AssertionError):
        tree.get_prefix_sum_idx(.2)
    # test large prefix-sum-idx
    actual_len = 16384
    tree = SegmentTree(actual_len)
    naive = np.random.rand(actual_len)
    tree[np.arange(actual_len)] = naive
    for _ in range(1000):
        scalar = np.random.rand() * naive.sum()
        index = tree.get_prefix_sum_idx(scalar)
        assert naive[:index].sum() <= scalar <= naive[:index + 1].sum()

    # profile
    if __name__ == '__main__':
        size = 100000
        bsz = 64
        naive = np.random.rand(size)
        tree = SegmentTree(size)
        tree[np.arange(size)] = naive

        def sample_npbuf():
            return np.random.choice(size, bsz, p=naive / naive.sum())

        def sample_tree():
            scalar = np.random.rand(bsz) * tree.reduce()
            return tree.get_prefix_sum_idx(scalar)

        print('npbuf', timeit(sample_npbuf, setup=sample_npbuf, number=1000))
        print('tree', timeit(sample_tree, setup=sample_tree, number=1000))
コード例 #6
0
class PrioritizedReplayBuffer(ReplayBuffer):
    """Implementation of Prioritized Experience Replay. arXiv:1511.05952.

    :param float alpha: the prioritization exponent.
    :param float beta: the importance sample soft coefficient.

    .. seealso::

        Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
    """
    def __init__(self, size: int, alpha: float, beta: float,
                 **kwargs: Any) -> None:
        # will raise KeyError in PrioritizedVectorReplayBuffer
        # super().__init__(size, **kwargs)
        ReplayBuffer.__init__(self, size, **kwargs)
        assert alpha > 0.0 and beta >= 0.0
        self._alpha, self._beta = alpha, beta
        self._max_prio = self._min_prio = 1.0
        # save weight directly in this class instead of self._meta
        self.weight = SegmentTree(size)
        self.__eps = np.finfo(np.float32).eps.item()
        self.options.update(alpha=alpha, beta=beta)

    def init_weight(self, index: Union[int, np.ndarray]) -> None:
        self.weight[index] = self._max_prio**self._alpha

    def update(self, buffer: ReplayBuffer) -> np.ndarray:
        indices = super().update(buffer)
        self.init_weight(indices)
        return indices

    def add(
        self,
        batch: Batch,
        buffer_ids: Optional[Union[np.ndarray, List[int]]] = None
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        ptr, ep_rew, ep_len, ep_idx = super().add(batch, buffer_ids)
        self.init_weight(ptr)
        return ptr, ep_rew, ep_len, ep_idx

    def sample_index(self, batch_size: int) -> np.ndarray:
        if batch_size > 0 and len(self) > 0:
            scalar = np.random.rand(batch_size) * self.weight.reduce()
            return self.weight.get_prefix_sum_idx(scalar)  # type: ignore
        else:
            return super().sample_index(batch_size)

    def get_weight(self, index: Union[int,
                                      np.ndarray]) -> Union[float, np.ndarray]:
        """Get the importance sampling weight.

        The "weight" in the returned Batch is the weight on loss function to de-bias
        the sampling process (some transition tuples are sampled more often so their
        losses are weighted less).
        """
        # important sampling weight calculation
        # original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta))
        # simplified formula: (p_j/p_min)**(-beta)
        return (self.weight[index] / self._min_prio)**(-self._beta)

    def update_weight(self, index: np.ndarray,
                      new_weight: Union[np.ndarray, torch.Tensor]) -> None:
        """Update priority weight by index in this buffer.

        :param np.ndarray index: index you want to update weight.
        :param np.ndarray new_weight: new priority weight you want to update.
        """
        weight = np.abs(to_numpy(new_weight)) + self.__eps
        self.weight[index] = weight**self._alpha
        self._max_prio = max(self._max_prio, weight.max())
        self._min_prio = min(self._min_prio, weight.min())

    def __getitem__(self, index: Union[slice, int, List[int],
                                       np.ndarray]) -> Batch:
        if isinstance(index, slice):  # change slice to np array
            # buffer[:] will get all available data
            indice = self.sample_index(0) if index == slice(None) \
                else self._indices[:len(self)][index]
        else:
            indice = index
        batch = super().__getitem__(indice)
        batch.weight = self.get_weight(indice)
        return batch
コード例 #7
0
ファイル: buffer.py プロジェクト: ZhangMaoJun/tianshou
class PrioritizedReplayBuffer(ReplayBuffer):
    """Implementation of Prioritized Experience Replay. arXiv:1511.05952.

    :param float alpha: the prioritization exponent.
    :param float beta: the importance sample soft coefficient.

    .. seealso::

        Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed
        explanation.
    """
    def __init__(self, size: int, alpha: float, beta: float,
                 **kwargs: Any) -> None:
        super().__init__(size, **kwargs)
        assert alpha > 0.0 and beta >= 0.0
        self._alpha, self._beta = alpha, beta
        self._max_prio = self._min_prio = 1.0
        # save weight directly in this class instead of self._meta
        self.weight = SegmentTree(size)
        self.__eps = np.finfo(np.float32).eps.item()

    def add(
        self,
        obs: Any,
        act: Any,
        rew: Union[Number, np.number, np.ndarray],
        done: Union[Number, np.number, np.bool_],
        obs_next: Any = None,
        info: Optional[Union[dict, Batch]] = {},
        policy: Optional[Union[dict, Batch]] = {},
        weight: Optional[Union[Number, np.number]] = None,
        **kwargs: Any,
    ) -> Tuple[int, Union[float, np.ndarray]]:
        if weight is None:
            weight = self._max_prio
        else:
            weight = np.abs(weight)
            self._max_prio = max(self._max_prio, weight)
            self._min_prio = min(self._min_prio, weight)
        self.weight[self._index] = weight**self._alpha
        return super().add(obs, act, rew, done, obs_next, info, policy,
                           **kwargs)

    def sample_index(self, batch_size: int) -> np.ndarray:
        if batch_size > 0 and self._size > 0:
            scalar = np.random.rand(batch_size) * self.weight.reduce()
            return self.weight.get_prefix_sum_idx(scalar)
        else:
            return super().sample_index(batch_size)

    def get_weight(
            self, index: Union[slice, int, np.integer,
                               np.ndarray]) -> np.ndarray:
        """Get the importance sampling weight.

        The "weight" in the returned Batch is the weight on loss function
        to de-bias the sampling process (some transition tuples are sampled
        more often so their losses are weighted less).
        """
        # important sampling weight calculation
        # original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta))
        # simplified formula: (p_j/p_min)**(-beta)
        return (self.weight[index] / self._min_prio)**(-self._beta)

    def update_weight(
        self,
        index: np.ndarray,
        new_weight: Union[np.ndarray, torch.Tensor],
    ) -> None:
        """Update priority weight by index in this buffer.

        :param np.ndarray index: index you want to update weight.
        :param np.ndarray new_weight: new priority weight you want to update.
        """
        weight = np.abs(to_numpy(new_weight)) + self.__eps
        self.weight[index] = weight**self._alpha
        self._max_prio = max(self._max_prio, weight.max())
        self._min_prio = min(self._min_prio, weight.min())

    def __getitem__(self, index: Union[slice, int, np.integer,
                                       np.ndarray]) -> Batch:
        batch = super().__getitem__(index)
        batch.weight = self.get_weight(index)
        return batch
コード例 #8
0
class PrioritizedReplayBuffer(ReplayBuffer):
    """Implementation of Prioritized Experience Replay. arXiv:1511.05952.

    :param float alpha: the prioritization exponent.
    :param float beta: the importance sample soft coefficient.

    .. seealso::

        Please refer to :class:`~tianshou.data.ReplayBuffer` for more detailed
        explanation.
    """

    def __init__(
        self, size: int, alpha: float, beta: float, **kwargs: Any
    ) -> None:
        super().__init__(size, **kwargs)
        assert alpha > 0.0 and beta >= 0.0
        self._alpha, self._beta = alpha, beta
        self._max_prio = self._min_prio = 1.0
        # save weight directly in this class instead of self._meta
        self.weight = SegmentTree(size)
        self.__eps = np.finfo(np.float32).eps.item()

    def add(
        self,
        obs: Any,
        act: Any,
        rew: Union[Number, np.number, np.ndarray],
        done: Union[Number, np.number, np.bool_],
        obs_next: Any = None,
        info: Optional[Union[dict, Batch]] = {},
        policy: Optional[Union[dict, Batch]] = {},
        weight: Optional[Union[Number, np.number]] = None,
        **kwargs: Any,
    ) -> None:
        """Add a batch of data into replay buffer."""
        if weight is None:
            weight = self._max_prio
        else:
            weight = np.abs(weight)
            self._max_prio = max(self._max_prio, weight)
            self._min_prio = min(self._min_prio, weight)
        self.weight[self._index] = weight ** self._alpha
        super().add(obs, act, rew, done, obs_next, info, policy, **kwargs)

    def sample(self, batch_size: int) -> Tuple[Batch, np.ndarray]:
        """Get a random sample from buffer with priority probability.

        Return all the data in the buffer if batch_size is 0.

        :return: Sample data and its corresponding index inside the buffer.

        The "weight" in the returned Batch is the weight on loss function
        to de-bias the sampling process (some transition tuples are sampled
        more often so their losses are weighted less).
        """
        assert self._size > 0, "Cannot sample a buffer with 0 size!"
        if batch_size == 0:
            indice = np.concatenate([
                np.arange(self._index, self._size),
                np.arange(0, self._index),
            ])
        else:
            scalar = np.random.rand(batch_size) * self.weight.reduce()
            indice = self.weight.get_prefix_sum_idx(scalar)
        batch = self[indice]
        # important sampling weight calculation
        # original formula: ((p_j/p_sum*N)**(-beta))/((p_min/p_sum*N)**(-beta))
        # simplified formula: (p_j/p_min)**(-beta)
        batch.weight = (batch.weight / self._min_prio) ** (-self._beta)
        return batch, indice

    def update_weight(
        self,
        indice: Union[np.ndarray],
        new_weight: Union[np.ndarray, torch.Tensor]
    ) -> None:
        """Update priority weight by indice in this buffer.

        :param np.ndarray indice: indice you want to update weight.
        :param np.ndarray new_weight: new priority weight you want to update.
        """
        weight = np.abs(to_numpy(new_weight)) + self.__eps
        self.weight[indice] = weight ** self._alpha
        self._max_prio = max(self._max_prio, weight.max())
        self._min_prio = min(self._min_prio, weight.min())

    def __getitem__(
        self, index: Union[slice, int, np.integer, np.ndarray]
    ) -> Batch:
        return Batch(
            obs=self.get(index, "obs"),
            act=self.act[index],
            rew=self.rew[index],
            done=self.done[index],
            obs_next=self.get(index, "obs_next"),
            info=self.get(index, "info"),
            policy=self.get(index, "policy"),
            weight=self.weight[index],
        )