def example_6_6(): fig, ax = plt.subplots() fig.suptitle(f'Example 6.6 (Averaged over {EX_6_6_N_SEEDS} seeds)') ax.set_xlabel('Episodes') ax.set_ylabel( f'(Average of last {EX_6_6_N_AVG}) sum of rewards during episodes') ax.set_yticks(EX_6_6_YTICKS) ax.set_ylim(bottom=min(EX_6_6_YTICKS)) n_ep = EX_6_6_N_EPS env = TheCliff() qlearning_alg = QLearning(env, step_size=EX_6_5_STEP_SIZE, gamma=UNDISCOUNTED, eps=EX_6_5_EPS) sarsa_alg = Sarsa(env, step_size=EX_6_5_STEP_SIZE, gamma=UNDISCOUNTED, eps=EX_6_5_EPS) qlearning_rew = np.zeros(n_ep) sarsa_rew = np.zeros(n_ep) for seed in range(EX_6_6_N_SEEDS): print(f"seed={seed}") qlearning_alg.seed(seed) qlearning_rew += qlearning_alg.q_learning(n_ep) sarsa_alg.seed(seed) sarsa_rew += sarsa_alg.on_policy_td_control(n_ep, rews=True) plt.plot(smooth_rewards(qlearning_rew / EX_6_6_N_SEEDS, EX_6_6_N_AVG), color='r', label='Q learning') plt.plot(smooth_rewards(sarsa_rew / EX_6_6_N_SEEDS, EX_6_6_N_AVG), color='b', label='Sarsa') plt.legend() plt.savefig('example6.6.png') plt.show()
def plot_sarsa(ax, n_ep, label=None, diags=False, stay=False, stoch=False, seed=0): env = WindyGridworld(diags, stay, stoch) alg = Sarsa(env, step_size=EX_6_5_STEP_SIZE, gamma=UNDISCOUNTED, eps=EX_6_5_EPS) alg.seed(seed) kwargs = {"label": label} if label else {} plt.plot(alg.on_policy_td_control(n_ep), **kwargs)