def test_noisy_dqn(self): env = VectorEnv("CartPole-v0") algo = NoisyDQN("mlp", env, batch_size=5, replay_size=100) assert algo.dqn_type == "noisy" assert algo.noisy assert isinstance(algo.model, MlpNoisyValue) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", max_ep_len=200, epochs=4, warmup_steps=10, start_update=10, ) trainer.train() shutil.rmtree("./logs")
def test_noisy_dqn(self): env = VectorEnv("Pong-v0", env_type="atari") algo = NoisyDQN("cnn", env, batch_size=5, replay_size=100, value_layers=[1, 1]) assert algo.dqn_type == "noisy" assert algo.noisy assert isinstance(algo.model, CnnNoisyValue) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", max_ep_len=200, epochs=4, warmup_steps=10, start_update=10, max_timesteps=100, ) trainer.train() shutil.rmtree("./logs")