def example_replay_buffer(): env = vectorize_env("Swimmer-v3", 1) def actor(state): return env.action_space.sample() interactions = TransitionGenerator(env, actor, max_step=1000) buffer = ReplayBuffer(10 ** 4) for steps, states, next_states, actions, rewards, dones, info in interactions: terminals = is_state_terminal(env, steps, dones) for id, (state, next_state, action, reward, terminal, done) in enumerate( zip(states, next_states, actions, rewards, terminals, dones) ): buffer.append( id=id, state=state, next_state=next_state, action=action, reward=reward, terminal=terminal, reset=done, ) s = buffer.sample(10) batch = TrainingBatch(**s, device="cuda") print(batch.reward) print(batch[0:5].reward) batch[0:5].reward += 100.0 print(batch.reward)
def example_episode_buffer(): env = vectorize_env(env_id="Hopper-v3", num_envs=3) def actor(state): return env.action_space.sample() interactions = TransitionGenerator(env, actor, max_step=300) buffer = EpisodeBuffer() for steps, states, next_states, actions, rewards, dones, info in interactions: terminals = is_state_terminal(env, steps, dones) for id, (state, next_state, action, reward, terminal, done) in enumerate( zip(states, next_states, actions, rewards, terminals, dones) ): buffer.append( id=id, state=state, next_state=next_state, action=action, reward=reward, terminal=terminal, reset=done, ) episodes = buffer.get_episodes() batch = EpisodicTrainingBatch(episodes, "cuda") print(batch[0].reward) batch[0].reward += 100 print(batch.flatten.reward) for episode in batch: for transition in reversed(episode): print(transition.reset) break
def example3(): def _make_env(*env_args, **env_kwargs): return ResetCostWrapper(make_env(*env_args, **env_kwargs)) env = vectorize_env(env_id="Hopper-v3", num_envs=3, seed=1, env_fn=_make_env) print(env) def actor(*arg, **kwargs): return env.action_space.sample() interactions = TransitionGenerator(env, actor, max_step=1005) for step, state, next_state, action, reward, done, info in interactions: print(step, is_state_terminal(env, step, done, info))
def train_sac(): parser = argparse.ArgumentParser() parser.add_argument("--env_id", default="HalfCheetah-v3", type=str) parser.add_argument("--seed", default=None, type=int) parser.add_argument("--num_envs", default=1, type=int) parser.add_argument("--max_step", default=10 ** 6, type=int) parser.add_argument("--eval_interval", type=int, default=10 ** 4) parser.add_argument("--num_evaluate", type=int, default=10) parser.add_argument("--logging_interval", type=int, default=10 ** 3) parser.add_argument("--gamma", default=0.99, type=float) parser.add_argument("--replay_start_size", default=10 ** 4, type=int) parser.add_argument("--num_videos", type=int, default=3) parser.add_argument("--log_level", type=int, default=logging.INFO) args = parser.parse_args() wandb.init(project="rl_algos_example", name="soft_actor_critic", tags=[args.env_id]) wandb.config.update(args) logging.basicConfig(level=args.log_level) logger = logging.getLogger(__name__) # fix seed manual_seed(args.seed) # make environment env = vectorize_env(env_id=args.env_id, num_envs=args.num_envs, seed=args.seed) dim_state = env.observation_space.shape[-1] dim_action = env.action_space.shape[-1] logger.info(f"env = {env}") logger.info(f"dim_state = {dim_state}") logger.info(f"dim_action = {dim_action}") logger.info(f"action_space = {env.action_space}") logger.info(f"max_episode_steps = {env.spec.max_episode_steps}") # make agent agent = SAC( dim_state=dim_state, dim_action=dim_action, gamma=args.gamma, replay_start_size=args.replay_start_size, ) evaluator = Evaluator( env=make_env(args.env_id, args.seed), eval_interval=args.eval_interval, num_evaluate=args.num_evaluate, ) recoder = ( Recoder( env=make_env(args.env_id, args.seed), record_interval=args.max_step // args.num_videos, ) if args.num_videos > 0 else None ) agent = training( env=env, agent=agent, max_steps=args.max_step, logging_interval=args.logging_interval, recorder=recoder, evaluator=evaluator, )
def train_td3(): parser = argparse.ArgumentParser() parser.add_argument("--env_id", type=str, default="HalfCheetah-v3") parser.add_argument("--seed", type=int, default=None) parser.add_argument("--num_envs", type=int, default=1) parser.add_argument("--gamma", type=float, default=0.99) parser.add_argument("--lr", type=float, default=3e-4) parser.add_argument("--batch_size", type=int, default=256) parser.add_argument("--policy_update_delay", type=int, default=2) parser.add_argument("--max_step", type=int, default=10**6) parser.add_argument("--eval_interval", type=int, default=10**4) parser.add_argument("--logging_interval", type=int, default=10**3) parser.add_argument("--num_evaluate", type=int, default=10) parser.add_argument("--num_videos", type=int, default=3) parser.add_argument("--log_level", type=int, default=logging.INFO) args = parser.parse_args() wandb.init(project="td3", tags=["td3", args.env_id], config=args) wandb.config.update(args) logging.basicConfig(level=args.log_level) logger = logging.getLogger(__name__) manual_seed(args.seed) env = vectorize_env(env_id=args.env_id, num_envs=args.num_envs, seed=args.seed) dim_state = env.observation_space.shape[-1] dim_action = env.action_space.shape[-1] logger.info(f"env = {env}") logger.info(f"dim_state = {dim_state}") logger.info(f"dim_action = {dim_action}") logger.info(f"action_space = {env.action_space}") logger.info(f"max_episode_steps = {env.spec.max_episode_steps}") agent = TD3( dim_state=dim_state, dim_action=dim_action, gamma=args.gamma, batch_size=args.batch_size, policy_update_delay=args.policy_update_delay, policy_optimizer_kwargs={"lr": args.lr}, q_optimizer_kwargs={"lr": args.lr}, ) evaluator = Evaluator( env=make_env(args.env_id, args.seed), eval_interval=args.eval_interval, num_evaluate=args.num_evaluate, ) recoder = (Recoder( env=make_env(args.env_id, args.seed), record_interval=args.max_step // args.num_videos, ) if args.num_videos > 0 else None) agent = training( env=env, agent=agent, max_steps=args.max_step, logging_interval=args.logging_interval, recorder=recoder, evaluator=evaluator, )
def train_trpo(): parser = argparse.ArgumentParser() parser.add_argument("--env_id", type=str, default="Hopper-v3") parser.add_argument("--seed", type=int, default=None) parser.add_argument("--num_envs", type=int, default=5) parser.add_argument("--update_interval", type=int, default=None) parser.add_argument("--gamma", type=float, default=0.995) parser.add_argument("--lambd", type=float, default=0.97) parser.add_argument("--lr", type=float, default=1e-3) parser.add_argument("--vf_epoch", type=int, default=5) parser.add_argument("--vf_batch_size", type=int, default=64) parser.add_argument("--conjugate_gradient_damping", type=float, default=1e-1) parser.add_argument("--use_state_normalizer", action="store_true") parser.add_argument("--max_step", type=int, default=10**6) parser.add_argument("--eval_interval", type=int, default=5 * 10**4) parser.add_argument("--num_evaluate", type=int, default=10) parser.add_argument("--num_videos", type=int, default=3) parser.add_argument("--log_level", type=int, default=logging.INFO) args = parser.parse_args() wandb.init(project="trpo", tags=["trpo", args.env_id], config=args) logging.basicConfig(level=args.log_level) logger = logging.getLogger(__name__) manual_seed(args.seed) env = vectorize_env(env_id=args.env_id, num_envs=args.num_envs, seed=args.seed) dim_state = env.observation_space.shape[-1] dim_action = env.action_space.shape[-1] logger.info(f"env = {env}") logger.info(f"dim_state = {dim_state}") logger.info(f"dim_action = {dim_action}") logger.info(f"action_space = {env.action_space}") logger.info(f"max_episode_steps = {env.spec.max_episode_steps}") agent = TRPO( dim_state, dim_action, gamma=args.gamma, lambd=args.lambd, entropy_coef=0.0, vf_epoch=args.vf_epoch, vf_batch_size=args.vf_batch_size, vf_optimizer_kwargs={"lr": args.lr}, state_normalizer=ZScoreFilter(dim_state) if args.use_state_normalizer else None, conjugate_gradient_damping=args.conjugate_gradient_damping, update_interval=args.num_envs * 1000 if args.update_interval is None else args.update_interval, ) evaluator = Evaluator( env=make_env(args.env_id, args.seed), eval_interval=args.eval_interval, num_evaluate=args.num_evaluate, ) recoder = (Recoder( env=make_env(args.env_id, args.seed), record_interval=args.max_step // args.num_videos, ) if args.num_videos > 0 else None) agent = training( env=env, agent=agent, max_steps=args.max_step, recorder=recoder, evaluator=evaluator, )