def test_ppo1(self): env = VectorEnv("CartPole-v0") algo = PPO1("mlp", env, rollout_size=128) trainer = OnPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", epochs=1 ) trainer.train() shutil.rmtree("./logs")
def test_ppo1_cnn(self): env = VectorEnv("Pong-v0", env_type="atari") algo = PPO1("cnn", env, rollout_size=128) trainer = OnPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", epochs=1 ) trainer.train() shutil.rmtree("./logs")
def test_on_policy_trainer(): env = VectorEnv("CartPole-v1", 2) algo = PPO1("mlp", env, rollout_size=128) trainer = OnPolicyTrainer(algo, env, ["stdout"], epochs=2, evaluate_episodes=2, max_timesteps=300) assert not trainer.off_policy trainer.train() trainer.evaluate()
def test_save_params(): """ test saving algorithm state dict """ env = VectorEnv("CartPole-v0", 1) algo = PPO1("mlp", env) trainer = OnPolicyTrainer(algo, env, ["stdout"], save_model="test_ckpt", save_interval=1, epochs=1) trainer.train() assert len(os.listdir("test_ckpt/PPO1_CartPole-v0")) != 0
def test_load_params(): """ test loading algorithm parameters """ env = VectorEnv("CartPole-v0", 1) algo = PPO1("mlp", env) trainer = OnPolicyTrainer( algo, env, epochs=0, load_model="test_ckpt/PPO1_CartPole-v0/0-log-0.pt") trainer.train() rmtree("logs")
def test_custom_ppo1(self): env = VectorEnv("CartPole-v0", 1) state_dim = env.observation_space.shape[0] action_dim = env.action_space.n actorcritic = custom_actorcritic(state_dim, action_dim) algo = PPO1(actorcritic, env) trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs")
def test_load_params(self): """ test loading algorithm parameters """ env = VectorEnv("CartPole-v0", 1) algo = PPO1("mlp", env) trainer = OnPolicyTrainer( algo, env, epochs=0, load_hyperparams="test_ckpt/PPO1_CartPole-v0/0-log-0.toml", load_weights="test_ckpt/PPO1_CartPole-v0/0-log-0.pt", ) trainer.train() rmtree("logs")