def test_prioritized_dqn(self): env = VectorEnv("Pong-v0", env_type="atari") algo = PrioritizedReplayDQN("cnn", env, batch_size=5, replay_size=100) assert isinstance(algo.model, CnnValue) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", max_ep_len=200, epochs=4, warmup_steps=10, start_update=10, ) trainer.train() shutil.rmtree("./logs")
def test_prioritized_dqn(self): env = VectorEnv("CartPole-v0") algo = PrioritizedReplayDQN("mlp", env, batch_size=5, replay_size=100) assert isinstance(algo.model, MlpValue) assert isinstance(algo.replay_buffer, PrioritizedBuffer) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", max_ep_len=200, epochs=4, warmup_steps=10, start_update=10, ) trainer.train() shutil.rmtree("./logs")