예제 #1
0
def test_sac():
    env = VectorEnv("Pendulum-v0", 2)
    algo = SAC("mlp", env, layers=[1, 1])

    trainer = OffPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1)
    trainer.train()
    shutil.rmtree("./logs")
예제 #2
0
def test_ddpg():
    env = VectorEnv("Pendulum-v0", 2)
    algo = DDPG("mlp", env, noise=NormalActionNoise, layers=[1, 1])

    trainer = OffPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1)
    trainer.train()
    shutil.rmtree("./logs")
예제 #3
0
def test_off_policy_trainer():
    env = VectorEnv("Pendulum-v0", 2)
    algo = DDPG("mlp", env, replay_size=100)
    trainer = OffPolicyTrainer(algo, env, ["stdout"], epochs=1, evaluate_episodes=2)
    assert trainer.off_policy
    trainer.train()
    trainer.evaluate()
예제 #4
0
    def test_sac(self):
        env = gym.make("Pendulum-v0")
        algo = SAC("mlp", env, layers=[1, 1])

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, render=False
        )
        trainer.train()
        shutil.rmtree("./logs")
예제 #5
0
    def test_td3(self):
        env = gym.make("Pendulum-v0")
        algo = TD3("mlp", env, noise=OrnsteinUhlenbeckActionNoise, layers=[1, 1])

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, evaluate_episodes=2
        )
        trainer.train()
        trainer.evaluate()
        shutil.rmtree("./logs")
예제 #6
0
    def test_ddpg(self):
        env = gym.make("Pendulum-v0")
        algo = DDPG("mlp", env, noise=NormalActionNoise, layers=[1, 1])

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, evaluate_episodes=2
        )
        trainer.train()
        trainer.evaluate()
        shutil.rmtree("./logs")
예제 #7
0
    def test_atari_env(self):
        """
        Tests working of Atari Wrappers and the AtariEnv function
        """
        env = VectorEnv("Pong-v0", env_type="atari")
        algo = DQN("cnn", env)

        trainer = OffPolicyTrainer(algo, env, epochs=1, steps_per_epoch=200)
        trainer.train()
        shutil.rmtree("./logs")
예제 #8
0
def test_td3():
    env = VectorEnv("Pendulum-v0", 2)
    algo = TD3("mlp", env, noise=OrnsteinUhlenbeckActionNoise, layers=[1, 1])

    trainer = OffPolicyTrainer(algo,
                               env,
                               log_mode=["csv"],
                               logdir="./logs",
                               epochs=1)
    trainer.train()
    shutil.rmtree("./logs")
예제 #9
0
def test_dqn():
    env = VectorEnv("CartPole-v0", 2)
    # DQN
    algo = DQN("mlp", env)

    trainer = OffPolicyTrainer(algo,
                               env,
                               log_mode=["csv"],
                               logdir="./logs",
                               epochs=1)
    trainer.train()
    shutil.rmtree("./logs")
예제 #10
0
def test_categorical_dqn_cnn():
    env = VectorEnv("Pong-v0", n_envs=2, env_type="atari")

    # Categorical DQN
    algo = DQN("cnn", env, categorical_dqn=True)

    trainer = OffPolicyTrainer(algo,
                               env,
                               log_mode=["csv"],
                               logdir="./logs",
                               epochs=1,
                               steps_per_epoch=200)
    trainer.train()
    shutil.rmtree("./logs")
예제 #11
0
def test_double_dqn_cnn():
    env = VectorEnv("Pong-v0", n_envs=2, env_type="atari")

    # Double DQN with prioritized replay buffer
    algo = DQN("cnn", env, double_dqn=True, prioritized_replay=True)

    trainer = OffPolicyTrainer(algo,
                               env,
                               log_mode=["csv"],
                               logdir="./logs",
                               epochs=1,
                               steps_per_epoch=200)
    trainer.train()
    shutil.rmtree("./logs")
예제 #12
0
    def test_dqn(self):
        env = gym.make("CartPole-v0")
        # DQN
        algo = DQN("mlp", env)

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, evaluate_episodes=2
        )
        trainer.train()
        trainer.evaluate()
        shutil.rmtree("./logs")

        # Double DQN with prioritized replay buffer
        algo1 = DQN("mlp", env, double_dqn=True, prioritized_replay=True)

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, render=False
        )
        trainer.train()
        shutil.rmtree("./logs")

        # Noisy DQN
        algo2 = DQN("mlp", env, noisy_dqn=True)

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, render=False
        )
        trainer.train()
        shutil.rmtree("./logs")

        # Dueling DQN
        algo3 = DQN("mlp", env, dueling_dqn=True)

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, render=False
        )
        trainer.train()
        shutil.rmtree("./logs")

        # Categorical DQN
        algo4 = DQN("mlp", env, categorical_dqn=True)

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, render=False
        )
        trainer.train()
        shutil.rmtree("./logs")
예제 #13
0
    def test_dqn_cnn(self):
        env = VectorEnv("Pong-v0", n_envs=2, env_type="atari")

        # DQN
        algo = DQN("cnn", env)

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200
        )
        trainer.train()
        shutil.rmtree("./logs")

        # Double DQN with prioritized replay buffer
        algo1 = DQN("cnn", env, double_dqn=True, prioritized_replay=True)

        trainer = OffPolicyTrainer(
            algo1, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200
        )
        trainer.train()
        shutil.rmtree("./logs")

        # Noisy DQN
        algo2 = DQN("cnn", env, noisy_dqn=True)

        trainer = OffPolicyTrainer(
            algo2, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200
        )
        trainer.train()
        shutil.rmtree("./logs")

        # Dueling DDQN
        algo3 = DQN("cnn", env, dueling_dqn=True, double_dqn=True)

        trainer = OffPolicyTrainer(
            algo3, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200
        )
        trainer.train()
        shutil.rmtree("./logs")

        # Categorical DQN
        algo4 = DQN("cnn", env, categorical_dqn=True)

        trainer = OffPolicyTrainer(
            algo4, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200
        )
        trainer.train()
        shutil.rmtree("./logs")
예제 #14
0
    def test_dqn_cnn(self):
        env = gym.make("Breakout-v0")

        # DQN
        algo = DQN("cnn", env)

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200
        )
        trainer.train()
        shutil.rmtree("./logs")

        # Double DQN with prioritized replay buffer
        algo1 = DQN("cnn", env, double_dqn=True, prioritized_replay=True)

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200
        )
        trainer.train()
        shutil.rmtree("./logs")

        # Noisy DQN
        algo2 = DQN("cnn", env, noisy_dqn=True)

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200
        )
        trainer.train()
        shutil.rmtree("./logs")

        # Dueling DQN
        algo3 = DQN("cnn", env, dueling_dqn=True)

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200
        )
        trainer.train()
        shutil.rmtree("./logs")

        # Categorical DQN
        algo4 = DQN("cnn", env, categorical_dqn=True)

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200
        )
        trainer.train()
        shutil.rmtree("./logs")
예제 #15
0
def test_dqn():
    env = VectorEnv("CartPole-v0", 2)
    # DQN
    algo = DQN("mlp", env)

    trainer = OffPolicyTrainer(algo,
                               env,
                               log_mode=["csv"],
                               logdir="./logs",
                               epochs=1)
    trainer.train()
    shutil.rmtree("./logs")

    # Double DQN with prioritized replay buffer
    # algo1 = DQN("mlp", env, double_dqn=True, prioritized_replay=True)

    # trainer = OffPolicyTrainer(algo1, env, log_mode=["csv"], logdir="./logs", epochs=1)
    # trainer.train()
    # shutil.rmtree("./logs")

    # Noisy DQN
    algo2 = DQN("mlp", env, noisy_dqn=True)

    trainer = OffPolicyTrainer(algo2,
                               env,
                               log_mode=["csv"],
                               logdir="./logs",
                               epochs=1)
    trainer.train()
    shutil.rmtree("./logs")

    # Dueling DDQN
    algo3 = DQN("mlp", env, dueling_dqn=True, double_dqn=True)

    trainer = OffPolicyTrainer(algo3,
                               env,
                               log_mode=["csv"],
                               logdir="./logs",
                               epochs=1)
    trainer.train()
    shutil.rmtree("./logs")

    # Categorical DQN
    algo4 = DQN("mlp", env, categorical_dqn=True)

    trainer = OffPolicyTrainer(algo4,
                               env,
                               log_mode=["csv"],
                               logdir="./logs",
                               epochs=1)
    trainer.train()
    shutil.rmtree("./logs")
예제 #16
0
def test_off_policy_trainer():
    env = VectorEnv("Pendulum-v0", 2)
    algo = TD3("mlp", env)
    trainer = OffPolicyTrainer(algo, env, ["stdout"], epochs=1)
    assert trainer.off_policy
    trainer.train()
예제 #17
0
def test_off_policy_trainer():
    env = gym.make("Pendulum-v0")
    algo = TD3("mlp", env)
    trainer = OffPolicyTrainer(algo, env, ["stdout"], epochs=1)
    assert trainer.off_policy == True
    trainer.train()