def test_dueling_dqn(self): env = VectorEnv("CartPole-v0") algo = DuelingDQN("mlp", env, batch_size=5, replay_size=100) assert algo.dqn_type == "dueling" assert isinstance(algo.model, MlpDuelingValue) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", max_ep_len=200, epochs=4, warmup_steps=10, start_update=10, ) trainer.train() shutil.rmtree("./logs")
def test_dueling_dqn(self): env = VectorEnv("Pong-v0", env_type="atari") algo = DuelingDQN( "cnn", env, batch_size=5, replay_size=100, value_layers=[1, 1] ) assert algo.dqn_type == "dueling" assert isinstance(algo.model, CnnDuelingValue) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", max_ep_len=200, epochs=4, warmup_steps=10, start_update=10, max_timesteps=100, ) trainer.train() shutil.rmtree("./logs")