def test_a2c_discrete(self): env = VectorEnv("CartPole-v0", 1) algo = A2C("mlp", env, rollout_size=128) trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() trainer.evaluate() shutil.rmtree("./logs")
def test_on_policy_trainer(): env = VectorEnv("CartPole-v1", 2) algo = PPO1("mlp", env, rollout_size=128) trainer = OnPolicyTrainer(algo, env, ["stdout"], epochs=2, evaluate_episodes=2, max_timesteps=300) assert not trainer.off_policy trainer.train() trainer.evaluate()
def main(args): ALGOS = { "sac": SAC, "a2c": A2C, "ppo": PPO1, "ddpg": DDPG, "td3": TD3, "vpg": VPG, "dqn": DQN, } algo = ALGOS[args.algo.lower()] env = VectorEnv(args.env, n_envs=args.n_envs, parallel=not args.serial, env_type=args.env_type) logger = get_logger(args.log) trainer = None if args.algo in ["ppo", "vpg", "a2c"]: agent = algo( args.arch, env, rollout_size=args.rollout_size) # , batch_size=args.batch_size) trainer = OnPolicyTrainer( agent, env, logger, epochs=args.epochs, render=args.render, log_interval=args.log_interval, ) else: agent = algo(args.arch, env, replay_size=args.replay_size, batch_size=args.batch_size) trainer = OffPolicyTrainer( agent, env, logger, epochs=args.epochs, render=args.render, warmup_steps=args.warmup_steps, log_interval=args.log_interval, ) trainer.train() trainer.evaluate() env.render()