Exemple #1
0
    def test_add(self):
        buffer_size = 256
        obs_shape = (15, 15)
        act_dim = 5

        rb = ReplayBuffer(
            buffer_size, {
                "obs": {
                    "shape": obs_shape
                },
                "act": {
                    "shape": act_dim
                },
                "rew": {},
                "next_obs": {
                    "shape": obs_shape
                },
                "done": {}
            })

        self.assertEqual(rb.get_next_index(), 0)
        self.assertEqual(rb.get_stored_size(), 0)

        obs = np.zeros(obs_shape)
        act = np.ones(act_dim)
        rew = 1
        next_obs = obs + 1
        done = 0

        rb.add(obs=obs, act=act, rew=rew, next_obs=next_obs, done=done)

        self.assertEqual(rb.get_next_index(), 1)
        self.assertEqual(rb.get_stored_size(), 1)

        with self.assertRaises(KeyError):
            rb.add(obs=obs)

        self.assertEqual(rb.get_next_index(), 1)
        self.assertEqual(rb.get_stored_size(), 1)

        obs = np.stack((obs, obs))
        act = np.stack((act, act))
        rew = (1, 0)
        next_obs = np.stack((next_obs, next_obs))
        done = (0.0, 1.0)

        rb.add(obs=obs, act=act, rew=rew, next_obs=next_obs, done=done)

        self.assertEqual(rb.get_next_index(), 3)
        self.assertEqual(rb.get_stored_size(), 3)
Exemple #2
0
    def test_next_obs(self):
        buffer_size = 256
        obs_shape = (15, 15)
        act_dim = 5

        rb = ReplayBuffer(buffer_size, {
            "obs": {
                "shape": obs_shape,
                "dtype": np.ubyte
            },
            "act": {
                "shape": act_dim
            },
            "rew": {},
            "done": {}
        },
                          next_of="obs")

        self.assertEqual(rb.get_next_index(), 0)
        self.assertEqual(rb.get_stored_size(), 0)

        obs = np.zeros(obs_shape, dtype=np.ubyte)
        act = np.ones(act_dim)
        rew = 1
        done = 0

        rb.add(obs=obs, act=act, rew=rew, next_obs=obs, done=done)

        self.assertEqual(rb.get_next_index(), 1)
        self.assertEqual(rb.get_stored_size(), 1)

        with self.assertRaises(KeyError):
            rb.add(obs=obs)

        self.assertEqual(rb.get_next_index(), 1)
        self.assertEqual(rb.get_stored_size(), 1)

        next_obs = rb.sample(32)["next_obs"]

        for i in range(512):
            obs = np.ones(obs_shape, dtype=np.ubyte) * i
            rb.add(obs=obs, act=act, rew=rew, next_obs=obs + 1, done=done)

        sample = rb._encode_sample(range(buffer_size))

        ith = rb.get_next_index()
        np.testing.assert_allclose(
            np.roll(sample["obs"], -ith - 1, axis=0)[1:],
            np.roll(sample["next_obs"], -ith - 1, axis=0)[:-1])
Exemple #3
0
    def test_buffer(self):

        buffer_size = 256
        obs_shape = (15,15)
        act_dim = 5

        N = 512

        erb = ReplayBuffer(buffer_size,{"obs":{"shape": obs_shape},
                                        "act":{"shape": act_dim},
                                        "rew":{},
                                        "next_obs":{"shape": obs_shape},
                                        "done":{}})

        for i in range(N):
            obs = np.full(obs_shape,i,dtype=np.double)
            act = np.full(act_dim,i,dtype=np.double)
            rew = i
            next_obs = obs + 1
            done = 0

            erb.add(obs=obs,act=act,rew=rew,next_obs=next_obs,done=done)

        es = erb._encode_sample(range(buffer_size))

        erb.sample(32)

        erb.clear()

        self.assertEqual(erb.get_next_index(),0)
        self.assertEqual(erb.get_stored_size(),0)