コード例 #1
0
    def update_epsilon(self, epsilon):
        self.epsilon = epsilon
        self.policy = epsilon_greedy_pi_from_q_table(self.env, self.q_table, self.epsilon)

        # Display plots
        self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
        plot_all_states(self.env, self.ax[0])
        plot_agent(self.s, self.ax[0])
        plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
        plot_policy(self.env, self.policy, self.ax[2])
        plt.show()
コード例 #2
0
    def finish_episode(self):
        print_string = ''
        current_i = self.i
        while self.i == current_i:
            print_string += f'{self.next_step(fast_execution=True)}'

        self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
        plot_all_states(self.env, self.ax[0])
        plot_agent(self.s, self.ax[0])
        plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
        plot_policy(self.env, self.policy, self.ax[2])
        plt.show()
        print(print_string)
コード例 #3
0
    def finish_episode(self):
        print_string = ''
        while self.phase != 'showing_new_q_value_terminal':
            print_string += f'{self.next_step(fast_execution=True)}'
        print_string += f'{self.next_step(fast_execution=True)}'

        self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
        plot_all_states(self.env, self.ax[0])
        plot_agent(self.s, self.ax[0])
        plot_q_table(self.env,
                     self.q_table,
                     self.ax[1],
                     vmin=self.vmin,
                     vmax=self.vmax)
        plot_policy(self.env, self.greedy_policy, self.ax[2])
        plt.show()
        print(print_string)
コード例 #4
0
    def next_step(self, fast_execution=False):
        print_string = ''
        if self.phase == 'choosing_actions':
            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_policy_at_state(self.env, self.s,
                                                   self.exploratory_policy,
                                                   self.ax[0])
                plot_q_table(self.env,
                             self.q_table,
                             self.ax[1],
                             vmin=self.vmin,
                             vmax=self.vmax)
                plot_policy(self.env, self.greedy_policy, self.ax[2])
                plt.show()
                print('Sampling action...\n')

            self.phase = 'showing_sampled_action'

        elif self.phase == 'showing_sampled_action':
            self.sampled_a = np.random.choice(
                4, p=self.exploratory_policy[self.s_idx])

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_chosen_action(self.env, self.s,
                                                 self.sampled_a, self.ax[0])
                plot_q_table(self.env,
                             self.q_table,
                             self.ax[1],
                             vmin=self.vmin,
                             vmax=self.vmax)
                plot_policy(self.env, self.greedy_policy, self.ax[2])
                plt.show()
                print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n')

            self.phase = 'carrying_out_action'

        elif self.phase == 'carrying_out_action':
            old_s = self.s
            s_prime, r, t = self.env.step(self.sampled_a)
            self.current_transition = [old_s, self.sampled_a, r, s_prime, t]

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(old_s, self.ax[0], alpha=0.3)
                plot_agent(self.s, self.ax[0])
                plot_q_table(self.env,
                             self.q_table,
                             self.ax[1],
                             vmin=self.vmin,
                             vmax=self.vmax)
                plot_policy(self.env, self.greedy_policy, self.ax[2])
                plt.show()
                print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n')

            if t:
                if not fast_execution:
                    print('A terminal state has been reached: end of episode.')
                self.phase = 'update_q_value_terminal'
            else:
                self.phase = 'computing_td_target'

        elif self.phase == 'computing_td_target':
            # Compute TD target
            s, a, r, s_prime, _ = self.current_transition
            s_prime_idx = self.env.states.index(s_prime)
            greedy_a = np.random.choice(4, p=self.greedy_policy[self.s_idx])
            td_target = r + 0.9 * self.q_table[s_prime_idx, greedy_a]

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_agent(s, self.ax[0], alpha=0.3)
                plot_q_table(self.env,
                             self.q_table,
                             self.ax[1],
                             vmin=self.vmin,
                             vmax=self.vmax)
                plot_policy(self.env, self.greedy_policy, self.ax[2])
                plt.show()

            if not fast_execution:
                print(
                    f'TD target = R + 𝛾 * Q({s_prime}, {a_index_to_symbol[greedy_a]}) = '
                    f'{r:.2f} + 0.9 * {self.q_table[s_prime_idx, greedy_a]:.2f} = '
                    f'{td_target:.2f}')

            self.phase = 'computing_new_q_value'

        elif self.phase == 'computing_new_q_value':
            # Compute TD target
            s, a, r, s_prime, _ = self.current_transition
            s_idx = self.env.states.index(s)
            s_prime_idx = self.env.states.index(s_prime)
            greedy_a = np.random.choice(4, p=self.greedy_policy[self.s_idx])
            td_target = r + 0.9 * self.q_table[s_prime_idx, greedy_a]

            # Compute new value
            new_value = self.alpha * td_target + (
                1 - self.alpha) * self.q_table[s_idx, a]

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_agent(s, self.ax[0], alpha=0.3)
                plot_q_table(self.env,
                             self.q_table,
                             self.ax[1],
                             vmin=self.vmin,
                             vmax=self.vmax)
                plot_policy(self.env, self.greedy_policy, self.ax[2])
                plt.show()

            if not fast_execution:
                print(f'TD target = {td_target:.2f}')
                print(
                    f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * '
                    f'({td_target:.2f} - {self.q_table[s_idx, a]:.2f}) = {new_value:.2f}'
                )

            self.phase = 'showing_new_q_value'

        elif self.phase == 'showing_new_q_value':

            # Compute TD target
            s, a, r, s_prime, _ = self.current_transition
            s_idx = self.env.states.index(s)
            s_prime_idx = self.env.states.index(s_prime)
            greedy_a = np.random.choice(4, p=self.greedy_policy[self.s_idx])
            td_target = r + 0.9 * self.q_table[s_prime_idx, greedy_a]

            # Compute new value and update q-table
            new_value = self.alpha * td_target + (
                1 - self.alpha) * self.q_table[s_idx, a]
            self.q_table[s_idx, a] = new_value
            self.exploratory_policy = epsilon_greedy_pi_from_q_table(
                self.env, self.q_table, self.epsilon)
            self.greedy_policy = epsilon_greedy_pi_from_q_table(
                self.env, self.q_table, 0)

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_agent(s, self.ax[0], alpha=0.3)
                plot_q_table(self.env,
                             self.q_table,
                             self.ax[1],
                             vmin=self.vmin,
                             vmax=self.vmax)
                plot_policy(self.env, self.greedy_policy, self.ax[2])
                plt.show()

                print(f'TD target = {td_target:.2f}')
                print(f'Q({s}, {a_index_to_symbol[a]}) = {new_value:.2f}')

            self.phase = 'choosing_actions'

        elif self.phase == 'update_q_value_terminal':
            s, a, r, s_prime, t = self.current_transition
            s_idx = self.env.states.index(s)
            new_value = self.alpha * r + (1 - self.alpha) * self.q_table[s_idx,
                                                                         a]

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_agent(s, self.ax[0], alpha=0.3)
                plot_q_table(self.env,
                             self.q_table,
                             self.ax[1],
                             vmin=self.vmin,
                             vmax=self.vmax)
                plot_policy(self.env, self.greedy_policy, self.ax[2])
                plt.show()

                print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n')
                print('A terminal state has been reached: end of episode.\n')
                print(f'TD target = R + 𝛾 * 0 = {r:.2f}')
                print(
                    f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * '
                    f'({r:.2f} - {self.q_table[s_idx, a]:.2f}) = {new_value:.2f}'
                )

            self.q_table[s_idx, a] = new_value
            self.exploratory_policy = epsilon_greedy_pi_from_q_table(
                self.env, self.q_table, self.epsilon)
            self.greedy_policy = epsilon_greedy_pi_from_q_table(
                self.env, self.q_table, 0)
            self.phase = 'showing_new_q_value_terminal'

        elif self.phase == 'showing_new_q_value_terminal':
            if not fast_execution:
                s, a, r, s_prime, t = self.current_transition
                s_idx = self.env.states.index(s)
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_agent(s, self.ax[0], alpha=0.3)
                plot_q_table(self.env,
                             self.q_table,
                             self.ax[1],
                             vmin=self.vmin,
                             vmax=self.vmax)
                plot_policy(self.env, self.greedy_policy, self.ax[2])
                plt.show()

                print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n')
                print('A terminal state has been reached: end of episode.\n')
                print(f'TD target = R + 𝛾 * 0 = {r:.2f}')
                print(
                    f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f}'
                )

            self.current_transition = []
            self.env.reset()
            self.phase = 'choosing_actions'

        else:
            raise ValueError(f'Phase {self.phase} not recognized.')
        return print_string
コード例 #5
0
    def next_step(self, fast_execution=False):
        print_string = ''
        if self.phase == 'initial_sampling_action':
            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_policy_at_state(self.env, self.s, self.policy, self.ax[0])
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print('Sampling action...\n')

            self.phase = 'initial_showing_sampled_action'

        elif self.phase == 'initial_showing_sampled_action':
            self.sampled_a = np.random.choice(4, p=self.policy[self.s_idx])

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0])
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n')

            self.phase = 'carrying_out_action'

        elif self.phase == 'carrying_out_action':
            old_s = self.s
            s_prime, r, t = self.env.step(self.sampled_a)
            self.current_transition = [old_s, self.sampled_a, r, s_prime, t]

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(old_s, self.ax[0], alpha=0.3)
                plot_agent(self.s, self.ax[0])
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n')

            if not t:
                self.phase = 'sampling_action'
            else:
                self.phase = 'compute_q_for_terminal_s_prime'

        elif self.phase == 'sampling_action':
            if not fast_execution:
                old_s = self.current_transition[0]
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_policy_at_state(self.env, self.s, self.policy, self.ax[0])
                plot_agent(old_s, self.ax[0], alpha=0.3)
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print('Sampling action...\n')

            self.phase = 'showing_sampled_action'

        elif self.phase == 'showing_sampled_action':
            self.sampled_a = np.random.choice(4, p=self.policy[self.s_idx])
            self.current_transition.append(self.sampled_a)

            if not fast_execution:
                old_s = self.current_transition[0]
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0])
                plot_agent(old_s, self.ax[0], alpha=0.3)
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n')

            self.phase = 'showing_transition_done'

        elif self.phase == 'showing_transition_done':
            if not fast_execution:
                old_s = self.current_transition[0]
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0])
                plot_agent(old_s, self.ax[0], alpha=0.3)
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()

                s, a, r, s_prime, t, a_prime = self.current_transition
                print(f'Transition: <{s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}, {a_index_to_symbol[a_prime]}>\n')

            self.phase = 'showing_td_target'

        elif self.phase == 'showing_td_target':
            if not fast_execution:
                s, a, r, s_prime, t, a_prime = self.current_transition
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0])
                plot_agent(s, self.ax[0], alpha=0.3)
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()

                print(f'Transition: <{s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}, {a_index_to_symbol[a_prime]}>\n')

                s_prime_idx = self.env.states.index(s_prime)
                td_target = r + 0.9 * self.q_table[s_prime_idx, a_prime]
                print(f'TD target: R + 𝛾 * Q({s_prime}, {a_index_to_symbol[a_prime]}) = '
                      f'{r:.2f} + 0.9 * {self.q_table[s_prime_idx, a_prime]:.2f} = '
                      f'{td_target:.2f}')

            self.phase = 'show_q_value_computation'

        elif self.phase == 'show_q_value_computation':
            s, a, r, s_prime, t, a_prime = self.current_transition
            if not fast_execution:
                plot_agent(s, self.ax[0], alpha=0.3)
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0])
                plot_agent(s, self.ax[0], alpha=0.3)
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()

            s_idx = self.env.states.index(s)
            s_prime_idx = self.env.states.index(s_prime)
            td_target = r + 0.9 * self.q_table[s_prime_idx, a_prime]
            new_value = self.alpha * td_target + (1 - self.alpha) * self.q_table[s_idx, a]

            if fast_execution:
                msg = f'Transition: <{s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}, {a_index_to_symbol[a_prime]}>\n' \
                      f'TD target: R + 𝛾 * Q({s_prime}, {a_index_to_symbol[a_prime]}) = ' \
                      f'{r:.2f} + 0.9 * {self.q_table[s_prime_idx, a_prime]:.2f} = {td_target:.2f}\n'  \
                      f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * ' \
                      f'({td_target:.2f} - {self.q_table[s_idx, a]:.2f}) = {new_value:.2f}\n\n'
                print_string += msg
            else:
                print(f'Transition: <{s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}, {a_index_to_symbol[a_prime]}>\n')
                print(f'TD target: {td_target:.2f}')
                print(f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * '
                      f'({td_target:.2f} - {self.q_table[s_idx, a]:.2f}) = {new_value:.2f}')

            self.q_table[s_idx, a] = new_value
            self.policy = epsilon_greedy_pi_from_q_table(self.env, self.q_table, self.epsilon)

            self.phase = 'showing_updated_q_value'

        elif self.phase == 'showing_updated_q_value':
            if not fast_execution:
                s, a, r, s_prime, t, a_prime = self.current_transition
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0])
                plot_agent(s, self.ax[0], alpha=0.3)
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()

                print(f'Transition: <{s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}, {a_index_to_symbol[a_prime]}>\n')

                s_prime_idx = self.env.states.index(s_prime)
                td_target = r + 0.9 * self.q_table[s_prime_idx, a_prime]
                print(f'TD target: {td_target:.2f}')

                s_idx = self.env.states.index(s)
                print(f'Q({s, a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f}')

            self.phase = 'carrying_out_action'

        elif self.phase == 'compute_q_for_terminal_s_prime':
            s, a, r, s_prime, t = self.current_transition
            s_idx = self.env.states.index(s)
            new_value = self.alpha * r + (1 - self.alpha) * self.q_table[s_idx, a]

            if fast_execution:
                print_string += f'Transition to terminal state: <{s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}, ∅>\n'
                print_string += f'TD target = R + 𝛾 * 0 = {r:.2f}\n'
                print_string += f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * ' \
                                f'({r:.2f} - {self.q_table[s_idx, a]:.2f}) = {new_value:.2f}'

            else:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(s, self.ax[0], alpha=0.3)
                plot_agent(s_prime, self.ax[0])
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print('Terminal state reached.\n')
                print(f'TD target = R + 𝛾 * 0 = {r:.2f}')
                print(f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * '
                      f'({r:.2f} - {self.q_table[s_idx, a]:.2f}) = {new_value:.2f}')

            self.q_table[s_idx, a] = new_value
            self.policy = epsilon_greedy_pi_from_q_table(self.env, self.q_table, self.epsilon)

            self.phase = 'showing_updated_q_value_terminal_s_prime'

        elif self.phase == 'showing_updated_q_value_terminal_s_prime':
            self.env.reset()
            self.i += 1

            if not fast_execution:
                s, a, r, s_prime, t = self.current_transition
                s_idx = self.env.states.index(s)

                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(s, self.ax[0], alpha=0.3)
                plot_agent(s_prime, self.ax[0])
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()

                print('Terminal state reached.\n')
                print(f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f}\n')
                print(f'Resetting environment')

            self.phase = 'initial_sampling_action'

        else:
            raise ValueError(f'Phase {self.phase} not recognized.')
        return print_string
コード例 #6
0
    def next_step(self, fast_execution=False):
        print_string = ''
        if self.phase == 'choosing_actions':
            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_policy_at_state(self.env, self.s,
                                                   self.policy, self.ax[0])
                plot_q_table(self.env,
                             self.q_table,
                             self.ax[1],
                             vmin=self.vmin,
                             vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print('Sampling action...\n')
                print(self.get_current_episode_as_string())

            self.phase = 'showing_sampled_action'

        elif self.phase == 'showing_sampled_action':
            self.sampled_a = np.random.choice(4, p=self.policy[self.s_idx])

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_chosen_action(self.env, self.s,
                                                 self.sampled_a, self.ax[0])
                plot_q_table(self.env,
                             self.q_table,
                             self.ax[1],
                             vmin=self.vmin,
                             vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n')
                print(self.get_current_episode_as_string())

            self.phase = 'carrying_out_action'

        elif self.phase == 'carrying_out_action':
            old_s = self.s
            s_prime, r, t = self.env.step(self.sampled_a)
            self.current_episode.append([old_s, self.sampled_a, r, s_prime])

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(old_s, self.ax[0], alpha=0.3)
                plot_agent(self.s, self.ax[0])
                plot_q_table(self.env,
                             self.q_table,
                             self.ax[1],
                             vmin=self.vmin,
                             vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(self.get_current_episode_as_string())

            if t:
                if not fast_execution:
                    print('A terminal state has been reached: end of episode.')
                self.phase = 'computing_returns'
            else:
                self.phase = 'choosing_actions'

        elif self.phase == 'computing_returns':
            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_q_table(self.env,
                             self.q_table,
                             self.ax[1],
                             vmin=self.vmin,
                             vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(self.get_current_episode_as_string())
                print('Computing returns:')

            self.current_episode_returns = [0] * len(self.current_episode)

            for i in reversed(range(len(self.current_episode))):
                transition = self.current_episode[i]
                if i == len(self.current_episode) - 1:
                    if not fast_execution:
                        print(
                            f'G_{i + 1} = {transition[2]:.1f},\t\t\t\tat {transition[0]}'
                        )
                    self.current_episode_returns[i] = transition[2]
                else:
                    self.current_episode_returns[i] = transition[
                        2] + 0.9 * self.current_episode_returns[i + 1]
                    if not fast_execution:
                        print(
                            f'G_{i+1} = {transition[2]:.1f} + 0.9 * {self.current_episode_returns[i+1]:.2f}',
                            end='')
                        print(
                            f' = {self.current_episode_returns[i]:.2f},\tat {transition[0]}'
                        )

            self.phase = 'presenting_computed_returns'

        elif self.phase == 'presenting_computed_returns':
            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_q_table(self.env,
                             self.q_table,
                             self.ax[1],
                             vmin=self.vmin,
                             vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(self.get_current_episode_as_string())

                print('Computing returns:')
                for i, (transition, g) in reversed(
                        list(
                            enumerate(
                                zip(self.current_episode,
                                    self.current_episode_returns)))):
                    print(f'G_{i+1} = {g:.2f},\tat {transition[0]}')

            self.phase = 'updating_values'

        elif self.phase == 'updating_values':
            if fast_execution:
                print_string += f'{self.get_current_episode_as_string()}'
                print_string += '\nComputing returns:\n'
            else:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_q_table(self.env,
                             self.q_table,
                             self.ax[1],
                             vmin=self.vmin,
                             vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(self.get_current_episode_as_string())
                print('Computing returns:')

            for i, (transition, g) in reversed(
                    list(
                        enumerate(
                            zip(self.current_episode,
                                self.current_episode_returns)))):
                msg = f'G_{i+1} = {g:.2f},\tat {transition[0]}'
                if fast_execution:
                    print_string += f'{msg}\n'
                else:
                    print(msg)

            if fast_execution:
                print_string += '\nUpdating values:\n'
            else:
                print('\nUpdating values:')

            for transition, g in zip(reversed(self.current_episode),
                                     reversed(self.current_episode_returns)):
                s_idx = self.env.states.index(transition[0])
                a = transition[1]

                new_value = self.alpha * g + (
                    1 - self.alpha) * self.q_table[s_idx, a]
                msg = f'Q({transition[0]}, {a_index_to_symbol[a]}) = ' \
                      f'{self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * ({g:.2f} - {self.q_table[s_idx, a]:.2f}) = ' \
                      f'{new_value:.2f}'

                if fast_execution:
                    print_string += f'{msg}\n'
                else:
                    print(msg)
                self.q_table[s_idx, a] = new_value

            self.policy = epsilon_greedy_pi_from_q_table(
                self.env, self.q_table, self.epsilon)
            self.phase = 'show_updated_value_function'

        elif self.phase == 'show_updated_value_function':
            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_q_table(self.env,
                             self.q_table,
                             self.ax[1],
                             vmin=self.vmin,
                             vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(self.get_current_episode_as_string())

                print('Computing returns:')
                for i, (transition, g) in reversed(
                        list(
                            enumerate(
                                zip(self.current_episode,
                                    self.current_episode_returns)))):
                    print(f'G_{i+1} = {g:.2f},\tat {transition[0]}')

                print('\nUpdating values:')
                for transition, g in zip(
                        reversed(self.current_episode),
                        reversed(self.current_episode_returns)):
                    s_idx = self.env.states.index(transition[0])
                    print(
                        f'Q({transition[0]}, {a_index_to_symbol[transition[1]]}) = '
                        f'{self.q_table[s_idx, transition[1]]:.2f}')

            self.current_episode = []
            self.env.reset()
            self.phase = 'choosing_actions'

        else:
            raise ValueError(f'Phase {self.phase} not recognized.')
        return print_string