Beispiel #1
0
def test_sac_shared():
    env = VectorEnv("Pendulum-v0", 2)
    algo = SAC(
        "mlp",
        env,
        batch_size=5,
        shared_layers=[1, 1],
        policy_layers=[1, 1],
        value_layers=[1, 1],
        max_timesteps=100,
    )

    trainer = OffPolicyTrainer(
        algo,
        env,
        log_mode=["csv"],
        logdir="./logs",
        epochs=5,
        max_ep_len=500,
        warmup_steps=10,
        start_update=10,
        max_timesteps=100,
    )
    trainer.train()
    shutil.rmtree("./logs")
Beispiel #2
0
    def test_atari_env(self):
        """
        Tests working of Atari Wrappers and the AtariEnv function
        """
        env = VectorEnv("Pong-v0", env_type="atari")
        algo = DQN("cnn", env, replay_size=100)

        trainer = OffPolicyTrainer(algo, env, epochs=1, max_timesteps=50)
        trainer.train()
        shutil.rmtree("./logs")
Beispiel #3
0
 def test_double_dqn(self):
     env = VectorEnv("CartPole-v0")
     algo = DoubleDQN("mlp", env, batch_size=5, replay_size=100)
     assert isinstance(algo.model, MlpValue)
     trainer = OffPolicyTrainer(
         algo,
         env,
         log_mode=["csv"],
         logdir="./logs",
         max_ep_len=200,
         epochs=4,
         warmup_steps=10,
         start_update=10,
     )
     trainer.train()
     shutil.rmtree("./logs")
 def test_prioritized_dqn(self):
     env = VectorEnv("Pong-v0", env_type="atari")
     algo = PrioritizedReplayDQN("cnn", env, batch_size=5, replay_size=100)
     assert isinstance(algo.model, CnnValue)
     trainer = OffPolicyTrainer(
         algo,
         env,
         log_mode=["csv"],
         logdir="./logs",
         max_ep_len=200,
         epochs=4,
         warmup_steps=10,
         start_update=10,
     )
     trainer.train()
     shutil.rmtree("./logs")
Beispiel #5
0
 def test_noisy_dqn(self):
     env = VectorEnv("CartPole-v0")
     algo = NoisyDQN("mlp", env, batch_size=5, replay_size=100, value_layers=[1, 1])
     assert algo.dqn_type == "noisy"
     assert algo.noisy
     assert isinstance(algo.model, MlpNoisyValue)
     trainer = OffPolicyTrainer(
         algo,
         env,
         log_mode=["csv"],
         logdir="./logs",
         max_ep_len=200,
         epochs=4,
         warmup_steps=10,
         start_update=10,
     )
     trainer.train()
     shutil.rmtree("./logs")
Beispiel #6
0
    def test_atari_env(self):
        """
        Tests working of Atari Wrappers and the AtariEnv function
        """
        env = VectorEnv("Pong-v0", env_type="atari")
        algo = DQN("cnn",
                   env,
                   batch_size=5,
                   replay_size=100,
                   value_layers=[1, 1])

        trainer = OffPolicyTrainer(algo,
                                   env,
                                   epochs=5,
                                   max_ep_len=200,
                                   warmup_steps=10,
                                   start_update=10)
        trainer.train()
        shutil.rmtree("./logs")
Beispiel #7
0
 def test_dueling_dqn(self):
     env = VectorEnv("Pong-v0", env_type="atari")
     algo = DuelingDQN(
         "cnn", env, batch_size=5, replay_size=100, value_layers=[1, 1]
     )
     assert algo.dqn_type == "dueling"
     assert isinstance(algo.model, CnnDuelingValue)
     trainer = OffPolicyTrainer(
         algo,
         env,
         log_mode=["csv"],
         logdir="./logs",
         max_ep_len=200,
         epochs=4,
         warmup_steps=10,
         start_update=10,
         max_timesteps=100,
     )
     trainer.train()
     shutil.rmtree("./logs")
Beispiel #8
0
 def test_prioritized_dqn(self):
     env = VectorEnv("CartPole-v0")
     algo = PrioritizedReplayDQN("mlp",
                                 env,
                                 batch_size=5,
                                 replay_size=100,
                                 value_layers=[1, 1])
     assert isinstance(algo.model, MlpValue)
     assert isinstance(algo.replay_buffer, PrioritizedBuffer)
     trainer = OffPolicyTrainer(
         algo,
         env,
         log_mode=["csv"],
         logdir="./logs",
         max_ep_len=200,
         epochs=4,
         warmup_steps=10,
         start_update=10,
     )
     trainer.train()
     shutil.rmtree("./logs")
Beispiel #9
0
def main(args):
    ALGOS = {
        "sac": SAC,
        "a2c": A2C,
        "ppo": PPO1,
        "ddpg": DDPG,
        "td3": TD3,
        "vpg": VPG,
        "dqn": DQN,
    }

    algo = ALGOS[args.algo.lower()]
    env = VectorEnv(args.env,
                    n_envs=args.n_envs,
                    parallel=not args.serial,
                    env_type=args.env_type)

    logger = get_logger(args.log)
    trainer = None

    if args.algo in ["ppo", "vpg", "a2c"]:
        agent = algo(
            args.arch, env,
            rollout_size=args.rollout_size)  # , batch_size=args.batch_size)
        trainer = OnPolicyTrainer(
            agent,
            env,
            logger,
            epochs=args.epochs,
            render=args.render,
            log_interval=args.log_interval,
        )

    else:
        agent = algo(args.arch,
                     env,
                     replay_size=args.replay_size,
                     batch_size=args.batch_size)
        trainer = OffPolicyTrainer(
            agent,
            env,
            logger,
            epochs=args.epochs,
            render=args.render,
            warmup_steps=args.warmup_steps,
            log_interval=args.log_interval,
        )

    trainer.train()
    trainer.evaluate()
    env.render()
Beispiel #10
0
def test_td3():
    env = VectorEnv("Pendulum-v0", 2)
    algo = TD3(
        "mlp",
        env,
        batch_size=5,
        noise=OrnsteinUhlenbeckActionNoise,
        policy_layers=[1, 1],
        value_layers=[1, 1],
    )

    trainer = OffPolicyTrainer(
        algo,
        env,
        log_mode=["csv"],
        logdir="./logs",
        epochs=5,
        max_ep_len=500,
        warmup_steps=10,
        start_update=10,
    )
    trainer.train()
    shutil.rmtree("./logs")
Beispiel #11
0
def test_ddpg():
    env = VectorEnv("Pendulum-v0", 2)
    algo = DDPG(
        "mlp",
        env,
        batch_size=5,
        noise=NormalActionNoise,
        policy_layers=[1, 1],
        value_layers=[1, 1],
    )

    trainer = OffPolicyTrainer(
        algo,
        env,
        log_mode=["csv"],
        logdir="./logs",
        epochs=4,
        max_ep_len=200,
        warmup_steps=10,
        start_update=10,
    )
    trainer.train()
    shutil.rmtree("./logs")
Beispiel #12
0
def test_off_policy_trainer():
    env = VectorEnv("Pendulum-v0", 2)
    algo = DDPG("mlp", env, replay_size=100)
    trainer = OffPolicyTrainer(
        algo,
        env,
        ["stdout"],
        epochs=2,
        evaluate_episodes=2,
        max_ep_len=300,
        max_timesteps=300,
    )
    assert trainer.off_policy
    trainer.train()
    trainer.evaluate()