def test_off_policy_trainer(): env = VectorEnv("Pendulum-v0", 2) algo = DDPG("mlp", env, replay_size=100) trainer = OffPolicyTrainer(algo, env, ["stdout"], epochs=1, evaluate_episodes=2) assert trainer.off_policy trainer.train() trainer.evaluate()
def test_ddpg(self): env = gym.make("Pendulum-v0") algo = DDPG("mlp", env, noise=NormalActionNoise, layers=[1, 1]) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", epochs=1, evaluate_episodes=2 ) trainer.train() trainer.evaluate() shutil.rmtree("./logs")
def test_td3(self): env = gym.make("Pendulum-v0") algo = TD3("mlp", env, noise=OrnsteinUhlenbeckActionNoise, layers=[1, 1]) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", epochs=1, evaluate_episodes=2 ) trainer.train() trainer.evaluate() shutil.rmtree("./logs")
def test_dqn(self): env = gym.make("CartPole-v0") # DQN algo = DQN("mlp", env) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", epochs=1, evaluate_episodes=2 ) trainer.train() trainer.evaluate() shutil.rmtree("./logs") # Double DQN with prioritized replay buffer algo1 = DQN("mlp", env, double_dqn=True, prioritized_replay=True) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", epochs=1, render=False ) trainer.train() shutil.rmtree("./logs") # Noisy DQN algo2 = DQN("mlp", env, noisy_dqn=True) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", epochs=1, render=False ) trainer.train() shutil.rmtree("./logs") # Dueling DQN algo3 = DQN("mlp", env, dueling_dqn=True) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", epochs=1, render=False ) trainer.train() shutil.rmtree("./logs") # Categorical DQN algo4 = DQN("mlp", env, categorical_dqn=True) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", epochs=1, render=False ) trainer.train() shutil.rmtree("./logs")