def test_sac_shared(): env = VectorEnv("Pendulum-v0", 2) algo = SAC( "mlp", env, batch_size=5, shared_layers=[1, 1], policy_layers=[1, 1], value_layers=[1, 1], max_timesteps=100, ) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", epochs=5, max_ep_len=500, warmup_steps=10, start_update=10, max_timesteps=100, ) trainer.train() shutil.rmtree("./logs")
def test_atari_env(self): """ Tests working of Atari Wrappers and the AtariEnv function """ env = VectorEnv("Pong-v0", env_type="atari") algo = DQN("cnn", env, replay_size=100) trainer = OffPolicyTrainer(algo, env, epochs=1, max_timesteps=50) trainer.train() shutil.rmtree("./logs")
def test_double_dqn(self): env = VectorEnv("CartPole-v0") algo = DoubleDQN("mlp", env, batch_size=5, replay_size=100) assert isinstance(algo.model, MlpValue) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", max_ep_len=200, epochs=4, warmup_steps=10, start_update=10, ) trainer.train() shutil.rmtree("./logs")
def test_prioritized_dqn(self): env = VectorEnv("Pong-v0", env_type="atari") algo = PrioritizedReplayDQN("cnn", env, batch_size=5, replay_size=100) assert isinstance(algo.model, CnnValue) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", max_ep_len=200, epochs=4, warmup_steps=10, start_update=10, ) trainer.train() shutil.rmtree("./logs")
def test_noisy_dqn(self): env = VectorEnv("CartPole-v0") algo = NoisyDQN("mlp", env, batch_size=5, replay_size=100, value_layers=[1, 1]) assert algo.dqn_type == "noisy" assert algo.noisy assert isinstance(algo.model, MlpNoisyValue) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", max_ep_len=200, epochs=4, warmup_steps=10, start_update=10, ) trainer.train() shutil.rmtree("./logs")
def test_atari_env(self): """ Tests working of Atari Wrappers and the AtariEnv function """ env = VectorEnv("Pong-v0", env_type="atari") algo = DQN("cnn", env, batch_size=5, replay_size=100, value_layers=[1, 1]) trainer = OffPolicyTrainer(algo, env, epochs=5, max_ep_len=200, warmup_steps=10, start_update=10) trainer.train() shutil.rmtree("./logs")
def test_dueling_dqn(self): env = VectorEnv("Pong-v0", env_type="atari") algo = DuelingDQN( "cnn", env, batch_size=5, replay_size=100, value_layers=[1, 1] ) assert algo.dqn_type == "dueling" assert isinstance(algo.model, CnnDuelingValue) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", max_ep_len=200, epochs=4, warmup_steps=10, start_update=10, max_timesteps=100, ) trainer.train() shutil.rmtree("./logs")
def test_prioritized_dqn(self): env = VectorEnv("CartPole-v0") algo = PrioritizedReplayDQN("mlp", env, batch_size=5, replay_size=100, value_layers=[1, 1]) assert isinstance(algo.model, MlpValue) assert isinstance(algo.replay_buffer, PrioritizedBuffer) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", max_ep_len=200, epochs=4, warmup_steps=10, start_update=10, ) trainer.train() shutil.rmtree("./logs")
def main(args): ALGOS = { "sac": SAC, "a2c": A2C, "ppo": PPO1, "ddpg": DDPG, "td3": TD3, "vpg": VPG, "dqn": DQN, } algo = ALGOS[args.algo.lower()] env = VectorEnv(args.env, n_envs=args.n_envs, parallel=not args.serial, env_type=args.env_type) logger = get_logger(args.log) trainer = None if args.algo in ["ppo", "vpg", "a2c"]: agent = algo( args.arch, env, rollout_size=args.rollout_size) # , batch_size=args.batch_size) trainer = OnPolicyTrainer( agent, env, logger, epochs=args.epochs, render=args.render, log_interval=args.log_interval, ) else: agent = algo(args.arch, env, replay_size=args.replay_size, batch_size=args.batch_size) trainer = OffPolicyTrainer( agent, env, logger, epochs=args.epochs, render=args.render, warmup_steps=args.warmup_steps, log_interval=args.log_interval, ) trainer.train() trainer.evaluate() env.render()
def test_td3(): env = VectorEnv("Pendulum-v0", 2) algo = TD3( "mlp", env, batch_size=5, noise=OrnsteinUhlenbeckActionNoise, policy_layers=[1, 1], value_layers=[1, 1], ) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", epochs=5, max_ep_len=500, warmup_steps=10, start_update=10, ) trainer.train() shutil.rmtree("./logs")
def test_ddpg(): env = VectorEnv("Pendulum-v0", 2) algo = DDPG( "mlp", env, batch_size=5, noise=NormalActionNoise, policy_layers=[1, 1], value_layers=[1, 1], ) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", epochs=4, max_ep_len=200, warmup_steps=10, start_update=10, ) trainer.train() shutil.rmtree("./logs")
def test_off_policy_trainer(): env = VectorEnv("Pendulum-v0", 2) algo = DDPG("mlp", env, replay_size=100) trainer = OffPolicyTrainer( algo, env, ["stdout"], epochs=2, evaluate_episodes=2, max_ep_len=300, max_timesteps=300, ) assert trainer.off_policy trainer.train() trainer.evaluate()