def policy_improve_on_click(obj): plt.close() fig, ax = plt.subplots(ncols=2, figsize=(14, 6)) output.clear_output(True) policy_improvement(env, policy, value_table) with output: plot_value_table(env, value_table, ax=ax[0], vmin=vmin, vmax=vmax) plot_policy(env, policy, ax=ax[1]) plt.show()
def finish_iteration(self): initial_i = self.i while initial_i == self.i: self.next_step(fast_execution=True) # Plot the resulting state of the env, value table, and policy self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_s_pmf(self.env, self.s, self.a, ax=self.ax[0]) plot_value_table(self.env, self.value_table, ax=self.ax[1], vmin=-8, vmax=0) plot_policy(self.env, self.policy, ax=self.ax[2]) plt.show()
def finish_episode(self): print_string = '' current_i = self.i while self.i == current_i: print_string += f'{self.next_step(fast_execution=True)}' self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_value_table(self.env, self.value_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(print_string)
def policy_iteration_wrapper(): env.reset() value_table = np.zeros(9) policy = np.array([[1, 0, 0, 0], [0.5, 0.5, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 1, 0], [1, 0, 0, 0]]) btn_eval = widgets.Button(description='Policy eval iteration') btn_improve = widgets.Button(description='Policy improvement') display(btn_eval) display(btn_improve) output = widgets.Output() vmin = -8 vmax = 0 with output: fig, ax = plt.subplots(ncols=2, figsize=(14, 6)) plot_value_table(env, value_table, ax=ax[0], vmin=vmin, vmax=vmax) plot_policy(env, policy, ax=ax[1]) plt.show() def policy_eval_on_click(obj): plt.close() fig, ax = plt.subplots(ncols=2, figsize=(14, 6)) output.clear_output(True) policy_evaluation_one_step(env, policy, value_table) with output: plot_value_table(env, value_table, ax=ax[0], vmin=vmin, vmax=vmax) plot_policy(env, policy, ax=ax[1]) plt.show() def policy_improve_on_click(obj): plt.close() fig, ax = plt.subplots(ncols=2, figsize=(14, 6)) output.clear_output(True) policy_improvement(env, policy, value_table) with output: plot_value_table(env, value_table, ax=ax[0], vmin=vmin, vmax=vmax) plot_policy(env, policy, ax=ax[1]) plt.show() btn_eval.on_click(policy_eval_on_click) btn_improve.on_click(policy_improve_on_click) display(output)
def next_step(self, fast_execution=False): print_string = '' if self.phase == 'choosing_actions': if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_env_agent_and_policy_at_state(self.env, self.s, self.policy, self.ax[0]) plot_value_table(self.env, self.value_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print('Sampling action...\n') self.phase = 'showing_sampled_action' elif self.phase == 'showing_sampled_action': self.sampled_a = np.random.choice(4, p=self.policy[self.s_idx]) if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0]) plot_value_table(self.env, self.value_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n') self.phase = 'carrying_out_action' elif self.phase == 'carrying_out_action': old_s = self.s s_prime, r, t = self.env.step(self.sampled_a) self.current_episode.append([old_s, self.sampled_a, r, s_prime, t]) if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(old_s, self.ax[0], alpha=0.3) plot_agent(self.s, self.ax[0]) plot_value_table(self.env, self.value_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print( f'Transition: <{old_s}, {a_index_to_symbol[self.sampled_a]}, {r:.1f}, {s_prime}>' ) self.phase = 'showing_td_target' elif self.phase == 'showing_td_target': if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_value_table(self.env, self.value_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() old_s, a, r, s_prime, _ = self.current_episode[-1] s_prime_idx = self.env.states.index(s_prime) self.td_target = r + 0.9 * self.value_table[s_prime_idx] if not fast_execution: print( f'Transition: <{old_s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}>\n' ) print('Updating value table:') print( f'TD target = R + 𝛾 * V({s_prime}) = {r:.1f} + 0.9*{self.value_table[s_prime_idx]:.2f} = ' f'{self.td_target:.2f}') self.phase = 'updating_value_table' elif self.phase == 'updating_value_table': if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_value_table(self.env, self.value_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() old_s, a, r, s_prime, _ = self.current_episode[-1] old_s_idx = self.env.states.index(old_s) new_value = self.alpha * self.td_target + ( 1 - self.alpha) * self.value_table[old_s_idx] if fast_execution: print_string += f'Transition: <{old_s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}>\n' print_string += f'TD target = {self.td_target:.2f}\n' print_string += f'V({old_s}) ← {self.value_table[old_s_idx]:.2f} + {self.alpha:.2f} * ' \ f'({self.td_target:.2f} - {self.value_table[old_s_idx]:.2f}) = {new_value:.2f}\n\n' else: print( f'Transition: <{old_s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}>\n' ) print('Updating value table:') print(f'TD target = {self.td_target:.2f}') print( f'V({old_s}) ← {self.value_table[old_s_idx]:.2f} + {self.alpha:.2f} * ' f'({self.td_target:.2f} - {self.value_table[old_s_idx]:.2f}) = {new_value:.2f}\n' ) self.value_table[old_s_idx] = new_value self.phase = 'showing_new_value_table' elif self.phase == 'showing_new_value_table': if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_value_table(self.env, self.value_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() old_s, a, r, s_prime, t = self.current_episode[-1] old_s_idx = self.env.states.index(old_s) if not fast_execution: print( f'Transition: <{old_s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}>\n' ) print('Updating value table:') print(f'TD target = {self.td_target:.2f}') print(f'V({old_s}) = {self.value_table[old_s_idx]:.2f}') if t: if not fast_execution: print('\nTerminal state, reseting environment.') self.env.reset() self.current_episode = [] self.i += 1 self.phase = 'choosing_actions' else: raise ValueError(f'Phase {self.phase} not recognized.') return print_string
# Set policy to greedy w.r.t. Q-values computed above new_policy[np.argmax(q)] = 1 policy[i_s] = new_policy def epsilon_greedy_pi_from_q_table(env, q_table, epsilon): policy = np.zeros((len(env.states), 4)) for i_s, s in enumerate(env.states): q_values = q_table[i_s] a_star = np.argmax(q_values) policy[i_s, :] = [epsilon / 4] * 4 policy[i_s, a_star] += 1 - epsilon return policy if __name__ == '__main__': from lib.grid_world import grid_world_3x3 as env from lib.plot_utils import plot_policy, plot_value_table import matplotlib.pyplot as plt np.random.seed(2) env.reset() value_table = np.random.rand(130) policy = np.ones((len(env.states), 4)) * 0.25 for i, s in enumerate(env.states): policy[i, :] = np.random.dirichlet([0.1, 0.1, 0.1, 0.1]) plot_value_table(env, value_table) plot_policy(env, policy) plt.show()
def next_step(self, fast_execution=False): print_string = '' if self.phase == 'choosing_actions': if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_env_agent_and_policy_at_state(self.env, self.s, self.policy, self.ax[0]) plot_value_table(self.env, self.value_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print('Sampling action...\n') print(self.get_current_episode_as_string()) self.phase = 'showing_sampled_action' elif self.phase == 'showing_sampled_action': self.sampled_a = np.random.choice(4, p=self.policy[self.s_idx]) if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0]) plot_value_table(self.env, self.value_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n') print(self.get_current_episode_as_string()) self.phase = 'carrying_out_action' elif self.phase == 'carrying_out_action': old_s = self.s s_prime, r, t = self.env.step(self.sampled_a) self.current_episode.append([old_s, self.sampled_a, r, s_prime]) if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(old_s, self.ax[0], alpha=0.3) plot_agent(self.s, self.ax[0]) plot_value_table(self.env, self.value_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(self.get_current_episode_as_string()) if t: if not fast_execution: print('A terminal state has been reached: end of episode.') self.phase = 'computing_returns' else: self.phase = 'choosing_actions' elif self.phase == 'computing_returns': if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_value_table(self.env, self.value_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(self.get_current_episode_as_string()) print('\nComputing returns:') self.current_episode_returns = [0] * len(self.current_episode) for i in reversed(range(len(self.current_episode))): transition = self.current_episode[i] if i == len(self.current_episode) - 1: if not fast_execution: print( f'G_{i+1} = {transition[2]:.1f},\t\t\t\tat {transition[0]}' ) self.current_episode_returns[i] = transition[2] else: self.current_episode_returns[i] = transition[ 2] + 0.9 * self.current_episode_returns[i + 1] if not fast_execution: print( f'G_{i+1} = {transition[2]:.1f} + 0.9 * {self.current_episode_returns[i+1]:.2f}', end='') print( f' = {self.current_episode_returns[i]:.2f},\tat {transition[0]}' ) self.phase = 'presenting_computed_returns' elif self.phase == 'presenting_computed_returns': if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_value_table(self.env, self.value_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(self.get_current_episode_as_string()) print('\nComputing returns:') for i, (transition, g) in reversed( list( enumerate( zip(self.current_episode, self.current_episode_returns)))): print(f'G_{i+1} = {g:.2f},\tat {transition[0]}') self.phase = 'updating_values' elif self.phase == 'updating_values': if fast_execution: print_string += f'{self.get_current_episode_as_string()}\n' else: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_value_table(self.env, self.value_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(self.get_current_episode_as_string()) print('\nComputing returns:') # Print computed returns for all visited states for i, (transition, g) in reversed( list( enumerate( zip(self.current_episode, self.current_episode_returns)))): msg = f'G_{i+1} = {g:.2f},\tat {transition[0]}' if fast_execution: print_string += f'{msg}\n' else: print(msg) if fast_execution: print_string += '\n' else: print('\nUpdating values:') for transition, g in zip(reversed(self.current_episode), reversed(self.current_episode_returns)): s_idx = self.env.states.index(transition[0]) new_value = self.alpha * g + ( 1 - self.alpha) * self.value_table[s_idx] cur_value = self.value_table[s_idx] msg = f'V({transition[0]}) = {cur_value:.2f} + {self.alpha:.2f} * ({g:.2f} - {cur_value:.2f}) ' \ f'= {new_value:.2f}' self.value_table[s_idx] = new_value if fast_execution: print_string += f'{msg}\n' else: print(msg) self.phase = 'show_updated_value_function' elif self.phase == 'show_updated_value_function': if not fast_execution: self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_all_states(self.env, self.ax[0]) plot_agent(self.s, self.ax[0]) plot_value_table(self.env, self.value_table, self.ax[1], vmin=self.vmin, vmax=self.vmax) plot_policy(self.env, self.policy, self.ax[2]) plt.show() print(self.get_current_episode_as_string()) print('\nComputing returns:') for i, (transition, g) in reversed( list( enumerate( zip(self.current_episode, self.current_episode_returns)))): print(f'G_{i+1} = {g:.2f},\tat {transition[0]}') print('\nUpdating values:') for transition, g in zip( reversed(self.current_episode), reversed(self.current_episode_returns)): s_idx = self.env.states.index(transition[0]) print( f'V({transition[0]}) = {self.value_table[s_idx]:.2f}') self.current_episode = [] self.env.reset() self.phase = 'choosing_actions' else: raise ValueError(f'Phase {self.phase} not recognized.') return print_string
def next_step(self, fast_execution=False): if not fast_execution: # Shared plotting for all phases self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6)) plot_s_pmf(self.env, self.s, self.a, ax=self.ax[0]) plot_value_table(self.env, self.value_table, ax=self.ax[1], vmin=-8, vmax=0) plot_policy(self.env, self.policy, ax=self.ax[2]) # Phase-specific plotting if self.phase == 'computing_final_v' and self.env.terminal_table[ self.s_i] != 1: self.ax[2].add_patch( matplotlib.patches.Rectangle(self.s, 1.0, 1.0, edgecolor='r', facecolor='None', lw=5)) if self.phase == 'updating_value_table': self.ax[1].add_patch( matplotlib.patches.Rectangle(self.s, 1.0, 1.0, edgecolor='r', facecolor='None', lw=5)) plt.show() print(f'Iteration: {self.i}\nState: {self.s}\n') if self.phase == 'going_through_actions': # Special handling for terminal states if self.env.terminal_table[self.s_i] == 1: self.a = 4 else: # Get pmf over possible next states s_prime_idxs, s_prime_probs = get_pmf_possible_s_primes( self.env, self.s, self.a) # Compute Q values q_message = f'Q(s, {a_index_to_symbol[self.a]}) = ' r = self.env.reward_table[self.s_i, self.a] for s_prime_idx, s_prime_prob in zip(s_prime_idxs, s_prime_probs): q_message += f'{s_prime_prob:.1f}*({r:.1f} + ' \ f'0.9*{self.value_table[s_prime_idx]:.2f}) + ' self.q_values[self.a] += s_prime_prob * ( r + 0.9 * self.value_table[s_prime_idx]) # Prepare message to be printed q_message = q_message[:-2] q_message += f' = {self.q_values[self.a]:.2f}' self.q_messages += f'\n{q_message}' if not fast_execution: print(self.q_messages) self.a += 1 if self.a == 4: self.a = 3 # Show the same action as before self.phase = 'computing_final_v' elif self.phase == 'computing_final_v': if self.env.terminal_table[self.s_i] == 1: final_value_message = 'State is terminal. Setting its value to zero.\nV(s) = 0' if not fast_execution: print(final_value_message) final_value = 0 else: # Print computed Q messages if not fast_execution: for i, q_value in enumerate(self.q_values): print(f'Q(s, {a_index_to_symbol[i]}) = {q_value:.2f}') # Print resulting value of current state final_value_message = '\nV(s) = ' final_value = 0 for q_value, pi in zip(self.q_values, self.policy[self.s_i]): final_value_message += f'{q_value:.2f}*{pi:.1f} + ' final_value += q_value * pi final_value_message = f'{final_value_message[:-3]}) = {final_value:.2f}' if not fast_execution: print(final_value_message) # Update value table and move to next phase self.value_table[self.s_i] = final_value self.phase = 'updating_value_table' elif self.phase == 'updating_value_table': if not fast_execution: # Print computed Q messages for i, q_value in enumerate(self.q_values): print(f'Q(s, {a_index_to_symbol[i]}) = {q_value:.2f}') # Print resulting value of current state print(f'\nV(s) = {self.value_table[self.s_i]:.2f}') # Reset action counter, move to next state, erase previous q messages self.a = 0 self.s_i += 1 self.q_messages = '' self.q_values = [0, 0, 0, 0] self.phase = 'going_through_actions' # If reached last state, restart from first state and increment iteration number if self.s_i == len(self.env.states): self.s_i = 0 self.i += 1