def draw_env_and_agent(self, old_s=None): self.fig, self.ax = plt.subplots(figsize=(5, 5)) plot_all_states(self.env, self.ax) plot_agent(self.s, self.ax) if old_s is not None: plot_agent(old_s, self.ax, alpha=0.3) plt.show()
def update_epsilon(self, epsilon): self.epsilon = epsilon self.policy = epsilon_greedy_pi_from_q_table(self.env, self.q_table, self.epsilon) # Display plots self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show()
def finish_episode(self): print_string = '' current_i = self.i while self.i == current_i: print_string += f'{self.next_step(fast_execution=True)}' self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(print_string)
def finish_episode(self): print_string = '' while self.phase != 'showing_new_q_value_terminal': print_string += f'{self.next_step(fast_execution=True)}' print_string += f'{self.next_step(fast_execution=True)}' self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.greedy_policy, self.ax[2]) plt.show() print(print_string)
def next_step(self, fast_execution=False): print_string = '' if self.phase == 'choosing_actions': if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_env_agent_and_policy_at_state(self.env, self.s, self.policy, self.ax[0]) plot_value_table(self.env, self.value_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print('Sampling action...\n') self.phase = 'showing_sampled_action' elif self.phase == 'showing_sampled_action': self.sampled_a = np.random.choice(4, p=self.policy[self.s_idx]) if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0]) plot_value_table(self.env, self.value_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n') self.phase = 'carrying_out_action' elif self.phase == 'carrying_out_action': old_s = self.s s_prime, r, t = self.env.step(self.sampled_a) self.current_episode.append([old_s, self.sampled_a, r, s_prime, t]) if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(old_s, self.ax[0], alpha=0.3) plot_agent(self.s, self.ax[0]) plot_value_table(self.env, self.value_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print( f'Transition: <{old_s}, {a_index_to_symbol[self.sampled_a]}, {r:.1f}, {s_prime}>' ) self.phase = 'showing_td_target' elif self.phase == 'showing_td_target': if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_value_table(self.env, self.value_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() old_s, a, r, s_prime, _ = self.current_episode[-1] s_prime_idx = self.env.states.index(s_prime) self.td_target = r + 0.9 * self.value_table[s_prime_idx] if not fast_execution: print( f'Transition: <{old_s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}>\n' ) print('Updating value table:') print( f'TD target = R + 𝛾 * V({s_prime}) = {r:.1f} + 0.9*{self.value_table[s_prime_idx]:.2f} = ' f'{self.td_target:.2f}') self.phase = 'updating_value_table' elif self.phase == 'updating_value_table': if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_value_table(self.env, self.value_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() old_s, a, r, s_prime, _ = self.current_episode[-1] old_s_idx = self.env.states.index(old_s) new_value = self.alpha * self.td_target + ( 1 - self.alpha) * self.value_table[old_s_idx] if fast_execution: print_string += f'Transition: <{old_s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}>\n' print_string += f'TD target = {self.td_target:.2f}\n' print_string += f'V({old_s}) ← {self.value_table[old_s_idx]:.2f} + {self.alpha:.2f} * ' \ f'({self.td_target:.2f} - {self.value_table[old_s_idx]:.2f}) = {new_value:.2f}\n\n' else: print( f'Transition: <{old_s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}>\n' ) print('Updating value table:') print(f'TD target = {self.td_target:.2f}') print( f'V({old_s}) ← {self.value_table[old_s_idx]:.2f} + {self.alpha:.2f} * ' f'({self.td_target:.2f} - {self.value_table[old_s_idx]:.2f}) = {new_value:.2f}\n' ) self.value_table[old_s_idx] = new_value self.phase = 'showing_new_value_table' elif self.phase == 'showing_new_value_table': if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_value_table(self.env, self.value_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() old_s, a, r, s_prime, t = self.current_episode[-1] old_s_idx = self.env.states.index(old_s) if not fast_execution: print( f'Transition: <{old_s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}>\n' ) print('Updating value table:') print(f'TD target = {self.td_target:.2f}') print(f'V({old_s}) = {self.value_table[old_s_idx]:.2f}') if t: if not fast_execution: print('\nTerminal state, reseting environment.') self.env.reset() self.current_episode = [] self.i += 1 self.phase = 'choosing_actions' else: raise ValueError(f'Phase {self.phase} not recognized.') return print_string
def next_step(self, fast_execution=False): print_string = '' if self.phase == 'choosing_actions': if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_env_agent_and_policy_at_state(self.env, self.s, self.exploratory_policy, self.ax[0]) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.greedy_policy, self.ax[2]) plt.show() print('Sampling action...\n') self.phase = 'showing_sampled_action' elif self.phase == 'showing_sampled_action': self.sampled_a = np.random.choice( 4, p=self.exploratory_policy[self.s_idx]) if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0]) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.greedy_policy, self.ax[2]) plt.show() print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n') self.phase = 'carrying_out_action' elif self.phase == 'carrying_out_action': old_s = self.s s_prime, r, t = self.env.step(self.sampled_a) self.current_transition = [old_s, self.sampled_a, r, s_prime, t] if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(old_s, self.ax[0], alpha=0.3) plot_agent(self.s, self.ax[0]) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.greedy_policy, self.ax[2]) plt.show() print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n') if t: if not fast_execution: print('A terminal state has been reached: end of episode.') self.phase = 'update_q_value_terminal' else: self.phase = 'computing_td_target' elif self.phase == 'computing_td_target': # Compute TD target s, a, r, s_prime, _ = self.current_transition s_prime_idx = self.env.states.index(s_prime) greedy_a = np.random.choice(4, p=self.greedy_policy[self.s_idx]) td_target = r + 0.9 * self.q_table[s_prime_idx, greedy_a] if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_agent(s, self.ax[0], alpha=0.3) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.greedy_policy, self.ax[2]) plt.show() if not fast_execution: print( f'TD target = R + 𝛾 * Q({s_prime}, {a_index_to_symbol[greedy_a]}) = ' f'{r:.2f} + 0.9 * {self.q_table[s_prime_idx, greedy_a]:.2f} = ' f'{td_target:.2f}') self.phase = 'computing_new_q_value' elif self.phase == 'computing_new_q_value': # Compute TD target s, a, r, s_prime, _ = self.current_transition s_idx = self.env.states.index(s) s_prime_idx = self.env.states.index(s_prime) greedy_a = np.random.choice(4, p=self.greedy_policy[self.s_idx]) td_target = r + 0.9 * self.q_table[s_prime_idx, greedy_a] # Compute new value new_value = self.alpha * td_target + ( 1 - self.alpha) * self.q_table[s_idx, a] if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_agent(s, self.ax[0], alpha=0.3) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.greedy_policy, self.ax[2]) plt.show() if not fast_execution: print(f'TD target = {td_target:.2f}') print( f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * ' f'({td_target:.2f} - {self.q_table[s_idx, a]:.2f}) = {new_value:.2f}' ) self.phase = 'showing_new_q_value' elif self.phase == 'showing_new_q_value': # Compute TD target s, a, r, s_prime, _ = self.current_transition s_idx = self.env.states.index(s) s_prime_idx = self.env.states.index(s_prime) greedy_a = np.random.choice(4, p=self.greedy_policy[self.s_idx]) td_target = r + 0.9 * self.q_table[s_prime_idx, greedy_a] # Compute new value and update q-table new_value = self.alpha * td_target + ( 1 - self.alpha) * self.q_table[s_idx, a] self.q_table[s_idx, a] = new_value self.exploratory_policy = epsilon_greedy_pi_from_q_table( self.env, self.q_table, self.epsilon) self.greedy_policy = epsilon_greedy_pi_from_q_table( self.env, self.q_table, 0) if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_agent(s, self.ax[0], alpha=0.3) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.greedy_policy, self.ax[2]) plt.show() print(f'TD target = {td_target:.2f}') print(f'Q({s}, {a_index_to_symbol[a]}) = {new_value:.2f}') self.phase = 'choosing_actions' elif self.phase == 'update_q_value_terminal': s, a, r, s_prime, t = self.current_transition s_idx = self.env.states.index(s) new_value = self.alpha * r + (1 - self.alpha) * self.q_table[s_idx, a] if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_agent(s, self.ax[0], alpha=0.3) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.greedy_policy, self.ax[2]) plt.show() print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n') print('A terminal state has been reached: end of episode.\n') print(f'TD target = R + 𝛾 * 0 = {r:.2f}') print( f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * ' f'({r:.2f} - {self.q_table[s_idx, a]:.2f}) = {new_value:.2f}' ) self.q_table[s_idx, a] = new_value self.exploratory_policy = epsilon_greedy_pi_from_q_table( self.env, self.q_table, self.epsilon) self.greedy_policy = epsilon_greedy_pi_from_q_table( self.env, self.q_table, 0) self.phase = 'showing_new_q_value_terminal' elif self.phase == 'showing_new_q_value_terminal': if not fast_execution: s, a, r, s_prime, t = self.current_transition s_idx = self.env.states.index(s) self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_agent(s, self.ax[0], alpha=0.3) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.greedy_policy, self.ax[2]) plt.show() print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n') print('A terminal state has been reached: end of episode.\n') print(f'TD target = R + 𝛾 * 0 = {r:.2f}') print( f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f}' ) self.current_transition = [] self.env.reset() self.phase = 'choosing_actions' else: raise ValueError(f'Phase {self.phase} not recognized.') return print_string
def next_step(self, fast_execution=False): print_string = '' if self.phase == 'initial_sampling_action': if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_env_agent_and_policy_at_state(self.env, self.s, self.policy, self.ax[0]) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print('Sampling action...\n') self.phase = 'initial_showing_sampled_action' elif self.phase == 'initial_showing_sampled_action': self.sampled_a = np.random.choice(4, p=self.policy[self.s_idx]) if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0]) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n') self.phase = 'carrying_out_action' elif self.phase == 'carrying_out_action': old_s = self.s s_prime, r, t = self.env.step(self.sampled_a) self.current_transition = [old_s, self.sampled_a, r, s_prime, t] if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(old_s, self.ax[0], alpha=0.3) plot_agent(self.s, self.ax[0]) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n') if not t: self.phase = 'sampling_action' else: self.phase = 'compute_q_for_terminal_s_prime' elif self.phase == 'sampling_action': if not fast_execution: old_s = self.current_transition[0] self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_env_agent_and_policy_at_state(self.env, self.s, self.policy, self.ax[0]) plot_agent(old_s, self.ax[0], alpha=0.3) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print('Sampling action...\n') self.phase = 'showing_sampled_action' elif self.phase == 'showing_sampled_action': self.sampled_a = np.random.choice(4, p=self.policy[self.s_idx]) self.current_transition.append(self.sampled_a) if not fast_execution: old_s = self.current_transition[0] self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0]) plot_agent(old_s, self.ax[0], alpha=0.3) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n') self.phase = 'showing_transition_done' elif self.phase == 'showing_transition_done': if not fast_execution: old_s = self.current_transition[0] self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0]) plot_agent(old_s, self.ax[0], alpha=0.3) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() s, a, r, s_prime, t, a_prime = self.current_transition print(f'Transition: <{s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}, {a_index_to_symbol[a_prime]}>\n') self.phase = 'showing_td_target' elif self.phase == 'showing_td_target': if not fast_execution: s, a, r, s_prime, t, a_prime = self.current_transition self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0]) plot_agent(s, self.ax[0], alpha=0.3) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(f'Transition: <{s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}, {a_index_to_symbol[a_prime]}>\n') s_prime_idx = self.env.states.index(s_prime) td_target = r + 0.9 * self.q_table[s_prime_idx, a_prime] print(f'TD target: R + 𝛾 * Q({s_prime}, {a_index_to_symbol[a_prime]}) = ' f'{r:.2f} + 0.9 * {self.q_table[s_prime_idx, a_prime]:.2f} = ' f'{td_target:.2f}') self.phase = 'show_q_value_computation' elif self.phase == 'show_q_value_computation': s, a, r, s_prime, t, a_prime = self.current_transition if not fast_execution: plot_agent(s, self.ax[0], alpha=0.3) self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0]) plot_agent(s, self.ax[0], alpha=0.3) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() s_idx = self.env.states.index(s) s_prime_idx = self.env.states.index(s_prime) td_target = r + 0.9 * self.q_table[s_prime_idx, a_prime] new_value = self.alpha * td_target + (1 - self.alpha) * self.q_table[s_idx, a] if fast_execution: msg = f'Transition: <{s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}, {a_index_to_symbol[a_prime]}>\n' \ f'TD target: R + 𝛾 * Q({s_prime}, {a_index_to_symbol[a_prime]}) = ' \ f'{r:.2f} + 0.9 * {self.q_table[s_prime_idx, a_prime]:.2f} = {td_target:.2f}\n' \ f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * ' \ f'({td_target:.2f} - {self.q_table[s_idx, a]:.2f}) = {new_value:.2f}\n\n' print_string += msg else: print(f'Transition: <{s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}, {a_index_to_symbol[a_prime]}>\n') print(f'TD target: {td_target:.2f}') print(f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * ' f'({td_target:.2f} - {self.q_table[s_idx, a]:.2f}) = {new_value:.2f}') self.q_table[s_idx, a] = new_value self.policy = epsilon_greedy_pi_from_q_table(self.env, self.q_table, self.epsilon) self.phase = 'showing_updated_q_value' elif self.phase == 'showing_updated_q_value': if not fast_execution: s, a, r, s_prime, t, a_prime = self.current_transition self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0]) plot_agent(s, self.ax[0], alpha=0.3) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(f'Transition: <{s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}, {a_index_to_symbol[a_prime]}>\n') s_prime_idx = self.env.states.index(s_prime) td_target = r + 0.9 * self.q_table[s_prime_idx, a_prime] print(f'TD target: {td_target:.2f}') s_idx = self.env.states.index(s) print(f'Q({s, a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f}') self.phase = 'carrying_out_action' elif self.phase == 'compute_q_for_terminal_s_prime': s, a, r, s_prime, t = self.current_transition s_idx = self.env.states.index(s) new_value = self.alpha * r + (1 - self.alpha) * self.q_table[s_idx, a] if fast_execution: print_string += f'Transition to terminal state: <{s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}, ∅>\n' print_string += f'TD target = R + 𝛾 * 0 = {r:.2f}\n' print_string += f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * ' \ f'({r:.2f} - {self.q_table[s_idx, a]:.2f}) = {new_value:.2f}' else: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(s, self.ax[0], alpha=0.3) plot_agent(s_prime, self.ax[0]) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print('Terminal state reached.\n') print(f'TD target = R + 𝛾 * 0 = {r:.2f}') print(f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * ' f'({r:.2f} - {self.q_table[s_idx, a]:.2f}) = {new_value:.2f}') self.q_table[s_idx, a] = new_value self.policy = epsilon_greedy_pi_from_q_table(self.env, self.q_table, self.epsilon) self.phase = 'showing_updated_q_value_terminal_s_prime' elif self.phase == 'showing_updated_q_value_terminal_s_prime': self.env.reset() self.i += 1 if not fast_execution: s, a, r, s_prime, t = self.current_transition s_idx = self.env.states.index(s) self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(s, self.ax[0], alpha=0.3) plot_agent(s_prime, self.ax[0]) plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print('Terminal state reached.\n') print(f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f}\n') print(f'Resetting environment') self.phase = 'initial_sampling_action' else: raise ValueError(f'Phase {self.phase} not recognized.') return print_string
def next_step(self, fast_execution=False): print_string = '' if self.phase == 'choosing_actions': if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_env_agent_and_policy_at_state(self.env, self.s, self.policy, self.ax[0]) plot_value_table(self.env, self.value_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print('Sampling action...\n') print(self.get_current_episode_as_string()) self.phase = 'showing_sampled_action' elif self.phase == 'showing_sampled_action': self.sampled_a = np.random.choice(4, p=self.policy[self.s_idx]) if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0]) plot_value_table(self.env, self.value_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n') print(self.get_current_episode_as_string()) self.phase = 'carrying_out_action' elif self.phase == 'carrying_out_action': old_s = self.s s_prime, r, t = self.env.step(self.sampled_a) self.current_episode.append([old_s, self.sampled_a, r, s_prime]) if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(old_s, self.ax[0], alpha=0.3) plot_agent(self.s, self.ax[0]) plot_value_table(self.env, self.value_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(self.get_current_episode_as_string()) if t: if not fast_execution: print('A terminal state has been reached: end of episode.') self.phase = 'computing_returns' else: self.phase = 'choosing_actions' elif self.phase == 'computing_returns': if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_value_table(self.env, self.value_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(self.get_current_episode_as_string()) print('\nComputing returns:') self.current_episode_returns = [0] * len(self.current_episode) for i in reversed(range(len(self.current_episode))): transition = self.current_episode[i] if i == len(self.current_episode) - 1: if not fast_execution: print( f'G_{i+1} = {transition[2]:.1f},\t\t\t\tat {transition[0]}' ) self.current_episode_returns[i] = transition[2] else: self.current_episode_returns[i] = transition[ 2] + 0.9 * self.current_episode_returns[i + 1] if not fast_execution: print( f'G_{i+1} = {transition[2]:.1f} + 0.9 * {self.current_episode_returns[i+1]:.2f}', end='') print( f' = {self.current_episode_returns[i]:.2f},\tat {transition[0]}' ) self.phase = 'presenting_computed_returns' elif self.phase == 'presenting_computed_returns': if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_value_table(self.env, self.value_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(self.get_current_episode_as_string()) print('\nComputing returns:') for i, (transition, g) in reversed( list( enumerate( zip(self.current_episode, self.current_episode_returns)))): print(f'G_{i+1} = {g:.2f},\tat {transition[0]}') self.phase = 'updating_values' elif self.phase == 'updating_values': if fast_execution: print_string += f'{self.get_current_episode_as_string()}\n' else: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_value_table(self.env, self.value_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(self.get_current_episode_as_string()) print('\nComputing returns:') # Print computed returns for all visited states for i, (transition, g) in reversed( list( enumerate( zip(self.current_episode, self.current_episode_returns)))): msg = f'G_{i+1} = {g:.2f},\tat {transition[0]}' if fast_execution: print_string += f'{msg}\n' else: print(msg) if fast_execution: print_string += '\n' else: print('\nUpdating values:') for transition, g in zip(reversed(self.current_episode), reversed(self.current_episode_returns)): s_idx = self.env.states.index(transition[0]) new_value = self.alpha * g + ( 1 - self.alpha) * self.value_table[s_idx] cur_value = self.value_table[s_idx] msg = f'V({transition[0]}) = {cur_value:.2f} + {self.alpha:.2f} * ({g:.2f} - {cur_value:.2f}) ' \ f'= {new_value:.2f}' self.value_table[s_idx] = new_value if fast_execution: print_string += f'{msg}\n' else: print(msg) self.phase = 'show_updated_value_function' elif self.phase == 'show_updated_value_function': if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_value_table(self.env, self.value_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(self.get_current_episode_as_string()) print('\nComputing returns:') for i, (transition, g) in reversed( list( enumerate( zip(self.current_episode, self.current_episode_returns)))): print(f'G_{i+1} = {g:.2f},\tat {transition[0]}') print('\nUpdating values:') for transition, g in zip( reversed(self.current_episode), reversed(self.current_episode_returns)): s_idx = self.env.states.index(transition[0]) print( f'V({transition[0]}) = {self.value_table[s_idx]:.2f}') self.current_episode = [] self.env.reset() self.phase = 'choosing_actions' else: raise ValueError(f'Phase {self.phase} not recognized.') return print_string