Пример #1
0
def example_replay_buffer():
    env = vectorize_env("Swimmer-v3", 1)

    def actor(state):
        return env.action_space.sample()

    interactions = TransitionGenerator(env, actor, max_step=1000)

    buffer = ReplayBuffer(10 ** 4)
    for steps, states, next_states, actions, rewards, dones, info in interactions:
        terminals = is_state_terminal(env, steps, dones)
        for id, (state, next_state, action, reward, terminal, done) in enumerate(
            zip(states, next_states, actions, rewards, terminals, dones)
        ):
            buffer.append(
                id=id,
                state=state,
                next_state=next_state,
                action=action,
                reward=reward,
                terminal=terminal,
                reset=done,
            )
    s = buffer.sample(10)
    batch = TrainingBatch(**s, device="cuda")
    print(batch.reward)
    print(batch[0:5].reward)
    batch[0:5].reward += 100.0
    print(batch.reward)
Пример #2
0
def example_episode_buffer():
    env = vectorize_env(env_id="Hopper-v3", num_envs=3)

    def actor(state):
        return env.action_space.sample()

    interactions = TransitionGenerator(env, actor, max_step=300)

    buffer = EpisodeBuffer()

    for steps, states, next_states, actions, rewards, dones, info in interactions:
        terminals = is_state_terminal(env, steps, dones)
        for id, (state, next_state, action, reward, terminal, done) in enumerate(
            zip(states, next_states, actions, rewards, terminals, dones)
        ):
            buffer.append(
                id=id,
                state=state,
                next_state=next_state,
                action=action,
                reward=reward,
                terminal=terminal,
                reset=done,
            )

    episodes = buffer.get_episodes()
    batch = EpisodicTrainingBatch(episodes, "cuda")

    print(batch[0].reward)
    batch[0].reward += 100
    print(batch.flatten.reward)
    for episode in batch:
        for transition in reversed(episode):
            print(transition.reset)
        break
Пример #3
0
def example3():
    def _make_env(*env_args, **env_kwargs):
        return ResetCostWrapper(make_env(*env_args, **env_kwargs))

    env = vectorize_env(env_id="Hopper-v3",
                        num_envs=3,
                        seed=1,
                        env_fn=_make_env)
    print(env)

    def actor(*arg, **kwargs):
        return env.action_space.sample()

    interactions = TransitionGenerator(env, actor, max_step=1005)

    for step, state, next_state, action, reward, done, info in interactions:
        print(step, is_state_terminal(env, step, done, info))
Пример #4
0
def train_sac():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env_id", default="HalfCheetah-v3", type=str)
    parser.add_argument("--seed", default=None, type=int)
    parser.add_argument("--num_envs", default=1, type=int)
    parser.add_argument("--max_step", default=10 ** 6, type=int)
    parser.add_argument("--eval_interval", type=int, default=10 ** 4)
    parser.add_argument("--num_evaluate", type=int, default=10)
    parser.add_argument("--logging_interval", type=int, default=10 ** 3)
    parser.add_argument("--gamma", default=0.99, type=float)
    parser.add_argument("--replay_start_size", default=10 ** 4, type=int)
    parser.add_argument("--num_videos", type=int, default=3)
    parser.add_argument("--log_level", type=int, default=logging.INFO)
    args = parser.parse_args()

    wandb.init(project="rl_algos_example", name="soft_actor_critic", tags=[args.env_id])

    wandb.config.update(args)

    logging.basicConfig(level=args.log_level)
    logger = logging.getLogger(__name__)

    # fix seed
    manual_seed(args.seed)

    # make environment
    env = vectorize_env(env_id=args.env_id, num_envs=args.num_envs, seed=args.seed)

    dim_state = env.observation_space.shape[-1]
    dim_action = env.action_space.shape[-1]

    logger.info(f"env = {env}")
    logger.info(f"dim_state = {dim_state}")
    logger.info(f"dim_action = {dim_action}")
    logger.info(f"action_space = {env.action_space}")
    logger.info(f"max_episode_steps = {env.spec.max_episode_steps}")

    # make agent
    agent = SAC(
        dim_state=dim_state,
        dim_action=dim_action,
        gamma=args.gamma,
        replay_start_size=args.replay_start_size,
    )

    evaluator = Evaluator(
        env=make_env(args.env_id, args.seed),
        eval_interval=args.eval_interval,
        num_evaluate=args.num_evaluate,
    )
    recoder = (
        Recoder(
            env=make_env(args.env_id, args.seed),
            record_interval=args.max_step // args.num_videos,
        )
        if args.num_videos > 0
        else None
    )

    agent = training(
        env=env,
        agent=agent,
        max_steps=args.max_step,
        logging_interval=args.logging_interval,
        recorder=recoder,
        evaluator=evaluator,
    )
Пример #5
0
def train_td3():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env_id", type=str, default="HalfCheetah-v3")
    parser.add_argument("--seed", type=int, default=None)
    parser.add_argument("--num_envs", type=int, default=1)
    parser.add_argument("--gamma", type=float, default=0.99)
    parser.add_argument("--lr", type=float, default=3e-4)
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--policy_update_delay", type=int, default=2)
    parser.add_argument("--max_step", type=int, default=10**6)
    parser.add_argument("--eval_interval", type=int, default=10**4)
    parser.add_argument("--logging_interval", type=int, default=10**3)
    parser.add_argument("--num_evaluate", type=int, default=10)
    parser.add_argument("--num_videos", type=int, default=3)
    parser.add_argument("--log_level", type=int, default=logging.INFO)
    args = parser.parse_args()

    wandb.init(project="td3", tags=["td3", args.env_id], config=args)

    wandb.config.update(args)

    logging.basicConfig(level=args.log_level)
    logger = logging.getLogger(__name__)

    manual_seed(args.seed)

    env = vectorize_env(env_id=args.env_id,
                        num_envs=args.num_envs,
                        seed=args.seed)
    dim_state = env.observation_space.shape[-1]
    dim_action = env.action_space.shape[-1]

    logger.info(f"env = {env}")
    logger.info(f"dim_state = {dim_state}")
    logger.info(f"dim_action = {dim_action}")
    logger.info(f"action_space = {env.action_space}")
    logger.info(f"max_episode_steps = {env.spec.max_episode_steps}")

    agent = TD3(
        dim_state=dim_state,
        dim_action=dim_action,
        gamma=args.gamma,
        batch_size=args.batch_size,
        policy_update_delay=args.policy_update_delay,
        policy_optimizer_kwargs={"lr": args.lr},
        q_optimizer_kwargs={"lr": args.lr},
    )

    evaluator = Evaluator(
        env=make_env(args.env_id, args.seed),
        eval_interval=args.eval_interval,
        num_evaluate=args.num_evaluate,
    )

    recoder = (Recoder(
        env=make_env(args.env_id, args.seed),
        record_interval=args.max_step // args.num_videos,
    ) if args.num_videos > 0 else None)

    agent = training(
        env=env,
        agent=agent,
        max_steps=args.max_step,
        logging_interval=args.logging_interval,
        recorder=recoder,
        evaluator=evaluator,
    )
Пример #6
0
def train_trpo():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env_id", type=str, default="Hopper-v3")
    parser.add_argument("--seed", type=int, default=None)
    parser.add_argument("--num_envs", type=int, default=5)
    parser.add_argument("--update_interval", type=int, default=None)
    parser.add_argument("--gamma", type=float, default=0.995)
    parser.add_argument("--lambd", type=float, default=0.97)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--vf_epoch", type=int, default=5)
    parser.add_argument("--vf_batch_size", type=int, default=64)
    parser.add_argument("--conjugate_gradient_damping",
                        type=float,
                        default=1e-1)
    parser.add_argument("--use_state_normalizer", action="store_true")
    parser.add_argument("--max_step", type=int, default=10**6)
    parser.add_argument("--eval_interval", type=int, default=5 * 10**4)
    parser.add_argument("--num_evaluate", type=int, default=10)
    parser.add_argument("--num_videos", type=int, default=3)
    parser.add_argument("--log_level", type=int, default=logging.INFO)
    args = parser.parse_args()

    wandb.init(project="trpo", tags=["trpo", args.env_id], config=args)

    logging.basicConfig(level=args.log_level)
    logger = logging.getLogger(__name__)

    manual_seed(args.seed)

    env = vectorize_env(env_id=args.env_id,
                        num_envs=args.num_envs,
                        seed=args.seed)
    dim_state = env.observation_space.shape[-1]
    dim_action = env.action_space.shape[-1]

    logger.info(f"env = {env}")
    logger.info(f"dim_state = {dim_state}")
    logger.info(f"dim_action = {dim_action}")
    logger.info(f"action_space = {env.action_space}")
    logger.info(f"max_episode_steps = {env.spec.max_episode_steps}")

    agent = TRPO(
        dim_state,
        dim_action,
        gamma=args.gamma,
        lambd=args.lambd,
        entropy_coef=0.0,
        vf_epoch=args.vf_epoch,
        vf_batch_size=args.vf_batch_size,
        vf_optimizer_kwargs={"lr": args.lr},
        state_normalizer=ZScoreFilter(dim_state)
        if args.use_state_normalizer else None,
        conjugate_gradient_damping=args.conjugate_gradient_damping,
        update_interval=args.num_envs *
        1000 if args.update_interval is None else args.update_interval,
    )

    evaluator = Evaluator(
        env=make_env(args.env_id, args.seed),
        eval_interval=args.eval_interval,
        num_evaluate=args.num_evaluate,
    )

    recoder = (Recoder(
        env=make_env(args.env_id, args.seed),
        record_interval=args.max_step // args.num_videos,
    ) if args.num_videos > 0 else None)

    agent = training(
        env=env,
        agent=agent,
        max_steps=args.max_step,
        recorder=recoder,
        evaluator=evaluator,
    )