def example2(): env = ResetCostWrapper(make_env("Hopper-v3")) print(env) def actor(*arg, **kwargs): return env.action_space.sample() interactions = TransitionGenerator(env, actor, max_step=1005) for step, state, next_state, action, reward, done, info in interactions: print(step, done, state[0][0], next_state[0][0], reward, info["is_terminal_state"])
def example1(): env = make_env("Hopper-v3") env = ResetCostWrapper(env, reset_cost=float("nan"), terminal_step=300) print(env) step = 0 env.reset() while True: action = env.action_space.sample() next_state, reward, done, info = env.step(action) print(info) step += 1 if done: break print(step)
def train_sac(): parser = argparse.ArgumentParser() parser.add_argument("--env_id", default="HalfCheetah-v3", type=str) parser.add_argument("--seed", default=None, type=int) parser.add_argument("--num_envs", default=1, type=int) parser.add_argument("--max_step", default=10 ** 6, type=int) parser.add_argument("--eval_interval", type=int, default=10 ** 4) parser.add_argument("--num_evaluate", type=int, default=10) parser.add_argument("--logging_interval", type=int, default=10 ** 3) parser.add_argument("--gamma", default=0.99, type=float) parser.add_argument("--replay_start_size", default=10 ** 4, type=int) parser.add_argument("--num_videos", type=int, default=3) parser.add_argument("--log_level", type=int, default=logging.INFO) args = parser.parse_args() wandb.init(project="rl_algos_example", name="soft_actor_critic", tags=[args.env_id]) wandb.config.update(args) logging.basicConfig(level=args.log_level) logger = logging.getLogger(__name__) # fix seed manual_seed(args.seed) # make environment env = vectorize_env(env_id=args.env_id, num_envs=args.num_envs, seed=args.seed) dim_state = env.observation_space.shape[-1] dim_action = env.action_space.shape[-1] logger.info(f"env = {env}") logger.info(f"dim_state = {dim_state}") logger.info(f"dim_action = {dim_action}") logger.info(f"action_space = {env.action_space}") logger.info(f"max_episode_steps = {env.spec.max_episode_steps}") # make agent agent = SAC( dim_state=dim_state, dim_action=dim_action, gamma=args.gamma, replay_start_size=args.replay_start_size, ) evaluator = Evaluator( env=make_env(args.env_id, args.seed), eval_interval=args.eval_interval, num_evaluate=args.num_evaluate, ) recoder = ( Recoder( env=make_env(args.env_id, args.seed), record_interval=args.max_step // args.num_videos, ) if args.num_videos > 0 else None ) agent = training( env=env, agent=agent, max_steps=args.max_step, logging_interval=args.logging_interval, recorder=recoder, evaluator=evaluator, )
def train_td3(): parser = argparse.ArgumentParser() parser.add_argument("--env_id", type=str, default="HalfCheetah-v3") parser.add_argument("--seed", type=int, default=None) parser.add_argument("--num_envs", type=int, default=1) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--policy_update_delay", type=int, default=2) parser.add_argument("--max_step", type=int, default=10**6) parser.add_argument("--eval_interval", type=int, default=10**4) parser.add_argument("--logging_interval", type=int, default=10**3) parser.add_argument("--num_evaluate", type=int, default=10) parser.add_argument("--num_videos", type=int, default=3) parser.add_argument("--log_level", type=int, default=logging.INFO) args = parser.parse_args() wandb.init(project="td3", tags=["td3", args.env_id], config=args) wandb.config.update(args) logging.basicConfig(level=args.log_level) logger = logging.getLogger(__name__) manual_seed(args.seed) env = vectorize_env(env_id=args.env_id, num_envs=args.num_envs, seed=args.seed) dim_state = env.observation_space.shape[-1] dim_action = env.action_space.shape[-1] logger.info(f"env = {env}") logger.info(f"dim_state = {dim_state}") logger.info(f"dim_action = {dim_action}") logger.info(f"action_space = {env.action_space}") logger.info(f"max_episode_steps = {env.spec.max_episode_steps}") agent = TD3( dim_state=dim_state, dim_action=dim_action, gamma=args.gamma, batch_size=args.batch_size, policy_update_delay=args.policy_update_delay, policy_optimizer_kwargs={"lr": args.lr}, q_optimizer_kwargs={"lr": args.lr}, ) evaluator = Evaluator( env=make_env(args.env_id, args.seed), eval_interval=args.eval_interval, num_evaluate=args.num_evaluate, ) recoder = (Recoder( env=make_env(args.env_id, args.seed), record_interval=args.max_step // args.num_videos, ) if args.num_videos > 0 else None) agent = training( env=env, agent=agent, max_steps=args.max_step, logging_interval=args.logging_interval, recorder=recoder, evaluator=evaluator, )
def train_trpo(): parser = argparse.ArgumentParser() parser.add_argument("--env_id", type=str, default="Hopper-v3") parser.add_argument("--seed", type=int, default=None) parser.add_argument("--num_envs", type=int, default=5) parser.add_argument("--update_interval", type=int, default=None) parser.add_argument("--gamma", type=float, default=0.995) parser.add_argument("--lambd", type=float, default=0.97) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--vf_epoch", type=int, default=5) parser.add_argument("--vf_batch_size", type=int, default=64) parser.add_argument("--conjugate_gradient_damping", type=float, default=1e-1) parser.add_argument("--use_state_normalizer", action="store_true") parser.add_argument("--max_step", type=int, default=10**6) parser.add_argument("--eval_interval", type=int, default=5 * 10**4) parser.add_argument("--num_evaluate", type=int, default=10) parser.add_argument("--num_videos", type=int, default=3) parser.add_argument("--log_level", type=int, default=logging.INFO) args = parser.parse_args() wandb.init(project="trpo", tags=["trpo", args.env_id], config=args) logging.basicConfig(level=args.log_level) logger = logging.getLogger(__name__) manual_seed(args.seed) env = vectorize_env(env_id=args.env_id, num_envs=args.num_envs, seed=args.seed) dim_state = env.observation_space.shape[-1] dim_action = env.action_space.shape[-1] logger.info(f"env = {env}") logger.info(f"dim_state = {dim_state}") logger.info(f"dim_action = {dim_action}") logger.info(f"action_space = {env.action_space}") logger.info(f"max_episode_steps = {env.spec.max_episode_steps}") agent = TRPO( dim_state, dim_action, gamma=args.gamma, lambd=args.lambd, entropy_coef=0.0, vf_epoch=args.vf_epoch, vf_batch_size=args.vf_batch_size, vf_optimizer_kwargs={"lr": args.lr}, state_normalizer=ZScoreFilter(dim_state) if args.use_state_normalizer else None, conjugate_gradient_damping=args.conjugate_gradient_damping, update_interval=args.num_envs * 1000 if args.update_interval is None else args.update_interval, ) evaluator = Evaluator( env=make_env(args.env_id, args.seed), eval_interval=args.eval_interval, num_evaluate=args.num_evaluate, ) recoder = (Recoder( env=make_env(args.env_id, args.seed), record_interval=args.max_step // args.num_videos, ) if args.num_videos > 0 else None) agent = training( env=env, agent=agent, max_steps=args.max_step, recorder=recoder, evaluator=evaluator, )
def _make_env(*env_args, **env_kwargs): return ResetCostWrapper(make_env(*env_args, **env_kwargs))
def _make_env(*args, **kwargs): return ResetCostWrapper(make_env(*args, **kwargs))