Beispiel #1
0
    def test_too_big_max_steps(self):
        """
        Raise ValueError for too big max_steps
        """
        rb = ReplayBuffer(
            32, {
                "obs": {
                    "shape": (3, )
                },
                "act": {},
                "rew": {},
                "next_obs": {
                    "shape": (3, )
                },
                "done": {}
            })

        def update(kw, step, episode):
            raise RuntimeError

        with self.assertRaises(ValueError):
            train(rb,
                  self.env,
                  lambda obs, step, episode, is_warmup: 1.0,
                  update,
                  max_steps=int(1e+32))
Beispiel #2
0
    def test_update_count(self):
        """
        Check step and episode

        step < max_steps
        episode <= step
        """
        rb = ReplayBuffer(
            32, {
                "obs": {
                    "shape": (3, )
                },
                "act": {},
                "rew": {},
                "next_obs": {
                    "shape": (3, )
                },
                "done": {}
            })

        def update(kw, step, episode):
            self.assertLess(step, 10)
            self.assertLessEqual(episode, step)
            return 0.5

        train(rb,
              self.env,
              lambda obs, step, episode, is_warmup: 1.0,
              update,
              max_steps=10)
Beispiel #3
0
    def test_episode_callback(self):
        """
        Pass custom episode_callback
        """
        rb = ReplayBuffer(
            32, {
                "obs": {
                    "shape": (3, )
                },
                "act": {},
                "rew": {},
                "next_obs": {
                    "shape": (3, )
                },
                "done": {}
            })

        def callback(episode, episode_step, episode_reward):
            self.assertEqual(episode_step, int(episode_reward))

        train(rb,
              self.env,
              lambda obs, step, episode, is_warmup: 1.0,
              lambda tr, step, episode: 0.5,
              max_steps=10,
              rew_sum=lambda sum, tr: sum + 1.0,
              done_check=lambda tr: True)
Beispiel #4
0
    def test_done_check(self):
        """
        Pass custom check_done which always return `True`

        Always step == episode
        """
        rb = ReplayBuffer(
            32, {
                "obs": {
                    "shape": (3, )
                },
                "act": {},
                "rew": {},
                "next_obs": {
                    "shape": (3, )
                },
                "done": {}
            })

        def update(kw, step, episode):
            self.assertLess(step, 10)
            self.assertEqual(step, episode)
            return 0.5

        train(rb,
              self.env,
              lambda obs, step, episode, is_warmup: 1.0,
              update,
              max_steps=10,
              done_check=lambda kw: True)
Beispiel #5
0
    def test_warmup(self):
        """
        Skip warmup steps

        n_warmups <= step
        """
        rb = ReplayBuffer(
            32, {
                "obs": {
                    "shape": (3, )
                },
                "act": {},
                "rew": {},
                "next_obs": {
                    "shape": (3, )
                },
                "done": {}
            })

        def update(kw, step, episode):
            self.assertGreaterEqual(step, 5)
            self.assertLess(step, 10)
            self.assertLessEqual(episode, step)
            return 0.5

        train(rb,
              self.env,
              lambda obs, step, episode, is_warmup: 1.0,
              update,
              max_steps=10,
              n_warmups=5)
Beispiel #6
0
 def test_per_train(self):
     """
     Run train function with PER
     """
     rb = PrioritizedReplayBuffer(
         32, {
             "obs": {
                 "shape": (3, )
             },
             "act": {},
             "rew": {},
             "next_obs": {
                 "shape": (3, )
             },
             "done": {}
         })
     train(rb,
           self.env,
           lambda obs, step, episode, is_warmup: 1.0,
           lambda kwargs, step, episode: 0.5,
           max_steps=10)
Beispiel #7
0
 def test_default_train(self):
     """
     Run train function with default arguments
     """
     rb = ReplayBuffer(
         32, {
             "obs": {
                 "shape": (3, )
             },
             "act": {},
             "rew": {},
             "next_obs": {
                 "shape": (3, )
             },
             "done": {}
         })
     train(rb,
           self.env,
           lambda obs, step, episode, is_warmup: 1.0,
           lambda kwargs, step, episode: 0.5,
           max_steps=10)
Beispiel #8
0
    def test_after_step(self):
        """
        Pass custom after_step
        """
        rb = ReplayBuffer(
            32, {
                "obs": {
                    "shape": (3, )
                },
                "act": {},
                "rew": {},
                "next_obs": {
                    "shape": (3, )
                },
                "done": {}
            })

        def after_step(obs, act, step_returns, step, episode):
            next_obs, rew, done, info = step_returns
            self.assertEqual(obs.shape, next_obs.shape)
            return {
                "obs": obs,
                "act": act,
                "next_obs": next_obs,
                "rew": rew,
                "done": done
            }

        def update(kw, step, episode):
            self.assertLess(step, 10)
            return 0.5

        train(rb,
              self.env,
              lambda obs, step, episode, is_warmup: 1.0,
              update,
              max_steps=10,
              after_step=after_step)
Beispiel #9
0
    def test_per_without_TD(self):
        """
        Run train function with PER withou TD

        Raise TypeError
        """
        rb = PrioritizedReplayBuffer(
            32, {
                "obs": {
                    "shape": (3, )
                },
                "act": {},
                "rew": {},
                "next_obs": {
                    "shape": (3, )
                },
                "done": {}
            })
        with self.assertRaises(TypeError):
            train(rb,
                  self.env,
                  lambda obs, step, episode, is_warmup: 1.0,
                  lambda kwargs, step, episode: None,
                  max_steps=10)