def test_gym_env(self): """ Tests working of Gym Wrapper and the GymEnv function """ env = VectorEnv("Pendulum-v0", env_type="gym") env.reset() env.step(env.sample()) env.close()
def main(args): ALGOS = { "sac": SAC, "a2c": A2C, "ppo": PPO1, "ddpg": DDPG, "td3": TD3, "vpg": VPG, "dqn": DQN, } algo = ALGOS[args.algo.lower()] env = VectorEnv(args.env, n_envs=args.n_envs, parallel=not args.serial, env_type=args.env_type) logger = get_logger(args.log) trainer = None if args.algo in ["ppo", "vpg", "a2c"]: agent = algo( args.arch, env, rollout_size=args.rollout_size) # , batch_size=args.batch_size) trainer = OnPolicyTrainer( agent, env, logger, epochs=args.epochs, render=args.render, log_interval=args.log_interval, ) else: agent = algo(args.arch, env, replay_size=args.replay_size, batch_size=args.batch_size) trainer = OffPolicyTrainer( agent, env, logger, epochs=args.epochs, render=args.render, warmup_steps=args.warmup_steps, log_interval=args.log_interval, ) trainer.train() trainer.evaluate() env.render()
def main(args): env = VectorEnv( args.env, n_envs=args.n_envs, parallel=not args.serial, env_type=args.env_type ) input_dim, action_dim, discrete, action_lim = get_env_properties(env, "mlp") network = MlpActorCritic( input_dim, action_dim, (1, 1), # layers (1, 1), "V", # type of value function discrete, action_lim=action_lim, activation="relu", ) generic_agent = A2C(network, env, rollout_size=args.rollout_size) agent_parameter_choices = { "gamma": [12, 121], # 'clip_param': [0.2, 0.3], # 'lr_policy': [0.001, 0.002], # 'lr_value': [0.001, 0.002] } generate( args.generations, args.population, agent_parameter_choices, env, generic_agent, args, )
def test_ddpg(): env = VectorEnv("Pendulum-v0", 2) algo = DDPG("mlp", env, noise=NormalActionNoise, layers=[1, 1]) trainer = OffPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs")
def test_off_policy_trainer(): env = VectorEnv("Pendulum-v0", 2) algo = DDPG("mlp", env, replay_size=100) trainer = OffPolicyTrainer(algo, env, ["stdout"], epochs=1, evaluate_episodes=2) assert trainer.off_policy trainer.train() trainer.evaluate()
def test_sac_shared(): env = VectorEnv("Pendulum-v0", 2) algo = SAC( "mlp", env, batch_size=5, shared_layers=[1, 1], policy_layers=[1, 1], value_layers=[1, 1], max_timesteps=100, ) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", epochs=5, max_ep_len=500, warmup_steps=10, start_update=10, max_timesteps=100, ) trainer.train() shutil.rmtree("./logs")
def test_on_policy_trainer(): env = VectorEnv("CartPole-v1", 2) algo = PPO1("mlp", env) trainer = OnPolicyTrainer(algo, env, ["stdout"], epochs=1, evaluate_episodes=2) assert not trainer.off_policy trainer.train() trainer.evaluate()
def test_sac(): env = VectorEnv("Pendulum-v0", 2) algo = SAC("mlp", env, layers=[1, 1]) trainer = OffPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs")
def test_ppo1(self): env = VectorEnv("CartPole-v0") algo = PPO1("mlp", env, rollout_size=128) trainer = OnPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", epochs=1 ) trainer.train() shutil.rmtree("./logs")
def test_ppo1_cnn(self): env = VectorEnv("Pong-v0", env_type="atari") algo = PPO1("cnn", env, rollout_size=128) trainer = OnPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", epochs=1 ) trainer.train() shutil.rmtree("./logs")
def test_get_env_properties(self): """ test getting environment properties """ env = VectorEnv("CartPole-v0", 1) state_dim, action_dim, discrete, _ = get_env_properties(env) assert state_dim == 4 assert action_dim == 2 assert discrete is True env = VectorEnv("Pendulum-v0", 1) state_dim, action_dim, discrete, action_lim = get_env_properties(env) assert state_dim == 3 assert action_dim == 1 assert discrete is False assert action_lim == 2.0
def test_vpg(self): env = VectorEnv("CartPole-v0", 2) algo = VPG("mlp", env, layers=[1, 1]) trainer = OnPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", epochs=1, evaluate_episodes=2 ) trainer.train() trainer.evaluate() shutil.rmtree("./logs")
def test_atari_env(self): """ Tests working of Atari Wrappers and the AtariEnv function """ env = VectorEnv("Pong-v0", env_type="atari") algo = DQN("cnn", env) trainer = OffPolicyTrainer(algo, env, epochs=1, steps_per_epoch=200) trainer.train() shutil.rmtree("./logs")
def test_vpg_cnn(): env = VectorEnv("Pong-v0", 1, env_type="atari") algo = VPG("cnn", env) trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs")
def test_vpg(): env = VectorEnv("CartPole-v0", 1) algo = VPG("mlp", env) trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs")
def test_td3(self): env = VectorEnv("Pendulum-v0", 2) algo = TD3("mlp", env, noise=OrnsteinUhlenbeckActionNoise, layers=[1, 1]) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", epochs=1, evaluate_episodes=2 ) trainer.train() trainer.evaluate() shutil.rmtree("./logs")
def test_a2c_shared(): env = VectorEnv("CartPole-v0", 1) algo = A2C("mlp", env, shared_layers=(32, 32), rollout_size=128) trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs")
def test_a2c_discrete(self): env = VectorEnv("CartPole-v0", 1) algo = A2C("mlp", env, rollout_size=128) trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() trainer.evaluate() shutil.rmtree("./logs")
def test_on_policy_trainer(): env = VectorEnv("CartPole-v1", 2) algo = PPO1("mlp", env, rollout_size=128) trainer = OnPolicyTrainer(algo, env, ["stdout"], epochs=2, evaluate_episodes=2, max_timesteps=300) assert not trainer.off_policy trainer.train() trainer.evaluate()
def test_dqn(): env = VectorEnv("CartPole-v0", 2) # DQN algo = DQN("mlp", env) trainer = OffPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs")
def test_dqn(): env = VectorEnv("CartPole-v0", 2) # DQN algo = DQN("mlp", env) trainer = OffPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs") # Double DQN with prioritized replay buffer # algo1 = DQN("mlp", env, double_dqn=True, prioritized_replay=True) # trainer = OffPolicyTrainer(algo1, env, log_mode=["csv"], logdir="./logs", epochs=1) # trainer.train() # shutil.rmtree("./logs") # Noisy DQN algo2 = DQN("mlp", env, noisy_dqn=True) trainer = OffPolicyTrainer(algo2, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs") # Dueling DDQN algo3 = DQN("mlp", env, dueling_dqn=True, double_dqn=True) trainer = OffPolicyTrainer(algo3, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs") # Categorical DQN algo4 = DQN("mlp", env, categorical_dqn=True) trainer = OffPolicyTrainer(algo4, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs")
def test_load_params(self): """ test loading algorithm parameters """ env = VectorEnv("CartPole-v0", 1) algo = PPO1( "mlp", env, epochs=1, load_model="test_ckpt/PPO1_CartPole-v0/0-log-0.pt", ) rmtree("logs")
def test_double_dqn_cnn(): env = VectorEnv("Pong-v0", n_envs=2, env_type="atari") # Double DQN with prioritized replay buffer algo = DQN("cnn", env, double_dqn=True, prioritized_replay=True) trainer = OffPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200) trainer.train() shutil.rmtree("./logs")
def test_categorical_dqn_cnn(): env = VectorEnv("Pong-v0", n_envs=2, env_type="atari") # Categorical DQN algo = DQN("cnn", env, categorical_dqn=True) trainer = OffPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1, steps_per_epoch=200) trainer.train() shutil.rmtree("./logs")
def test_save_params(self): """ test saving algorithm state dict """ env = VectorEnv("CartPole-v0", 1) algo = PPO1("mlp", env) trainer = OnPolicyTrainer(algo, env, ["stdout"], save_model="test_ckpt", save_interval=1, epochs=1) trainer.train() assert len(os.listdir("test_ckpt/PPO1_CartPole-v0")) != 0
def test_load_params(self): """ test loading algorithm parameters """ env = VectorEnv("CartPole-v0", 1) algo = PPO1("mlp", env) trainer = OnPolicyTrainer( algo, env, epochs=0, load_model="test_ckpt/PPO1_CartPole-v0/0-log-0.pt") trainer.train() rmtree("logs")
def test_custom_ppo1(self): env = VectorEnv("CartPole-v0", 1) state_dim = env.observation_space.shape[0] action_dim = env.action_space.n actorcritic = custom_actorcritic(state_dim, action_dim) algo = PPO1(actorcritic, env) trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs")
def test_custom_vpg(self): env = VectorEnv("CartPole-v0", 1) state_dim = env.observation_space.shape[0] action_dim = env.action_space.n policy = custom_policy(state_dim, action_dim) algo = VPG(policy, env) trainer = OnPolicyTrainer(algo, env, log_mode=["csv"], logdir="./logs", epochs=1) trainer.train() shutil.rmtree("./logs")
def test_off_policy_trainer(self): env = VectorEnv("Pendulum-v0", 2) algo = DDPG("mlp", env, replay_size=100) trainer = OffPolicyTrainer( algo, env, ["stdout"], epochs=2, evaluate_episodes=2, max_ep_len=300, max_timesteps=300, ) assert trainer.off_policy trainer.train() trainer.evaluate()
def test_double_dqn(self): env = VectorEnv("CartPole-v0") algo = DoubleDQN("mlp", env, batch_size=5, replay_size=100) assert isinstance(algo.model, MlpValue) trainer = OffPolicyTrainer( algo, env, log_mode=["csv"], logdir="./logs", max_ep_len=200, epochs=4, warmup_steps=10, start_update=10, ) trainer.train() shutil.rmtree("./logs")