Exemple #1
0
agent_0.run_episode((2, 0), Nplanning_loops=0)

sarsnL = learn_tracker_0.get_episode_sarsn_list(0)
agent_5.run_episode((2, 0), Nplanning_loops=5, iter_sarsn=iter(sarsnL))
agent_50.run_episode((2, 0), Nplanning_loops=50, iter_sarsn=iter(sarsnL))

# episodes 2 to 50
for i in range(49):
    print(i, end=' ')
    agent_0.run_episode((2, 0), Nplanning_loops=0)
    agent_5.run_episode((2, 0), Nplanning_loops=5)
    agent_50.run_episode((2, 0), Nplanning_loops=50)

fig, ax = plt.subplots()

step_0L = learn_tracker_0.steps_per_episode()[1:]
ax.plot(step_0L, 'c', label='0 planning steps')

step_5L = learn_tracker_5.steps_per_episode()[1:]
ax.plot(step_5L, 'g', label='5 planning steps')

step_50L = learn_tracker_50.steps_per_episode()[1:]
ax.plot(step_50L, 'r', label='50 planning steps')

ax.legend()
ax.set(title='Figure 8.2 Dyna Maze\n(common 1st episode)')
#ax.axhline(y=0, color='k')
#ax.axvline(x=0, color='k')
plt.ylabel('Steps per Episode')
plt.xlabel('Episodes')
plt.ylim(0, 800)
                          save_pickle_file='',
                          use_list_of_start_states=False, # use list OR single start state of environment.
                          do_summ_print=True, show_last_change=True, fmt_Q='%g', fmt_R='%g',
                          max_num_episodes=170, min_num_episodes=10, max_abserr=0.001, gamma=1.0,
                          iteration_prints=0,
                          max_episode_steps=10000,
                          epsilon=0.1, const_epsilon=True,
                          alpha=0.5, const_alpha=True)

print('_' * 55)
score = gridworld.get_policy_score(policy,
                                   start_state_hash=None,
                                   step_limit=1000)
print('Policy Score =', score, ' = (r_sum, n_steps, msg)')

steps_per_episodeL = learn_tracker.steps_per_episode()

print(gridworld.get_info())

episode = make_episode(gridworld.start_state_hash,
                       policy,
                       gridworld,
                       gridworld.terminal_set,
                       max_steps=20)

epi_summ_print(episode,
               policy,
               gridworld,
               show_rewards=False,
               show_env_states=True,
               none_str='*')
        policy, action_value = \
            sarsa_epsilon_greedy( maze_q, learn_tracker=learn_tracker_q,
                                  initial_Qsa=0.0, # init non-terminal_set of V(s) (terminal_set=0.0)
                                  read_pickle_file=read_pickle_file,
                                  save_pickle_file='blocking_sarsa',
                                  use_list_of_start_states=False, # use list OR single start state of environment.
                                  do_summ_print=False, show_last_change=False, fmt_Q='%g', fmt_R='%g',
                                  show_banner = False,
                                  max_num_episodes=1, min_num_episodes=11, max_abserr=0.001, gamma=GAMMA,
                                  iteration_prints=0,
                                  max_episode_steps=3000,
                                  epsilon=EPSILON,
                                  alpha=ALPHA)

        time_stamp = sum(learn_tracker_q.steps_per_episode())
        read_pickle_file = 'blocking_sarsa'

    cum_rew_qL = learn_tracker_q.cum_reward_per_step()

    while len(q_raveL) < len(cum_rew_qL):
        q_raveL.append(RunningAve())
    for i, r in enumerate(cum_rew_qL):
        q_raveL[i].add_val(r)

fig, ax = plt.subplots()

cum_rew_qL = [R.get_ave() for R in q_raveL]
ax.plot(cum_rew_qL, 'c', label='SARSA', linewidth=3)

# Digitized Sutton & Barto values