def test_eps_greedy(self):
     dqn, locals = self.create_dqn()
     dqn.init()
     env = locals['env']
     eps = EpsilonGreedy(action_space=dqn.env_spec.action_space,
                         init_random_prob=0.5)
     st = env.reset()
     for i in range(100):
         ac = eps.predict(obs=st, sess=self.sess, batch_flag=False, algo=dqn)
         st_new, re, done, _ = env.step(action=ac)
         self.assertTrue(env.action_space.contains(ac))
    def test_eps_with_scheduler(self):
        dqn, locals = self.create_dqn()
        env = locals['env']

        def func():
            global x
            return x

        dqn.init()
        eps = EpsilonGreedy(action_space=dqn.env_spec.action_space,
                            prob_scheduler=LinearScheduler(initial_p=1.0, t_fn=func, schedule_timesteps=10,
                                                           final_p=0.0),
                            init_random_prob=1.0)
        st = env.reset()
        for i in range(10):
            global x
            ac = eps.predict(obs=st, sess=self.sess, batch_flag=False, algo=dqn)
            st_new, re, done, _ = env.step(action=ac)
            self.assertAlmostEqual(eps.parameters('random_prob_func')(), 1.0 - (1.0 - 0.0) / 10 * x)
            x += 1