예제 #1
0
    def test_evaluation_option(self):
        # Use a custom callback that asserts that we are running the
        # configured exact number of episodes per evaluation.
        config = (dqn.DQNConfig().environment(env="CartPole-v0").evaluation(
            evaluation_interval=2,
            evaluation_duration=2,
            evaluation_duration_unit="episodes",
            evaluation_config={
                "gamma": 0.98,
            },
        ).callbacks(callbacks_class=AssertEvalCallback))

        for _ in framework_iterator(config, frameworks=("tf", "torch")):
            trainer = config.build()
            # Given evaluation_interval=2, r0, r2, r4 should not contain
            # evaluation metrics, while r1, r3 should.
            r0 = trainer.train()
            print(r0)
            r1 = trainer.train()
            print(r1)
            r2 = trainer.train()
            print(r2)
            r3 = trainer.train()
            print(r3)
            trainer.stop()

            self.assertFalse("evaluation" in r0)
            self.assertTrue("evaluation" in r1)
            self.assertFalse("evaluation" in r2)
            self.assertTrue("evaluation" in r3)
            self.assertTrue("episode_reward_mean" in r1["evaluation"])
            self.assertNotEqual(r1["evaluation"], r3["evaluation"])
예제 #2
0
    def test_evaluation_option_always_attach_eval_metrics(self):
        # Use a custom callback that asserts that we are running the
        # configured exact number of episodes per evaluation.
        config = (dqn.DQNConfig().environment(env="CartPole-v0").evaluation(
            evaluation_interval=2,
            evaluation_duration=2,
            evaluation_duration_unit="episodes",
            evaluation_config={
                "gamma": 0.98,
            },
            always_attach_evaluation_results=True,
        ).callbacks(callbacks_class=AssertEvalCallback))
        for _ in framework_iterator(config, frameworks=("tf", "torch")):
            trainer = config.build()
            # Should always see latest available eval results.
            r0 = trainer.train()
            r1 = trainer.train()
            r2 = trainer.train()
            r3 = trainer.train()
            trainer.stop()

            # Eval results are not available at step 0.
            # But step 3 should still have it, even though no eval was
            # run during that step.
            self.assertTrue("evaluation" in r0)
            self.assertTrue("evaluation" in r1)
            self.assertTrue("evaluation" in r2)
            self.assertTrue("evaluation" in r3)
예제 #3
0
파일: registry.py 프로젝트: parasj/ray
def _import_dqn():
    import ray.rllib.algorithms.dqn as dqn

    return dqn.DQN, dqn.DQNConfig().to_dict()