示例#1
0
def test_off_policy_trainer():
    env = VectorEnv("Pendulum-v0", 2)
    algo = DDPG("mlp", env, replay_size=100)
    trainer = OffPolicyTrainer(algo, env, ["stdout"], epochs=1, evaluate_episodes=2)
    assert trainer.off_policy
    trainer.train()
    trainer.evaluate()
示例#2
0
    def test_ddpg(self):
        env = gym.make("Pendulum-v0")
        algo = DDPG("mlp", env, noise=NormalActionNoise, layers=[1, 1])

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, evaluate_episodes=2
        )
        trainer.train()
        trainer.evaluate()
        shutil.rmtree("./logs")
示例#3
0
    def test_td3(self):
        env = gym.make("Pendulum-v0")
        algo = TD3("mlp", env, noise=OrnsteinUhlenbeckActionNoise, layers=[1, 1])

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, evaluate_episodes=2
        )
        trainer.train()
        trainer.evaluate()
        shutil.rmtree("./logs")
示例#4
0
    def test_dqn(self):
        env = gym.make("CartPole-v0")
        # DQN
        algo = DQN("mlp", env)

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, evaluate_episodes=2
        )
        trainer.train()
        trainer.evaluate()
        shutil.rmtree("./logs")

        # Double DQN with prioritized replay buffer
        algo1 = DQN("mlp", env, double_dqn=True, prioritized_replay=True)

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, render=False
        )
        trainer.train()
        shutil.rmtree("./logs")

        # Noisy DQN
        algo2 = DQN("mlp", env, noisy_dqn=True)

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, render=False
        )
        trainer.train()
        shutil.rmtree("./logs")

        # Dueling DQN
        algo3 = DQN("mlp", env, dueling_dqn=True)

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, render=False
        )
        trainer.train()
        shutil.rmtree("./logs")

        # Categorical DQN
        algo4 = DQN("mlp", env, categorical_dqn=True)

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, render=False
        )
        trainer.train()
        shutil.rmtree("./logs")