def test_ppo1(self): env = VectorEnv("CartPole-v0") algo = PPO1("mlp", env, rollout_size=128) trainer = OnPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", epochs=1 ) trainer.train() shutil.rmtree("./logs")
def test_ppo1_cnn(self): env = VectorEnv("Pong-v0", env_type="atari") algo = PPO1("cnn", env, rollout_size=128) trainer = OnPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", epochs=1 ) trainer.train() shutil.rmtree("./logs")
def test_vpg(): env = VectorEnv("CartPole-v0") algo = VPG("mlp", env) trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs")
def test_a2c_shared(): env = VectorEnv("CartPole-v0", 1) algo = A2C("mlp", env, shared_layers=(32, 32), rollout_size=128) trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs")
def test_save_params(): """ test saving algorithm state dict """ env = VectorEnv("CartPole-v0", 1) algo = PPO1("mlp", env) trainer = OnPolicyTrainer(algo, env, ["stdout"], save_model="test_ckpt", save_interval=1, epochs=1) trainer.train() assert len(os.listdir("test_ckpt/PPO1_CartPole-v0")) != 0
def test_load_params(): """ test loading algorithm parameters """ env = VectorEnv("CartPole-v0", 1) algo = PPO1("mlp", env) trainer = OnPolicyTrainer( algo, env, epochs=0, load_model="test_ckpt/PPO1_CartPole-v0/0-log-0.pt") trainer.train() rmtree("logs")
def test_custom_ppo1(self): env = VectorEnv("CartPole-v0", 1) state_dim = env.observation_space.shape[0] action_dim = env.action_space.n actorcritic = custom_actorcritic(state_dim, action_dim) algo = PPO1(actorcritic, env) trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs")
def test_custom_vpg(self): env = VectorEnv("CartPole-v0", 1) state_dim = env.observation_space.shape[0] action_dim = env.action_space.n policy = custom_policy(state_dim, action_dim) algo = VPG(policy, env) trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs")
def test_load_params(self): """ test loading algorithm parameters """ env = VectorEnv("CartPole-v0", 1) algo = PPO1("mlp", env) trainer = OnPolicyTrainer( algo, env, epochs=0, load_hyperparams="test_ckpt/PPO1_CartPole-v0/0-log-0.toml", load_weights="test_ckpt/PPO1_CartPole-v0/0-log-0.pt", ) trainer.train() rmtree("logs")
def main(args): ALGOS = { "sac": SAC, "a2c": A2C, "ppo": PPO1, "ddpg": DDPG, "td3": TD3, "vpg": VPG, "dqn": DQN, } algo = ALGOS[args.algo.lower()] env = VectorEnv(args.env, n_envs=args.n_envs, parallel=not args.serial, env_type=args.env_type) logger = get_logger(args.log) trainer = None if args.algo in ["ppo", "vpg", "a2c"]: agent = algo( args.arch, env, rollout_size=args.rollout_size) # , batch_size=args.batch_size) trainer = OnPolicyTrainer( agent, env, logger, epochs=args.epochs, render=args.render, log_interval=args.log_interval, ) else: agent = algo(args.arch, env, replay_size=args.replay_size, batch_size=args.batch_size) trainer = OffPolicyTrainer( agent, env, logger, epochs=args.epochs, render=args.render, warmup_steps=args.warmup_steps, log_interval=args.log_interval, ) trainer.train() trainer.evaluate() env.render()
def test_a2c_discrete(self): env = VectorEnv("CartPole-v0", 1) algo = A2C("mlp", env, rollout_size=128) trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() trainer.evaluate() shutil.rmtree("./logs")
def test_on_policy_trainer(): env = VectorEnv("CartPole-v1", 2) algo = PPO1("mlp", env, rollout_size=128) trainer = OnPolicyTrainer(algo, env, ["stdout"], epochs=2, evaluate_episodes=2, max_timesteps=300) assert not trainer.off_policy trainer.train() trainer.evaluate()