Ejemplo n.º 1
0
def test_dqn():
    env = VectorEnv("CartPole-v0", 2)
    # DQN
    algo = DQN("mlp", env)

    trainer = OffPolicyTrainer(algo,
                               env,
                               log_mode=["csv"],
                               logdir="./logs",
                               epochs=1)
    trainer.train()
    shutil.rmtree("./logs")

    # Double DQN with prioritized replay buffer
    # algo1 = DQN("mlp", env, double_dqn=True, prioritized_replay=True)

    # trainer = OffPolicyTrainer(algo1, env, log_mode=["csv"], logdir="./logs", epochs=1)
    # trainer.train()
    # shutil.rmtree("./logs")

    # Noisy DQN
    algo2 = DQN("mlp", env, noisy_dqn=True)

    trainer = OffPolicyTrainer(algo2,
                               env,
                               log_mode=["csv"],
                               logdir="./logs",
                               epochs=1)
    trainer.train()
    shutil.rmtree("./logs")

    # Dueling DDQN
    algo3 = DQN("mlp", env, dueling_dqn=True, double_dqn=True)

    trainer = OffPolicyTrainer(algo3,
                               env,
                               log_mode=["csv"],
                               logdir="./logs",
                               epochs=1)
    trainer.train()
    shutil.rmtree("./logs")

    # Categorical DQN
    algo4 = DQN("mlp", env, categorical_dqn=True)

    trainer = OffPolicyTrainer(algo4,
                               env,
                               log_mode=["csv"],
                               logdir="./logs",
                               epochs=1)
    trainer.train()
    shutil.rmtree("./logs")
Ejemplo n.º 2
0
    def test_dqn(self):
        env = gym.make("CartPole-v0")
        # DQN
        algo = DQN("mlp", env)

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, evaluate_episodes=2
        )
        trainer.train()
        trainer.evaluate()
        shutil.rmtree("./logs")

        # Double DQN with prioritized replay buffer
        algo1 = DQN("mlp", env, double_dqn=True, prioritized_replay=True)

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, render=False
        )
        trainer.train()
        shutil.rmtree("./logs")

        # Noisy DQN
        algo2 = DQN("mlp", env, noisy_dqn=True)

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, render=False
        )
        trainer.train()
        shutil.rmtree("./logs")

        # Dueling DQN
        algo3 = DQN("mlp", env, dueling_dqn=True)

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, render=False
        )
        trainer.train()
        shutil.rmtree("./logs")

        # Categorical DQN
        algo4 = DQN("mlp", env, categorical_dqn=True)

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, render=False
        )
        trainer.train()
        shutil.rmtree("./logs")
Ejemplo n.º 3
0
    def test_dqn_cnn(self):
        env = gym.make("Breakout-v0")

        # DQN
        algo = DQN("cnn", env)

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200
        )
        trainer.train()
        shutil.rmtree("./logs")

        # Double DQN with prioritized replay buffer
        algo1 = DQN("cnn", env, double_dqn=True, prioritized_replay=True)

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200
        )
        trainer.train()
        shutil.rmtree("./logs")

        # Noisy DQN
        algo2 = DQN("cnn", env, noisy_dqn=True)

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200
        )
        trainer.train()
        shutil.rmtree("./logs")

        # Dueling DQN
        algo3 = DQN("cnn", env, dueling_dqn=True)

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200
        )
        trainer.train()
        shutil.rmtree("./logs")

        # Categorical DQN
        algo4 = DQN("cnn", env, categorical_dqn=True)

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200
        )
        trainer.train()
        shutil.rmtree("./logs")
Ejemplo n.º 4
0
    def test_dqn_cnn(self):
        env = VectorEnv("Pong-v0", n_envs=2, env_type="atari")

        # DQN
        algo = DQN("cnn", env)

        trainer = OffPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200
        )
        trainer.train()
        shutil.rmtree("./logs")

        # Double DQN with prioritized replay buffer
        algo1 = DQN("cnn", env, double_dqn=True, prioritized_replay=True)

        trainer = OffPolicyTrainer(
            algo1, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200
        )
        trainer.train()
        shutil.rmtree("./logs")

        # Noisy DQN
        algo2 = DQN("cnn", env, noisy_dqn=True)

        trainer = OffPolicyTrainer(
            algo2, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200
        )
        trainer.train()
        shutil.rmtree("./logs")

        # Dueling DDQN
        algo3 = DQN("cnn", env, dueling_dqn=True, double_dqn=True)

        trainer = OffPolicyTrainer(
            algo3, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200
        )
        trainer.train()
        shutil.rmtree("./logs")

        # Categorical DQN
        algo4 = DQN("cnn", env, categorical_dqn=True)

        trainer = OffPolicyTrainer(
            algo4, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200
        )
        trainer.train()
        shutil.rmtree("./logs")
Ejemplo n.º 5
0
    def test_atari_env(self):
        """
        Tests working of Atari Wrappers and the AtariEnv function
        """
        env = VectorEnv("Pong-v0", env_type="atari")
        algo = DQN("cnn", env)

        trainer = OffPolicyTrainer(algo, env, epochs=1, steps_per_epoch=200)
        trainer.train()
        shutil.rmtree("./logs")
Ejemplo n.º 6
0
def test_dqn():
    env = VectorEnv("CartPole-v0", 2)
    # DQN
    algo = DQN("mlp", env)

    trainer = OffPolicyTrainer(algo,
                               env,
                               log_mode=["csv"],
                               logdir="./logs",
                               epochs=1)
    trainer.train()
    shutil.rmtree("./logs")
Ejemplo n.º 7
0
def test_categorical_dqn_cnn():
    env = VectorEnv("Pong-v0", n_envs=2, env_type="atari")

    # Categorical DQN
    algo = DQN("cnn", env, categorical_dqn=True)

    trainer = OffPolicyTrainer(algo,
                               env,
                               log_mode=["csv"],
                               logdir="./logs",
                               epochs=1,
                               steps_per_epoch=200)
    trainer.train()
    shutil.rmtree("./logs")
Ejemplo n.º 8
0
def test_double_dqn_cnn():
    env = VectorEnv("Pong-v0", n_envs=2, env_type="atari")

    # Double DQN with prioritized replay buffer
    algo = DQN("cnn", env, double_dqn=True, prioritized_replay=True)

    trainer = OffPolicyTrainer(algo,
                               env,
                               log_mode=["csv"],
                               logdir="./logs",
                               epochs=1,
                               steps_per_epoch=200)
    trainer.train()
    shutil.rmtree("./logs")
argument_parser.add_argument("-b", "--batch-size", type=int, default=64)
argument_parser.add_argument("-l", "--length", type=int, default=None)
argument_parser.add_argument("--enable-cuda", action="store_true")
args = argument_parser.parse_args()

if args.enable_cuda:
    if torch.cuda_is_available():
        device = "cuda"
    else:
        device = "cpu"
        warnings.warn("cuda is ot available. Defaulting to cpu")
else:
    device = "cpu"

env = gym_super_mario_bros.make("SuperMarioBros-v0")
env = JoypadSpace(env, SIMPLE_MOVEMENT)
env = MarioEnv(env)
agent = DQN("cnn", env, replay_size=100000, epsilon_decay=100000)
trainer = AdversariaTrainer(
    agent=agent,
    env=env,
    dataset=args.input_path,
    possible_actions=SIMPLE_MOVEMENT,
    device=device,
    length=args.length,
    off_policy=True,
    evaluate_episodes=1,
)
trainer.train(epochs=args.epochs, lr=args.lr, batch_size=args.batch_size)
trainer.evaluate(render=True)