agent_0.run_episode((2, 0), Nplanning_loops=0) sarsnL = learn_tracker_0.get_episode_sarsn_list(0) agent_5.run_episode((2, 0), Nplanning_loops=5, iter_sarsn=iter(sarsnL)) agent_50.run_episode((2, 0), Nplanning_loops=50, iter_sarsn=iter(sarsnL)) # episodes 2 to 50 for i in range(49): print(i, end=' ') agent_0.run_episode((2, 0), Nplanning_loops=0) agent_5.run_episode((2, 0), Nplanning_loops=5) agent_50.run_episode((2, 0), Nplanning_loops=50) fig, ax = plt.subplots() step_0L = learn_tracker_0.steps_per_episode()[1:] ax.plot(step_0L, 'c', label='0 planning steps') step_5L = learn_tracker_5.steps_per_episode()[1:] ax.plot(step_5L, 'g', label='5 planning steps') step_50L = learn_tracker_50.steps_per_episode()[1:] ax.plot(step_50L, 'r', label='50 planning steps') ax.legend() ax.set(title='Figure 8.2 Dyna Maze\n(common 1st episode)') #ax.axhline(y=0, color='k') #ax.axvline(x=0, color='k') plt.ylabel('Steps per Episode') plt.xlabel('Episodes') plt.ylim(0, 800)
save_pickle_file='', use_list_of_start_states=False, # use list OR single start state of environment. do_summ_print=True, show_last_change=True, fmt_Q='%g', fmt_R='%g', max_num_episodes=170, min_num_episodes=10, max_abserr=0.001, gamma=1.0, iteration_prints=0, max_episode_steps=10000, epsilon=0.1, const_epsilon=True, alpha=0.5, const_alpha=True) print('_' * 55) score = gridworld.get_policy_score(policy, start_state_hash=None, step_limit=1000) print('Policy Score =', score, ' = (r_sum, n_steps, msg)') steps_per_episodeL = learn_tracker.steps_per_episode() print(gridworld.get_info()) episode = make_episode(gridworld.start_state_hash, policy, gridworld, gridworld.terminal_set, max_steps=20) epi_summ_print(episode, policy, gridworld, show_rewards=False, show_env_states=True, none_str='*')
policy, action_value = \ sarsa_epsilon_greedy( maze_q, learn_tracker=learn_tracker_q, initial_Qsa=0.0, # init non-terminal_set of V(s) (terminal_set=0.0) read_pickle_file=read_pickle_file, save_pickle_file='blocking_sarsa', use_list_of_start_states=False, # use list OR single start state of environment. do_summ_print=False, show_last_change=False, fmt_Q='%g', fmt_R='%g', show_banner = False, max_num_episodes=1, min_num_episodes=11, max_abserr=0.001, gamma=GAMMA, iteration_prints=0, max_episode_steps=3000, epsilon=EPSILON, alpha=ALPHA) time_stamp = sum(learn_tracker_q.steps_per_episode()) read_pickle_file = 'blocking_sarsa' cum_rew_qL = learn_tracker_q.cum_reward_per_step() while len(q_raveL) < len(cum_rew_qL): q_raveL.append(RunningAve()) for i, r in enumerate(cum_rew_qL): q_raveL[i].add_val(r) fig, ax = plt.subplots() cum_rew_qL = [R.get_ave() for R in q_raveL] ax.plot(cum_rew_qL, 'c', label='SARSA', linewidth=3) # Digitized Sutton & Barto values