Пример #1
0
 def test_ppo1(self):
     env = VectorEnv("CartPole-v0")
     algo = PPO1("mlp", env, rollout_size=128)
     trainer = OnPolicyTrainer(
         algo, env, log_mode=["csv"], logdir="./logs", epochs=1
     )
     trainer.train()
     shutil.rmtree("./logs")
Пример #2
0
 def test_ppo1_cnn(self):
     env = VectorEnv("Pong-v0", env_type="atari")
     algo = PPO1("cnn", env, rollout_size=128)
     trainer = OnPolicyTrainer(
         algo, env, log_mode=["csv"], logdir="./logs", epochs=1
     )
     trainer.train()
     shutil.rmtree("./logs")
Пример #3
0
def test_on_policy_trainer():
    env = VectorEnv("CartPole-v1", 2)
    algo = PPO1("mlp", env, rollout_size=128)
    trainer = OnPolicyTrainer(algo,
                              env, ["stdout"],
                              epochs=2,
                              evaluate_episodes=2,
                              max_timesteps=300)
    assert not trainer.off_policy
    trainer.train()
    trainer.evaluate()
Пример #4
0
def test_save_params():
    """
    test saving algorithm state dict
    """
    env = VectorEnv("CartPole-v0", 1)
    algo = PPO1("mlp", env)
    trainer = OnPolicyTrainer(algo,
                              env, ["stdout"],
                              save_model="test_ckpt",
                              save_interval=1,
                              epochs=1)
    trainer.train()

    assert len(os.listdir("test_ckpt/PPO1_CartPole-v0")) != 0
Пример #5
0
def test_load_params():
    """
    test loading algorithm parameters
    """
    env = VectorEnv("CartPole-v0", 1)
    algo = PPO1("mlp", env)
    trainer = OnPolicyTrainer(
        algo,
        env,
        epochs=0,
        load_model="test_ckpt/PPO1_CartPole-v0/0-log-0.pt")
    trainer.train()

    rmtree("logs")
Пример #6
0
    def test_custom_ppo1(self):
        env = VectorEnv("CartPole-v0", 1)
        state_dim = env.observation_space.shape[0]
        action_dim = env.action_space.n
        actorcritic = custom_actorcritic(state_dim, action_dim)

        algo = PPO1(actorcritic, env)

        trainer = OnPolicyTrainer(algo,
                                  env,
                                  log_mode=["csv"],
                                  logdir="./logs",
                                  epochs=1)
        trainer.train()
        shutil.rmtree("./logs")
Пример #7
0
    def test_load_params(self):
        """
        test loading algorithm parameters
        """
        env = VectorEnv("CartPole-v0", 1)
        algo = PPO1("mlp", env)
        trainer = OnPolicyTrainer(
            algo,
            env,
            epochs=0,
            load_hyperparams="test_ckpt/PPO1_CartPole-v0/0-log-0.toml",
            load_weights="test_ckpt/PPO1_CartPole-v0/0-log-0.pt",
        )
        trainer.train()

        rmtree("logs")