from GridWorld.GridWorldMDPClass import GridWorldMDP
from MDP.StateAbstractionClass import StateAbstraction
from MDP.AbstractMDPClass import AbstractMDP
from MDP.ValueIterationClass import ValueIteration
from resources.AbstractionTypes import Abstr_type
from resources.AbstractionCorrupters import make_corruption
from resources.AbstractionMakers import make_abstr

import numpy as np

# Number of states to corrupt
STATE_NUM = 20

# Create abstract MDP
mdp = GridWorldMDP()
vi = ValueIteration(mdp)
vi.run_value_iteration()
q_table = vi.get_q_table()
state_abstr = make_abstr(q_table, Abstr_type.PI_STAR)
abstr_mdp = AbstractMDP(mdp, state_abstr)

# Randomly select our list of states and print them out
states_to_corrupt = np.random.choice(mdp.get_all_possible_states(),
                                     size=STATE_NUM,
                                     replace=False)
for state in states_to_corrupt:
    print(state)

# Create a corrupt MDP
corr_mdp = make_corruption(abstr_mdp, states_to_corrupt)
                best_action_intersect = list(
                    set(best_actions_1) & set(best_actions_2))
                if len(best_action_intersect) == 0:
                    return False
    return True


def print_policy(policy):
    '''
    Print the policy
    '''
    for key in policy.keys():
        print(key, policy[key])


if __name__ == '__main__':
    # Test that optimal ground policy for FourRooms is representable in
    # abstaction given by Q*

    # Get optimal ground policy for FourRooms
    four_rooms = GridWorldMDP(slip_prob=0.0, gamma=0.99)
    vi = ValueIteration(four_rooms)
    vi.run_value_iteration()
    optimal_policy = vi.get_optimal_policy()
    #print_policy(optimal_policy)

    # Get Q* abstraction for FourRooms and optimal abstract policy
    abstr = make_abstr(vi.get_q_table(), Abstr_type.A_STAR)

    print(is_optimal_policy_representable(vi, optimal_policy, abstr))
Example #3
0
if __name__ == '__main__':

    # Testing Apra's value iteration on FourRooms
    '''
    grid_mdp_test = GridWorldMDP(height=11, width=11, slip_prob=0.1, gamma=0.95, build_walls=True)
    value_itr = ValueIteration(grid_mdp_test, 0.0001)
    value_itr.doValueIteration(10000)
    #print(value_itr)
    result = value_itr.get_q_table()
    for key in result.keys():
        print(key[0], key[1], result[key])
    #viz = GridWorldVisualizer(grid_mdp_test, value_itr)
    #viz.visualizeLearnedPolicy()
    '''

    # Testing VI on TaxiMDP
    mdp = TaxiMDP(slip_prob=0.0, gamma=0.99)
    value_itr = ValueIteration(mdp, 0.01)
    value_itr.run_value_iteration(1000)
    result = value_itr.get_q_table()
    for key in result.keys():
        print(key[0], key[1], result[key])

#state = GridWorldState(1,1)
# out = grid_mdp_test.next_possible_states(state,Dir.UP)
# print([str(k) for k in out.keys()])
# print(out.values())
#
# all_states = grid_mdp_test.get_all_possible_states()
# print([str(state) for state in all_states])
    def visualize_q_value_error(self, folder, mdp, episodes):
        """
        Create graphs showing the difference in Q-value between the true Q-value (as determined by value iteration)
        and the Q-values learned by the agents corresponding to the ensemble given by 'folder'

        :param folder: string indicating the folder containing the q-values of interest
        :param mdp: MDP (required for value iteration)
        :param episodes: list of episode numbers for which the errors will be calculated
        """
        # Locate file
        if self.experiment:
            q_value_folder = os.path.join(self.experiment.results_dir, folder)
        else:
            q_value_folder = os.path.join(self.results_dir, folder)
        if not os.path.exists(q_value_folder):
            raise ValueError('Q-value file ' + str(q_value_folder) +
                             ' does not exist')

        # Read in dataframe
        q_value_df = pd.read_csv(os.path.join(q_value_folder, 'q_values.csv'),
                                 header=0,
                                 error_bad_lines=False)

        # Create df holding true q-values from value iteration
        vi = ValueIteration(mdp)
        vi.run_value_iteration()
        true_q_values = vi.get_q_table()
        true_q_value_lists = []
        for (state, action), value in true_q_values.items():
            true_q_value_lists.append([state, action, value])
        names = ['state', 'action', 'true q_value']
        true_q_value_df = pd.DataFrame(true_q_value_lists, columns=names)
        true_q_value_df['state'] = true_q_value_df['state'].astype(str)
        true_q_value_df['action'] = true_q_value_df['action'].astype(str)

        # Join dfs and calculate errors
        joined_df = q_value_df.merge(true_q_value_df, on=['state', 'action'])
        joined_df['error'] = joined_df['q_value'] - joined_df['true q_value']

        #print(joined_df[:10].to_string())
        #print(joined_df['ensemble_key'].unique())

        # Convert state to literal tuple
        def lit_eval(val):
            return ast.literal_eval(val)

        joined_df['state'] = joined_df['state'].apply(lit_eval)

        # Create 2d array of all states
        states = []
        for i in range(11, 0, -1):
            row = []
            for j in range(1, 12):
                row.append((j, i))
            states.append(row)
        #print(states)

        # Graph the errors for each ensemble and episode number given
        for episode in episodes:
            for key in joined_df['ensemble_key'].unique():
                print(key)
                if key != "ground":
                    abstr_type = key.split(',')[0].strip('(')
                    if 'PI' in abstr_type:
                        abstr_type = 'Pi*'
                    elif 'A_STAR' in abstr_type:
                        abstr_type = 'A*'
                    elif 'Q_STAR' in abstr_type:
                        abstr_type = 'Q*'
                    try:
                        num = key.split(',')[1].strip(')')
                    except:
                        print(abstr_type)
                        print(key)
                        quit()
                    title = str(key) + ', episode ' + str(episode)
                else:
                    abstr_type = 'ground'
                    num = ''
                fig, axs = plt.subplots(2, 2)

                #print(key, episode)

                # Subset for the given ensemble/episode
                temp_df = joined_df.loc[(joined_df['episode'] == episode)
                                        & (joined_df['ensemble_key'] == key)]

                # Average error across all ensembles
                temp_df = temp_df[['state', 'action', 'error']]
                temp_df = temp_df.groupby(['state', 'action'],
                                          as_index=False).mean()
                #print(temp_df.to_string())

                # This will hold the array mapping action to error-per-state
                error_dict = {}

                # Create 2d-array of errors where position corresponds to square location.
                # This is hacky, but it gets the data into a heatmap-able form
                for i in range(len(states)):
                    row = states[i]
                    up_row = np.array([])
                    down_row = np.array([])
                    left_row = np.array([])
                    right_row = np.array([])
                    for j in range(len(row)):
                        state_df = temp_df.loc[temp_df['state'] == states[i]
                                               [j]]
                        if state_df.empty:
                            up_row = np.append(up_row, 0)
                            down_row = np.append(down_row, 0)
                            left_row = np.append(left_row, 0)
                            right_row = np.append(right_row, 0)

                        else:
                            up_df = state_df.loc[state_df['action'] ==
                                                 'Dir.UP']
                            if up_df.empty:
                                up_row = np.append(up_row, 0)
                            else:
                                up_row = np.append(up_row,
                                                   up_df['error'].values[0])

                            down_df = state_df.loc[state_df['action'] ==
                                                   'Dir.DOWN']
                            if down_df.empty:
                                down_row = np.append(down_row, 0)
                            else:
                                down_row = np.append(
                                    down_row, down_df['error'].values[0])

                            left_df = state_df.loc[state_df['action'] ==
                                                   'Dir.LEFT']
                            if left_df.empty:
                                left_row = np.append(left_row, 0)
                            else:
                                left_row = np.append(
                                    left_row, left_df['error'].values[0])

                            right_df = state_df.loc[state_df['action'] ==
                                                    'Dir.RIGHT']
                            if right_df.empty:
                                right_row = np.append(right_row, 0)
                            else:
                                right_row = np.append(
                                    right_row, right_df['error'].values[0])
                        """
                        else:
                            up_row = np.append(up_row, state_df.loc[state_df['action'] == 'Dir.UP']['error'].values[0])
                            down_row = np.append(down_row, state_df.loc[state_df['action'] == 'Dir.DOWN']['error'].values[0])
                            left_row = np.append(left_row, state_df.loc[state_df['action'] == 'Dir.LEFT']['error'].values[0])
                            right_row = np.append(right_row, state_df.loc[state_df['action'] == 'Dir.RIGHT']['error'].values[0])
                        """

                    if 'Dir.UP' not in error_dict.keys():
                        error_dict['Dir.UP'] = up_row
                    else:
                        try:
                            error_dict['Dir.UP'] = np.vstack(
                                [error_dict['Dir.UP'], up_row])
                        except:
                            print('F**K')
                            print(error_dict['Dir.UP'], up_row)
                            quit()

                    if 'Dir.DOWN' not in error_dict.keys():
                        error_dict['Dir.DOWN'] = down_row
                    else:
                        error_dict['Dir.DOWN'] = np.vstack(
                            [error_dict['Dir.DOWN'], down_row])

                    if 'Dir.LEFT' not in error_dict.keys():
                        error_dict['Dir.LEFT'] = left_row
                    else:
                        error_dict['Dir.LEFT'] = np.vstack(
                            [error_dict['Dir.LEFT'], left_row])

                    if 'Dir.RIGHT' not in error_dict.keys():
                        error_dict['Dir.RIGHT'] = right_row
                    else:
                        error_dict['Dir.RIGHT'] = np.vstack(
                            [error_dict['Dir.RIGHT'], right_row])

                # Graph figures
                fig.suptitle(abstr_type + ', mdp' + num + ', episode ' +
                             str(episode))
                axs[0, 0].set_title('Up')
                im = axs[0, 0].imshow(error_dict['Dir.UP'],
                                      norm=MidpointNormalize(vmin=-1,
                                                             vmax=0,
                                                             midpoint=0),
                                      cmap=plt.get_cmap('bwr'))
                axs[0, 1].set_title('Down')
                im = axs[0, 1].imshow(error_dict['Dir.DOWN'],
                                      norm=MidpointNormalize(vmin=-1,
                                                             vmax=0,
                                                             midpoint=0),
                                      cmap=plt.get_cmap('bwr'))
                axs[1, 0].set_title('Left')
                im = axs[1, 0].imshow(error_dict['Dir.LEFT'],
                                      norm=MidpointNormalize(vmin=-1,
                                                             vmax=0,
                                                             midpoint=0),
                                      cmap=plt.get_cmap('bwr'))
                axs[1, 1].set_title('Right')
                im = axs[1, 1].imshow(error_dict['Dir.RIGHT'],
                                      norm=MidpointNormalize(vmin=-1,
                                                             vmax=0,
                                                             midpoint=0),
                                      cmap=plt.get_cmap('bwr'))
                cbar_ax = fig.add_axes([0.85, 0.15, 0.01, 0.7])
                fig.colorbar(im, cax=cbar_ax, cmap='bwr')

                # Save figure
                file_name = os.path.join(
                    q_value_folder,
                    abstr_type[:-1] + '_mdp' + num[1:] + '_ep' + str(episode))
                plt.savefig(file_name)
                fig.clf()
    def __init__(self,
                 mdp,
                 abstr_dicts=None,
                 num_corrupted_mdps=1,
                 num_agents=10,
                 num_episodes=100,
                 results_dir='exp_results/simple',
                 agent_exploration_epsilon=0.1,
                 agent_learning_rate=0.1,
                 detach_interval=None,
                 prevent_cycles=False,
                 variance_threshold=False,
                 reset_q_value=False):

        self.ground_mdp = mdp
        self.abstr_dicts = abstr_dicts
        self.num_agents = num_agents
        self.num_corrupted_mdps = num_corrupted_mdps
        self.num_episodes = num_episodes
        self.results_dir = results_dir
        self.num_episodes = num_episodes
        self.agent_exploration_epsilon = agent_exploration_epsilon
        self.agent_learning_rate = agent_learning_rate
        self.detach_interval = detach_interval
        self.prevent_cycles = prevent_cycles
        self.variance_threshold = variance_threshold
        self.reset_q_value = reset_q_value

        # Run VI and get q-table. Used for graphing results
        vi = ValueIteration(mdp)
        vi.run_value_iteration()
        q_table = vi.get_q_table()
        self.vi_table = q_table
        self.vi = vi

        # This will hold all the agents. Key is ('explicit errors', abstraction dict number, mdp number),
        #  value is the MDP itself
        self.agents = {}

        # Create the corrupt MDPs from the provided abstraction dictionaries
        self.corrupt_mdp_dict = {}
        if self.abstr_dicts is not None:
            if not os.path.exists(os.path.join(self.results_dir, 'corrupted')):
                os.makedirs(os.path.join(self.results_dir, 'corrupted'))
            for i in range(len(self.abstr_dicts)):
                abstr_dict = self.abstr_dicts[i]
                for j in range(self.num_corrupted_mdps):
                    # Make a state abstraction that corresponds to given abstraction dictionary
                    s_a = StateAbstraction(abstr_dict=abstr_dict, epsilon=0)
                    abstr_mdp = AbstractMDP(mdp, s_a)
                    self.corrupt_mdp_dict[('explicit errors', i,
                                           j)] = abstr_mdp

        # Create the agents on the ground MDP
        ground_agents = []
        for i in range(self.num_agents):
            temp_mdp = SimpleMDP()
            agent = AbstractionAgent(temp_mdp,
                                     epsilon=agent_exploration_epsilon,
                                     alpha=agent_learning_rate,
                                     decay_exploration=False)
            ground_agents.append(agent)
        self.agents['ground'] = ground_agents

        # Create agents on the corrupt MDPs
        self.corr_agents = {}
        for key in self.corrupt_mdp_dict.keys():
            corr_ensemble = []
            for i in range(self.num_agents):
                # This makes an AbstractionAgent from the state abstraction corresponding to the abstract MDP
                temp_mdp = copy.deepcopy(SimpleMDP())
                corr_mdp = copy.deepcopy(self.corrupt_mdp_dict[key].copy())
                s_a = copy.deepcopy(corr_mdp.state_abstr)
                agent = AbstractionAgent(temp_mdp,
                                         s_a,
                                         epsilon=agent_exploration_epsilon,
                                         alpha=agent_learning_rate,
                                         decay_exploration=False)
                corr_ensemble.append(agent)
            self.corr_agents[key] = corr_ensemble

        # If detach interval is set, create another set of agents that will run detachment algorithm
        self.corr_detach_agents = {}
        for key in self.corrupt_mdp_dict.keys():
            corr_ensemble = []
            for i in range(self.num_agents):
                print('making detach agent', i)
                temp_mdp = copy.deepcopy(SimpleMDP())
                corr_mdp = copy.deepcopy(self.corrupt_mdp_dict[key].copy())
                s_a = copy.deepcopy(corr_mdp.state_abstr)
                agent = AbstractionAgent(temp_mdp,
                                         s_a,
                                         epsilon=agent_exploration_epsilon,
                                         alpha=agent_learning_rate,
                                         decay_exploration=False)
                corr_ensemble.append(agent)
            self.corr_detach_agents[key] = corr_ensemble