Exemplo n.º 1
0
def test_on_policy_trainer():
    logger = Logger()
    env = gym.make("CartPole-v1")
    algo = PPO1("mlp", env)
    trainer = OnPolicyTrainer(algo, env, logger, epochs=1)
    assert trainer.off_policy == False
    trainer.train()
Exemplo n.º 2
0
def test_on_policy_trainer():
    env = VectorEnv("CartPole-v1", 2)
    algo = PPO1("mlp", env)
    trainer = OnPolicyTrainer(algo, env, ["stdout"], epochs=1, evaluate_episodes=2)
    assert not trainer.off_policy
    trainer.train()
    trainer.evaluate()
Exemplo n.º 3
0
 def test_evaluate(self):
     """
     test evaluating trained algorithm
     """
     env = gym.make("CartPole-v0")
     algo = PPO1("mlp", env, epochs=1)
     algo.learn()
     evaluate(algo, num_timesteps=10)
Exemplo n.º 4
0
    def test_ppo1(self):
        env = gym.make("Pendulum-v0")
        algo = PPO1("mlp", env, layers=[1, 1])
        logger = Logger("./logs", ["csv"])

        trainer = OnPolicyTrainer(algo, env, logger, epochs=1, render=False)
        trainer.train()
        shutil.rmtree("./logs")
Exemplo n.º 5
0
    def test_save_params(self):
        """
        test saving algorithm state dict
        """
        env = gym.make("CartPole-v0")
        algo = PPO1("mlp", env, epochs=1, save_model="test_ckpt")
        algo.learn()

        assert len(os.listdir("test_ckpt/PPO1_CartPole-v0")) != 0
Exemplo n.º 6
0
    def test_ppo1(self):
        env = gym.make("Pendulum-v0")
        algo = PPO1("mlp", env, layers=[1, 1])

        trainer = OnPolicyTrainer(
            algo, env, log_mode=["csv"], logdir="./logs", epochs=1, evaluate_episodes=2
        )
        trainer.train()
        trainer.evaluate()
        shutil.rmtree("./logs")
Exemplo n.º 7
0
def test_ppo1():
    env = VectorEnv("CartPole-v0", 1)
    algo = PPO1("mlp", env)
    trainer = OnPolicyTrainer(algo,
                              env,
                              log_mode=["csv"],
                              logdir="./logs",
                              epochs=1)
    trainer.train()
    shutil.rmtree("./logs")
Exemplo n.º 8
0
def test_ppo1_cnn():
    env = VectorEnv("Pong-v0", 1, env_type="atari")
    algo = PPO1("cnn", env)
    trainer = OnPolicyTrainer(algo,
                              env,
                              log_mode=["csv"],
                              logdir="./logs",
                              epochs=1)
    trainer.train()
    shutil.rmtree("./logs")
Exemplo n.º 9
0
    def test_load_params(self):
        """
        test loading algorithm parameters
        """
        env = gym.make("CartPole-v0")
        algo = PPO1("mlp",
                    env,
                    epochs=1,
                    load_model="test_ckpt/PPO1_CartPole-v0/0-log-0.pt")

        rmtree("test_ckpt")
Exemplo n.º 10
0
    def test_load_params(self):
        """
        test loading algorithm parameters
        """
        env = VectorEnv("CartPole-v0", 1)
        algo = PPO1(
            "mlp",
            env,
            epochs=1,
            load_model="test_ckpt/PPO1_CartPole-v0/0-log-0.pt",
        )

        rmtree("logs")
Exemplo n.º 11
0
    def test_load_params(self):
        """
        test loading algorithm parameters
        """
        env = VectorEnv("CartPole-v0", 1)
        algo = PPO1("mlp", env)
        trainer = OnPolicyTrainer(
            algo,
            env,
            epochs=0,
            load_model="test_ckpt/PPO1_CartPole-v0/0-log-0.pt")
        trainer.train()

        rmtree("logs")
Exemplo n.º 12
0
    def test_save_params(self):
        """
        test saving algorithm state dict
        """
        env = VectorEnv("CartPole-v0", 1)
        algo = PPO1("mlp", env)
        trainer = OnPolicyTrainer(algo,
                                  env, ["stdout"],
                                  save_model="test_ckpt",
                                  save_interval=1,
                                  epochs=1)
        trainer.train()

        assert len(os.listdir("test_ckpt/PPO1_CartPole-v0")) != 0