Exemple #1
0
def example_6_6():
    fig, ax = plt.subplots()
    fig.suptitle(f'Example 6.6 (Averaged over {EX_6_6_N_SEEDS} seeds)')
    ax.set_xlabel('Episodes')
    ax.set_ylabel(
        f'(Average of last {EX_6_6_N_AVG}) sum of rewards during episodes')
    ax.set_yticks(EX_6_6_YTICKS)
    ax.set_ylim(bottom=min(EX_6_6_YTICKS))
    n_ep = EX_6_6_N_EPS
    env = TheCliff()
    qlearning_alg = QLearning(env,
                              step_size=EX_6_5_STEP_SIZE,
                              gamma=UNDISCOUNTED,
                              eps=EX_6_5_EPS)
    sarsa_alg = Sarsa(env,
                      step_size=EX_6_5_STEP_SIZE,
                      gamma=UNDISCOUNTED,
                      eps=EX_6_5_EPS)
    qlearning_rew = np.zeros(n_ep)
    sarsa_rew = np.zeros(n_ep)
    for seed in range(EX_6_6_N_SEEDS):
        print(f"seed={seed}")
        qlearning_alg.seed(seed)
        qlearning_rew += qlearning_alg.q_learning(n_ep)
        sarsa_alg.seed(seed)
        sarsa_rew += sarsa_alg.on_policy_td_control(n_ep, rews=True)
    plt.plot(smooth_rewards(qlearning_rew / EX_6_6_N_SEEDS, EX_6_6_N_AVG),
             color='r',
             label='Q learning')
    plt.plot(smooth_rewards(sarsa_rew / EX_6_6_N_SEEDS, EX_6_6_N_AVG),
             color='b',
             label='Sarsa')
    plt.legend()
    plt.savefig('example6.6.png')
    plt.show()
Exemple #2
0
def plot_sarsa(ax,
               n_ep,
               label=None,
               diags=False,
               stay=False,
               stoch=False,
               seed=0):
    env = WindyGridworld(diags, stay, stoch)
    alg = Sarsa(env,
                step_size=EX_6_5_STEP_SIZE,
                gamma=UNDISCOUNTED,
                eps=EX_6_5_EPS)
    alg.seed(seed)
    kwargs = {"label": label} if label else {}
    plt.plot(alg.on_policy_td_control(n_ep), **kwargs)