Beispiel #1
0
 def policy_improve_on_click(obj):
     plt.close()
     fig, ax = plt.subplots(ncols=2, figsize=(14, 6))
     output.clear_output(True)
     policy_improvement(env, policy, value_table)
     with output:
         plot_value_table(env, value_table, ax=ax[0], vmin=vmin, vmax=vmax)
         plot_policy(env, policy, ax=ax[1])
         plt.show()
Beispiel #2
0
    def update_epsilon(self, epsilon):
        self.epsilon = epsilon
        self.policy = epsilon_greedy_pi_from_q_table(self.env, self.q_table, self.epsilon)

        # Display plots
        self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
        plot_all_states(self.env, self.ax[0])
        plot_agent(self.s, self.ax[0])
        plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
        plot_policy(self.env, self.policy, self.ax[2])
        plt.show()
Beispiel #3
0
    def finish_episode(self):
        print_string = ''
        current_i = self.i
        while self.i == current_i:
            print_string += f'{self.next_step(fast_execution=True)}'

        self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
        plot_all_states(self.env, self.ax[0])
        plot_agent(self.s, self.ax[0])
        plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
        plot_policy(self.env, self.policy, self.ax[2])
        plt.show()
        print(print_string)
Beispiel #4
0
    def finish_iteration(self):
        initial_i = self.i
        while initial_i == self.i:
            self.next_step(fast_execution=True)

        # Plot the resulting state of the env, value table, and policy
        self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
        plot_s_pmf(self.env, self.s, self.a, ax=self.ax[0])
        plot_value_table(self.env,
                         self.value_table,
                         ax=self.ax[1],
                         vmin=-8,
                         vmax=0)
        plot_policy(self.env, self.policy, ax=self.ax[2])
        plt.show()
Beispiel #5
0
    def finish_episode(self):
        print_string = ''
        while self.phase != 'showing_new_q_value_terminal':
            print_string += f'{self.next_step(fast_execution=True)}'
        print_string += f'{self.next_step(fast_execution=True)}'

        self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
        plot_all_states(self.env, self.ax[0])
        plot_agent(self.s, self.ax[0])
        plot_q_table(self.env,
                     self.q_table,
                     self.ax[1],
                     vmin=self.vmin,
                     vmax=self.vmax)
        plot_policy(self.env, self.greedy_policy, self.ax[2])
        plt.show()
        print(print_string)
Beispiel #6
0
def policy_iteration_wrapper():
    env.reset()
    value_table = np.zeros(9)
    policy = np.array([[1, 0, 0, 0], [0.5, 0.5, 0, 0], [0, 1, 0, 0],
                       [0, 0, 0, 1], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 1, 0],
                       [0, 0, 1, 0], [1, 0, 0, 0]])

    btn_eval = widgets.Button(description='Policy eval iteration')
    btn_improve = widgets.Button(description='Policy improvement')

    display(btn_eval)
    display(btn_improve)

    output = widgets.Output()

    vmin = -8
    vmax = 0
    with output:
        fig, ax = plt.subplots(ncols=2, figsize=(14, 6))
        plot_value_table(env, value_table, ax=ax[0], vmin=vmin, vmax=vmax)
        plot_policy(env, policy, ax=ax[1])
        plt.show()

    def policy_eval_on_click(obj):
        plt.close()
        fig, ax = plt.subplots(ncols=2, figsize=(14, 6))
        output.clear_output(True)
        policy_evaluation_one_step(env, policy, value_table)
        with output:
            plot_value_table(env, value_table, ax=ax[0], vmin=vmin, vmax=vmax)
            plot_policy(env, policy, ax=ax[1])
            plt.show()

    def policy_improve_on_click(obj):
        plt.close()
        fig, ax = plt.subplots(ncols=2, figsize=(14, 6))
        output.clear_output(True)
        policy_improvement(env, policy, value_table)
        with output:
            plot_value_table(env, value_table, ax=ax[0], vmin=vmin, vmax=vmax)
            plot_policy(env, policy, ax=ax[1])
            plt.show()

    btn_eval.on_click(policy_eval_on_click)
    btn_improve.on_click(policy_improve_on_click)
    display(output)
Beispiel #7
0
    def next_step(self, fast_execution=False):
        print_string = ''
        if self.phase == 'choosing_actions':
            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_policy_at_state(self.env, self.s,
                                                   self.policy, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print('Sampling action...\n')

            self.phase = 'showing_sampled_action'

        elif self.phase == 'showing_sampled_action':
            self.sampled_a = np.random.choice(4, p=self.policy[self.s_idx])

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_chosen_action(self.env, self.s,
                                                 self.sampled_a, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n')

            self.phase = 'carrying_out_action'

        elif self.phase == 'carrying_out_action':
            old_s = self.s
            s_prime, r, t = self.env.step(self.sampled_a)
            self.current_episode.append([old_s, self.sampled_a, r, s_prime, t])

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(old_s, self.ax[0], alpha=0.3)
                plot_agent(self.s, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(
                    f'Transition: <{old_s}, {a_index_to_symbol[self.sampled_a]}, {r:.1f}, {s_prime}>'
                )

            self.phase = 'showing_td_target'

        elif self.phase == 'showing_td_target':
            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()

            old_s, a, r, s_prime, _ = self.current_episode[-1]
            s_prime_idx = self.env.states.index(s_prime)
            self.td_target = r + 0.9 * self.value_table[s_prime_idx]
            if not fast_execution:
                print(
                    f'Transition: <{old_s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}>\n'
                )
                print('Updating value table:')
                print(
                    f'TD target = R + 𝛾 * V({s_prime}) = {r:.1f} + 0.9*{self.value_table[s_prime_idx]:.2f} = '
                    f'{self.td_target:.2f}')

            self.phase = 'updating_value_table'

        elif self.phase == 'updating_value_table':
            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()

            old_s, a, r, s_prime, _ = self.current_episode[-1]
            old_s_idx = self.env.states.index(old_s)
            new_value = self.alpha * self.td_target + (
                1 - self.alpha) * self.value_table[old_s_idx]

            if fast_execution:
                print_string += f'Transition: <{old_s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}>\n'
                print_string += f'TD target = {self.td_target:.2f}\n'
                print_string += f'V({old_s}) ← {self.value_table[old_s_idx]:.2f} + {self.alpha:.2f} * ' \
                                f'({self.td_target:.2f} - {self.value_table[old_s_idx]:.2f}) = {new_value:.2f}\n\n'
            else:
                print(
                    f'Transition: <{old_s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}>\n'
                )
                print('Updating value table:')
                print(f'TD target = {self.td_target:.2f}')
                print(
                    f'V({old_s}) ← {self.value_table[old_s_idx]:.2f} + {self.alpha:.2f} * '
                    f'({self.td_target:.2f} - {self.value_table[old_s_idx]:.2f}) = {new_value:.2f}\n'
                )

            self.value_table[old_s_idx] = new_value
            self.phase = 'showing_new_value_table'

        elif self.phase == 'showing_new_value_table':
            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()

            old_s, a, r, s_prime, t = self.current_episode[-1]
            old_s_idx = self.env.states.index(old_s)

            if not fast_execution:
                print(
                    f'Transition: <{old_s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}>\n'
                )
                print('Updating value table:')
                print(f'TD target = {self.td_target:.2f}')
                print(f'V({old_s}) = {self.value_table[old_s_idx]:.2f}')

            if t:
                if not fast_execution:
                    print('\nTerminal state, reseting environment.')
                self.env.reset()
                self.current_episode = []
                self.i += 1
            self.phase = 'choosing_actions'

        else:
            raise ValueError(f'Phase {self.phase} not recognized.')
        return print_string
Beispiel #8
0
    def next_step(self, fast_execution=False):
        print_string = ''
        if self.phase == 'choosing_actions':
            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_policy_at_state(self.env, self.s,
                                                   self.exploratory_policy,
                                                   self.ax[0])
                plot_q_table(self.env,
                             self.q_table,
                             self.ax[1],
                             vmin=self.vmin,
                             vmax=self.vmax)
                plot_policy(self.env, self.greedy_policy, self.ax[2])
                plt.show()
                print('Sampling action...\n')

            self.phase = 'showing_sampled_action'

        elif self.phase == 'showing_sampled_action':
            self.sampled_a = np.random.choice(
                4, p=self.exploratory_policy[self.s_idx])

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_chosen_action(self.env, self.s,
                                                 self.sampled_a, self.ax[0])
                plot_q_table(self.env,
                             self.q_table,
                             self.ax[1],
                             vmin=self.vmin,
                             vmax=self.vmax)
                plot_policy(self.env, self.greedy_policy, self.ax[2])
                plt.show()
                print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n')

            self.phase = 'carrying_out_action'

        elif self.phase == 'carrying_out_action':
            old_s = self.s
            s_prime, r, t = self.env.step(self.sampled_a)
            self.current_transition = [old_s, self.sampled_a, r, s_prime, t]

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(old_s, self.ax[0], alpha=0.3)
                plot_agent(self.s, self.ax[0])
                plot_q_table(self.env,
                             self.q_table,
                             self.ax[1],
                             vmin=self.vmin,
                             vmax=self.vmax)
                plot_policy(self.env, self.greedy_policy, self.ax[2])
                plt.show()
                print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n')

            if t:
                if not fast_execution:
                    print('A terminal state has been reached: end of episode.')
                self.phase = 'update_q_value_terminal'
            else:
                self.phase = 'computing_td_target'

        elif self.phase == 'computing_td_target':
            # Compute TD target
            s, a, r, s_prime, _ = self.current_transition
            s_prime_idx = self.env.states.index(s_prime)
            greedy_a = np.random.choice(4, p=self.greedy_policy[self.s_idx])
            td_target = r + 0.9 * self.q_table[s_prime_idx, greedy_a]

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_agent(s, self.ax[0], alpha=0.3)
                plot_q_table(self.env,
                             self.q_table,
                             self.ax[1],
                             vmin=self.vmin,
                             vmax=self.vmax)
                plot_policy(self.env, self.greedy_policy, self.ax[2])
                plt.show()

            if not fast_execution:
                print(
                    f'TD target = R + 𝛾 * Q({s_prime}, {a_index_to_symbol[greedy_a]}) = '
                    f'{r:.2f} + 0.9 * {self.q_table[s_prime_idx, greedy_a]:.2f} = '
                    f'{td_target:.2f}')

            self.phase = 'computing_new_q_value'

        elif self.phase == 'computing_new_q_value':
            # Compute TD target
            s, a, r, s_prime, _ = self.current_transition
            s_idx = self.env.states.index(s)
            s_prime_idx = self.env.states.index(s_prime)
            greedy_a = np.random.choice(4, p=self.greedy_policy[self.s_idx])
            td_target = r + 0.9 * self.q_table[s_prime_idx, greedy_a]

            # Compute new value
            new_value = self.alpha * td_target + (
                1 - self.alpha) * self.q_table[s_idx, a]

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_agent(s, self.ax[0], alpha=0.3)
                plot_q_table(self.env,
                             self.q_table,
                             self.ax[1],
                             vmin=self.vmin,
                             vmax=self.vmax)
                plot_policy(self.env, self.greedy_policy, self.ax[2])
                plt.show()

            if not fast_execution:
                print(f'TD target = {td_target:.2f}')
                print(
                    f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * '
                    f'({td_target:.2f} - {self.q_table[s_idx, a]:.2f}) = {new_value:.2f}'
                )

            self.phase = 'showing_new_q_value'

        elif self.phase == 'showing_new_q_value':

            # Compute TD target
            s, a, r, s_prime, _ = self.current_transition
            s_idx = self.env.states.index(s)
            s_prime_idx = self.env.states.index(s_prime)
            greedy_a = np.random.choice(4, p=self.greedy_policy[self.s_idx])
            td_target = r + 0.9 * self.q_table[s_prime_idx, greedy_a]

            # Compute new value and update q-table
            new_value = self.alpha * td_target + (
                1 - self.alpha) * self.q_table[s_idx, a]
            self.q_table[s_idx, a] = new_value
            self.exploratory_policy = epsilon_greedy_pi_from_q_table(
                self.env, self.q_table, self.epsilon)
            self.greedy_policy = epsilon_greedy_pi_from_q_table(
                self.env, self.q_table, 0)

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_agent(s, self.ax[0], alpha=0.3)
                plot_q_table(self.env,
                             self.q_table,
                             self.ax[1],
                             vmin=self.vmin,
                             vmax=self.vmax)
                plot_policy(self.env, self.greedy_policy, self.ax[2])
                plt.show()

                print(f'TD target = {td_target:.2f}')
                print(f'Q({s}, {a_index_to_symbol[a]}) = {new_value:.2f}')

            self.phase = 'choosing_actions'

        elif self.phase == 'update_q_value_terminal':
            s, a, r, s_prime, t = self.current_transition
            s_idx = self.env.states.index(s)
            new_value = self.alpha * r + (1 - self.alpha) * self.q_table[s_idx,
                                                                         a]

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_agent(s, self.ax[0], alpha=0.3)
                plot_q_table(self.env,
                             self.q_table,
                             self.ax[1],
                             vmin=self.vmin,
                             vmax=self.vmax)
                plot_policy(self.env, self.greedy_policy, self.ax[2])
                plt.show()

                print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n')
                print('A terminal state has been reached: end of episode.\n')
                print(f'TD target = R + 𝛾 * 0 = {r:.2f}')
                print(
                    f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * '
                    f'({r:.2f} - {self.q_table[s_idx, a]:.2f}) = {new_value:.2f}'
                )

            self.q_table[s_idx, a] = new_value
            self.exploratory_policy = epsilon_greedy_pi_from_q_table(
                self.env, self.q_table, self.epsilon)
            self.greedy_policy = epsilon_greedy_pi_from_q_table(
                self.env, self.q_table, 0)
            self.phase = 'showing_new_q_value_terminal'

        elif self.phase == 'showing_new_q_value_terminal':
            if not fast_execution:
                s, a, r, s_prime, t = self.current_transition
                s_idx = self.env.states.index(s)
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_agent(s, self.ax[0], alpha=0.3)
                plot_q_table(self.env,
                             self.q_table,
                             self.ax[1],
                             vmin=self.vmin,
                             vmax=self.vmax)
                plot_policy(self.env, self.greedy_policy, self.ax[2])
                plt.show()

                print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n')
                print('A terminal state has been reached: end of episode.\n')
                print(f'TD target = R + 𝛾 * 0 = {r:.2f}')
                print(
                    f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f}'
                )

            self.current_transition = []
            self.env.reset()
            self.phase = 'choosing_actions'

        else:
            raise ValueError(f'Phase {self.phase} not recognized.')
        return print_string
        # Set policy to greedy w.r.t. Q-values computed above
        new_policy[np.argmax(q)] = 1
        policy[i_s] = new_policy


def epsilon_greedy_pi_from_q_table(env, q_table, epsilon):
    policy = np.zeros((len(env.states), 4))
    for i_s, s in enumerate(env.states):
        q_values = q_table[i_s]
        a_star = np.argmax(q_values)
        policy[i_s, :] = [epsilon / 4] * 4
        policy[i_s, a_star] += 1 - epsilon
    return policy


if __name__ == '__main__':
    from lib.grid_world import grid_world_3x3 as env
    from lib.plot_utils import plot_policy, plot_value_table
    import matplotlib.pyplot as plt

    np.random.seed(2)
    env.reset()
    value_table = np.random.rand(130)
    policy = np.ones((len(env.states), 4)) * 0.25
    for i, s in enumerate(env.states):
        policy[i, :] = np.random.dirichlet([0.1, 0.1, 0.1, 0.1])

    plot_value_table(env, value_table)
    plot_policy(env, policy)
    plt.show()
Beispiel #10
0
    def next_step(self, fast_execution=False):
        print_string = ''
        if self.phase == 'initial_sampling_action':
            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_policy_at_state(self.env, self.s, self.policy, self.ax[0])
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print('Sampling action...\n')

            self.phase = 'initial_showing_sampled_action'

        elif self.phase == 'initial_showing_sampled_action':
            self.sampled_a = np.random.choice(4, p=self.policy[self.s_idx])

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0])
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n')

            self.phase = 'carrying_out_action'

        elif self.phase == 'carrying_out_action':
            old_s = self.s
            s_prime, r, t = self.env.step(self.sampled_a)
            self.current_transition = [old_s, self.sampled_a, r, s_prime, t]

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(old_s, self.ax[0], alpha=0.3)
                plot_agent(self.s, self.ax[0])
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n')

            if not t:
                self.phase = 'sampling_action'
            else:
                self.phase = 'compute_q_for_terminal_s_prime'

        elif self.phase == 'sampling_action':
            if not fast_execution:
                old_s = self.current_transition[0]
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_policy_at_state(self.env, self.s, self.policy, self.ax[0])
                plot_agent(old_s, self.ax[0], alpha=0.3)
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print('Sampling action...\n')

            self.phase = 'showing_sampled_action'

        elif self.phase == 'showing_sampled_action':
            self.sampled_a = np.random.choice(4, p=self.policy[self.s_idx])
            self.current_transition.append(self.sampled_a)

            if not fast_execution:
                old_s = self.current_transition[0]
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0])
                plot_agent(old_s, self.ax[0], alpha=0.3)
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n')

            self.phase = 'showing_transition_done'

        elif self.phase == 'showing_transition_done':
            if not fast_execution:
                old_s = self.current_transition[0]
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0])
                plot_agent(old_s, self.ax[0], alpha=0.3)
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()

                s, a, r, s_prime, t, a_prime = self.current_transition
                print(f'Transition: <{s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}, {a_index_to_symbol[a_prime]}>\n')

            self.phase = 'showing_td_target'

        elif self.phase == 'showing_td_target':
            if not fast_execution:
                s, a, r, s_prime, t, a_prime = self.current_transition
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0])
                plot_agent(s, self.ax[0], alpha=0.3)
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()

                print(f'Transition: <{s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}, {a_index_to_symbol[a_prime]}>\n')

                s_prime_idx = self.env.states.index(s_prime)
                td_target = r + 0.9 * self.q_table[s_prime_idx, a_prime]
                print(f'TD target: R + 𝛾 * Q({s_prime}, {a_index_to_symbol[a_prime]}) = '
                      f'{r:.2f} + 0.9 * {self.q_table[s_prime_idx, a_prime]:.2f} = '
                      f'{td_target:.2f}')

            self.phase = 'show_q_value_computation'

        elif self.phase == 'show_q_value_computation':
            s, a, r, s_prime, t, a_prime = self.current_transition
            if not fast_execution:
                plot_agent(s, self.ax[0], alpha=0.3)
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0])
                plot_agent(s, self.ax[0], alpha=0.3)
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()

            s_idx = self.env.states.index(s)
            s_prime_idx = self.env.states.index(s_prime)
            td_target = r + 0.9 * self.q_table[s_prime_idx, a_prime]
            new_value = self.alpha * td_target + (1 - self.alpha) * self.q_table[s_idx, a]

            if fast_execution:
                msg = f'Transition: <{s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}, {a_index_to_symbol[a_prime]}>\n' \
                      f'TD target: R + 𝛾 * Q({s_prime}, {a_index_to_symbol[a_prime]}) = ' \
                      f'{r:.2f} + 0.9 * {self.q_table[s_prime_idx, a_prime]:.2f} = {td_target:.2f}\n'  \
                      f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * ' \
                      f'({td_target:.2f} - {self.q_table[s_idx, a]:.2f}) = {new_value:.2f}\n\n'
                print_string += msg
            else:
                print(f'Transition: <{s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}, {a_index_to_symbol[a_prime]}>\n')
                print(f'TD target: {td_target:.2f}')
                print(f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * '
                      f'({td_target:.2f} - {self.q_table[s_idx, a]:.2f}) = {new_value:.2f}')

            self.q_table[s_idx, a] = new_value
            self.policy = epsilon_greedy_pi_from_q_table(self.env, self.q_table, self.epsilon)

            self.phase = 'showing_updated_q_value'

        elif self.phase == 'showing_updated_q_value':
            if not fast_execution:
                s, a, r, s_prime, t, a_prime = self.current_transition
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_chosen_action(self.env, self.s, self.sampled_a, self.ax[0])
                plot_agent(s, self.ax[0], alpha=0.3)
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()

                print(f'Transition: <{s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}, {a_index_to_symbol[a_prime]}>\n')

                s_prime_idx = self.env.states.index(s_prime)
                td_target = r + 0.9 * self.q_table[s_prime_idx, a_prime]
                print(f'TD target: {td_target:.2f}')

                s_idx = self.env.states.index(s)
                print(f'Q({s, a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f}')

            self.phase = 'carrying_out_action'

        elif self.phase == 'compute_q_for_terminal_s_prime':
            s, a, r, s_prime, t = self.current_transition
            s_idx = self.env.states.index(s)
            new_value = self.alpha * r + (1 - self.alpha) * self.q_table[s_idx, a]

            if fast_execution:
                print_string += f'Transition to terminal state: <{s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}, ∅>\n'
                print_string += f'TD target = R + 𝛾 * 0 = {r:.2f}\n'
                print_string += f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * ' \
                                f'({r:.2f} - {self.q_table[s_idx, a]:.2f}) = {new_value:.2f}'

            else:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(s, self.ax[0], alpha=0.3)
                plot_agent(s_prime, self.ax[0])
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print('Terminal state reached.\n')
                print(f'TD target = R + 𝛾 * 0 = {r:.2f}')
                print(f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f} + {self.alpha:.2f} * '
                      f'({r:.2f} - {self.q_table[s_idx, a]:.2f}) = {new_value:.2f}')

            self.q_table[s_idx, a] = new_value
            self.policy = epsilon_greedy_pi_from_q_table(self.env, self.q_table, self.epsilon)

            self.phase = 'showing_updated_q_value_terminal_s_prime'

        elif self.phase == 'showing_updated_q_value_terminal_s_prime':
            self.env.reset()
            self.i += 1

            if not fast_execution:
                s, a, r, s_prime, t = self.current_transition
                s_idx = self.env.states.index(s)

                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(s, self.ax[0], alpha=0.3)
                plot_agent(s_prime, self.ax[0])
                plot_q_table(self.env, self.q_table, self.ax[1], vmin=self.vmin, vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()

                print('Terminal state reached.\n')
                print(f'Q({s}, {a_index_to_symbol[a]}) = {self.q_table[s_idx, a]:.2f}\n')
                print(f'Resetting environment')

            self.phase = 'initial_sampling_action'

        else:
            raise ValueError(f'Phase {self.phase} not recognized.')
        return print_string
Beispiel #11
0
    def next_step(self, fast_execution=False):
        print_string = ''
        if self.phase == 'choosing_actions':
            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_policy_at_state(self.env, self.s,
                                                   self.policy, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print('Sampling action...\n')
                print(self.get_current_episode_as_string())

            self.phase = 'showing_sampled_action'

        elif self.phase == 'showing_sampled_action':
            self.sampled_a = np.random.choice(4, p=self.policy[self.s_idx])

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_chosen_action(self.env, self.s,
                                                 self.sampled_a, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n')
                print(self.get_current_episode_as_string())

            self.phase = 'carrying_out_action'

        elif self.phase == 'carrying_out_action':
            old_s = self.s
            s_prime, r, t = self.env.step(self.sampled_a)
            self.current_episode.append([old_s, self.sampled_a, r, s_prime])

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(old_s, self.ax[0], alpha=0.3)
                plot_agent(self.s, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(self.get_current_episode_as_string())

            if t:
                if not fast_execution:
                    print('A terminal state has been reached: end of episode.')
                self.phase = 'computing_returns'
            else:
                self.phase = 'choosing_actions'

        elif self.phase == 'computing_returns':

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(self.get_current_episode_as_string())
                print('\nComputing returns:')

            self.current_episode_returns = [0] * len(self.current_episode)
            for i in reversed(range(len(self.current_episode))):
                transition = self.current_episode[i]
                if i == len(self.current_episode) - 1:
                    if not fast_execution:
                        print(
                            f'G_{i+1} = {transition[2]:.1f},\t\t\t\tat {transition[0]}'
                        )
                    self.current_episode_returns[i] = transition[2]
                else:
                    self.current_episode_returns[i] = transition[
                        2] + 0.9 * self.current_episode_returns[i + 1]
                    if not fast_execution:
                        print(
                            f'G_{i+1} = {transition[2]:.1f} + 0.9 * {self.current_episode_returns[i+1]:.2f}',
                            end='')
                        print(
                            f' = {self.current_episode_returns[i]:.2f},\tat {transition[0]}'
                        )

            self.phase = 'presenting_computed_returns'

        elif self.phase == 'presenting_computed_returns':
            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(self.get_current_episode_as_string())

                print('\nComputing returns:')
                for i, (transition, g) in reversed(
                        list(
                            enumerate(
                                zip(self.current_episode,
                                    self.current_episode_returns)))):
                    print(f'G_{i+1} = {g:.2f},\tat {transition[0]}')

            self.phase = 'updating_values'

        elif self.phase == 'updating_values':
            if fast_execution:
                print_string += f'{self.get_current_episode_as_string()}\n'
            else:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(self.get_current_episode_as_string())
                print('\nComputing returns:')

            # Print computed returns for all visited states
            for i, (transition, g) in reversed(
                    list(
                        enumerate(
                            zip(self.current_episode,
                                self.current_episode_returns)))):
                msg = f'G_{i+1} = {g:.2f},\tat {transition[0]}'
                if fast_execution:
                    print_string += f'{msg}\n'
                else:
                    print(msg)

            if fast_execution:
                print_string += '\n'
            else:
                print('\nUpdating values:')

            for transition, g in zip(reversed(self.current_episode),
                                     reversed(self.current_episode_returns)):
                s_idx = self.env.states.index(transition[0])
                new_value = self.alpha * g + (
                    1 - self.alpha) * self.value_table[s_idx]
                cur_value = self.value_table[s_idx]
                msg = f'V({transition[0]}) = {cur_value:.2f} + {self.alpha:.2f} * ({g:.2f} - {cur_value:.2f}) ' \
                      f'= {new_value:.2f}'
                self.value_table[s_idx] = new_value
                if fast_execution:
                    print_string += f'{msg}\n'
                else:
                    print(msg)

            self.phase = 'show_updated_value_function'

        elif self.phase == 'show_updated_value_function':
            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(self.get_current_episode_as_string())

                print('\nComputing returns:')
                for i, (transition, g) in reversed(
                        list(
                            enumerate(
                                zip(self.current_episode,
                                    self.current_episode_returns)))):
                    print(f'G_{i+1} = {g:.2f},\tat {transition[0]}')

                print('\nUpdating values:')
                for transition, g in zip(
                        reversed(self.current_episode),
                        reversed(self.current_episode_returns)):
                    s_idx = self.env.states.index(transition[0])
                    print(
                        f'V({transition[0]}) = {self.value_table[s_idx]:.2f}')

            self.current_episode = []
            self.env.reset()
            self.phase = 'choosing_actions'

        else:
            raise ValueError(f'Phase {self.phase} not recognized.')

        return print_string
Beispiel #12
0
    def next_step(self, fast_execution=False):

        if not fast_execution:
            # Shared plotting for all phases
            self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
            plot_s_pmf(self.env, self.s, self.a, ax=self.ax[0])
            plot_value_table(self.env,
                             self.value_table,
                             ax=self.ax[1],
                             vmin=-8,
                             vmax=0)
            plot_policy(self.env, self.policy, ax=self.ax[2])

            # Phase-specific plotting
            if self.phase == 'computing_final_v' and self.env.terminal_table[
                    self.s_i] != 1:
                self.ax[2].add_patch(
                    matplotlib.patches.Rectangle(self.s,
                                                 1.0,
                                                 1.0,
                                                 edgecolor='r',
                                                 facecolor='None',
                                                 lw=5))
            if self.phase == 'updating_value_table':
                self.ax[1].add_patch(
                    matplotlib.patches.Rectangle(self.s,
                                                 1.0,
                                                 1.0,
                                                 edgecolor='r',
                                                 facecolor='None',
                                                 lw=5))
            plt.show()

            print(f'Iteration: {self.i}\nState: {self.s}\n')

        if self.phase == 'going_through_actions':
            # Special handling for terminal states
            if self.env.terminal_table[self.s_i] == 1:
                self.a = 4
            else:
                # Get pmf over possible next states
                s_prime_idxs, s_prime_probs = get_pmf_possible_s_primes(
                    self.env, self.s, self.a)

                # Compute Q values
                q_message = f'Q(s, {a_index_to_symbol[self.a]}) = '
                r = self.env.reward_table[self.s_i, self.a]
                for s_prime_idx, s_prime_prob in zip(s_prime_idxs,
                                                     s_prime_probs):
                    q_message += f'{s_prime_prob:.1f}*({r:.1f} + ' \
                                 f'0.9*{self.value_table[s_prime_idx]:.2f}) + '
                    self.q_values[self.a] += s_prime_prob * (
                        r + 0.9 * self.value_table[s_prime_idx])

                # Prepare message to be printed
                q_message = q_message[:-2]
                q_message += f' = {self.q_values[self.a]:.2f}'
                self.q_messages += f'\n{q_message}'
                if not fast_execution:
                    print(self.q_messages)

                self.a += 1

            if self.a == 4:
                self.a = 3  # Show the same action as before
                self.phase = 'computing_final_v'

        elif self.phase == 'computing_final_v':
            if self.env.terminal_table[self.s_i] == 1:
                final_value_message = 'State is terminal. Setting its value to zero.\nV(s) = 0'
                if not fast_execution:
                    print(final_value_message)
                final_value = 0
            else:
                # Print computed Q messages
                if not fast_execution:
                    for i, q_value in enumerate(self.q_values):
                        print(f'Q(s, {a_index_to_symbol[i]}) = {q_value:.2f}')

                # Print resulting value of current state
                final_value_message = '\nV(s) = '
                final_value = 0
                for q_value, pi in zip(self.q_values, self.policy[self.s_i]):
                    final_value_message += f'{q_value:.2f}*{pi:.1f} + '
                    final_value += q_value * pi
                final_value_message = f'{final_value_message[:-3]}) = {final_value:.2f}'
                if not fast_execution:
                    print(final_value_message)

            # Update value table and move to next phase
            self.value_table[self.s_i] = final_value
            self.phase = 'updating_value_table'

        elif self.phase == 'updating_value_table':
            if not fast_execution:
                # Print computed Q messages
                for i, q_value in enumerate(self.q_values):
                    print(f'Q(s, {a_index_to_symbol[i]}) = {q_value:.2f}')

                # Print resulting value of current state
                print(f'\nV(s) = {self.value_table[self.s_i]:.2f}')

            # Reset action counter, move to next state, erase previous q messages
            self.a = 0
            self.s_i += 1
            self.q_messages = ''
            self.q_values = [0, 0, 0, 0]
            self.phase = 'going_through_actions'

            # If reached last state, restart from first state and increment iteration number
            if self.s_i == len(self.env.states):
                self.s_i = 0
                self.i += 1