コード例 #1
0
 def policy_improve_on_click(obj):
     plt.close()
     fig, ax = plt.subplots(ncols=2, figsize=(14, 6))
     output.clear_output(True)
     policy_improvement(env, policy, value_table)
     with output:
         plot_value_table(env, value_table, ax=ax[0], vmin=vmin, vmax=vmax)
         plot_policy(env, policy, ax=ax[1])
         plt.show()
コード例 #2
0
    def finish_iteration(self):
        initial_i = self.i
        while initial_i == self.i:
            self.next_step(fast_execution=True)

        # Plot the resulting state of the env, value table, and policy
        self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
        plot_s_pmf(self.env, self.s, self.a, ax=self.ax[0])
        plot_value_table(self.env,
                         self.value_table,
                         ax=self.ax[1],
                         vmin=-8,
                         vmax=0)
        plot_policy(self.env, self.policy, ax=self.ax[2])
        plt.show()
コード例 #3
0
    def finish_episode(self):
        print_string = ''
        current_i = self.i
        while self.i == current_i:
            print_string += f'{self.next_step(fast_execution=True)}'

        self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
        plot_all_states(self.env, self.ax[0])
        plot_agent(self.s, self.ax[0])
        plot_value_table(self.env,
                         self.value_table,
                         self.ax[1],
                         vmin=self.vmin,
                         vmax=self.vmax)
        plot_policy(self.env, self.policy, self.ax[2])
        plt.show()
        print(print_string)
コード例 #4
0
def policy_iteration_wrapper():
    env.reset()
    value_table = np.zeros(9)
    policy = np.array([[1, 0, 0, 0], [0.5, 0.5, 0, 0], [0, 1, 0, 0],
                       [0, 0, 0, 1], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 1, 0],
                       [0, 0, 1, 0], [1, 0, 0, 0]])

    btn_eval = widgets.Button(description='Policy eval iteration')
    btn_improve = widgets.Button(description='Policy improvement')

    display(btn_eval)
    display(btn_improve)

    output = widgets.Output()

    vmin = -8
    vmax = 0
    with output:
        fig, ax = plt.subplots(ncols=2, figsize=(14, 6))
        plot_value_table(env, value_table, ax=ax[0], vmin=vmin, vmax=vmax)
        plot_policy(env, policy, ax=ax[1])
        plt.show()

    def policy_eval_on_click(obj):
        plt.close()
        fig, ax = plt.subplots(ncols=2, figsize=(14, 6))
        output.clear_output(True)
        policy_evaluation_one_step(env, policy, value_table)
        with output:
            plot_value_table(env, value_table, ax=ax[0], vmin=vmin, vmax=vmax)
            plot_policy(env, policy, ax=ax[1])
            plt.show()

    def policy_improve_on_click(obj):
        plt.close()
        fig, ax = plt.subplots(ncols=2, figsize=(14, 6))
        output.clear_output(True)
        policy_improvement(env, policy, value_table)
        with output:
            plot_value_table(env, value_table, ax=ax[0], vmin=vmin, vmax=vmax)
            plot_policy(env, policy, ax=ax[1])
            plt.show()

    btn_eval.on_click(policy_eval_on_click)
    btn_improve.on_click(policy_improve_on_click)
    display(output)
コード例 #5
0
    def next_step(self, fast_execution=False):
        print_string = ''
        if self.phase == 'choosing_actions':
            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_policy_at_state(self.env, self.s,
                                                   self.policy, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print('Sampling action...\n')

            self.phase = 'showing_sampled_action'

        elif self.phase == 'showing_sampled_action':
            self.sampled_a = np.random.choice(4, p=self.policy[self.s_idx])

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_chosen_action(self.env, self.s,
                                                 self.sampled_a, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n')

            self.phase = 'carrying_out_action'

        elif self.phase == 'carrying_out_action':
            old_s = self.s
            s_prime, r, t = self.env.step(self.sampled_a)
            self.current_episode.append([old_s, self.sampled_a, r, s_prime, t])

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(old_s, self.ax[0], alpha=0.3)
                plot_agent(self.s, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(
                    f'Transition: <{old_s}, {a_index_to_symbol[self.sampled_a]}, {r:.1f}, {s_prime}>'
                )

            self.phase = 'showing_td_target'

        elif self.phase == 'showing_td_target':
            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()

            old_s, a, r, s_prime, _ = self.current_episode[-1]
            s_prime_idx = self.env.states.index(s_prime)
            self.td_target = r + 0.9 * self.value_table[s_prime_idx]
            if not fast_execution:
                print(
                    f'Transition: <{old_s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}>\n'
                )
                print('Updating value table:')
                print(
                    f'TD target = R + 𝛾 * V({s_prime}) = {r:.1f} + 0.9*{self.value_table[s_prime_idx]:.2f} = '
                    f'{self.td_target:.2f}')

            self.phase = 'updating_value_table'

        elif self.phase == 'updating_value_table':
            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()

            old_s, a, r, s_prime, _ = self.current_episode[-1]
            old_s_idx = self.env.states.index(old_s)
            new_value = self.alpha * self.td_target + (
                1 - self.alpha) * self.value_table[old_s_idx]

            if fast_execution:
                print_string += f'Transition: <{old_s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}>\n'
                print_string += f'TD target = {self.td_target:.2f}\n'
                print_string += f'V({old_s}) ← {self.value_table[old_s_idx]:.2f} + {self.alpha:.2f} * ' \
                                f'({self.td_target:.2f} - {self.value_table[old_s_idx]:.2f}) = {new_value:.2f}\n\n'
            else:
                print(
                    f'Transition: <{old_s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}>\n'
                )
                print('Updating value table:')
                print(f'TD target = {self.td_target:.2f}')
                print(
                    f'V({old_s}) ← {self.value_table[old_s_idx]:.2f} + {self.alpha:.2f} * '
                    f'({self.td_target:.2f} - {self.value_table[old_s_idx]:.2f}) = {new_value:.2f}\n'
                )

            self.value_table[old_s_idx] = new_value
            self.phase = 'showing_new_value_table'

        elif self.phase == 'showing_new_value_table':
            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()

            old_s, a, r, s_prime, t = self.current_episode[-1]
            old_s_idx = self.env.states.index(old_s)

            if not fast_execution:
                print(
                    f'Transition: <{old_s}, {a_index_to_symbol[a]}, {r:.1f}, {s_prime}>\n'
                )
                print('Updating value table:')
                print(f'TD target = {self.td_target:.2f}')
                print(f'V({old_s}) = {self.value_table[old_s_idx]:.2f}')

            if t:
                if not fast_execution:
                    print('\nTerminal state, reseting environment.')
                self.env.reset()
                self.current_episode = []
                self.i += 1
            self.phase = 'choosing_actions'

        else:
            raise ValueError(f'Phase {self.phase} not recognized.')
        return print_string
コード例 #6
0
        # Set policy to greedy w.r.t. Q-values computed above
        new_policy[np.argmax(q)] = 1
        policy[i_s] = new_policy


def epsilon_greedy_pi_from_q_table(env, q_table, epsilon):
    policy = np.zeros((len(env.states), 4))
    for i_s, s in enumerate(env.states):
        q_values = q_table[i_s]
        a_star = np.argmax(q_values)
        policy[i_s, :] = [epsilon / 4] * 4
        policy[i_s, a_star] += 1 - epsilon
    return policy


if __name__ == '__main__':
    from lib.grid_world import grid_world_3x3 as env
    from lib.plot_utils import plot_policy, plot_value_table
    import matplotlib.pyplot as plt

    np.random.seed(2)
    env.reset()
    value_table = np.random.rand(130)
    policy = np.ones((len(env.states), 4)) * 0.25
    for i, s in enumerate(env.states):
        policy[i, :] = np.random.dirichlet([0.1, 0.1, 0.1, 0.1])

    plot_value_table(env, value_table)
    plot_policy(env, policy)
    plt.show()
コード例 #7
0
    def next_step(self, fast_execution=False):
        print_string = ''
        if self.phase == 'choosing_actions':
            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_policy_at_state(self.env, self.s,
                                                   self.policy, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print('Sampling action...\n')
                print(self.get_current_episode_as_string())

            self.phase = 'showing_sampled_action'

        elif self.phase == 'showing_sampled_action':
            self.sampled_a = np.random.choice(4, p=self.policy[self.s_idx])

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_env_agent_and_chosen_action(self.env, self.s,
                                                 self.sampled_a, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(f'Action sampled: {a_index_to_symbol[self.sampled_a]}\n')
                print(self.get_current_episode_as_string())

            self.phase = 'carrying_out_action'

        elif self.phase == 'carrying_out_action':
            old_s = self.s
            s_prime, r, t = self.env.step(self.sampled_a)
            self.current_episode.append([old_s, self.sampled_a, r, s_prime])

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(old_s, self.ax[0], alpha=0.3)
                plot_agent(self.s, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(self.get_current_episode_as_string())

            if t:
                if not fast_execution:
                    print('A terminal state has been reached: end of episode.')
                self.phase = 'computing_returns'
            else:
                self.phase = 'choosing_actions'

        elif self.phase == 'computing_returns':

            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(self.get_current_episode_as_string())
                print('\nComputing returns:')

            self.current_episode_returns = [0] * len(self.current_episode)
            for i in reversed(range(len(self.current_episode))):
                transition = self.current_episode[i]
                if i == len(self.current_episode) - 1:
                    if not fast_execution:
                        print(
                            f'G_{i+1} = {transition[2]:.1f},\t\t\t\tat {transition[0]}'
                        )
                    self.current_episode_returns[i] = transition[2]
                else:
                    self.current_episode_returns[i] = transition[
                        2] + 0.9 * self.current_episode_returns[i + 1]
                    if not fast_execution:
                        print(
                            f'G_{i+1} = {transition[2]:.1f} + 0.9 * {self.current_episode_returns[i+1]:.2f}',
                            end='')
                        print(
                            f' = {self.current_episode_returns[i]:.2f},\tat {transition[0]}'
                        )

            self.phase = 'presenting_computed_returns'

        elif self.phase == 'presenting_computed_returns':
            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(self.get_current_episode_as_string())

                print('\nComputing returns:')
                for i, (transition, g) in reversed(
                        list(
                            enumerate(
                                zip(self.current_episode,
                                    self.current_episode_returns)))):
                    print(f'G_{i+1} = {g:.2f},\tat {transition[0]}')

            self.phase = 'updating_values'

        elif self.phase == 'updating_values':
            if fast_execution:
                print_string += f'{self.get_current_episode_as_string()}\n'
            else:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(self.get_current_episode_as_string())
                print('\nComputing returns:')

            # Print computed returns for all visited states
            for i, (transition, g) in reversed(
                    list(
                        enumerate(
                            zip(self.current_episode,
                                self.current_episode_returns)))):
                msg = f'G_{i+1} = {g:.2f},\tat {transition[0]}'
                if fast_execution:
                    print_string += f'{msg}\n'
                else:
                    print(msg)

            if fast_execution:
                print_string += '\n'
            else:
                print('\nUpdating values:')

            for transition, g in zip(reversed(self.current_episode),
                                     reversed(self.current_episode_returns)):
                s_idx = self.env.states.index(transition[0])
                new_value = self.alpha * g + (
                    1 - self.alpha) * self.value_table[s_idx]
                cur_value = self.value_table[s_idx]
                msg = f'V({transition[0]}) = {cur_value:.2f} + {self.alpha:.2f} * ({g:.2f} - {cur_value:.2f}) ' \
                      f'= {new_value:.2f}'
                self.value_table[s_idx] = new_value
                if fast_execution:
                    print_string += f'{msg}\n'
                else:
                    print(msg)

            self.phase = 'show_updated_value_function'

        elif self.phase == 'show_updated_value_function':
            if not fast_execution:
                self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
                plot_all_states(self.env, self.ax[0])
                plot_agent(self.s, self.ax[0])
                plot_value_table(self.env,
                                 self.value_table,
                                 self.ax[1],
                                 vmin=self.vmin,
                                 vmax=self.vmax)
                plot_policy(self.env, self.policy, self.ax[2])
                plt.show()
                print(self.get_current_episode_as_string())

                print('\nComputing returns:')
                for i, (transition, g) in reversed(
                        list(
                            enumerate(
                                zip(self.current_episode,
                                    self.current_episode_returns)))):
                    print(f'G_{i+1} = {g:.2f},\tat {transition[0]}')

                print('\nUpdating values:')
                for transition, g in zip(
                        reversed(self.current_episode),
                        reversed(self.current_episode_returns)):
                    s_idx = self.env.states.index(transition[0])
                    print(
                        f'V({transition[0]}) = {self.value_table[s_idx]:.2f}')

            self.current_episode = []
            self.env.reset()
            self.phase = 'choosing_actions'

        else:
            raise ValueError(f'Phase {self.phase} not recognized.')

        return print_string
コード例 #8
0
    def next_step(self, fast_execution=False):

        if not fast_execution:
            # Shared plotting for all phases
            self.fig, self.ax = plt.subplots(ncols=3, figsize=(20, 6))
            plot_s_pmf(self.env, self.s, self.a, ax=self.ax[0])
            plot_value_table(self.env,
                             self.value_table,
                             ax=self.ax[1],
                             vmin=-8,
                             vmax=0)
            plot_policy(self.env, self.policy, ax=self.ax[2])

            # Phase-specific plotting
            if self.phase == 'computing_final_v' and self.env.terminal_table[
                    self.s_i] != 1:
                self.ax[2].add_patch(
                    matplotlib.patches.Rectangle(self.s,
                                                 1.0,
                                                 1.0,
                                                 edgecolor='r',
                                                 facecolor='None',
                                                 lw=5))
            if self.phase == 'updating_value_table':
                self.ax[1].add_patch(
                    matplotlib.patches.Rectangle(self.s,
                                                 1.0,
                                                 1.0,
                                                 edgecolor='r',
                                                 facecolor='None',
                                                 lw=5))
            plt.show()

            print(f'Iteration: {self.i}\nState: {self.s}\n')

        if self.phase == 'going_through_actions':
            # Special handling for terminal states
            if self.env.terminal_table[self.s_i] == 1:
                self.a = 4
            else:
                # Get pmf over possible next states
                s_prime_idxs, s_prime_probs = get_pmf_possible_s_primes(
                    self.env, self.s, self.a)

                # Compute Q values
                q_message = f'Q(s, {a_index_to_symbol[self.a]}) = '
                r = self.env.reward_table[self.s_i, self.a]
                for s_prime_idx, s_prime_prob in zip(s_prime_idxs,
                                                     s_prime_probs):
                    q_message += f'{s_prime_prob:.1f}*({r:.1f} + ' \
                                 f'0.9*{self.value_table[s_prime_idx]:.2f}) + '
                    self.q_values[self.a] += s_prime_prob * (
                        r + 0.9 * self.value_table[s_prime_idx])

                # Prepare message to be printed
                q_message = q_message[:-2]
                q_message += f' = {self.q_values[self.a]:.2f}'
                self.q_messages += f'\n{q_message}'
                if not fast_execution:
                    print(self.q_messages)

                self.a += 1

            if self.a == 4:
                self.a = 3  # Show the same action as before
                self.phase = 'computing_final_v'

        elif self.phase == 'computing_final_v':
            if self.env.terminal_table[self.s_i] == 1:
                final_value_message = 'State is terminal. Setting its value to zero.\nV(s) = 0'
                if not fast_execution:
                    print(final_value_message)
                final_value = 0
            else:
                # Print computed Q messages
                if not fast_execution:
                    for i, q_value in enumerate(self.q_values):
                        print(f'Q(s, {a_index_to_symbol[i]}) = {q_value:.2f}')

                # Print resulting value of current state
                final_value_message = '\nV(s) = '
                final_value = 0
                for q_value, pi in zip(self.q_values, self.policy[self.s_i]):
                    final_value_message += f'{q_value:.2f}*{pi:.1f} + '
                    final_value += q_value * pi
                final_value_message = f'{final_value_message[:-3]}) = {final_value:.2f}'
                if not fast_execution:
                    print(final_value_message)

            # Update value table and move to next phase
            self.value_table[self.s_i] = final_value
            self.phase = 'updating_value_table'

        elif self.phase == 'updating_value_table':
            if not fast_execution:
                # Print computed Q messages
                for i, q_value in enumerate(self.q_values):
                    print(f'Q(s, {a_index_to_symbol[i]}) = {q_value:.2f}')

                # Print resulting value of current state
                print(f'\nV(s) = {self.value_table[self.s_i]:.2f}')

            # Reset action counter, move to next state, erase previous q messages
            self.a = 0
            self.s_i += 1
            self.q_messages = ''
            self.q_values = [0, 0, 0, 0]
            self.phase = 'going_through_actions'

            # If reached last state, restart from first state and increment iteration number
            if self.s_i == len(self.env.states):
                self.s_i = 0
                self.i += 1