def test_on_policy_trainer(): logger = Logger() env = gym.make("CartPole-v1") algo = PPO1("mlp", env) trainer = OnPolicyTrainer(algo, env, logger, epochs=1) assert trainer.off_policy == False trainer.train()
def test_on_policy_trainer(): env = VectorEnv("CartPole-v1", 2) algo = PPO1("mlp", env) trainer = OnPolicyTrainer(algo, env, ["stdout"], epochs=1, evaluate_episodes=2) assert not trainer.off_policy trainer.train() trainer.evaluate()
def test_evaluate(self): """ test evaluating trained algorithm """ env = gym.make("CartPole-v0") algo = PPO1("mlp", env, epochs=1) algo.learn() evaluate(algo, num_timesteps=10)
def test_ppo1(self): env = gym.make("Pendulum-v0") algo = PPO1("mlp", env, layers=[1, 1]) logger = Logger("./logs", ["csv"]) trainer = OnPolicyTrainer(algo, env, logger, epochs=1, render=False) trainer.train() shutil.rmtree("./logs")
def test_save_params(self): """ test saving algorithm state dict """ env = gym.make("CartPole-v0") algo = PPO1("mlp", env, epochs=1, save_model="test_ckpt") algo.learn() assert len(os.listdir("test_ckpt/PPO1_CartPole-v0")) != 0
def test_ppo1(self): env = gym.make("Pendulum-v0") algo = PPO1("mlp", env, layers=[1, 1]) trainer = OnPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", epochs=1, evaluate_episodes=2 ) trainer.train() trainer.evaluate() shutil.rmtree("./logs")
def test_ppo1(): env = VectorEnv("CartPole-v0", 1) algo = PPO1("mlp", env) trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs")
def test_ppo1_cnn(): env = VectorEnv("Pong-v0", 1, env_type="atari") algo = PPO1("cnn", env) trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs")
def test_load_params(self): """ test loading algorithm parameters """ env = gym.make("CartPole-v0") algo = PPO1("mlp", env, epochs=1, load_model="test_ckpt/PPO1_CartPole-v0/0-log-0.pt") rmtree("test_ckpt")
def test_load_params(self): """ test loading algorithm parameters """ env = VectorEnv("CartPole-v0", 1) algo = PPO1( "mlp", env, epochs=1, load_model="test_ckpt/PPO1_CartPole-v0/0-log-0.pt", ) rmtree("logs")
def test_load_params(self): """ test loading algorithm parameters """ env = VectorEnv("CartPole-v0", 1) algo = PPO1("mlp", env) trainer = OnPolicyTrainer( algo, env, epochs=0, load_model="test_ckpt/PPO1_CartPole-v0/0-log-0.pt") trainer.train() rmtree("logs")
def test_save_params(self): """ test saving algorithm state dict """ env = VectorEnv("CartPole-v0", 1) algo = PPO1("mlp", env) trainer = OnPolicyTrainer(algo, env, ["stdout"], save_model="test_ckpt", save_interval=1, epochs=1) trainer.train() assert len(os.listdir("test_ckpt/PPO1_CartPole-v0")) != 0