def test_dqn(): env = VectorEnv("CartPole-v0", 2) # DQN algo = DQN("mlp", env) trainer = OffPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs") # Double DQN with prioritized replay buffer # algo1 = DQN("mlp", env, double_dqn=True, prioritized_replay=True) # trainer = OffPolicyTrainer(algo1, env, log_mode=["csv"], logdir="./logs", epochs=1) # trainer.train() # shutil.rmtree("./logs") # Noisy DQN algo2 = DQN("mlp", env, noisy_dqn=True) trainer = OffPolicyTrainer(algo2, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs") # Dueling DDQN algo3 = DQN("mlp", env, dueling_dqn=True, double_dqn=True) trainer = OffPolicyTrainer(algo3, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs") # Categorical DQN algo4 = DQN("mlp", env, categorical_dqn=True) trainer = OffPolicyTrainer(algo4, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs")
def test_dqn(self): env = gym.make("CartPole-v0") # DQN algo = DQN("mlp", env) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", epochs=1, evaluate_episodes=2 ) trainer.train() trainer.evaluate() shutil.rmtree("./logs") # Double DQN with prioritized replay buffer algo1 = DQN("mlp", env, double_dqn=True, prioritized_replay=True) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", epochs=1, render=False ) trainer.train() shutil.rmtree("./logs") # Noisy DQN algo2 = DQN("mlp", env, noisy_dqn=True) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", epochs=1, render=False ) trainer.train() shutil.rmtree("./logs") # Dueling DQN algo3 = DQN("mlp", env, dueling_dqn=True) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", epochs=1, render=False ) trainer.train() shutil.rmtree("./logs") # Categorical DQN algo4 = DQN("mlp", env, categorical_dqn=True) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", epochs=1, render=False ) trainer.train() shutil.rmtree("./logs")
def test_dqn_cnn(self): env = gym.make("Breakout-v0") # DQN algo = DQN("cnn", env) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200 ) trainer.train() shutil.rmtree("./logs") # Double DQN with prioritized replay buffer algo1 = DQN("cnn", env, double_dqn=True, prioritized_replay=True) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200 ) trainer.train() shutil.rmtree("./logs") # Noisy DQN algo2 = DQN("cnn", env, noisy_dqn=True) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200 ) trainer.train() shutil.rmtree("./logs") # Dueling DQN algo3 = DQN("cnn", env, dueling_dqn=True) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200 ) trainer.train() shutil.rmtree("./logs") # Categorical DQN algo4 = DQN("cnn", env, categorical_dqn=True) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200 ) trainer.train() shutil.rmtree("./logs")
def test_dqn_cnn(self): env = VectorEnv("Pong-v0", n_envs=2, env_type="atari") # DQN algo = DQN("cnn", env) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200 ) trainer.train() shutil.rmtree("./logs") # Double DQN with prioritized replay buffer algo1 = DQN("cnn", env, double_dqn=True, prioritized_replay=True) trainer = OffPolicyTrainer( algo1, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200 ) trainer.train() shutil.rmtree("./logs") # Noisy DQN algo2 = DQN("cnn", env, noisy_dqn=True) trainer = OffPolicyTrainer( algo2, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200 ) trainer.train() shutil.rmtree("./logs") # Dueling DDQN algo3 = DQN("cnn", env, dueling_dqn=True, double_dqn=True) trainer = OffPolicyTrainer( algo3, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200 ) trainer.train() shutil.rmtree("./logs") # Categorical DQN algo4 = DQN("cnn", env, categorical_dqn=True) trainer = OffPolicyTrainer( algo4, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200 ) trainer.train() shutil.rmtree("./logs")
def test_atari_env(self): """ Tests working of Atari Wrappers and the AtariEnv function """ env = VectorEnv("Pong-v0", env_type="atari") algo = DQN("cnn", env) trainer = OffPolicyTrainer(algo, env, epochs=1, steps_per_epoch=200) trainer.train() shutil.rmtree("./logs")
def test_dqn(): env = VectorEnv("CartPole-v0", 2) # DQN algo = DQN("mlp", env) trainer = OffPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs")
def test_categorical_dqn_cnn(): env = VectorEnv("Pong-v0", n_envs=2, env_type="atari") # Categorical DQN algo = DQN("cnn", env, categorical_dqn=True) trainer = OffPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200) trainer.train() shutil.rmtree("./logs")
def test_double_dqn_cnn(): env = VectorEnv("Pong-v0", n_envs=2, env_type="atari") # Double DQN with prioritized replay buffer algo = DQN("cnn", env, double_dqn=True, prioritized_replay=True) trainer = OffPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200) trainer.train() shutil.rmtree("./logs")
argument_parser.add_argument("-b", "--batch-size", type=int, default=64) argument_parser.add_argument("-l", "--length", type=int, default=None) argument_parser.add_argument("--enable-cuda", action="store_true") args = argument_parser.parse_args() if args.enable_cuda: if torch.cuda_is_available(): device = "cuda" else: device = "cpu" warnings.warn("cuda is ot available. Defaulting to cpu") else: device = "cpu" env = gym_super_mario_bros.make("SuperMarioBros-v0") env = JoypadSpace(env, SIMPLE_MOVEMENT) env = MarioEnv(env) agent = DQN("cnn", env, replay_size=100000, epsilon_decay=100000) trainer = AdversariaTrainer( agent=agent, env=env, dataset=args.input_path, possible_actions=SIMPLE_MOVEMENT, device=device, length=args.length, off_policy=True, evaluate_episodes=1, ) trainer.train(epochs=args.epochs, lr=args.lr, batch_size=args.batch_size) trainer.evaluate(render=True)