Exemple #1
0
 def test_a2c_discrete(self):
     env = VectorEnv("CartPole-v0", 1)
     algo = A2C("mlp", env, rollout_size=128)
     trainer = OnPolicyTrainer(algo,
                               env,
                               log_mode=["csv"],
                               logdir="./logs",
                               epochs=1)
     trainer.train()
     trainer.evaluate()
     shutil.rmtree("./logs")
Exemple #2
0
def test_on_policy_trainer():
    env = VectorEnv("CartPole-v1", 2)
    algo = PPO1("mlp", env, rollout_size=128)
    trainer = OnPolicyTrainer(algo,
                              env, ["stdout"],
                              epochs=2,
                              evaluate_episodes=2,
                              max_timesteps=300)
    assert not trainer.off_policy
    trainer.train()
    trainer.evaluate()
Exemple #3
0
def main(args):
    ALGOS = {
        "sac": SAC,
        "a2c": A2C,
        "ppo": PPO1,
        "ddpg": DDPG,
        "td3": TD3,
        "vpg": VPG,
        "dqn": DQN,
    }

    algo = ALGOS[args.algo.lower()]
    env = VectorEnv(args.env,
                    n_envs=args.n_envs,
                    parallel=not args.serial,
                    env_type=args.env_type)

    logger = get_logger(args.log)
    trainer = None

    if args.algo in ["ppo", "vpg", "a2c"]:
        agent = algo(
            args.arch, env,
            rollout_size=args.rollout_size)  # , batch_size=args.batch_size)
        trainer = OnPolicyTrainer(
            agent,
            env,
            logger,
            epochs=args.epochs,
            render=args.render,
            log_interval=args.log_interval,
        )

    else:
        agent = algo(args.arch,
                     env,
                     replay_size=args.replay_size,
                     batch_size=args.batch_size)
        trainer = OffPolicyTrainer(
            agent,
            env,
            logger,
            epochs=args.epochs,
            render=args.render,
            warmup_steps=args.warmup_steps,
            log_interval=args.log_interval,
        )

    trainer.train()
    trainer.evaluate()
    env.render()