Пример #1
0
 def check_mdp(mdp):
     new_mdp = GridworldMdp.from_numpy_input(
         *mdp.convert_to_numpy_input())
     self.assertEqual(new_mdp.height, mdp.height)
     self.assertEqual(new_mdp.width, mdp.width)
     self.assertEqual(new_mdp.walls, mdp.walls)
     self.assertEqual(new_mdp.rewards, mdp.rewards)
     self.assertEqual(new_mdp.start_state, mdp.start_state)
Пример #2
0
    def test_constructor_invalid_inputs(self):
        # Height and width must be at least 2.
        with self.assertRaises(AssertionError):
            mdp = GridworldMdp(['X', 'X', 'X'])
        with self.assertRaises(AssertionError):
            mdp = GridworldMdp([['X', 'X', 'X']])

        with self.assertRaises(AssertionError):
            # Borders must be present.
            mdp = GridworldMdp(['  A', '3X ', '  1'])

        with self.assertRaises(AssertionError):
            # There can't be more than one agent.
            mdp = GridworldMdp(['XXXXX', 'XA 3X', 'X3 AX', 'XXXXX'])

        with self.assertRaises(AssertionError):
            # There must be one agent.
            mdp = GridworldMdp(['XXXXX', 'X  3X', 'X3  X', 'XXXXX'])

        with self.assertRaises(AssertionError):
            # There must be at least one reward.
            mdp = GridworldMdp(['XXXXX', 'XAX X', 'X   X', 'XXXXX'])

        with self.assertRaises(AssertionError):
            # B is not a valid element.
            mdp = GridworldMdp(['XXXXX', 'XB  X', 'X  3X', 'XXXXX'])
Пример #3
0
 def test_random_gridworld_generation(self):
     random.seed(314159)
     mdp = GridworldMdp.generate_random(8, 8, 0, 1)
     self.assertEqual(mdp.height, 8)
     self.assertEqual(mdp.width, 8)
     mdp_string = str(mdp)
     self.assertEqual(mdp_string.count('X'), 28)
     self.assertEqual(mdp_string.count(' '), 34)
     self.assertEqual(mdp_string.count('A'), 1)
     self.assertEqual(mdp_string.count('3'), 1)
Пример #4
0
    def setUp(self):
        self.grid1 = [['X', 'X', 'X', 'X', 'X'], ['X', ' ', ' ', 'A', 'X'],
                      ['X', '3', 'X', ' ', 'X'], ['X', ' ', ' ', '1', 'X'],
                      ['X', 'X', 'X', 'X', 'X']]
        self.grid2 = [
            'XXXXXXXXX', 'X9X X  AX', 'X X X   X', 'X       X', 'XXXXXXXXX'
        ]
        self.grid3 = [['X', 'X', 'X', 'X', 'X'], ['X', 3.5, 'X', -10, 'X'],
                      ['X', ' ', '0', ' ', 'X'], ['X', ' ', ' ', 'A', 'X'],
                      ['X', 'X', 'X', 'X', 'X']]

        self.mdp1 = GridworldMdp(self.grid1, living_reward=0)
        self.mdp2 = GridworldMdp(self.grid2, noise=0.2)
        self.mdp3 = GridworldMdp(self.grid3)
Пример #5
0
    def test_gridworld_planner(self):
        def check_model_equivalent(model, query, weights, mdp, num_iters):
            with tf.compat.v1.Session() as sess:
                sess.run(model.initialize_op)
                (qvals, ) = model.compute(['q_values'],
                                          sess,
                                          mdp,
                                          query,
                                          weight_inits=weights)

            agent = OptimalAgent(gamma=model.gamma, num_iters=num_iters)
            for i, proxy in enumerate(model.proxy_reward_space):
                for idx, val in zip(query, proxy):
                    mdp.rewards[idx] = val
                agent.set_mdp(mdp)
                check_qvals_equivalent(qvals[i], agent, mdp)

        def check_qvals_equivalent(qvals, agent, mdp):
            for state in mdp.get_states():
                if mdp.is_terminal(state):
                    continue
                x, y = state
                for action in mdp.get_actions(state):
                    expected_q = agent.qvalue(state, action)
                    action_num = Direction.get_number_from_direction(action)
                    actual_q = qvals[y, x, action_num]
                    # Using softmax, not max, so expect limited accuracy
                    self.assertAlmostEqual(expected_q, actual_q, places=2)

        np.random.seed(1)
        random.seed(1)
        dim = 4
        grid = GridworldMdp.generate_random(8, 8, 0.1, dim)
        mdp = GridworldMdpWithDistanceFeatures(grid)
        mdp.rewards = np.random.randint(-9, 10, size=[dim])
        query = [0, 3]
        other_weights = mdp.rewards[1:3]
        # Use beta_planner = 1000 so that softmax is approximately max
        model = GridworldModel(dim, 0.9, len(query), 2, 1, None, 1, 1000, [],
                               0.1, False, True, 8, 8, 25)
        check_model_equivalent(model, query, other_weights, mdp, 25)
Пример #6
0
            train_inferences.append(inference)


    # Set up env and agent for gridworld
    elif args.mdp_type == 'gridworld':
        # Create train and test MDPs
        test_inferences = []

        for i in range(args.num_test_envs):
            test_grid, test_goals = GridworldMdp.generate_random(
                args,
                height,
                width,
                0.35,
                args.feature_dim,
                None,
                living_reward=-0.01,
                print_grid=False,
                decorrelate=args.decorrelate_test_feat
            )
            mdp = GridworldMdpWithDistanceFeatures(
                test_grid,
                test_goals,
                args,
                dist_scale,
                living_reward=-0.01,
                noise=0
            )
            env = GridworldEnvironment(mdp)
Пример #7
0
class TestGridworld(unittest.TestCase):
    def setUp(self):
        self.grid1 = [['X', 'X', 'X', 'X', 'X'], ['X', ' ', ' ', 'A', 'X'],
                      ['X', '3', 'X', ' ', 'X'], ['X', ' ', ' ', '1', 'X'],
                      ['X', 'X', 'X', 'X', 'X']]
        self.grid2 = [
            'XXXXXXXXX', 'X9X X  AX', 'X X X   X', 'X       X', 'XXXXXXXXX'
        ]
        self.grid3 = [['X', 'X', 'X', 'X', 'X'], ['X', 3.5, 'X', -10, 'X'],
                      ['X', ' ', '0', ' ', 'X'], ['X', ' ', ' ', 'A', 'X'],
                      ['X', 'X', 'X', 'X', 'X']]

        self.mdp1 = GridworldMdp(self.grid1, living_reward=0)
        self.mdp2 = GridworldMdp(self.grid2, noise=0.2)
        self.mdp3 = GridworldMdp(self.grid3)

    def test_str(self):
        expected = '\n'.join([''.join(row) for row in self.grid1])
        self.assertEqual(str(self.mdp1), expected)
        expected = '\n'.join(self.grid2)
        self.assertEqual(str(self.mdp2), expected)
        expected = '\n'.join(['XXXXX', 'XRXNX', 'X 0 X', 'X  AX', 'XXXXX'])
        self.assertEqual(str(self.mdp3), expected)

    def test_constructor_invalid_inputs(self):
        # Height and width must be at least 2.
        with self.assertRaises(AssertionError):
            mdp = GridworldMdp(['X', 'X', 'X'])
        with self.assertRaises(AssertionError):
            mdp = GridworldMdp([['X', 'X', 'X']])

        with self.assertRaises(AssertionError):
            # Borders must be present.
            mdp = GridworldMdp(['  A', '3X ', '  1'])

        with self.assertRaises(AssertionError):
            # There can't be more than one agent.
            mdp = GridworldMdp(['XXXXX', 'XA 3X', 'X3 AX', 'XXXXX'])

        with self.assertRaises(AssertionError):
            # There must be one agent.
            mdp = GridworldMdp(['XXXXX', 'X  3X', 'X3  X', 'XXXXX'])

        with self.assertRaises(AssertionError):
            # There must be at least one reward.
            mdp = GridworldMdp(['XXXXX', 'XAX X', 'X   X', 'XXXXX'])

        with self.assertRaises(AssertionError):
            # B is not a valid element.
            mdp = GridworldMdp(['XXXXX', 'XB  X', 'X  3X', 'XXXXX'])

    def test_start_state(self):
        self.assertEqual(self.mdp1.get_start_state(), (3, 1))
        self.assertEqual(self.mdp2.get_start_state(), (7, 1))
        self.assertEqual(self.mdp3.get_start_state(), (3, 3))

    def test_reward_parsing(self):
        self.assertEqual(self.mdp1.rewards, {(1, 2): 3, (3, 3): 1})
        self.assertEqual(self.mdp2.rewards, {(1, 1): 9})
        self.assertEqual(self.mdp3.rewards, {
            (1, 1): 3.5,
            (2, 2): 0,
            (3, 1): -10
        })

    def test_actions(self):
        a = [Direction.NORTH, Direction.SOUTH, Direction.EAST, Direction.WEST]
        all_acts = set(a)
        exit_acts = set([Direction.EXIT])
        no_acts = set([])

        self.assertEqual(set(self.mdp1.get_actions((0, 0))), no_acts)
        self.assertEqual(set(self.mdp1.get_actions((1, 1))), all_acts)
        self.assertEqual(set(self.mdp1.get_actions((1, 2))), exit_acts)
        self.assertEqual(set(self.mdp2.get_actions((6, 2))), all_acts)
        self.assertEqual(set(self.mdp2.get_actions((3, 1))), all_acts)
        self.assertEqual(set(self.mdp3.get_actions((2, 2))), exit_acts)

    def test_rewards(self):
        grid1_reward_table = {
            ((3, 3), Direction.EXIT): 1,
            ((1, 2), Direction.EXIT): 3
        }
        grid2_reward_table = {((1, 1), Direction.EXIT): 9}
        grid3_reward_table = {
            ((1, 1), Direction.EXIT): 3.5,
            ((2, 2), Direction.EXIT): 0,
            ((3, 1), Direction.EXIT): -10
        }
        self.check_all_rewards(self.mdp1, grid1_reward_table, 0)
        self.check_all_rewards(self.mdp2, grid2_reward_table, -0.01)
        self.check_all_rewards(self.mdp3, grid3_reward_table, -0.01)

    def check_all_rewards(self, mdp, reward_lookup_table, default):
        for state in mdp.get_states():
            for action in mdp.get_actions(state):
                expected = reward_lookup_table.get((state, action), default)
                self.assertEqual(mdp.get_reward(state, action), expected)

    def test_transitions(self):
        n, s = Direction.NORTH, Direction.SOUTH
        e, w = Direction.EAST, Direction.WEST
        exit_action = Direction.EXIT

        # Grid 1: No noise
        result = self.mdp1.get_transition_states_and_probs((1, 3), n)
        self.assertEqual(set(result), set([((1, 2), 1)]))
        result = self.mdp1.get_transition_states_and_probs((1, 2), exit_action)
        self.assertEqual(set(result), set([(self.mdp1.terminal_state, 1)]))
        result = self.mdp1.get_transition_states_and_probs((1, 1), n)
        self.assertEqual(set(result), set([((1, 1), 1)]))

        # Grid 2: Noise of 0.2
        result = set(self.mdp2.get_transition_states_and_probs((1, 2), n))
        self.assertEqual(result, set([((1, 1), 0.8), ((1, 2), 0.2)]))
        result = set(self.mdp2.get_transition_states_and_probs((6, 2), w))
        self.assertEqual(result,
                         set([((5, 2), 0.8), ((6, 1), 0.1), ((6, 3), 0.1)]))
        result = set(self.mdp2.get_transition_states_and_probs((7, 3), e))
        self.assertEqual(result, set([((7, 3), 0.9), ((7, 2), 0.1)]))
        result = set(self.mdp2.get_transition_states_and_probs((5, 1), s))
        self.assertEqual(result,
                         set([((5, 2), 0.8), ((5, 1), 0.1), ((6, 1), 0.1)]))
        result = self.mdp2.get_transition_states_and_probs((3, 1), n)
        self.assertEqual(set(result), set([((3, 1), 1)]))
        result = self.mdp2.get_transition_states_and_probs((1, 1), exit_action)
        self.assertEqual(set(result), set([(self.mdp2.terminal_state, 1)]))

    def test_states_reachable(self):
        def check_grid(grid):
            self.assertEqual(set(grid.get_states()), self.dfs(grid))

        # Some of the states in self.mdp1 are not reachable, since the agent
        # can't move out of a state with reward in it, so don't check grid1.
        for grid in [self.mdp2, self.mdp3]:
            check_grid(grid)

    def dfs(self, grid):
        visited = set()

        def helper(state):
            if state in visited:
                return
            visited.add(state)
            for action in grid.get_actions(state):
                for next_state, _ in grid.get_transition_states_and_probs(
                        state, action):
                    helper(next_state)

        helper(grid.get_start_state())
        return visited

    def test_environment(self):
        env = GridworldEnvironment(self.mdp3)
        self.assertEqual(env.get_current_state(), (3, 3))
        next_state, reward = env.perform_action(Direction.NORTH)
        self.assertEqual(next_state, (3, 2))
        self.assertEqual(reward, -0.01)
        self.assertEqual(env.get_current_state(), next_state)
        self.assertFalse(env.is_done())
        env.reset()
        self.assertEqual(env.get_current_state(), (3, 3))
        self.assertFalse(env.is_done())
        next_state, reward = env.perform_action(Direction.WEST)
        self.assertEqual(next_state, (2, 3))
        self.assertEqual(reward, -0.01)
        self.assertEqual(env.get_current_state(), next_state)
        self.assertFalse(env.is_done())
        next_state, reward = env.perform_action(Direction.NORTH)
        self.assertEqual(next_state, (2, 2))
        self.assertEqual(reward, -0.01)
        self.assertEqual(env.get_current_state(), next_state)
        self.assertFalse(env.is_done())
        next_state, reward = env.perform_action(Direction.EXIT)
        self.assertEqual(next_state, self.mdp3.terminal_state)
        self.assertEqual(reward, 0)
        self.assertEqual(env.get_current_state(), next_state)
        self.assertTrue(env.is_done())
        env.reset()
        self.assertFalse(env.is_done())

    def test_random_gridworld_generation(self):
        random.seed(314159)
        mdp = GridworldMdp.generate_random(8, 8, 0, 1)
        self.assertEqual(mdp.height, 8)
        self.assertEqual(mdp.width, 8)
        mdp_string = str(mdp)
        self.assertEqual(mdp_string.count('X'), 28)
        self.assertEqual(mdp_string.count(' '), 34)
        self.assertEqual(mdp_string.count('A'), 1)
        self.assertEqual(mdp_string.count('3'), 1)
Пример #8
0
def generate_example(agent, config, other_agents=[], goals=None):
    """Generates an example Gridworld and corresponding agent actions.

    agent: The agent that acts in the generated MDP.
    config: Configuration parameters.
    other_agents: List of Agents that we wish to distinguish `agent` from. In
      particular, for every other agent, for our randomly chosen training
      examples, we report the number of examples (states) on which `agent` and
      the other agent would choose different actions.

    Returns: A tuple of five items:
      image: Numpy array of size imsize x imsize, each element is 1 if there is
             a wall at that location, 0 otherwise.
      rewards: Numpy array of size imsize x imsize, each element is the reward
               obtained at that state. (Most will be zero.)
      start_state: The starting state for the gridworld (a tuple (x, y)).
      action_dists: Numpy array of size imsize x imsize x num_actions. The
                    probability distributions over actions for each state.
      num_different: Numpy array of size `len(other_agents)`. `num_different[i]`
                     is the number of states where `other_agents[i]` would
                     choose a different action compared to `agent`.
    
    For every i < L, the action taken by the agent in state (x, y) is drawn from
    the distribution action_dists[x, y, :]. This can be used to train a planning
    module to recreate the actions of the agent.
    """
    imsize = config.imsize
    num_actions = config.num_actions
    if config.simple_mdp:
        assert False, 'simple_mdp no longer supported'
        # pr_wall, pr_reward = config.wall_prob, config.reward_prob
        # mdp = GridworldMdp.generate_random(imsize, imsize, pr_wall, pr_reward)
    else:
        num_rewards, noise = config.num_rewards, config.noise
        mdp = GridworldMdp.generate_random_connected(imsize, imsize,
                                                     num_rewards, noise, goals)

    def dist_to_numpy(dist):
        return dist.as_numpy_array(Direction.get_number_from_direction,
                                   num_actions)

    def action(state):
        # Walls are invalid states and the MDP will refuse to give an action for
        # them. However, the VIN's architecture requires it to provide an action
        # distribution for walls too, so hardcode it to always be STAY.
        x, y = state
        if mdp.walls[y][x]:
            return dist_to_numpy(Distribution({Direction.STAY: 1}))
        return dist_to_numpy(agent.get_action_distribution(state))

    agent.set_mdp(mdp)
    action_dists = [[action((x, y)) for x in range(imsize)]
                    for y in range(imsize)]
    action_dists = np.array(action_dists)

    def calculate_different(other_agent):
        """
        Return the number of states in minibatches on which the action chosen by
        `agent` is different from the action chosen by `other_agent`.
        """
        other_agent.set_mdp(mdp)

        def differs(s):
            x, y = s
            action_dist = action_dists[y][x]
            dist = dist_to_numpy(other_agent.get_action_distribution(s))
            # Two action distributions are "different" if they are sufficiently
            # far away from each other according to some distance metric.
            # TODO(rohinmshah): L2 norm is not the right distance metric for
            # probability distributions, maybe use something else?
            # Not KL divergence, since it may be undefined
            return np.linalg.norm(action_dist -
                                  dist) > config.action_distance_threshold

        return sum([
            sum([(1 if differs((x, y)) else 0) for x in range(imsize)])
            for y in range(imsize)
        ])

    num_different = np.array([calculate_different(o) for o in other_agents])
    walls, rewards, start_state = mdp.convert_to_numpy_input()
    return walls, rewards, start_state, action_dists, num_different
Пример #9
0
import agents
import tensorflow as tf
# imsize = 16
# pr_wall = 0.05
# pr_reward = 0.0

grid = ['XXXXXXXXX', 'X9XAX   X', 'X X X   X', 'X       X', 'XXXXXXXXX']

preference_grid = [
    'XXXXXXXXXXXXXX', 'XXXXXX4XXXXXXX', 'XXXXXX XXXXXXX', 'XXXXX     XXXX',
    'XXXXX XXX  2XX', 'XXXXX XXX XXXX', 'XXXX1 XXX XXXX', 'XXXXX XXX XXXX',
    'XXXXX XXX XXXX', 'XXXXX XXX XXXX', 'X1        XXXX', 'XXXXX XX1XXXXX',
    'XXXXXAXXXXXXXX', 'XXXXXXXXXXXXXX'
]

mdp = GridworldMdp(preference_grid)
# mdp = GridworldMdp.generate_random(imsize, imsize, pr_wall, pr_reward)
# agent = agents.OptimalAgent()
agent = agents.SophisticatedTimeDiscountingAgent(2, 0.01)
agent.set_mdp(mdp)

env = Mdp(mdp)
trajectory = env.perform_rollout(agent, max_iter=20, print_step=1000)
print_training_example(mdp, trajectory)
print(agent.reward)

# class NeuralAgent(Agent):

# 	def __init__(self, save_dir):
# 		Agent.__init__(self)
# 		self.sess = tf.Session(graph=tf.Graph())
Пример #10
0
class TestGridworld(unittest.TestCase):
    def setUp(self):
        self.grid1 = [['X', 'X', 'X', 'X', 'X'], ['X', ' ', ' ', 'A', 'X'],
                      ['X', '3', 'X', ' ', 'X'], ['X', ' ', ' ', '1', 'X'],
                      ['X', 'X', 'X', 'X', 'X']]
        self.grid2 = [
            'XXXXXXXXX', 'X9X X  AX', 'X X X   X', 'X       X', 'XXXXXXXXX'
        ]
        self.grid3 = [['X', 'X', 'X', 'X', 'X'], ['X', 3.5, 'X', -10, 'X'],
                      ['X', ' ', '1', ' ', 'X'], ['X', ' ', ' ', 'A', 'X'],
                      ['X', 'X', 'X', 'X', 'X']]

        self.mdp1 = GridworldMdp(self.grid1, living_reward=0)
        self.mdp2 = GridworldMdp(self.grid2, noise=0.2)
        self.mdp3 = GridworldMdp(self.grid3)

    def test_str(self):
        expected = '\n'.join([''.join(row) for row in self.grid1])
        self.assertEqual(str(self.mdp1), expected)
        expected = '\n'.join(self.grid2)
        self.assertEqual(str(self.mdp2), expected)
        expected = '\n'.join(['XXXXX', 'XRXNX', 'X 1 X', 'X  AX', 'XXXXX'])
        self.assertEqual(str(self.mdp3), expected)

    def test_constructor_invalid_inputs(self):
        # Height and width must be at least 2.
        with self.assertRaises(AssertionError):
            mdp = GridworldMdp(['X', 'X', 'X'])
        with self.assertRaises(AssertionError):
            mdp = GridworldMdp([['X', 'X', 'X']])

        with self.assertRaises(AssertionError):
            # Borders must be present.
            mdp = GridworldMdp(['  A', '3X ', '  1'])

        with self.assertRaises(AssertionError):
            # There can't be more than one agent.
            mdp = GridworldMdp(['XXXXX', 'XA 3X', 'X3 AX', 'XXXXX'])

        with self.assertRaises(AssertionError):
            # There must be one agent.
            mdp = GridworldMdp(['XXXXX', 'X  3X', 'X3  X', 'XXXXX'])

        with self.assertRaises(AssertionError):
            # There must be at least one reward.
            mdp = GridworldMdp(['XXXXX', 'XAX X', 'X   X', 'XXXXX'])

        with self.assertRaises(AssertionError):
            # B is not a valid element.
            mdp = GridworldMdp(['XXXXX', 'XB  X', 'X  3X', 'XXXXX'])

    def test_start_state(self):
        self.assertEqual(self.mdp1.get_start_state(), (3, 1))
        self.assertEqual(self.mdp2.get_start_state(), (7, 1))
        self.assertEqual(self.mdp3.get_start_state(), (3, 3))

    def test_reward_parsing(self):
        self.assertEqual(self.mdp1.rewards, {(1, 2): 3, (3, 3): 1})
        self.assertEqual(self.mdp2.rewards, {(1, 1): 9})
        self.assertEqual(self.mdp3.rewards, {
            (1, 1): 3.5,
            (2, 2): 1,
            (3, 1): -10
        })

    def test_actions(self):
        a = [
            Direction.NORTH, Direction.SOUTH, Direction.EAST, Direction.WEST,
            Direction.STAY
        ]
        all_acts = set(a)
        self.assertEqual(set(Direction.ALL_DIRECTIONS), all_acts)

        with self.assertRaises(ValueError):
            self.mdp1.get_actions((0, 0))

        self.assertEqual(set(self.mdp1.get_actions((1, 1))), all_acts)
        self.assertEqual(set(self.mdp1.get_actions((1, 2))), all_acts)
        self.assertEqual(set(self.mdp2.get_actions((6, 2))), all_acts)
        self.assertEqual(set(self.mdp2.get_actions((3, 1))), all_acts)
        self.assertEqual(set(self.mdp3.get_actions((2, 2))), all_acts)

    def test_rewards(self):
        grid1_reward_table = {(3, 3): 1, (1, 2): 3}
        grid2_reward_table = {(1, 1): 9}
        grid3_reward_table = {(1, 1): 3.5, (2, 2): 1, (3, 1): -10}
        self.check_all_rewards(self.mdp1, grid1_reward_table, 0)
        self.check_all_rewards(self.mdp2, grid2_reward_table, -0.01)
        self.check_all_rewards(self.mdp3, grid3_reward_table, -0.01)

    def check_all_rewards(self, mdp, reward_lookup_table, living_reward):
        for state in mdp.get_states():
            for action in mdp.get_actions(state):
                expected = 0
                if state in reward_lookup_table:
                    expected += reward_lookup_table[state]
                if action != Direction.STAY:
                    expected += living_reward
                self.assertEqual(mdp.get_reward(state, action), expected)

    def test_transitions(self):
        n, s = Direction.NORTH, Direction.SOUTH
        e, w = Direction.EAST, Direction.WEST
        stay_action = Direction.STAY

        # Grid 1: No noise
        with self.assertRaises(ValueError):
            self.mdp1.get_transition_states_and_probs((0, 0), stay_action)

        result = self.mdp1.get_transition_states_and_probs((1, 3), n)
        self.assertEqual(set(result), set([((1, 2), 1)]))
        result = self.mdp1.get_transition_states_and_probs((1, 2), stay_action)
        self.assertEqual(set(result), set([((1, 2), 1)]))
        result = self.mdp1.get_transition_states_and_probs((1, 1), n)
        self.assertEqual(set(result), set([((1, 1), 1)]))

        # Grid 2: Noise of 0.2
        result = set(self.mdp2.get_transition_states_and_probs((1, 2), n))
        self.assertEqual(result, set([((1, 1), 0.8), ((1, 2), 0.2)]))
        result = set(self.mdp2.get_transition_states_and_probs((6, 2), w))
        self.assertEqual(result,
                         set([((5, 2), 0.8), ((6, 1), 0.1), ((6, 3), 0.1)]))
        result = set(self.mdp2.get_transition_states_and_probs((7, 3), e))
        self.assertEqual(result, set([((7, 3), 0.9), ((7, 2), 0.1)]))
        result = set(self.mdp2.get_transition_states_and_probs((5, 1), s))
        self.assertEqual(result,
                         set([((5, 2), 0.8), ((5, 1), 0.1), ((6, 1), 0.1)]))
        result = self.mdp2.get_transition_states_and_probs((3, 1), n)
        self.assertEqual(set(result), set([((3, 1), 1)]))
        result = self.mdp2.get_transition_states_and_probs((1, 1), stay_action)
        self.assertEqual(set(result), set([((1, 1), 1)]))

    def test_states_reachable(self):
        def check_grid(grid):
            self.assertEqual(set(grid.get_states()), self.dfs(grid))

        for grid in [self.mdp1, self.mdp2, self.mdp3]:
            check_grid(grid)

    def dfs(self, grid):
        visited = set()

        def helper(state):
            if state in visited:
                return
            visited.add(state)
            for action in grid.get_actions(state):
                for next_state, _ in grid.get_transition_states_and_probs(
                        state, action):
                    helper(next_state)

        helper(grid.get_start_state())
        return visited

    def test_environment(self):
        env = Mdp(self.mdp3)
        self.assertEqual(env.get_current_state(), (3, 3))
        next_state, reward = env.perform_action(Direction.NORTH)
        self.assertEqual(next_state, (3, 2))
        self.assertEqual(reward, -0.01)
        self.assertEqual(env.get_current_state(), next_state)
        self.assertFalse(env.is_done())
        env.reset()
        self.assertEqual(env.get_current_state(), (3, 3))
        self.assertFalse(env.is_done())
        next_state, reward = env.perform_action(Direction.WEST)
        self.assertEqual(next_state, (2, 3))
        self.assertEqual(reward, -0.01)
        self.assertEqual(env.get_current_state(), next_state)
        self.assertFalse(env.is_done())
        next_state, reward = env.perform_action(Direction.NORTH)
        self.assertEqual(next_state, (2, 2))
        self.assertEqual(reward, -0.01)
        self.assertEqual(env.get_current_state(), next_state)
        self.assertFalse(env.is_done())
        next_state, reward = env.perform_action(Direction.STAY)
        self.assertEqual(next_state, (2, 2))
        self.assertEqual(reward, 1)
        self.assertEqual(env.get_current_state(), next_state)
        self.assertFalse(env.is_done())
        env.reset()
        self.assertFalse(env.is_done())
        self.assertEqual(env.get_current_state(), (3, 3))

    def test_numpy_conversion(self):
        def check_mdp(mdp):
            new_mdp = GridworldMdp.from_numpy_input(
                *mdp.convert_to_numpy_input())
            self.assertEqual(new_mdp.height, mdp.height)
            self.assertEqual(new_mdp.width, mdp.width)
            self.assertEqual(new_mdp.walls, mdp.walls)
            self.assertEqual(new_mdp.rewards, mdp.rewards)
            self.assertEqual(new_mdp.start_state, mdp.start_state)

        check_mdp(self.mdp1)
        check_mdp(self.mdp2)
        check_mdp(self.mdp3)

    def test_random_gridworld_generation(self):
        set_seeds(314159)
        mdp = GridworldMdp.generate_random(8, 8, 0, 0)
        self.assertEqual(mdp.height, 8)
        self.assertEqual(mdp.width, 8)
        mdp_string = str(mdp)
        self.assertEqual(mdp_string.count('X'), 28)
        self.assertEqual(mdp_string.count(' '), 34)
        self.assertEqual(mdp_string.count('A'), 1)
        self.assertEqual(mdp_string.count('3'), 1)