예제 #1
0
 def draw_env_and_agent(self, old_s=None):
     self.fig, self.ax = plt.subplots(figsize=(5, 5))
     plot_all_states(self.env, self.ax)
     plot_agent(self.s, self.ax)
     if old_s is not None:
         plot_agent(old_s, self.ax, alpha=0.3)
     plt.show()
예제 #2
0
    def update_epsilon(self, epsilon):
        self.epsilon = epsilon
        self.policy = epsilon_greedy_pi_from_q_table(self.env, self.q_table, self.epsilon)

        # Display plots
        self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
        plot_all_states(self.env, self.ax[0])
        plot_agent(self.s, self.ax[0])
        plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
        plot_policy(self.env, self.policy, self.ax[2])
        plt.show()
예제 #3
0
    def finish_episode(self):
        print_string = ''
        current_i = self.i
        while self.i == current_i:
            print_string += f'{self.next_step(fast_execution=True)}'

        self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
        plot_all_states(self.env, self.ax[0])
        plot_agent(self.s, self.ax[0])
        plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
        plot_policy(self.env, self.policy, self.ax[2])
        plt.show()
        print(print_string)
예제 #4
0
    def finish_episode(self):
        print_string = ''
        while self.phase != 'showing_new_q_value_terminal':
            print_string += f'{self.next_step(fast_execution=True)}'
        print_string += f'{self.next_step(fast_execution=True)}'

        self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
        plot_all_states(self.env, self.ax[0])
        plot_agent(self.s, self.ax[0])
        plot_q_table(self.env,
                     self.q_table,
                     self.ax[1],
                     vmin=self.vmin,
                     vmax=self.vmax)
        plot_policy(self.env, self.greedy_policy, self.ax[2])
        plt.show()
        print(print_string)
예제 #5
0
    def next_step(self, fast_execution=False):
        print_string = ''
        if self.phase == 'choosing_actions':
            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_policy_at_state(self.env, self.s,
                                                   self.policy, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print('Sampling action...\n')

            self.phase = 'showing_sampled_action'

        elif self.phase == 'showing_sampled_action':
            self.sampled_a = np.random.choice(4, p=self.policy[self.s_idx])

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_chosen_action(self.env, self.s,
                                                 self.sampled_a, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n')

            self.phase = 'carrying_out_action'

        elif self.phase == 'carrying_out_action':
            old_s = self.s
            s_prime, r, t = self.env.step(self.sampled_a)
            self.current_episode.append([old_s, self.sampled_a, r, s_prime, t])

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(old_s, self.ax[0], alpha=0.3)
                plot_agent(self.s, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(
                    f'Transition: <{old_s}, {a_index_to_symbol[self.sampled_a]}, {r:.1f}, {s_prime}>'
                )

            self.phase = 'showing_td_target'

        elif self.phase == 'showing_td_target':
            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()

            old_s, a, r, s_prime, _ = self.current_episode[-1]
            s_prime_idx = self.env.states.index(s_prime)
            self.td_target = r + 0.9 * self.value_table[s_prime_idx]
            if not fast_execution:
                print(
                    f'Transition: <{old_s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}>\n'
                )
                print('Updating value table:')
                print(
                    f'TD target = R + 𝛾 * V({s_prime}) = {r:.1f} + 0.9*{self.value_table[s_prime_idx]:.2f} = '
                    f'{self.td_target:.2f}')

            self.phase = 'updating_value_table'

        elif self.phase == 'updating_value_table':
            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()

            old_s, a, r, s_prime, _ = self.current_episode[-1]
            old_s_idx = self.env.states.index(old_s)
            new_value = self.alpha * self.td_target + (
                1 - self.alpha) * self.value_table[old_s_idx]

            if fast_execution:
                print_string += f'Transition: <{old_s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}>\n'
                print_string += f'TD target = {self.td_target:.2f}\n'
                print_string += f'V({old_s}) ← {self.value_table[old_s_idx]:.2f} + {self.alpha:.2f} * ' \
                                f'({self.td_target:.2f} - {self.value_table[old_s_idx]:.2f}) = {new_value:.2f}\n\n'
            else:
                print(
                    f'Transition: <{old_s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}>\n'
                )
                print('Updating value table:')
                print(f'TD target = {self.td_target:.2f}')
                print(
                    f'V({old_s}) ← {self.value_table[old_s_idx]:.2f} + {self.alpha:.2f} * '
                    f'({self.td_target:.2f} - {self.value_table[old_s_idx]:.2f}) = {new_value:.2f}\n'
                )

            self.value_table[old_s_idx] = new_value
            self.phase = 'showing_new_value_table'

        elif self.phase == 'showing_new_value_table':
            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()

            old_s, a, r, s_prime, t = self.current_episode[-1]
            old_s_idx = self.env.states.index(old_s)

            if not fast_execution:
                print(
                    f'Transition: <{old_s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}>\n'
                )
                print('Updating value table:')
                print(f'TD target = {self.td_target:.2f}')
                print(f'V({old_s}) = {self.value_table[old_s_idx]:.2f}')

            if t:
                if not fast_execution:
                    print('\nTerminal state, reseting environment.')
                self.env.reset()
                self.current_episode = []
                self.i += 1
            self.phase = 'choosing_actions'

        else:
            raise ValueError(f'Phase {self.phase} not recognized.')
        return print_string
예제 #6
0
    def next_step(self, fast_execution=False):
        print_string = ''
        if self.phase == 'choosing_actions':
            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_policy_at_state(self.env, self.s,
                                                   self.exploratory_policy,
                                                   self.ax[0])
                plot_q_table(self.env,
                             self.q_table,
                             self.ax[1],
                             vmin=self.vmin,
                             vmax=self.vmax)
                plot_policy(self.env, self.greedy_policy, self.ax[2])
                plt.show()
                print('Sampling action...\n')

            self.phase = 'showing_sampled_action'

        elif self.phase == 'showing_sampled_action':
            self.sampled_a = np.random.choice(
                4, p=self.exploratory_policy[self.s_idx])

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_chosen_action(self.env, self.s,
                                                 self.sampled_a, self.ax[0])
                plot_q_table(self.env,
                             self.q_table,
                             self.ax[1],
                             vmin=self.vmin,
                             vmax=self.vmax)
                plot_policy(self.env, self.greedy_policy, self.ax[2])
                plt.show()
                print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n')

            self.phase = 'carrying_out_action'

        elif self.phase == 'carrying_out_action':
            old_s = self.s
            s_prime, r, t = self.env.step(self.sampled_a)
            self.current_transition = [old_s, self.sampled_a, r, s_prime, t]

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(old_s, self.ax[0], alpha=0.3)
                plot_agent(self.s, self.ax[0])
                plot_q_table(self.env,
                             self.q_table,
                             self.ax[1],
                             vmin=self.vmin,
                             vmax=self.vmax)
                plot_policy(self.env, self.greedy_policy, self.ax[2])
                plt.show()
                print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n')

            if t:
                if not fast_execution:
                    print('A terminal state has been reached: end of episode.')
                self.phase = 'update_q_value_terminal'
            else:
                self.phase = 'computing_td_target'

        elif self.phase == 'computing_td_target':
            # Compute TD target
            s, a, r, s_prime, _ = self.current_transition
            s_prime_idx = self.env.states.index(s_prime)
            greedy_a = np.random.choice(4, p=self.greedy_policy[self.s_idx])
            td_target = r + 0.9 * self.q_table[s_prime_idx, greedy_a]

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_agent(s, self.ax[0], alpha=0.3)
                plot_q_table(self.env,
                             self.q_table,
                             self.ax[1],
                             vmin=self.vmin,
                             vmax=self.vmax)
                plot_policy(self.env, self.greedy_policy, self.ax[2])
                plt.show()

            if not fast_execution:
                print(
                    f'TD target = R + 𝛾 * Q({s_prime}, {a_index_to_symbol[greedy_a]}) = '
                    f'{r:.2f} + 0.9 * {self.q_table[s_prime_idx, greedy_a]:.2f} = '
                    f'{td_target:.2f}')

            self.phase = 'computing_new_q_value'

        elif self.phase == 'computing_new_q_value':
            # Compute TD target
            s, a, r, s_prime, _ = self.current_transition
            s_idx = self.env.states.index(s)
            s_prime_idx = self.env.states.index(s_prime)
            greedy_a = np.random.choice(4, p=self.greedy_policy[self.s_idx])
            td_target = r + 0.9 * self.q_table[s_prime_idx, greedy_a]

            # Compute new value
            new_value = self.alpha * td_target + (
                1 - self.alpha) * self.q_table[s_idx, a]

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_agent(s, self.ax[0], alpha=0.3)
                plot_q_table(self.env,
                             self.q_table,
                             self.ax[1],
                             vmin=self.vmin,
                             vmax=self.vmax)
                plot_policy(self.env, self.greedy_policy, self.ax[2])
                plt.show()

            if not fast_execution:
                print(f'TD target = {td_target:.2f}')
                print(
                    f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * '
                    f'({td_target:.2f} - {self.q_table[s_idx, a]:.2f}) = {new_value:.2f}'
                )

            self.phase = 'showing_new_q_value'

        elif self.phase == 'showing_new_q_value':

            # Compute TD target
            s, a, r, s_prime, _ = self.current_transition
            s_idx = self.env.states.index(s)
            s_prime_idx = self.env.states.index(s_prime)
            greedy_a = np.random.choice(4, p=self.greedy_policy[self.s_idx])
            td_target = r + 0.9 * self.q_table[s_prime_idx, greedy_a]

            # Compute new value and update q-table
            new_value = self.alpha * td_target + (
                1 - self.alpha) * self.q_table[s_idx, a]
            self.q_table[s_idx, a] = new_value
            self.exploratory_policy = epsilon_greedy_pi_from_q_table(
                self.env, self.q_table, self.epsilon)
            self.greedy_policy = epsilon_greedy_pi_from_q_table(
                self.env, self.q_table, 0)

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_agent(s, self.ax[0], alpha=0.3)
                plot_q_table(self.env,
                             self.q_table,
                             self.ax[1],
                             vmin=self.vmin,
                             vmax=self.vmax)
                plot_policy(self.env, self.greedy_policy, self.ax[2])
                plt.show()

                print(f'TD target = {td_target:.2f}')
                print(f'Q({s}, {a_index_to_symbol[a]}) = {new_value:.2f}')

            self.phase = 'choosing_actions'

        elif self.phase == 'update_q_value_terminal':
            s, a, r, s_prime, t = self.current_transition
            s_idx = self.env.states.index(s)
            new_value = self.alpha * r + (1 - self.alpha) * self.q_table[s_idx,
                                                                         a]

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_agent(s, self.ax[0], alpha=0.3)
                plot_q_table(self.env,
                             self.q_table,
                             self.ax[1],
                             vmin=self.vmin,
                             vmax=self.vmax)
                plot_policy(self.env, self.greedy_policy, self.ax[2])
                plt.show()

                print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n')
                print('A terminal state has been reached: end of episode.\n')
                print(f'TD target = R + 𝛾 * 0 = {r:.2f}')
                print(
                    f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * '
                    f'({r:.2f} - {self.q_table[s_idx, a]:.2f}) = {new_value:.2f}'
                )

            self.q_table[s_idx, a] = new_value
            self.exploratory_policy = epsilon_greedy_pi_from_q_table(
                self.env, self.q_table, self.epsilon)
            self.greedy_policy = epsilon_greedy_pi_from_q_table(
                self.env, self.q_table, 0)
            self.phase = 'showing_new_q_value_terminal'

        elif self.phase == 'showing_new_q_value_terminal':
            if not fast_execution:
                s, a, r, s_prime, t = self.current_transition
                s_idx = self.env.states.index(s)
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_agent(s, self.ax[0], alpha=0.3)
                plot_q_table(self.env,
                             self.q_table,
                             self.ax[1],
                             vmin=self.vmin,
                             vmax=self.vmax)
                plot_policy(self.env, self.greedy_policy, self.ax[2])
                plt.show()

                print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n')
                print('A terminal state has been reached: end of episode.\n')
                print(f'TD target = R + 𝛾 * 0 = {r:.2f}')
                print(
                    f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f}'
                )

            self.current_transition = []
            self.env.reset()
            self.phase = 'choosing_actions'

        else:
            raise ValueError(f'Phase {self.phase} not recognized.')
        return print_string
예제 #7
0
    def next_step(self, fast_execution=False):
        print_string = ''
        if self.phase == 'initial_sampling_action':
            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_policy_at_state(self.env, self.s, self.policy, self.ax[0])
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print('Sampling action...\n')

            self.phase = 'initial_showing_sampled_action'

        elif self.phase == 'initial_showing_sampled_action':
            self.sampled_a = np.random.choice(4, p=self.policy[self.s_idx])

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0])
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n')

            self.phase = 'carrying_out_action'

        elif self.phase == 'carrying_out_action':
            old_s = self.s
            s_prime, r, t = self.env.step(self.sampled_a)
            self.current_transition = [old_s, self.sampled_a, r, s_prime, t]

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(old_s, self.ax[0], alpha=0.3)
                plot_agent(self.s, self.ax[0])
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n')

            if not t:
                self.phase = 'sampling_action'
            else:
                self.phase = 'compute_q_for_terminal_s_prime'

        elif self.phase == 'sampling_action':
            if not fast_execution:
                old_s = self.current_transition[0]
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_policy_at_state(self.env, self.s, self.policy, self.ax[0])
                plot_agent(old_s, self.ax[0], alpha=0.3)
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print('Sampling action...\n')

            self.phase = 'showing_sampled_action'

        elif self.phase == 'showing_sampled_action':
            self.sampled_a = np.random.choice(4, p=self.policy[self.s_idx])
            self.current_transition.append(self.sampled_a)

            if not fast_execution:
                old_s = self.current_transition[0]
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0])
                plot_agent(old_s, self.ax[0], alpha=0.3)
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n')

            self.phase = 'showing_transition_done'

        elif self.phase == 'showing_transition_done':
            if not fast_execution:
                old_s = self.current_transition[0]
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0])
                plot_agent(old_s, self.ax[0], alpha=0.3)
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()

                s, a, r, s_prime, t, a_prime = self.current_transition
                print(f'Transition: <{s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}, {a_index_to_symbol[a_prime]}>\n')

            self.phase = 'showing_td_target'

        elif self.phase == 'showing_td_target':
            if not fast_execution:
                s, a, r, s_prime, t, a_prime = self.current_transition
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0])
                plot_agent(s, self.ax[0], alpha=0.3)
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()

                print(f'Transition: <{s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}, {a_index_to_symbol[a_prime]}>\n')

                s_prime_idx = self.env.states.index(s_prime)
                td_target = r + 0.9 * self.q_table[s_prime_idx, a_prime]
                print(f'TD target: R + 𝛾 * Q({s_prime}, {a_index_to_symbol[a_prime]}) = '
                      f'{r:.2f} + 0.9 * {self.q_table[s_prime_idx, a_prime]:.2f} = '
                      f'{td_target:.2f}')

            self.phase = 'show_q_value_computation'

        elif self.phase == 'show_q_value_computation':
            s, a, r, s_prime, t, a_prime = self.current_transition
            if not fast_execution:
                plot_agent(s, self.ax[0], alpha=0.3)
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0])
                plot_agent(s, self.ax[0], alpha=0.3)
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()

            s_idx = self.env.states.index(s)
            s_prime_idx = self.env.states.index(s_prime)
            td_target = r + 0.9 * self.q_table[s_prime_idx, a_prime]
            new_value = self.alpha * td_target + (1 - self.alpha) * self.q_table[s_idx, a]

            if fast_execution:
                msg = f'Transition: <{s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}, {a_index_to_symbol[a_prime]}>\n' \
                      f'TD target: R + 𝛾 * Q({s_prime}, {a_index_to_symbol[a_prime]}) = ' \
                      f'{r:.2f} + 0.9 * {self.q_table[s_prime_idx, a_prime]:.2f} = {td_target:.2f}\n'  \
                      f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * ' \
                      f'({td_target:.2f} - {self.q_table[s_idx, a]:.2f}) = {new_value:.2f}\n\n'
                print_string += msg
            else:
                print(f'Transition: <{s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}, {a_index_to_symbol[a_prime]}>\n')
                print(f'TD target: {td_target:.2f}')
                print(f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * '
                      f'({td_target:.2f} - {self.q_table[s_idx, a]:.2f}) = {new_value:.2f}')

            self.q_table[s_idx, a] = new_value
            self.policy = epsilon_greedy_pi_from_q_table(self.env, self.q_table, self.epsilon)

            self.phase = 'showing_updated_q_value'

        elif self.phase == 'showing_updated_q_value':
            if not fast_execution:
                s, a, r, s_prime, t, a_prime = self.current_transition
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0])
                plot_agent(s, self.ax[0], alpha=0.3)
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()

                print(f'Transition: <{s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}, {a_index_to_symbol[a_prime]}>\n')

                s_prime_idx = self.env.states.index(s_prime)
                td_target = r + 0.9 * self.q_table[s_prime_idx, a_prime]
                print(f'TD target: {td_target:.2f}')

                s_idx = self.env.states.index(s)
                print(f'Q({s, a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f}')

            self.phase = 'carrying_out_action'

        elif self.phase == 'compute_q_for_terminal_s_prime':
            s, a, r, s_prime, t = self.current_transition
            s_idx = self.env.states.index(s)
            new_value = self.alpha * r + (1 - self.alpha) * self.q_table[s_idx, a]

            if fast_execution:
                print_string += f'Transition to terminal state: <{s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}, ∅>\n'
                print_string += f'TD target = R + 𝛾 * 0 = {r:.2f}\n'
                print_string += f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * ' \
                                f'({r:.2f} - {self.q_table[s_idx, a]:.2f}) = {new_value:.2f}'

            else:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(s, self.ax[0], alpha=0.3)
                plot_agent(s_prime, self.ax[0])
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print('Terminal state reached.\n')
                print(f'TD target = R + 𝛾 * 0 = {r:.2f}')
                print(f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * '
                      f'({r:.2f} - {self.q_table[s_idx, a]:.2f}) = {new_value:.2f}')

            self.q_table[s_idx, a] = new_value
            self.policy = epsilon_greedy_pi_from_q_table(self.env, self.q_table, self.epsilon)

            self.phase = 'showing_updated_q_value_terminal_s_prime'

        elif self.phase == 'showing_updated_q_value_terminal_s_prime':
            self.env.reset()
            self.i += 1

            if not fast_execution:
                s, a, r, s_prime, t = self.current_transition
                s_idx = self.env.states.index(s)

                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(s, self.ax[0], alpha=0.3)
                plot_agent(s_prime, self.ax[0])
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()

                print('Terminal state reached.\n')
                print(f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f}\n')
                print(f'Resetting environment')

            self.phase = 'initial_sampling_action'

        else:
            raise ValueError(f'Phase {self.phase} not recognized.')
        return print_string
예제 #8
0
    def next_step(self, fast_execution=False):
        print_string = ''
        if self.phase == 'choosing_actions':
            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_policy_at_state(self.env, self.s,
                                                   self.policy, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print('Sampling action...\n')
                print(self.get_current_episode_as_string())

            self.phase = 'showing_sampled_action'

        elif self.phase == 'showing_sampled_action':
            self.sampled_a = np.random.choice(4, p=self.policy[self.s_idx])

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_chosen_action(self.env, self.s,
                                                 self.sampled_a, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n')
                print(self.get_current_episode_as_string())

            self.phase = 'carrying_out_action'

        elif self.phase == 'carrying_out_action':
            old_s = self.s
            s_prime, r, t = self.env.step(self.sampled_a)
            self.current_episode.append([old_s, self.sampled_a, r, s_prime])

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(old_s, self.ax[0], alpha=0.3)
                plot_agent(self.s, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(self.get_current_episode_as_string())

            if t:
                if not fast_execution:
                    print('A terminal state has been reached: end of episode.')
                self.phase = 'computing_returns'
            else:
                self.phase = 'choosing_actions'

        elif self.phase == 'computing_returns':

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(self.get_current_episode_as_string())
                print('\nComputing returns:')

            self.current_episode_returns = [0] * len(self.current_episode)
            for i in reversed(range(len(self.current_episode))):
                transition = self.current_episode[i]
                if i == len(self.current_episode) - 1:
                    if not fast_execution:
                        print(
                            f'G_{i+1} = {transition[2]:.1f},\t\t\t\tat {transition[0]}'
                        )
                    self.current_episode_returns[i] = transition[2]
                else:
                    self.current_episode_returns[i] = transition[
                        2] + 0.9 * self.current_episode_returns[i + 1]
                    if not fast_execution:
                        print(
                            f'G_{i+1} = {transition[2]:.1f} + 0.9 * {self.current_episode_returns[i+1]:.2f}',
                            end='')
                        print(
                            f' = {self.current_episode_returns[i]:.2f},\tat {transition[0]}'
                        )

            self.phase = 'presenting_computed_returns'

        elif self.phase == 'presenting_computed_returns':
            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(self.get_current_episode_as_string())

                print('\nComputing returns:')
                for i, (transition, g) in reversed(
                        list(
                            enumerate(
                                zip(self.current_episode,
                                    self.current_episode_returns)))):
                    print(f'G_{i+1} = {g:.2f},\tat {transition[0]}')

            self.phase = 'updating_values'

        elif self.phase == 'updating_values':
            if fast_execution:
                print_string += f'{self.get_current_episode_as_string()}\n'
            else:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(self.get_current_episode_as_string())
                print('\nComputing returns:')

            # Print computed returns for all visited states
            for i, (transition, g) in reversed(
                    list(
                        enumerate(
                            zip(self.current_episode,
                                self.current_episode_returns)))):
                msg = f'G_{i+1} = {g:.2f},\tat {transition[0]}'
                if fast_execution:
                    print_string += f'{msg}\n'
                else:
                    print(msg)

            if fast_execution:
                print_string += '\n'
            else:
                print('\nUpdating values:')

            for transition, g in zip(reversed(self.current_episode),
                                     reversed(self.current_episode_returns)):
                s_idx = self.env.states.index(transition[0])
                new_value = self.alpha * g + (
                    1 - self.alpha) * self.value_table[s_idx]
                cur_value = self.value_table[s_idx]
                msg = f'V({transition[0]}) = {cur_value:.2f} + {self.alpha:.2f} * ({g:.2f} - {cur_value:.2f}) ' \
                      f'= {new_value:.2f}'
                self.value_table[s_idx] = new_value
                if fast_execution:
                    print_string += f'{msg}\n'
                else:
                    print(msg)

            self.phase = 'show_updated_value_function'

        elif self.phase == 'show_updated_value_function':
            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(self.get_current_episode_as_string())

                print('\nComputing returns:')
                for i, (transition, g) in reversed(
                        list(
                            enumerate(
                                zip(self.current_episode,
                                    self.current_episode_returns)))):
                    print(f'G_{i+1} = {g:.2f},\tat {transition[0]}')

                print('\nUpdating values:')
                for transition, g in zip(
                        reversed(self.current_episode),
                        reversed(self.current_episode_returns)):
                    s_idx = self.env.states.index(transition[0])
                    print(
                        f'V({transition[0]}) = {self.value_table[s_idx]:.2f}')

            self.current_episode = []
            self.env.reset()
            self.phase = 'choosing_actions'

        else:
            raise ValueError(f'Phase {self.phase} not recognized.')

        return print_string