def test_inverted_pendulum_v2(self): env_name = "InvertedPendulum-v2" env, (policy, optimizer), _ = create_models(env_name, self.hidden_sizes, self.lr) baseline = FutureReturnBaseline() policy_update = REINFORCE(policy, optimizer, baseline) result = solve(env_name, env, policy_update, logdir) self.assertEqual(result, True)
def test_lunar_lander_v2(self): env_name = "LunarLander-v2" env, (policy, optimizer), _ = create_models(env_name, self.hidden_sizes, self.lr) baseline = FullReturnBaseline() policy_update = REINFORCE(policy, optimizer, baseline) result = solve(env_name, env, policy_update, logdir, epochs=500) self.assertEqual(result, True)
def test_cartpole_v1(self): env_name = "CartPole-v1" env, (policy, optimizer), _ = create_models(env_name, self.hidden_sizes, self.lr) baseline = FutureReturnBaseline() policy_update = REINFORCE(policy, optimizer, baseline) result = solve(env_name, env, policy_update, logdir) self.assertEqual(result, True)
parser.add_argument("--num-envs", type=int, default=multiprocessing.cpu_count() - 1) parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--batch-size", type=int, default=5000) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--lam", type=float, default=0.97) parser.add_argument("--lr", type=float, default=1e-2) args = parser.parse_args() logger.info("Using vanilla formulation of policy gradient.") hidden_sizes = [100] lr = 1e-2 env = make_env(args.env_name, args.num_envs) (policy, optimizer), (value, vopt) = create_models(env, hidden_sizes, args.lr, args.lr) baseline = GAEBaseline(value, gamma=args.gamma, lambda_=args.lam) policy_update = REINFORCE(policy, optimizer, baseline) vbaseline = DiscountedReturnBaseline(gamma=args.gamma, normalize=False) value_update = ValueUpdate(value, vopt, vbaseline, iters=1) update = ActorCriticUpdate(policy_update, value_update) solve( args.env_name, env, policy_update, logdir, epochs=args.epochs, batch_size=args.batch_size, )
dist = policy(obs) log_probs = dist.log_prob(acts) loss = -((weights * log_probs).mean()) return loss if __name__ == "__main__": import argparse from rl_baselines.core import ( solve, create_models, FullReturnBaseline, FutureReturnBaseline, ) parser = argparse.ArgumentParser() parser.add_argument("--env-name", "--env", type=str, default="CartPole-v0") parser.add_argument("--epochs", type=int, default=100) parser.add_argument("--render", action="store_true") parser.add_argument("--lr", type=float, default=1e-2) args = parser.parse_args() logger.info("Using vanilla formulation of policy gradient.") hidden_sizes = [100] lr = 1e-2 env, (policy, optimizer), _ = create_models(args.env_name, hidden_sizes, lr) baseline = FutureReturnBaseline() policy_update = REINFORCE(policy, optimizer, baseline) solve(args.env_name, env, policy_update, logdir, epochs=args.epochs)
def test_cartpole_v0(self): env_name = "CartPole-v0" env = make_env(env_name, 1) policy_update = self.get_update_model(env) result = solve(env_name, env, policy_update, logdir) self.assertEqual(result, True)
def test_lunar_lander_v2(self): env_name = "LunarLander-v2" env = make_env(env_name, 1) policy_update = self.get_update_model(env) result = solve(env_name, env, policy_update, logdir, epochs=500) self.assertEqual(result, True)