コード例 #1
0
    def test_atari_env(self):
        """
        Tests working of Atari Wrappers and the AtariEnv function
        """
        env = VectorEnv("Pong-v0", env_type="atari")
        algo = DQN("cnn", env, replay_size=100)

        trainer = OffPolicyTrainer(algo, env, epochs=1, max_timesteps=50)
        trainer.train()
        shutil.rmtree("./logs")
コード例 #2
0
ファイル: test_dqn.py プロジェクト: threewisemonkeys-as/genrl
 def test_vanilla_dqn(self):
     env = VectorEnv("CartPole-v0")
     algo = DQN("mlp", env, batch_size=5, replay_size=100)
     assert isinstance(algo.model, MlpValue)
     trainer = OffPolicyTrainer(
         algo,
         env,
         log_mode=["csv"],
         logdir="./logs",
         max_ep_len=200,
         epochs=4,
         warmup_steps=10,
         start_update=10,
     )
     trainer.train()
     shutil.rmtree("./logs")
コード例 #3
0
ファイル: test_dqn_cnn.py プロジェクト: Sharad24/genrl
 def test_vanilla_dqn(self):
     env = VectorEnv("Pong-v0", env_type="atari")
     algo = DQN("cnn", env, batch_size=5, replay_size=100, value_layers=[1, 1])
     assert isinstance(algo.model, CnnValue)
     trainer = OffPolicyTrainer(
         algo,
         env,
         log_mode=["csv"],
         logdir="./logs",
         max_ep_len=200,
         epochs=4,
         warmup_steps=10,
         start_update=10,
         max_timesteps=100,
     )
     trainer.train()
     shutil.rmtree("./logs")
コード例 #4
0
    def test_atari_env(self):
        """
        Tests working of Atari Wrappers and the AtariEnv function
        """
        env = VectorEnv("Pong-v0", env_type="atari")
        algo = DQN("cnn",
                   env,
                   batch_size=5,
                   replay_size=100,
                   value_layers=[1, 1])

        trainer = OffPolicyTrainer(algo,
                                   env,
                                   epochs=5,
                                   max_ep_len=200,
                                   warmup_steps=10,
                                   start_update=10)
        trainer.train()
        shutil.rmtree("./logs")