Example #1
0
 def test_ppo1(self):
     env = VectorEnv("CartPole-v0")
     algo = PPO1("mlp", env, rollout_size=128)
     trainer = OnPolicyTrainer(
         algo, env, log_mode=["csv"], logdir="./logs", epochs=1
     )
     trainer.train()
     shutil.rmtree("./logs")
Example #2
0
 def test_ppo1_cnn(self):
     env = VectorEnv("Pong-v0", env_type="atari")
     algo = PPO1("cnn", env, rollout_size=128)
     trainer = OnPolicyTrainer(
         algo, env, log_mode=["csv"], logdir="./logs", epochs=1
     )
     trainer.train()
     shutil.rmtree("./logs")
Example #3
0
def test_vpg():
    env = VectorEnv("CartPole-v0")
    algo = VPG("mlp", env)
    trainer = OnPolicyTrainer(algo,
                              env,
                              log_mode=["csv"],
                              logdir="./logs",
                              epochs=1)
    trainer.train()
    shutil.rmtree("./logs")
Example #4
0
def test_a2c_shared():
    env = VectorEnv("CartPole-v0", 1)
    algo = A2C("mlp", env, shared_layers=(32, 32), rollout_size=128)
    trainer = OnPolicyTrainer(algo,
                              env,
                              log_mode=["csv"],
                              logdir="./logs",
                              epochs=1)
    trainer.train()
    shutil.rmtree("./logs")
Example #5
0
 def test_a2c_discrete(self):
     env = VectorEnv("CartPole-v0", 1)
     algo = A2C("mlp", env, rollout_size=128)
     trainer = OnPolicyTrainer(algo,
                               env,
                               log_mode=["csv"],
                               logdir="./logs",
                               epochs=1)
     trainer.train()
     trainer.evaluate()
     shutil.rmtree("./logs")
Example #6
0
def test_on_policy_trainer():
    env = VectorEnv("CartPole-v1", 2)
    algo = PPO1("mlp", env, rollout_size=128)
    trainer = OnPolicyTrainer(algo,
                              env, ["stdout"],
                              epochs=2,
                              evaluate_episodes=2,
                              max_timesteps=300)
    assert not trainer.off_policy
    trainer.train()
    trainer.evaluate()
Example #7
0
def main(args):
    ALGOS = {
        "sac": SAC,
        "a2c": A2C,
        "ppo": PPO1,
        "ddpg": DDPG,
        "td3": TD3,
        "vpg": VPG,
        "dqn": DQN,
    }

    algo = ALGOS[args.algo.lower()]
    env = VectorEnv(args.env,
                    n_envs=args.n_envs,
                    parallel=not args.serial,
                    env_type=args.env_type)

    logger = get_logger(args.log)
    trainer = None

    if args.algo in ["ppo", "vpg", "a2c"]:
        agent = algo(
            args.arch, env,
            rollout_size=args.rollout_size)  # , batch_size=args.batch_size)
        trainer = OnPolicyTrainer(
            agent,
            env,
            logger,
            epochs=args.epochs,
            render=args.render,
            log_interval=args.log_interval,
        )

    else:
        agent = algo(args.arch,
                     env,
                     replay_size=args.replay_size,
                     batch_size=args.batch_size)
        trainer = OffPolicyTrainer(
            agent,
            env,
            logger,
            epochs=args.epochs,
            render=args.render,
            warmup_steps=args.warmup_steps,
            log_interval=args.log_interval,
        )

    trainer.train()
    trainer.evaluate()
    env.render()
Example #8
0
def test_save_params():
    """
    test saving algorithm state dict
    """
    env = VectorEnv("CartPole-v0", 1)
    algo = PPO1("mlp", env)
    trainer = OnPolicyTrainer(algo,
                              env, ["stdout"],
                              save_model="test_ckpt",
                              save_interval=1,
                              epochs=1)
    trainer.train()

    assert len(os.listdir("test_ckpt/PPO1_CartPole-v0")) != 0
Example #9
0
def test_load_params():
    """
    test loading algorithm parameters
    """
    env = VectorEnv("CartPole-v0", 1)
    algo = PPO1("mlp", env)
    trainer = OnPolicyTrainer(
        algo,
        env,
        epochs=0,
        load_model="test_ckpt/PPO1_CartPole-v0/0-log-0.pt")
    trainer.train()

    rmtree("logs")
Example #10
0
    def test_custom_ppo1(self):
        env = VectorEnv("CartPole-v0", 1)
        state_dim = env.observation_space.shape[0]
        action_dim = env.action_space.n
        actorcritic = custom_actorcritic(state_dim, action_dim)

        algo = PPO1(actorcritic, env)

        trainer = OnPolicyTrainer(algo,
                                  env,
                                  log_mode=["csv"],
                                  logdir="./logs",
                                  epochs=1)
        trainer.train()
        shutil.rmtree("./logs")
Example #11
0
    def test_custom_vpg(self):
        env = VectorEnv("CartPole-v0", 1)
        state_dim = env.observation_space.shape[0]
        action_dim = env.action_space.n
        policy = custom_policy(state_dim, action_dim)

        algo = VPG(policy, env)

        trainer = OnPolicyTrainer(algo,
                                  env,
                                  log_mode=["csv"],
                                  logdir="./logs",
                                  epochs=1)
        trainer.train()
        shutil.rmtree("./logs")
Example #12
0
    def test_load_params(self):
        """
        test loading algorithm parameters
        """
        env = VectorEnv("CartPole-v0", 1)
        algo = PPO1("mlp", env)
        trainer = OnPolicyTrainer(
            algo,
            env,
            epochs=0,
            load_hyperparams="test_ckpt/PPO1_CartPole-v0/0-log-0.toml",
            load_weights="test_ckpt/PPO1_CartPole-v0/0-log-0.pt",
        )
        trainer.train()

        rmtree("logs")