Beispiel #1
0
def test_sac_shared():
    env = VectorEnv("Pendulum-v0", 2)
    algo = SAC(
        "mlp",
        env,
        batch_size=5,
        shared_layers=[1, 1],
        policy_layers=[1, 1],
        value_layers=[1, 1],
        max_timesteps=100,
    )

    trainer = OffPolicyTrainer(
        algo,
        env,
        log_mode=["csv"],
        logdir="./logs",
        epochs=5,
        max_ep_len=500,
        warmup_steps=10,
        start_update=10,
        max_timesteps=100,
    )
    trainer.train()
    shutil.rmtree("./logs")
Beispiel #2
0
    def test_atari_env(self):
        """
        Tests working of Atari Wrappers and the AtariEnv function
        """
        env = VectorEnv("Pong-v0", env_type="atari")
        algo = DQN("cnn", env, replay_size=100)

        trainer = OffPolicyTrainer(algo, env, epochs=1, max_timesteps=50)
        trainer.train()
        shutil.rmtree("./logs")
Beispiel #3
0
def test_off_policy_trainer():
    env = VectorEnv("Pendulum-v0", 2)
    algo = DDPG("mlp", env, replay_size=100)
    trainer = OffPolicyTrainer(
        algo,
        env,
        ["stdout"],
        epochs=2,
        evaluate_episodes=2,
        max_ep_len=300,
        max_timesteps=300,
    )
    assert trainer.off_policy
    trainer.train()
    trainer.evaluate()
Beispiel #4
0
 def test_double_dqn(self):
     env = VectorEnv("CartPole-v0")
     algo = DoubleDQN("mlp", env, batch_size=5, replay_size=100)
     assert isinstance(algo.model, MlpValue)
     trainer = OffPolicyTrainer(
         algo,
         env,
         log_mode=["csv"],
         logdir="./logs",
         max_ep_len=200,
         epochs=4,
         warmup_steps=10,
         start_update=10,
     )
     trainer.train()
     shutil.rmtree("./logs")
 def test_prioritized_dqn(self):
     env = VectorEnv("Pong-v0", env_type="atari")
     algo = PrioritizedReplayDQN("cnn", env, batch_size=5, replay_size=100)
     assert isinstance(algo.model, CnnValue)
     trainer = OffPolicyTrainer(
         algo,
         env,
         log_mode=["csv"],
         logdir="./logs",
         max_ep_len=200,
         epochs=4,
         warmup_steps=10,
         start_update=10,
     )
     trainer.train()
     shutil.rmtree("./logs")
Beispiel #6
0
 def test_noisy_dqn(self):
     env = VectorEnv("CartPole-v0")
     algo = NoisyDQN("mlp", env, batch_size=5, replay_size=100, value_layers=[1, 1])
     assert algo.dqn_type == "noisy"
     assert algo.noisy
     assert isinstance(algo.model, MlpNoisyValue)
     trainer = OffPolicyTrainer(
         algo,
         env,
         log_mode=["csv"],
         logdir="./logs",
         max_ep_len=200,
         epochs=4,
         warmup_steps=10,
         start_update=10,
     )
     trainer.train()
     shutil.rmtree("./logs")
Beispiel #7
0
    def test_atari_env(self):
        """
        Tests working of Atari Wrappers and the AtariEnv function
        """
        env = VectorEnv("Pong-v0", env_type="atari")
        algo = DQN("cnn",
                   env,
                   batch_size=5,
                   replay_size=100,
                   value_layers=[1, 1])

        trainer = OffPolicyTrainer(algo,
                                   env,
                                   epochs=5,
                                   max_ep_len=200,
                                   warmup_steps=10,
                                   start_update=10)
        trainer.train()
        shutil.rmtree("./logs")
Beispiel #8
0
 def test_dueling_dqn(self):
     env = VectorEnv("Pong-v0", env_type="atari")
     algo = DuelingDQN(
         "cnn", env, batch_size=5, replay_size=100, value_layers=[1, 1]
     )
     assert algo.dqn_type == "dueling"
     assert isinstance(algo.model, CnnDuelingValue)
     trainer = OffPolicyTrainer(
         algo,
         env,
         log_mode=["csv"],
         logdir="./logs",
         max_ep_len=200,
         epochs=4,
         warmup_steps=10,
         start_update=10,
         max_timesteps=100,
     )
     trainer.train()
     shutil.rmtree("./logs")
Beispiel #9
0
 def test_prioritized_dqn(self):
     env = VectorEnv("CartPole-v0")
     algo = PrioritizedReplayDQN("mlp",
                                 env,
                                 batch_size=5,
                                 replay_size=100,
                                 value_layers=[1, 1])
     assert isinstance(algo.model, MlpValue)
     assert isinstance(algo.replay_buffer, PrioritizedBuffer)
     trainer = OffPolicyTrainer(
         algo,
         env,
         log_mode=["csv"],
         logdir="./logs",
         max_ep_len=200,
         epochs=4,
         warmup_steps=10,
         start_update=10,
     )
     trainer.train()
     shutil.rmtree("./logs")
Beispiel #10
0
def test_td3():
    env = VectorEnv("Pendulum-v0", 2)
    algo = TD3(
        "mlp",
        env,
        batch_size=5,
        noise=OrnsteinUhlenbeckActionNoise,
        policy_layers=[1, 1],
        value_layers=[1, 1],
    )

    trainer = OffPolicyTrainer(
        algo,
        env,
        log_mode=["csv"],
        logdir="./logs",
        epochs=5,
        max_ep_len=500,
        warmup_steps=10,
        start_update=10,
    )
    trainer.train()
    shutil.rmtree("./logs")
Beispiel #11
0
def test_ddpg():
    env = VectorEnv("Pendulum-v0", 2)
    algo = DDPG(
        "mlp",
        env,
        batch_size=5,
        noise=NormalActionNoise,
        policy_layers=[1, 1],
        value_layers=[1, 1],
    )

    trainer = OffPolicyTrainer(
        algo,
        env,
        log_mode=["csv"],
        logdir="./logs",
        epochs=4,
        max_ep_len=200,
        warmup_steps=10,
        start_update=10,
    )
    trainer.train()
    shutil.rmtree("./logs")