def update_epsilon(self, epsilon): self.epsilon = epsilon self.policy = epsilon_greedy_pi_from_q_table(self.env, self.q_table, self.epsilon) # Display plots self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show()
def finish_episode(self): print_string = '' current_i = self.i while self.i == current_i: print_string += f'{self.next_step(fast_execution=True)}' self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(print_string)
def finish_episode(self): print_string = '' while self.phase != 'showing_new_q_value_terminal': print_string += f'{self.next_step(fast_execution=True)}' print_string += f'{self.next_step(fast_execution=True)}' self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.greedy_policy, self.ax[2]) plt.show() print(print_string)
def next_step(self, fast_execution=False): print_string = '' if self.phase == 'choosing_actions': if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_env_agent_and_policy_at_state(self.env, self.s, self.exploratory_policy, self.ax[0]) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.greedy_policy, self.ax[2]) plt.show() print('Sampling action...\n') self.phase = 'showing_sampled_action' elif self.phase == 'showing_sampled_action': self.sampled_a = np.random.choice( 4, p=self.exploratory_policy[self.s_idx]) if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0]) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.greedy_policy, self.ax[2]) plt.show() print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n') self.phase = 'carrying_out_action' elif self.phase == 'carrying_out_action': old_s = self.s s_prime, r, t = self.env.step(self.sampled_a) self.current_transition = [old_s, self.sampled_a, r, s_prime, t] if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(old_s, self.ax[0], alpha=0.3) plot_agent(self.s, self.ax[0]) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.greedy_policy, self.ax[2]) plt.show() print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n') if t: if not fast_execution: print('A terminal state has been reached: end of episode.') self.phase = 'update_q_value_terminal' else: self.phase = 'computing_td_target' elif self.phase == 'computing_td_target': # Compute TD target s, a, r, s_prime, _ = self.current_transition s_prime_idx = self.env.states.index(s_prime) greedy_a = np.random.choice(4, p=self.greedy_policy[self.s_idx]) td_target = r + 0.9 * self.q_table[s_prime_idx, greedy_a] if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_agent(s, self.ax[0], alpha=0.3) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.greedy_policy, self.ax[2]) plt.show() if not fast_execution: print( f'TD target = R + 𝛾 * Q({s_prime}, {a_index_to_symbol[greedy_a]}) = ' f'{r:.2f} + 0.9 * {self.q_table[s_prime_idx, greedy_a]:.2f} = ' f'{td_target:.2f}') self.phase = 'computing_new_q_value' elif self.phase == 'computing_new_q_value': # Compute TD target s, a, r, s_prime, _ = self.current_transition s_idx = self.env.states.index(s) s_prime_idx = self.env.states.index(s_prime) greedy_a = np.random.choice(4, p=self.greedy_policy[self.s_idx]) td_target = r + 0.9 * self.q_table[s_prime_idx, greedy_a] # Compute new value new_value = self.alpha * td_target + ( 1 - self.alpha) * self.q_table[s_idx, a] if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_agent(s, self.ax[0], alpha=0.3) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.greedy_policy, self.ax[2]) plt.show() if not fast_execution: print(f'TD target = {td_target:.2f}') print( f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * ' f'({td_target:.2f} - {self.q_table[s_idx, a]:.2f}) = {new_value:.2f}' ) self.phase = 'showing_new_q_value' elif self.phase == 'showing_new_q_value': # Compute TD target s, a, r, s_prime, _ = self.current_transition s_idx = self.env.states.index(s) s_prime_idx = self.env.states.index(s_prime) greedy_a = np.random.choice(4, p=self.greedy_policy[self.s_idx]) td_target = r + 0.9 * self.q_table[s_prime_idx, greedy_a] # Compute new value and update q-table new_value = self.alpha * td_target + ( 1 - self.alpha) * self.q_table[s_idx, a] self.q_table[s_idx, a] = new_value self.exploratory_policy = epsilon_greedy_pi_from_q_table( self.env, self.q_table, self.epsilon) self.greedy_policy = epsilon_greedy_pi_from_q_table( self.env, self.q_table, 0) if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_agent(s, self.ax[0], alpha=0.3) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.greedy_policy, self.ax[2]) plt.show() print(f'TD target = {td_target:.2f}') print(f'Q({s}, {a_index_to_symbol[a]}) = {new_value:.2f}') self.phase = 'choosing_actions' elif self.phase == 'update_q_value_terminal': s, a, r, s_prime, t = self.current_transition s_idx = self.env.states.index(s) new_value = self.alpha * r + (1 - self.alpha) * self.q_table[s_idx, a] if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_agent(s, self.ax[0], alpha=0.3) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.greedy_policy, self.ax[2]) plt.show() print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n') print('A terminal state has been reached: end of episode.\n') print(f'TD target = R + 𝛾 * 0 = {r:.2f}') print( f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * ' f'({r:.2f} - {self.q_table[s_idx, a]:.2f}) = {new_value:.2f}' ) self.q_table[s_idx, a] = new_value self.exploratory_policy = epsilon_greedy_pi_from_q_table( self.env, self.q_table, self.epsilon) self.greedy_policy = epsilon_greedy_pi_from_q_table( self.env, self.q_table, 0) self.phase = 'showing_new_q_value_terminal' elif self.phase == 'showing_new_q_value_terminal': if not fast_execution: s, a, r, s_prime, t = self.current_transition s_idx = self.env.states.index(s) self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_agent(s, self.ax[0], alpha=0.3) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.greedy_policy, self.ax[2]) plt.show() print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n') print('A terminal state has been reached: end of episode.\n') print(f'TD target = R + 𝛾 * 0 = {r:.2f}') print( f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f}' ) self.current_transition = [] self.env.reset() self.phase = 'choosing_actions' else: raise ValueError(f'Phase {self.phase} not recognized.') return print_string
def next_step(self, fast_execution=False): print_string = '' if self.phase == 'initial_sampling_action': if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_env_agent_and_policy_at_state(self.env, self.s, self.policy, self.ax[0]) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print('Sampling action...\n') self.phase = 'initial_showing_sampled_action' elif self.phase == 'initial_showing_sampled_action': self.sampled_a = np.random.choice(4, p=self.policy[self.s_idx]) if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0]) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n') self.phase = 'carrying_out_action' elif self.phase == 'carrying_out_action': old_s = self.s s_prime, r, t = self.env.step(self.sampled_a) self.current_transition = [old_s, self.sampled_a, r, s_prime, t] if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(old_s, self.ax[0], alpha=0.3) plot_agent(self.s, self.ax[0]) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n') if not t: self.phase = 'sampling_action' else: self.phase = 'compute_q_for_terminal_s_prime' elif self.phase == 'sampling_action': if not fast_execution: old_s = self.current_transition[0] self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_env_agent_and_policy_at_state(self.env, self.s, self.policy, self.ax[0]) plot_agent(old_s, self.ax[0], alpha=0.3) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print('Sampling action...\n') self.phase = 'showing_sampled_action' elif self.phase == 'showing_sampled_action': self.sampled_a = np.random.choice(4, p=self.policy[self.s_idx]) self.current_transition.append(self.sampled_a) if not fast_execution: old_s = self.current_transition[0] self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0]) plot_agent(old_s, self.ax[0], alpha=0.3) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n') self.phase = 'showing_transition_done' elif self.phase == 'showing_transition_done': if not fast_execution: old_s = self.current_transition[0] self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0]) plot_agent(old_s, self.ax[0], alpha=0.3) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() s, a, r, s_prime, t, a_prime = self.current_transition print(f'Transition: <{s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}, {a_index_to_symbol[a_prime]}>\n') self.phase = 'showing_td_target' elif self.phase == 'showing_td_target': if not fast_execution: s, a, r, s_prime, t, a_prime = self.current_transition self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0]) plot_agent(s, self.ax[0], alpha=0.3) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(f'Transition: <{s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}, {a_index_to_symbol[a_prime]}>\n') s_prime_idx = self.env.states.index(s_prime) td_target = r + 0.9 * self.q_table[s_prime_idx, a_prime] print(f'TD target: R + 𝛾 * Q({s_prime}, {a_index_to_symbol[a_prime]}) = ' f'{r:.2f} + 0.9 * {self.q_table[s_prime_idx, a_prime]:.2f} = ' f'{td_target:.2f}') self.phase = 'show_q_value_computation' elif self.phase == 'show_q_value_computation': s, a, r, s_prime, t, a_prime = self.current_transition if not fast_execution: plot_agent(s, self.ax[0], alpha=0.3) self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0]) plot_agent(s, self.ax[0], alpha=0.3) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() s_idx = self.env.states.index(s) s_prime_idx = self.env.states.index(s_prime) td_target = r + 0.9 * self.q_table[s_prime_idx, a_prime] new_value = self.alpha * td_target + (1 - self.alpha) * self.q_table[s_idx, a] if fast_execution: msg = f'Transition: <{s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}, {a_index_to_symbol[a_prime]}>\n' \ f'TD target: R + 𝛾 * Q({s_prime}, {a_index_to_symbol[a_prime]}) = ' \ f'{r:.2f} + 0.9 * {self.q_table[s_prime_idx, a_prime]:.2f} = {td_target:.2f}\n' \ f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * ' \ f'({td_target:.2f} - {self.q_table[s_idx, a]:.2f}) = {new_value:.2f}\n\n' print_string += msg else: print(f'Transition: <{s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}, {a_index_to_symbol[a_prime]}>\n') print(f'TD target: {td_target:.2f}') print(f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * ' f'({td_target:.2f} - {self.q_table[s_idx, a]:.2f}) = {new_value:.2f}') self.q_table[s_idx, a] = new_value self.policy = epsilon_greedy_pi_from_q_table(self.env, self.q_table, self.epsilon) self.phase = 'showing_updated_q_value' elif self.phase == 'showing_updated_q_value': if not fast_execution: s, a, r, s_prime, t, a_prime = self.current_transition self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0]) plot_agent(s, self.ax[0], alpha=0.3) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(f'Transition: <{s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}, {a_index_to_symbol[a_prime]}>\n') s_prime_idx = self.env.states.index(s_prime) td_target = r + 0.9 * self.q_table[s_prime_idx, a_prime] print(f'TD target: {td_target:.2f}') s_idx = self.env.states.index(s) print(f'Q({s, a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f}') self.phase = 'carrying_out_action' elif self.phase == 'compute_q_for_terminal_s_prime': s, a, r, s_prime, t = self.current_transition s_idx = self.env.states.index(s) new_value = self.alpha * r + (1 - self.alpha) * self.q_table[s_idx, a] if fast_execution: print_string += f'Transition to terminal state: <{s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}, ∅>\n' print_string += f'TD target = R + 𝛾 * 0 = {r:.2f}\n' print_string += f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * ' \ f'({r:.2f} - {self.q_table[s_idx, a]:.2f}) = {new_value:.2f}' else: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(s, self.ax[0], alpha=0.3) plot_agent(s_prime, self.ax[0]) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print('Terminal state reached.\n') print(f'TD target = R + 𝛾 * 0 = {r:.2f}') print(f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * ' f'({r:.2f} - {self.q_table[s_idx, a]:.2f}) = {new_value:.2f}') self.q_table[s_idx, a] = new_value self.policy = epsilon_greedy_pi_from_q_table(self.env, self.q_table, self.epsilon) self.phase = 'showing_updated_q_value_terminal_s_prime' elif self.phase == 'showing_updated_q_value_terminal_s_prime': self.env.reset() self.i += 1 if not fast_execution: s, a, r, s_prime, t = self.current_transition s_idx = self.env.states.index(s) self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(s, self.ax[0], alpha=0.3) plot_agent(s_prime, self.ax[0]) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print('Terminal state reached.\n') print(f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f}\n') print(f'Resetting environment') self.phase = 'initial_sampling_action' else: raise ValueError(f'Phase {self.phase} not recognized.') return print_string
def next_step(self, fast_execution=False): print_string = '' if self.phase == 'choosing_actions': if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_env_agent_and_policy_at_state(self.env, self.s, self.policy, self.ax[0]) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print('Sampling action...\n') print(self.get_current_episode_as_string()) self.phase = 'showing_sampled_action' elif self.phase == 'showing_sampled_action': self.sampled_a = np.random.choice(4, p=self.policy[self.s_idx]) if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0]) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n') print(self.get_current_episode_as_string()) self.phase = 'carrying_out_action' elif self.phase == 'carrying_out_action': old_s = self.s s_prime, r, t = self.env.step(self.sampled_a) self.current_episode.append([old_s, self.sampled_a, r, s_prime]) if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(old_s, self.ax[0], alpha=0.3) plot_agent(self.s, self.ax[0]) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(self.get_current_episode_as_string()) if t: if not fast_execution: print('A terminal state has been reached: end of episode.') self.phase = 'computing_returns' else: self.phase = 'choosing_actions' elif self.phase == 'computing_returns': if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(self.get_current_episode_as_string()) print('Computing returns:') self.current_episode_returns = [0] * len(self.current_episode) for i in reversed(range(len(self.current_episode))): transition = self.current_episode[i] if i == len(self.current_episode) - 1: if not fast_execution: print( f'G_{i + 1} = {transition[2]:.1f},\t\t\t\tat {transition[0]}' ) self.current_episode_returns[i] = transition[2] else: self.current_episode_returns[i] = transition[ 2] + 0.9 * self.current_episode_returns[i + 1] if not fast_execution: print( f'G_{i+1} = {transition[2]:.1f} + 0.9 * {self.current_episode_returns[i+1]:.2f}', end='') print( f' = {self.current_episode_returns[i]:.2f},\tat {transition[0]}' ) self.phase = 'presenting_computed_returns' elif self.phase == 'presenting_computed_returns': if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(self.get_current_episode_as_string()) print('Computing returns:') for i, (transition, g) in reversed( list( enumerate( zip(self.current_episode, self.current_episode_returns)))): print(f'G_{i+1} = {g:.2f},\tat {transition[0]}') self.phase = 'updating_values' elif self.phase == 'updating_values': if fast_execution: print_string += f'{self.get_current_episode_as_string()}' print_string += '\nComputing returns:\n' else: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(self.get_current_episode_as_string()) print('Computing returns:') for i, (transition, g) in reversed( list( enumerate( zip(self.current_episode, self.current_episode_returns)))): msg = f'G_{i+1} = {g:.2f},\tat {transition[0]}' if fast_execution: print_string += f'{msg}\n' else: print(msg) if fast_execution: print_string += '\nUpdating values:\n' else: print('\nUpdating values:') for transition, g in zip(reversed(self.current_episode), reversed(self.current_episode_returns)): s_idx = self.env.states.index(transition[0]) a = transition[1] new_value = self.alpha * g + ( 1 - self.alpha) * self.q_table[s_idx, a] msg = f'Q({transition[0]}, {a_index_to_symbol[a]}) = ' \ f'{self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * ({g:.2f} - {self.q_table[s_idx, a]:.2f}) = ' \ f'{new_value:.2f}' if fast_execution: print_string += f'{msg}\n' else: print(msg) self.q_table[s_idx, a] = new_value self.policy = epsilon_greedy_pi_from_q_table( self.env, self.q_table, self.epsilon) self.phase = 'show_updated_value_function' elif self.phase == 'show_updated_value_function': if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(self.get_current_episode_as_string()) print('Computing returns:') for i, (transition, g) in reversed( list( enumerate( zip(self.current_episode, self.current_episode_returns)))): print(f'G_{i+1} = {g:.2f},\tat {transition[0]}') print('\nUpdating values:') for transition, g in zip( reversed(self.current_episode), reversed(self.current_episode_returns)): s_idx = self.env.states.index(transition[0]) print( f'Q({transition[0]}, {a_index_to_symbol[transition[1]]}) = ' f'{self.q_table[s_idx, transition[1]]:.2f}') self.current_episode = [] self.env.reset() self.phase = 'choosing_actions' else: raise ValueError(f'Phase {self.phase} not recognized.') return print_string