Esempio n. 1
0
File: v8.py Progetto: ymd-h/cpprb
    def test_sample(self):
        buffer_size = 500
        obs_shape = (84, 84, 3)
        act_dim = 4

        rb = PrioritizedReplayBuffer(buffer_size, {
            "obs": {
                "shape": obs_shape
            },
            "act": {
                "shape": act_dim
            },
            "rew": {},
            "done": {}
        },
                                     next_of="obs")

        obs = np.zeros(obs_shape)
        act = np.ones(act_dim)
        rew = 1
        done = 0

        rb.add(obs=obs, act=act, rew=rew, next_obs=obs, done=done)

        ps = 1.5

        rb.add(obs=obs,
               act=act,
               rew=rew,
               next_obs=obs,
               done=done,
               priorities=ps)

        self.assertAlmostEqual(rb.get_max_priority(), 1.5)

        obs = np.stack((obs, obs))
        act = np.stack((act, act))
        rew = (1, 0)
        done = (0.0, 1.0)

        rb.add(obs=obs, act=act, rew=rew, next_obs=obs, done=done)

        ps = (0.2, 0.4)
        rb.add(obs=obs,
               act=act,
               rew=rew,
               next_obs=obs,
               done=done,
               priorities=ps)

        sample = rb.sample(64)

        w = sample["weights"]
        i = sample["indexes"]

        rb.update_priorities(i, w * w)
Esempio n. 2
0
File: v8.py Progetto: ymd-h/cpprb
    def test_add(self):
        buffer_size = 500
        obs_shape = (84, 84, 3)
        act_dim = 10

        rb = PrioritizedReplayBuffer(buffer_size, {
            "obs": {
                "shape": obs_shape
            },
            "act": {
                "shape": act_dim
            },
            "rew": {},
            "done": {}
        },
                                     next_of=("obs"))

        obs = np.zeros(obs_shape)
        act = np.ones(act_dim)
        rew = 1
        done = 0

        rb.add(obs=obs, act=act, rew=rew, next_obs=obs, done=done)

        ps = 1.5

        rb.add(obs=obs,
               act=act,
               rew=rew,
               next_obs=obs,
               done=done,
               priorities=ps)

        self.assertAlmostEqual(rb.get_max_priority(), 1.5)

        obs = np.stack((obs, obs))
        act = np.stack((act, act))
        rew = (1, 0)
        done = (0.0, 1.0)

        rb.add(obs=obs, act=act, rew=rew, next_obs=obs, done=done)

        ps = (0.2, 0.4)
        rb.add(obs=obs,
               act=act,
               rew=rew,
               next_obs=obs,
               done=done,
               priorities=ps)

        rb.clear()
        self.assertEqual(rb.get_next_index(), 0)
        self.assertEqual(rb.get_stored_size(), 0)