def test_record_episode_statistics_reset_info():
    env = gym.make("CartPole-v1")
    env = RecordEpisodeStatistics(env)
    ob_space = env.observation_space
    obs = env.reset()
    assert ob_space.contains(obs)
    del obs
    obs, info = env.reset(return_info=True)
    assert ob_space.contains(obs)
    assert isinstance(info, dict)
def test_record_episode_statistics_with_vectorenv(num_envs):
    envs = gym.vector.make("CartPole-v0",
                           num_envs=num_envs,
                           asynchronous=False)
    envs = RecordEpisodeStatistics(envs)
    envs.reset()
    for _ in range(envs.env.envs[0].spec.max_episode_steps + 1):
        _, _, dones, infos = envs.step(envs.action_space.sample())
        for idx, info in enumerate(infos):
            if dones[idx]:
                assert "episode" in info
                assert all(
                    [item in info["episode"] for item in ["r", "l", "t"]])
                break
Exemple #3
0
def main():
    env_name = "Taxi-v3"
    state_units = 16
    hid_units = 8
    dirichlet_alpha = 0.25
    exploration_fraction = 0.25
    pb_c_base = 19652
    pb_c_init = 1.25
    discount = 0.99
    num_simulations = 100
    filename = "model_last.pth"

    device = get_device(True)

    env = gym.make(env_name)
    env = RecordEpisodeStatistics(env)
    env = TaxiObservationWrapper(env)

    network = Network(env.observation_space.nvec.sum(), env.action_space.n,
                      state_units, hid_units)
    mcts = MCTS(dirichlet_alpha, exploration_fraction, pb_c_base, pb_c_init,
                discount, num_simulations)
    agent = Agent(network, mcts)
    trainer = Trainer()

    if os.path.exists(filename):
        agent.load_model(filename, device)
        # print(network.state_dict())

    trainer.validate(env, agent, network)
Exemple #4
0
def test_record_episode_statistics(env_id, deque_size):
    env = gym.make(env_id)
    env = RecordEpisodeStatistics(env, deque_size)

    for n in range(5):
        env.reset()
        assert env.episode_return == 0.0
        assert env.episode_length == 0
        for t in range(env.spec.max_episode_steps):
            _, _, done, info = env.step(env.action_space.sample())
            if done:
                assert "episode" in info
                assert all([item in info["episode"] for item in ["r", "l", "t"]])
                break
    assert len(env.return_queue) == deque_size
    assert len(env.length_queue) == deque_size
Exemple #5
0
def test_record_episode_statistics_with_vectorenv(num_envs, asynchronous):
    envs = gym.vector.make("CartPole-v1",
                           render_mode=None,
                           num_envs=num_envs,
                           asynchronous=asynchronous)
    envs = RecordEpisodeStatistics(envs)
    max_episode_step = (envs.env_fns[0]().spec.max_episode_steps
                        if asynchronous else
                        envs.env.envs[0].spec.max_episode_steps)
    envs.reset()
    for _ in range(max_episode_step + 1):
        _, _, dones, infos = envs.step(envs.action_space.sample())
        if any(dones):
            assert "episode" in infos
            assert "_episode" in infos
            assert all(infos["_episode"] == dones)
            assert all([item in infos["episode"] for item in ["r", "l", "t"]])
            break
        else:
            assert "episode" not in infos
            assert "_episode" not in infos
Exemple #6
0
def test_wrong_wrapping_order():
    envs = gym.vector.make("CartPole-v1", num_envs=3)
    wrapped_env = RecordEpisodeStatistics(VectorListInfo(envs))
    wrapped_env.reset()

    with pytest.raises(AssertionError):
        wrapped_env.step(wrapped_env.action_space.sample())
Exemple #7
0
def test_info_to_list_statistics():
    env_to_wrap = gym.vector.make(ENV_ID, num_envs=NUM_ENVS)
    wrapped_env = VectorListInfo(RecordEpisodeStatistics(env_to_wrap))
    _, info = wrapped_env.reset(seed=SEED, return_info=True)
    wrapped_env.action_space.seed(SEED)
    assert isinstance(info, list)
    assert len(info) == NUM_ENVS

    for _ in range(ENV_STEPS):
        action = wrapped_env.action_space.sample()
        _, _, dones, list_info = wrapped_env.step(action)
        for i, done in enumerate(dones):
            if done:
                assert "episode" in list_info[i]
                for stats in ["r", "l", "t"]:
                    assert stats in list_info[i]["episode"]
                    assert isinstance(list_info[i]["episode"][stats], float)
            else:
                assert "episode" not in list_info[i]
Exemple #8
0
def main():
    seed = 1
    env_name = "Taxi-v3"
    state_units = 16
    hid_units = 8
    dirichlet_alpha = 0.25
    exploration_fraction = 0.25
    pb_c_base = 19652
    pb_c_init = 1.25
    discount = 0.99
    num_simulations = 100
    window_size = 100
    nb_self_play = 5
    num_unroll_steps = 5
    td_steps = 200
    batch_size = 64
    lr = 1e-4
    nb_train_update = 20
    nb_train_epochs = 10000
    max_grad_norm = 0.5
    filename = "model_last.pth"
    ent_c = 0.2

    device = get_device(True)

    env = gym.make(env_name)
    env = RecordEpisodeStatistics(env)
    env = TaxiObservationWrapper(env)

    np.random.seed(seed)
    torch.manual_seed(seed)
    env.seed(seed)
    np.set_printoptions(formatter={"float": "{: 0.3f}".format})

    network = Network(env.observation_space.nvec.sum(), env.action_space.n, state_units, hid_units)
    mcts = MCTS(dirichlet_alpha, exploration_fraction, pb_c_base, pb_c_init, discount, num_simulations)
    agent = Agent(network, mcts)
    trainer = Trainer()
    optimizer = Ralamb(network.parameters(), lr=lr)

    if os.path.exists(filename):
        agent.load_model(filename, device)

    print("Train start")
    try:
        trainer.train(
            env,
            agent,
            network,
            optimizer,
            window_size,
            nb_self_play,
            num_unroll_steps,
            td_steps,
            discount,
            batch_size,
            nb_train_update,
            nb_train_epochs,
            max_grad_norm,
            filename,
            ent_c,
        )
    except KeyboardInterrupt:
        print("Keyboard interrupt")
    print("Train complete")

    agent.save_model(filename)
Exemple #9
0
def main(cfg):
    random.seed(cfg.exp.seed)
    np.random.seed(cfg.exp.seed)
    torch.manual_seed(cfg.exp.seed)
    torch.backends.cudnn.deterministic = cfg.exp.torch_deterministic

    # so that the environment automatically resets
    env = SyncVectorEnv([
        lambda: RecordEpisodeStatistics(gym.make('CartPole-v1'))
    ])

    actor, critic = Actor(), Critic()
    actor_optim = Adam(actor.parameters(), eps=1e-5, lr=cfg.params.actor_lr)
    critic_optim = Adam(critic.parameters(), eps=1e-5, lr=cfg.params.critic_lr)
    memory = Memory(mini_batch_size=cfg.params.mini_batch_size, batch_size=cfg.params.batch_size)
    obs = env.reset()
    global_rewards = []

    NUM_UPDATES = (cfg.params.total_timesteps // cfg.params.batch_size) * cfg.params.epochs
    cur_timestep = 0

    def calc_factor(cur_timestep: int) -> float:
        """Calculates the factor to be multiplied with the learning rate to update it."""
        update_number = cur_timestep // cfg.params.batch_size
        total_updates = cfg.params.total_timesteps // cfg.params.batch_size
        fraction = 1.0 - update_number / total_updates
        return fraction

    actor_scheduler = LambdaLR(actor_optim, lr_lambda=calc_factor, verbose=True)
    critic_scheduler = LambdaLR(critic_optim, lr_lambda=calc_factor, verbose=True)

    while cur_timestep < cfg.params.total_timesteps:
        # keep playing the game
        obs = torch.as_tensor(obs, dtype=torch.float32)
        with torch.no_grad():
            dist = actor(obs)
            action = dist.sample()
            log_prob = dist.log_prob(action)
            value = critic(obs)
        action = action.cpu().numpy()
        value = value.cpu().numpy()
        log_prob = log_prob.cpu().numpy()
        obs_, reward, done, info = env.step(action)
        
        if done[0]:
            tqdm.write(f'Reward: {info[0]["episode"]["r"]}, Avg Reward: {np.mean(global_rewards[-10:]):.3f}')
            global_rewards.append(info[0]['episode']['r'])
            wandb.log({'Avg_Reward': np.mean(global_rewards[-10:]), 'Reward': info[0]['episode']['r']})

        memory.remember(obs.squeeze(0).cpu().numpy(), action.item(), log_prob.item(), reward.item(), done.item(), value.item())
        obs = obs_
        cur_timestep += 1

        # if the current timestep is a multiple of the batch size, then we need to update the model
        if cur_timestep % cfg.params.batch_size == 0:
            for epoch in tqdm(range(cfg.params.epochs), desc=f'Num updates: {cfg.params.epochs * (cur_timestep // cfg.params.batch_size)} / {NUM_UPDATES}'):
                # sample a batch from memory of experiences
                old_states, old_actions, old_log_probs, old_rewards, old_dones, old_values, batch_indices = memory.sample()
                old_log_probs = torch.tensor(old_log_probs, dtype=torch.float32)
                old_actions = torch.tensor(old_actions, dtype=torch.float32)
                advantage = calculate_advantage(old_rewards, old_values, old_dones, gae_gamma=cfg.params.gae_gamma, gae_lambda=cfg.params.gae_lambda)
                
                advantage = torch.tensor(advantage, dtype=torch.float32)
                old_rewards = torch.tensor(old_rewards, dtype=torch.float32)
                old_values = torch.tensor(old_values, dtype=torch.float32)

                # for each mini batch from batch, calculate advantage using GAE
                for mini_batch_index in batch_indices:
                    # remember: Normalization of advantage is done on mini batch, not the entire batch
                    advantage[mini_batch_index] = (advantage[mini_batch_index] - advantage[mini_batch_index].mean()) / (advantage[mini_batch_index].std() + 1e-8)

                    dist = actor(torch.tensor(old_states[mini_batch_index], dtype=torch.float32).unsqueeze(0))
                    # actions = dist.sample()
                    log_probs = dist.log_prob(old_actions[mini_batch_index]).squeeze(0)
                    entropy = dist.entropy().squeeze(0)

                    log_ratio = log_probs - old_log_probs[mini_batch_index]
                    ratio = torch.exp(log_ratio)

                    with torch.no_grad():
                        # approx_kl = ((ratio-1)-log_ratio).mean()
                        approx_kl = ((old_log_probs[mini_batch_index] - log_probs)**2).mean()
                        wandb.log({'Approx_KL': approx_kl})

                    actor_loss = -torch.min(
                        ratio * advantage[mini_batch_index],
                        torch.clamp(ratio, 1 - cfg.params.actor_loss_clip, 1 + cfg.params.actor_loss_clip) * advantage[mini_batch_index]
                    ).mean()

                    values = critic(torch.tensor(old_states[mini_batch_index], dtype=torch.float32).unsqueeze(0)).squeeze(-1)
                    returns = old_values[mini_batch_index] + advantage[mini_batch_index]

                    critic_loss = torch.max(
                        (values - returns)**2,
                        (old_values[mini_batch_index] + torch.clamp(
                            values - old_values[mini_batch_index], -cfg.params.critic_loss_clip, cfg.params.critic_loss_clip
                            ) - returns
                        )**2
                    ).mean()
                    # critic_loss = F.mse_loss(values, returns)

                    wandb.log({'Actor_Loss': actor_loss.item(), 'Critic_Loss': critic_loss.item(), 'Entropy': entropy.mean().item()})
                    loss = actor_loss + 0.25 * critic_loss - 0.01 * entropy.mean()
                    actor_optim.zero_grad()
                    critic_optim.zero_grad()
                    loss.backward()
                    nn.utils.clip_grad_norm_(actor.parameters(), cfg.params.max_grad_norm)
                    nn.utils.clip_grad_norm_(critic.parameters(), cfg.params.max_grad_norm)

                    actor_optim.step()
                    critic_optim.step()

            memory.reset()
            actor_scheduler.step(cur_timestep)
            critic_scheduler.step(cur_timestep)

            y_pred, y_true = old_values.cpu().numpy(), (old_values + advantage).cpu().numpy()
            var_y = np.var(y_true)
            explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y
            wandb.log({'Explained_Var': explained_var})

    if cfg.exp.save_weights:
        torch.save(actor.state_dict(), Path(f'{hydra.utils.get_original_cwd()}/{cfg.exp.model_dir}/actor.pth'))
        torch.save(critic.state_dict(), Path(f'{hydra.utils.get_original_cwd()}/{cfg.exp.model_dir}/critic.pth'))