def test_categorical_dqn(self): env = VectorEnv("CartPole-v0") algo = CategoricalDQN("mlp", env, batch_size=5, replay_size=100) assert algo.dqn_type == "categorical" assert algo.noisy assert isinstance(algo.model, MlpCategoricalValue) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", max_ep_len=200, epochs=4, warmup_steps=10, start_update=10, ) trainer.train() shutil.rmtree("./logs")
def test_categorical_dqn(self): env = VectorEnv("Pong-v0", env_type="atari") algo = CategoricalDQN( "cnn", env, batch_size=5, replay_size=100, value_layers=[1, 1] ) assert algo.dqn_type == "categorical" assert algo.noisy assert isinstance(algo.model, CnnCategoricalValue) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", max_ep_len=200, epochs=4, warmup_steps=10, start_update=10, max_timesteps=100, ) trainer.train() shutil.rmtree("./logs")