Example #1
0
	def make_abstraction(self, abstr_type, epsilon, ignore_zeroes=False, threshold=1e-6):
		"""
		Create an abstraction out of the current q-table of the given type with given epsilon
		:return: new_abstr_mdp, a new abstract MDP made from the current q-table, with q-values informed by current
					q-table
		"""
		# Create a state abstraction based on the current q-table
		curr_q_table = self.get_q_table()
		new_abstr = make_abstr(curr_q_table, abstr_type, epsilon=epsilon, ignore_zeroes=ignore_zeroes, threshold=threshold)

		# Update agent's q-table for the new abstract states
		# For each new abstract state, average the state-action values of the constituent states and
		#  make that average the state-action value for the new abstract state
		new_q_table = defaultdict(lambda : 0.0)

		#  All old values are the same
		for key, value in curr_q_table.items():
			new_q_table[key] = value

		# Get possible states of MDP
		possible_states = self.mdp.get_all_possible_states()

		# Make guess at new values for abstract states by averaging state-action values of constituent states
		#  Iterate through all abstract states
		for abstr_state in new_abstr.abstr_dict.values():
			# For each action...
			for action in self.mdp.actions:
				action_val = 0
				map_count = 0
				# ...Get the states that are grouped together and average their state-action values for that action
				for ground_state in possible_states:
					if new_abstr.get_abstr_from_ground(ground_state).data == abstr_state:
						action_val += curr_q_table[(ground_state, action)]
						map_count += 1
				if map_count != 0:
					# Since abstr_state is just an integer, we have to make a State out of it
					new_q_table[(State(data=abstr_state, is_terminal=False), action)] = action_val / map_count

		# Assign this updated q-table to the agent's q-table
		self._q_table = new_q_table

		# Update the agent's MDP to be the AbstractMDP generated by combining the state abstraction with the current
		#  MDP
		new_abstr_mdp = AbstractMDP(self.mdp, new_abstr)
		self.mdp = new_abstr_mdp

		# Return number of abstract states and number of ground states mapped to abstract states
		unique_abstr_states = []
		ground_states = []
		for key in new_abstr.abstr_dict.keys():
			if key not in ground_states:
				ground_states.append(key)
			if new_abstr.abstr_dict[key] not in unique_abstr_states:
				unique_abstr_states.append(new_abstr.abstr_dict[key])
		return len(unique_abstr_states), len(ground_states)
def apply_noise_from_distribution(ground_mdp,
                                  abstr_type,
                                  approximation_epsilon=0.0,
                                  distribution=None,
                                  distribution_parameters=None,
                                  per_state_distribution=None,
                                  per_state_parameters=None,
                                  seed=None):
    """
    Run value iteration on ground MDP to get true abstraction of given type. Then apply noise by sampling from given
    distribution and add the sampled value to the Q-values. Then create approximate abstraction by grouping together
    based on given epsilon
    :param ground_mdp: the ground mdp with no abstractions
    :param abstr_type: what type of abstraction is desired
    :param distribution: a scipy distribution
    :param distribution_parameters: a dictionary of parameters passed to the distribution when sampling
    :param approximation_epsilon: the epsilon used in making approximate abstractions
    :param per_state_distribution: dictionary mapping states to distributions
    :param per_state_parameters: dictionary mapping states to parameters used for their per-state distributions
    """
    # Get Q-table
    vi = ValueIteration(ground_mdp)
    vi.run_value_iteration()
    q_table = vi.get_q_table()

    # Apply noise sampled from distribution to Q-table
    for (state, action), value in q_table.items():
        #print(state, action, value)
        # If there is a specific per-state distribution, apply that
        if per_state_distribution:
            if state in per_state_distribution.keys():
                dist = per_state_distribution[state]
                args = per_state_parameters[state]
                noise = dist.rvs(**args)
                q_table[(state, action)] += noise
        # Otherwise apply mdp-wide distribution
        else:
            noise = distribution.rvs(**distribution_parameters)
            #print(noise)
            q_table[(state, action)] += noise
        #print('New value:', q_table[(state, action)],'\n')

    # Make new epsilon-approximate abstraction
    new_s_a = make_abstr(q_table,
                         abstr_type,
                         epsilon=approximation_epsilon,
                         combine_zeroes=True,
                         threshold=0.0,
                         seed=seed)

    # Create abstract MDP with this corrupted s_a
    corr_mdp = AbstractMDP(ground_mdp, new_s_a)

    return corr_mdp
    def make_abstr_mdp(self, abstr_type, abstr_epsilon=0.0, seed=None):
        """
		Create an abstract MDP with the given abstraction type
		:param abstr_type: the type of abstraction
		:param abstr_epsilon: the epsilon threshold for approximate abstraction
		:return: abstr_mdp
		"""
        vi = ValueIteration(self)
        vi.run_value_iteration()
        q_table = vi.get_q_table()
        s_a = make_abstr(q_table, abstr_type, abstr_epsilon, seed=seed)
        abstr_mdp = AbstractMDP(self, s_a)
        return abstr_mdp
    def make_online_abstraction(self,
                                abstr_type,
                                epsilon=1e-12,
                                combine_zeroes=False,
                                seed=None):
        """
        Convert the existing Q-table into the given type of abstraction
        :param abstr_type: type of abstraction to make
        :param epsilon: approximation epsilon for making abstraction
        :param combine_zeroes: if true, all states with value 0 are combined
        :param threshold: minimum threshold for what counts as a 0 state
        :param seed: ignore
        """
        approx_s_a = make_abstr(self._q_table,
                                abstr_type,
                                epsilon=epsilon,
                                combine_zeroes=combine_zeroes,
                                seed=seed)

        q_table = self.get_q_table()

        self.s_a = approx_s_a
        zero_count = 0
        for state in self.s_a.get_all_abstr_states():
            is_zero = True
            for a in self.mdp.actions:
                if self._q_table[(state, a)] != 0:
                    is_zero = False
            if is_zero:
                zero_count += 1
        print('zero count:', zero_count)

        # Add abstract states to transition records
        for abstr_state in self.s_a.get_all_abstr_states():
            ground_states = self.s_a.get_ground_from_abstr(abstr_state)
            self.abstr_state_action_pair_counts[abstr_state] = {}
            for action in self.mdp.actions:
                self.abstr_state_action_pair_counts[abstr_state][action] = 0
                for ground in ground_states:
                    self.abstr_state_action_pair_counts[abstr_state][
                        action] += self.state_action_pair_counts[ground][
                            action]
                    self.abstr_update_record[abstr_state][action].extend(
                        self.ground_update_record[ground][action])
        self.group_dict = self.reverse_abstr_dict(self.s_a.abstr_dict)
Example #5
0
    # results: 
    # 12 abstract states, 9 have 2 ground states and the other 
    # 3 have 3 
    """

    # Taxi MDP tests

    # Q-star, epsilon = 0
    mdp = TaxiMDP(slip_prob=0.0, gamma=0.99)
    vi = ValueIteration(mdp)
    vi.run_value_iteration()
    q_table = vi.get_q_table()

    for key in q_table.keys():
        print(key[0], key[1], q_table[key])
    abstr = make_abstr(q_table, Abstr_type.Q_STAR)

    # Count the number of states that get abstracted together
    state_count = 0
    states_visited = []
    for key in abstr.get_abstr_dict().keys():
        state_count += 1
    print(state_count)
    print(abstr.get_abstr_dict())

    # Write results of q-table to file
    f = open('test_abstr_results.txt', 'w')
    for key in q_table.keys():
        to_write = str(key[0]) + ' ' + str(key[1]) + ' ' + str(
            q_table[key]) + '\n'
        f.write(to_write)
Example #6
0
def main():

    # Testing what a Q* abstraction looks like in
    # four rooms

    # Make MDP and train an agent in it
    grid_mdp = GridWorldMDP(height=9, width=9, slip_prob=0.0, gamma=0.99)
    agent = Agent(grid_mdp)

    # Train the agent for 10000 steps
    trajectory = []
    for i in range(100000):
        if i % 1000 == 0:
            print("epsilon, alpha:", agent._epsilon, agent._alpha)
        current_state, action, next_state, _ = agent.explore()
        trajectory.append(current_state)

    already_printed = []
    for state in trajectory:
        if state not in already_printed:
            already_printed.append(state)

    # Print the action values learned at each state
    for state in already_printed:
        print("values learned at state", state)
        print_action_values(agent.get_action_values(state))
        print()

    # Make an abstraction from the agent's q-table
    state_abstr = make_abstr(agent.get_q_table(),
                             Abstr_type.Q_STAR,
                             epsilon=0.05)
    print(state_abstr)

    # Testing that Pi* abstraction works
    '''
	# Create toy q_table to build abstraction from 
	q_table = {(GridWorldState(1,1), Dir.UP): 0.9,
				(GridWorldState(1,1), Dir.DOWN): 0.8,
				(GridWorldState(1,1), Dir.LEFT): 0.7,
				(GridWorldState(1,1), Dir.RIGHT): 0.6,

				# Same optimal action and action value as (1,1)
				(GridWorldState(1,2), Dir.UP): 0.9,
				(GridWorldState(1,2), Dir.DOWN): 0.0,
				(GridWorldState(1,2), Dir.LEFT): 0.2,
				(GridWorldState(1,2), Dir.RIGHT): 0.5,

				# val(UP) = 0.9 but val(DOWN) = 0.91
				(GridWorldState(2,2), Dir.UP): 0.9,
				(GridWorldState(2,2), Dir.DOWN): 0.91,
				(GridWorldState(2,2), Dir.LEFT): 0.8,
				(GridWorldState(2,2), Dir.RIGHT): 0.9,

				# val(UP) = 0.89, max val
				(GridWorldState(2,1), Dir.UP): 0.9,
				(GridWorldState(2,1), Dir.DOWN): 0.9,
				(GridWorldState(2,1), Dir.LEFT): 0.90000000001,
				(GridWorldState(2,1), Dir.RIGHT): 0.7,

				# val(UP) = 0.93, max val 
				(GridWorldState(3,1), Dir.UP): 1000,
				(GridWorldState(3,1), Dir.DOWN): 0.89,
				(GridWorldState(3,1), Dir.LEFT): 0.89,
				(GridWorldState(3,1), Dir.RIGHT): 0.89}
	
	state_abstr = make_abstr(q_table, Abstr_type.PI_STAR)
	print("(1,1), (1,2), and (3,1) should all get mapped together")
	print(state_abstr)
	'''

    # Testing that A* abstraction works
    '''
	# Create toy q_table to build abstraction from 
				# Optimal action/val is UP/0.9
	q_table = {(GridWorldState(1,1), Dir.UP): 0.9,
				(GridWorldState(1,1), Dir.DOWN): 0.8,
				(GridWorldState(1,1), Dir.LEFT): 0.7,
				(GridWorldState(1,1), Dir.RIGHT): 0.6,

				# Same optimal action and action value as (1,1)
				(GridWorldState(1,2), Dir.UP): 0.9,
				(GridWorldState(1,2), Dir.DOWN): 0.0,
				(GridWorldState(1,2), Dir.LEFT): 0.2,
				(GridWorldState(1,2), Dir.RIGHT): 0.5,

				# val(UP) = 0.9 but val(DOWN) = 0.91
				(GridWorldState(2,2), Dir.UP): 0.9,
				(GridWorldState(2,2), Dir.DOWN): 0.91,
				(GridWorldState(2,2), Dir.LEFT): 0.8,
				(GridWorldState(2,2), Dir.RIGHT): 0.9,

				# val(UP) = 0.89, max val
				(GridWorldState(2,1), Dir.UP): 0.89,
				(GridWorldState(2,1), Dir.DOWN): 0.88,
				(GridWorldState(2,1), Dir.LEFT): 0.8,
				(GridWorldState(2,1), Dir.RIGHT): 0.7,

				# val(UP) = 0.93, max val 
				(GridWorldState(3,1), Dir.UP): 0.93,
				(GridWorldState(3,1), Dir.DOWN): 0.89,
				(GridWorldState(3,1), Dir.LEFT): 0.89,
				(GridWorldState(3,1), Dir.RIGHT): 0.89}
	state_abstr = make_abstr(q_table, Abstr_type.A_STAR)
	print("Epsilon = 0. (1,1) and (1,2) should be mapped together")
	print(state_abstr)

	state_abstr = make_abstr(q_table, Abstr_type.A_STAR, epsilon=0.015)
	print("Epsilon = 0.015. (1,1), (1,2), and (2,1) should all be mapped together")
	print(state_abstr)

	state_abstr = make_abstr(q_table, Abstr_type.A_STAR, epsilon=0.031)
	print("Epsilon = 0.031. (1,1), (1,2), (2,1), (3,1) should all be mapped together")
	print(state_abstr)
	'''

    # Testing that Q* abstraction function works
    '''
	# Create toy q_table to build the abstraction from
	q_table = {(GridWorldState(1,1), Dir.UP): 1.0,
				(GridWorldState(1,1), Dir.DOWN): 2.5,
				(GridWorldState(1,1), Dir.LEFT): 2.3,
				(GridWorldState(1,1), Dir.RIGHT): 5.0,

				(GridWorldState(2,1), Dir.UP): 1.0,
				(GridWorldState(2,1), Dir.DOWN): 2.5,
				(GridWorldState(2,1), Dir.LEFT): 2.3,
				(GridWorldState(2,1), Dir.RIGHT): 5.05,

				(GridWorldState(2,2), Dir.UP): 1.1,
				(GridWorldState(2,2), Dir.DOWN): 2.4,
				(GridWorldState(2,2), Dir.LEFT): 2.3,
				(GridWorldState(2,2), Dir.RIGHT): 4.8,

				(GridWorldState(1,2), Dir.UP): 1.3,
				(GridWorldState(1,2), Dir.DOWN): 2.0,
				(GridWorldState(1,2), Dir.LEFT): 2.0,
				(GridWorldState(1,2), Dir.RIGHT): 4.8
				}
	state_abstr = make_abstr(q_table, Abstr_type.Q_STAR)
	print("Epsilon = 0. No shapes should be mapped together.")
	print(str(state_abstr))

	state_abstr = make_abstr(q_table, Abstr_type.Q_STAR, epsilon=0.3)
	print("Epsilon = 0.3. (1,1), (2,1), (2,2) should all be mapped together")
	print(str(state_abstr))

	state_abstr = make_abstr(q_table, Abstr_type.Q_STAR, epsilon=0.1)
	print("Epsilon = 0.1. (1,1), (2,1) should be mapped together. (2,2) should not.")
	print(str(state_abstr))

	state_abstr = make_abstr(q_table, Abstr_type.Q_STAR, epsilon=0.5)
	print("Epsilon = 0.5. (1,1), (2,1), (1,2), (2,2) should all be mapped together")
	print(str(state_abstr))
	'''

    # Testing Q-learning in abstract Four Rooms
    '''
	# Map all the states in the bottom-right room to the same abstract state 
	abstr_dict = {} 
	for i in range(6,12):
		for j in range(1,6):
			abstr_dict[GridWorldState(i,j)] = 'oneroom'

	state_abstr = StateAbstraction(abstr_dict)

	abstr_mdp = AbstractGridWorldMDP(height=11, 
										width=11,
										slip_prob=0.0,
										gamma=0.95,
										build_walls=True,
										state_abstr=state_abstr)
	agent = Agent(abstr_mdp)

	trajectory = [] 
	for i in range(100000):
		#print("At step", i)
		#print("parameters are", agent._alpha, agent.mdp.gamma)
		current_state, action, next_state, _ = agent.explore()
		#print("At", str(current_state), "took action", action, "got to", str(next_state))
		#print("Values learned for", str(current_state), "is")
		#print_action_values(agent.get_action_values(current_state))
		trajectory.append(current_state)
		#print()

	already_printed = [] 
	for state in trajectory:
		if state not in already_printed:
			print("values learned at state", state)
			print_action_values(agent.get_action_values(state))
			already_printed.append(state)

	agent.reset_to_init()
	for i in range(25):
		current_state, action, next_state = agent.apply_best_action()
		print('At', str(current_state), 'taking action', str(action), 'now at', str(next_state))
	'''

    # Testing Q-learning in toy abstract MDP
    '''
	# Simple abstraction in a grid where all states above the start-to-goal
	# diagonal are grouped together and all states below that diagonal
	# are grouped together 
	toy_abstr = StateAbstraction({GridWorldState(2,1): 'up', 
									GridWorldState(3,1): 'up',
									GridWorldState(3,2): 'up',
									GridWorldState(4,1): 'up',
									GridWorldState(4,2): 'up',
									GridWorldState(4,3): 'up',
									GridWorldState(5,1): 'up',
									GridWorldState(5,2): 'up',
									GridWorldState(5,3): 'up',
									GridWorldState(5,4): 'up',
									GridWorldState(1,2): 'right',
									GridWorldState(1,3): 'right',
									GridWorldState(1,4): 'right',
									GridWorldState(1,5): 'right',
									GridWorldState(2,3): 'right',
									GridWorldState(2,4): 'right',
									GridWorldState(2,5): 'right',
									GridWorldState(3,4): 'right',
									GridWorldState(3,5): 'right',
									GridWorldState(4,5): 'right'})
	#print("states covered by abstraction are", toy_abstr.abstr_dict.keys())
	

	abstr_mdp = AbstractGridWorldMDP(height=5, 
							width=5, 
							slip_prob=0.0, 
							gamma=0.95, 
							build_walls=False,
							state_abstr=toy_abstr)

	#print(abstr_mdp.state_abstr.get_abstr_from_ground(GridWorldState(1,1)))
	agent = Agent(abstr_mdp)
	
	trajectory = [] 
	for i in range(10000):
		#print("At step", i)
		#print("parameters are", agent._alpha, agent.mdp.gamma)
		current_state, action, next_state, _ = agent.explore()
		#print("At", str(current_state), "took action", action, "got to", str(next_state))
		#print("Values learned for", str(current_state), "is")
		#print_action_values(agent.get_action_values(current_state))
		trajectory.append(current_state)
		#print()

	already_printed = [] 
	for state in trajectory:
		if state not in already_printed:
			print("values learned at state", state)
			print_action_values(agent.get_action_values(state))
			already_printed.append(state)
	'''

    # Testing both epsilon-greedy and application of best learned
    # policy in ground MDP
    '''
	grid_mdp = GridWorldMDP(height=9, width=9, slip_prob=0.0, gamma=0.95, build_walls=True)

	agent = Agent(grid_mdp)
	#agent.set_current_state(GridWorldState(1,1))

	print(grid_mdp.goal_location)

	# Testing if epsilon-greedy policy works properly 
	trajectory = [] 
	for i in range(10000):
		#print("At step", i)
		#print("parameters are", agent._alpha, agent.mdp.gamma)
		current_state, action, next_state, _ = agent.explore()
		#print("At", str(current_state), "took action", action, "got to", str(next_state))
		#print("Values learned for", str(current_state), "is")
		#print_action_values(agent.get_action_values(current_state))
		trajectory.append(current_state)
		#print()

	#print("Went through the following states:")
	#for state in trajectory:
	#	print(str(state))
	already_printed = [] 
	for state in trajectory:
		if state not in already_printed:
			print("values learned at state", state)
			print_action_values(agent.get_action_values(state))
			already_printed.append(state)
	#print(grid_mdp.walls)

	agent.reset_to_init()

	for i in range(25):
		current_state, action, next_state = agent.apply_best_action()
		print('At', str(current_state), 'taking action', str(action), 'now at', str(next_state))
	'''

    # Testing a few trajectories to make sure the q-table updates
    # properly
    '''
	test_trajectory = [Dir.UP, Dir.RIGHT, Dir.UP, Dir.RIGHT]
	for i in range(5):
		apply_trajectory(agent, test_trajectory)
		agent.set_current_state(GridWorldState(9,9))

	test_trajectory = [Dir.RIGHT, Dir.RIGHT, Dir.UP, Dir.UP]
	apply_trajectory(agent, test_trajectory)
	agent.set_current_state(GridWorldState(9,9))

	test_trajectory = [Dir.UP, Dir.UP, Dir.RIGHT, Dir.RIGHT]
	apply_trajectory(agent, test_trajectory)
	'''

    # Testing motion, reward at goal state, and reset to
    # initial state at terminal state
    '''
	agent = Agent(grid_mdp, go_up_right)
	for i in range(30):
		agent.act()
	print(grid_mdp.walls)
	'''

    # Testing getter for best action/value given state
    '''
	agent = Agent(grid_mdp, go_right, alpha=0.5)
	current_state = agent.get_current_state() 
	test_action = Dir.UP

	# Set q_value for init_state, Dir.UP = 1.0
	agent._set_q_value(current_state, test_action, 1.0)

	# Should give Dir.UP, 1.0 
	print("should give (Dir.UP, 1.0)", agent.get_best_action_value_pair(current_state))

	# Go right by one 
	agent.act()
	print("Currently at", agent.get_current_state())
	# Should give random action with value = 0 
	print("Should give (random_action, 0.0)", agent.get_best_action_value_pair(agent.get_current_state()))
	# Update q-values of this state
	agent._set_q_value(agent.get_current_state(), Dir.UP, -1.0)
	agent._set_q_value(agent.get_current_state(), Dir.DOWN, -1.0)
	agent._set_q_value(agent.get_current_state(), Dir.LEFT, -1.0)
	agent._set_q_value(agent.get_current_state(), Dir.RIGHT, 0.1)
	# Should give Dir.RIGHT, 0.1
	print("Should give (Dir.RIGHT, 0.1)", agent.get_best_action_value_pair(agent.get_current_state()))

	print()
	# Checking that all values were updated properly
	for action in agent.mdp.actions:
		print("action:q-value = ", action, ":", agent.get_q_value(agent.get_current_state(), action))
	'''

    # Testing single instance of the act, update flow
    # Start agent at (10,11), go one right, get reward,
    # check that update happened
    '''
from MDP.ValueIterationClass import ValueIteration
from resources.AbstractionTypes import Abstr_type
from resources.AbstractionCorrupters import make_corruption
from resources.AbstractionMakers import make_abstr

import numpy as np

# Number of states to corrupt
STATE_NUM = 20

# Create abstract MDP
mdp = GridWorldMDP()
vi = ValueIteration(mdp)
vi.run_value_iteration()
q_table = vi.get_q_table()
state_abstr = make_abstr(q_table, Abstr_type.PI_STAR)
abstr_mdp = AbstractMDP(mdp, state_abstr)

# Randomly select our list of states and print them out
states_to_corrupt = np.random.choice(mdp.get_all_possible_states(),
                                     size=STATE_NUM,
                                     replace=False)
for state in states_to_corrupt:
    print(state)

# Create a corrupt MDP
corr_mdp = make_corruption(abstr_mdp, states_to_corrupt)

for state in states_to_corrupt:
    print(abstr_mdp.get_abstr_from_ground(state),
          corr_mdp.get_abstr_from_ground(state))
Example #8
0
    for key in q_table:
        print(key[0], key[1], q_table[key])


if __name__ == '__main__':

    # GridWorld

    # Make ground MDP
    mdp = GridWorldMDP(slip_prob=0.0)
    # Run VI to get q-table
    vi = ValueIteration(mdp)
    vi.run_value_iteration()
    q_table = vi.get_q_table()
    # Make state abstractions
    q_star_abstr = make_abstr(q_table, Abstr_type.Q_STAR)
    a_star_abstr = make_abstr(q_table, Abstr_type.A_STAR)
    pi_star_abstr = make_abstr(q_table, Abstr_type.PI_STAR)
    # Make abstract MDPs - NOTE THIS CLASS HAS BEEN DEPRECATED DO NOT USE
    q_mdp = AbstractGridWorldMDP(state_abstr=q_star_abstr)
    a_mdp = AbstractGridWorldMDP(state_abstr=a_star_abstr)
    pi_mdp = AbstractGridWorldMDP(state_abstr=pi_star_abstr)

    # This is the type of
    q2_mdp = AbstractMDP(mdp, state_abstr=q_star_abstr)

    print("VALUE OF OPTIMAL POLICY")
    print_q_table(q_table)

    print("\n\n\nQ* ABSTR")
    print(q_star_abstr)
Example #9
0
from resources.AbstractionTypes import Abstr_type
from resources.AbstractionMakers import make_abstr
from MDP.ValueIterationClass import ValueIteration
from scipy.stats import skewnorm, norm
import numpy as np

if __name__ == '__main__':
    np.random.seed(1234)
    args = {'loc': 0, 'scale': 0.01}
    mdp = GridWorldMDP()
    abstr_type = Abstr_type.PI_STAR
    corr_abstr_mdp = apply_noise_from_distribution(mdp,
                                                   abstr_type,
                                                   norm,
                                                   args,
                                                   0.005,
                                                   seed=124)

    # Get true abstraction to compare it with
    vi = ValueIteration(mdp)
    vi.run_value_iteration()
    q_table = vi.get_q_table()
    true_abstr = make_abstr(q_table, abstr_type, seed=1234)

    corr_abstr_dict = corr_abstr_mdp.state_abstr.abstr_dict
    true_abstr_dict = true_abstr.abstr_dict
    for state in true_abstr_dict.keys():
        print(state)
        print('True abstr state', true_abstr_dict[state])
        print('Corr abstr state', corr_abstr_dict[state])
                best_action_intersect = list(
                    set(best_actions_1) & set(best_actions_2))
                if len(best_action_intersect) == 0:
                    return False
    return True


def print_policy(policy):
    '''
    Print the policy
    '''
    for key in policy.keys():
        print(key, policy[key])


if __name__ == '__main__':
    # Test that optimal ground policy for FourRooms is representable in
    # abstaction given by Q*

    # Get optimal ground policy for FourRooms
    four_rooms = GridWorldMDP(slip_prob=0.0, gamma=0.99)
    vi = ValueIteration(four_rooms)
    vi.run_value_iteration()
    optimal_policy = vi.get_optimal_policy()
    #print_policy(optimal_policy)

    # Get Q* abstraction for FourRooms and optimal abstract policy
    abstr = make_abstr(vi.get_q_table(), Abstr_type.A_STAR)

    print(is_optimal_policy_representable(vi, optimal_policy, abstr))
# # abs_agent = Agent(abstr_grid_mdp)
# # abs_g_viz = AbstractGridWorldVisualizer(abstr_grid_mdp,abs_agent)
# # #abs_g_viz.displayAbstractMDP()
# # for i in range(100000):
# #     if i % 1000 == 0:
# #         print("epsilon, alpha:", abs_agent._epsilon, abs_agent._alpha)
# #     current_state, action, next_state,_  = abs_agent.explore()
# #
# # abs_g_viz.visualizeLearnedPolicy()

#Q-STAR - USING VI
mdp = GridWorldMDP(slip_prob=0, gamma=0.99)
vi = ValueIteration(mdp)
vi.run_value_iteration()
q_table = vi.get_q_table()
q_star_abstr = make_abstr(q_table, Abstr_type.Q_STAR, epsilon=0.01)
abstr_grid_mdp = AbstractGridWorldMDP(state_abstr=q_star_abstr)
abs_agent = Agent(abstr_grid_mdp)
abs_g_viz = AbstractGridWorldVisualizer(abstr_grid_mdp, abs_agent)
#abs_g_viz.displayAbstractMDP()
for i in range(100000):
    if i % 1000 == 0:
        print("epsilon, alpha:", abs_agent._epsilon, abs_agent._alpha)
    current_state, action, next_state, _ = abs_agent.explore()

abs_g_viz.visualizeLearnedPolicy()
'''
#A-STAR - USING VI
mdp = GridWorldMDP(slip_prob=0, gamma=0.99)
vi = ValueIteration(mdp)
vi.run_value_iteration()