Ejemplo n.º 1
0
def example2():
    env = ResetCostWrapper(make_env("Hopper-v3"))
    print(env)

    def actor(*arg, **kwargs):
        return env.action_space.sample()

    interactions = TransitionGenerator(env, actor, max_step=1005)

    for step, state, next_state, action, reward, done, info in interactions:
        print(step, done, state[0][0], next_state[0][0], reward,
              info["is_terminal_state"])
Ejemplo n.º 2
0
def example1():
    env = make_env("Hopper-v3")
    env = ResetCostWrapper(env, reset_cost=float("nan"), terminal_step=300)
    print(env)

    step = 0
    env.reset()
    while True:
        action = env.action_space.sample()
        next_state, reward, done, info = env.step(action)
        print(info)
        step += 1
        if done:
            break

    print(step)
Ejemplo n.º 3
0
def train_sac():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env_id", default="HalfCheetah-v3", type=str)
    parser.add_argument("--seed", default=None, type=int)
    parser.add_argument("--num_envs", default=1, type=int)
    parser.add_argument("--max_step", default=10 ** 6, type=int)
    parser.add_argument("--eval_interval", type=int, default=10 ** 4)
    parser.add_argument("--num_evaluate", type=int, default=10)
    parser.add_argument("--logging_interval", type=int, default=10 ** 3)
    parser.add_argument("--gamma", default=0.99, type=float)
    parser.add_argument("--replay_start_size", default=10 ** 4, type=int)
    parser.add_argument("--num_videos", type=int, default=3)
    parser.add_argument("--log_level", type=int, default=logging.INFO)
    args = parser.parse_args()

    wandb.init(project="rl_algos_example", name="soft_actor_critic", tags=[args.env_id])

    wandb.config.update(args)

    logging.basicConfig(level=args.log_level)
    logger = logging.getLogger(__name__)

    # fix seed
    manual_seed(args.seed)

    # make environment
    env = vectorize_env(env_id=args.env_id, num_envs=args.num_envs, seed=args.seed)

    dim_state = env.observation_space.shape[-1]
    dim_action = env.action_space.shape[-1]

    logger.info(f"env = {env}")
    logger.info(f"dim_state = {dim_state}")
    logger.info(f"dim_action = {dim_action}")
    logger.info(f"action_space = {env.action_space}")
    logger.info(f"max_episode_steps = {env.spec.max_episode_steps}")

    # make agent
    agent = SAC(
        dim_state=dim_state,
        dim_action=dim_action,
        gamma=args.gamma,
        replay_start_size=args.replay_start_size,
    )

    evaluator = Evaluator(
        env=make_env(args.env_id, args.seed),
        eval_interval=args.eval_interval,
        num_evaluate=args.num_evaluate,
    )
    recoder = (
        Recoder(
            env=make_env(args.env_id, args.seed),
            record_interval=args.max_step // args.num_videos,
        )
        if args.num_videos > 0
        else None
    )

    agent = training(
        env=env,
        agent=agent,
        max_steps=args.max_step,
        logging_interval=args.logging_interval,
        recorder=recoder,
        evaluator=evaluator,
    )
Ejemplo n.º 4
0
def train_td3():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env_id", type=str, default="HalfCheetah-v3")
    parser.add_argument("--seed", type=int, default=None)
    parser.add_argument("--num_envs", type=int, default=1)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--lr", type=float, default=3e-4)
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--policy_update_delay", type=int, default=2)
    parser.add_argument("--max_step", type=int, default=10**6)
    parser.add_argument("--eval_interval", type=int, default=10**4)
    parser.add_argument("--logging_interval", type=int, default=10**3)
    parser.add_argument("--num_evaluate", type=int, default=10)
    parser.add_argument("--num_videos", type=int, default=3)
    parser.add_argument("--log_level", type=int, default=logging.INFO)
    args = parser.parse_args()

    wandb.init(project="td3", tags=["td3", args.env_id], config=args)

    wandb.config.update(args)

    logging.basicConfig(level=args.log_level)
    logger = logging.getLogger(__name__)

    manual_seed(args.seed)

    env = vectorize_env(env_id=args.env_id,
                        num_envs=args.num_envs,
                        seed=args.seed)
    dim_state = env.observation_space.shape[-1]
    dim_action = env.action_space.shape[-1]

    logger.info(f"env = {env}")
    logger.info(f"dim_state = {dim_state}")
    logger.info(f"dim_action = {dim_action}")
    logger.info(f"action_space = {env.action_space}")
    logger.info(f"max_episode_steps = {env.spec.max_episode_steps}")

    agent = TD3(
        dim_state=dim_state,
        dim_action=dim_action,
        gamma=args.gamma,
        batch_size=args.batch_size,
        policy_update_delay=args.policy_update_delay,
        policy_optimizer_kwargs={"lr": args.lr},
        q_optimizer_kwargs={"lr": args.lr},
    )

    evaluator = Evaluator(
        env=make_env(args.env_id, args.seed),
        eval_interval=args.eval_interval,
        num_evaluate=args.num_evaluate,
    )

    recoder = (Recoder(
        env=make_env(args.env_id, args.seed),
        record_interval=args.max_step // args.num_videos,
    ) if args.num_videos > 0 else None)

    agent = training(
        env=env,
        agent=agent,
        max_steps=args.max_step,
        logging_interval=args.logging_interval,
        recorder=recoder,
        evaluator=evaluator,
    )
Ejemplo n.º 5
0
def train_trpo():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env_id", type=str, default="Hopper-v3")
    parser.add_argument("--seed", type=int, default=None)
    parser.add_argument("--num_envs", type=int, default=5)
    parser.add_argument("--update_interval", type=int, default=None)
    parser.add_argument("--gamma", type=float, default=0.995)
    parser.add_argument("--lambd", type=float, default=0.97)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--vf_epoch", type=int, default=5)
    parser.add_argument("--vf_batch_size", type=int, default=64)
    parser.add_argument("--conjugate_gradient_damping",
                        type=float,
                        default=1e-1)
    parser.add_argument("--use_state_normalizer", action="store_true")
    parser.add_argument("--max_step", type=int, default=10**6)
    parser.add_argument("--eval_interval", type=int, default=5 * 10**4)
    parser.add_argument("--num_evaluate", type=int, default=10)
    parser.add_argument("--num_videos", type=int, default=3)
    parser.add_argument("--log_level", type=int, default=logging.INFO)
    args = parser.parse_args()

    wandb.init(project="trpo", tags=["trpo", args.env_id], config=args)

    logging.basicConfig(level=args.log_level)
    logger = logging.getLogger(__name__)

    manual_seed(args.seed)

    env = vectorize_env(env_id=args.env_id,
                        num_envs=args.num_envs,
                        seed=args.seed)
    dim_state = env.observation_space.shape[-1]
    dim_action = env.action_space.shape[-1]

    logger.info(f"env = {env}")
    logger.info(f"dim_state = {dim_state}")
    logger.info(f"dim_action = {dim_action}")
    logger.info(f"action_space = {env.action_space}")
    logger.info(f"max_episode_steps = {env.spec.max_episode_steps}")

    agent = TRPO(
        dim_state,
        dim_action,
        gamma=args.gamma,
        lambd=args.lambd,
        entropy_coef=0.0,
        vf_epoch=args.vf_epoch,
        vf_batch_size=args.vf_batch_size,
        vf_optimizer_kwargs={"lr": args.lr},
        state_normalizer=ZScoreFilter(dim_state)
        if args.use_state_normalizer else None,
        conjugate_gradient_damping=args.conjugate_gradient_damping,
        update_interval=args.num_envs *
        1000 if args.update_interval is None else args.update_interval,
    )

    evaluator = Evaluator(
        env=make_env(args.env_id, args.seed),
        eval_interval=args.eval_interval,
        num_evaluate=args.num_evaluate,
    )

    recoder = (Recoder(
        env=make_env(args.env_id, args.seed),
        record_interval=args.max_step // args.num_videos,
    ) if args.num_videos > 0 else None)

    agent = training(
        env=env,
        agent=agent,
        max_steps=args.max_step,
        recorder=recoder,
        evaluator=evaluator,
    )
Ejemplo n.º 6
0
 def _make_env(*env_args, **env_kwargs):
     return ResetCostWrapper(make_env(*env_args, **env_kwargs))
Ejemplo n.º 7
0
 def _make_env(*args, **kwargs):
     return ResetCostWrapper(make_env(*args, **kwargs))