def test_off_policy_trainer(): env = VectorEnv("Pendulum-v0", 2) algo = DDPG("mlp", env, replay_size=100) trainer = OffPolicyTrainer( algo, env, ["stdout"], epochs=2, evaluate_episodes=2, max_ep_len=300, max_timesteps=300, ) assert trainer.off_policy trainer.train() trainer.evaluate()
def test_vanilla_dqn(self): env = VectorEnv("CartPole-v0") algo = DQN("mlp", env, batch_size=5, replay_size=100, value_layers=[1, 1]) assert isinstance(algo.model, MlpValue) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", max_ep_len=200, epochs=4, warmup_steps=10, start_update=10, ) trainer.train() trainer.evaluate() shutil.rmtree("./logs")
def test_double_dqn(self): env = VectorEnv("Pong-v0", env_type="atari") algo = DoubleDQN("cnn", env, batch_size=5, replay_size=100, value_layers=[1, 1]) assert isinstance(algo.model, CnnValue) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", max_ep_len=200, epochs=4, warmup_steps=10, start_update=10, ) trainer.train() trainer.evaluate() shutil.rmtree("./logs")