Exemple #1
0
def test_visitations(grid, agent):
    """Tests the expected_counts calculation--might be einsum error"""
    # print("Testing expected_counts")
    from gridworld.gridworld import GridworldMdp, Direction
    from utils import Distribution

    num_actions = len(Direction.ALL_DIRECTIONS)

    mdp = GridworldMdp(grid=grid)
    agent.set_mdp(mdp)

    def dist_to_numpy(dist):
        return dist.as_numpy_array(Direction.get_number_from_direction,
                                   num_actions)

    def action(state):
        # Walls are invalid states and the MDP will refuse to give an action for
        # them. However, the VIN's architecture requires it to provide an action
        # distribution for walls too, so hardcode it to always be STAY.
        x, y = state
        if mdp.walls[y][x]:
            return dist_to_numpy(Distribution({Direction.STAY: 1}))
        return dist_to_numpy(agent.get_action_distribution(state))

    imsize = len(grid)

    action_dists = [[action((x, y)) for y in range(imsize)]
                    for x in range(imsize)]
    action_dists = np.array(action_dists)

    walls, rewards, start_state = mdp.convert_to_numpy_input()

    # print("Start state for given mdp:", start_state)

    start = start_state
    trans = mdp.get_transition_matrix()
    initial_states = np.zeros((len(grid), len(grid)))
    initial_states[start[1]][start[0]] = 1
    initial_states = initial_states.reshape(-1)
    policy = flatten_policy(action_dists)

    demo_counts = expected_counts(policy, trans, initial_states, 20, 0.9)

    import matplotlib.pyplot as plt
    plt.imsave("democounts", demo_counts.reshape((len(grid), len(grid))))
Exemple #2
0
def test_coherence(grid, agent):
    """Test that these arrays perform as expected under np.einsum"""
    from gridworld.gridworld import GridworldMdp, Direction
    from utils import Distribution

    num_actions = len(Direction.ALL_DIRECTIONS)

    mdp = GridworldMdp(grid=grid)
    agent.set_mdp(mdp)

    def dist_to_numpy(dist):
        return dist.as_numpy_array(Direction.get_number_from_direction,
                                   num_actions)

    def action(state):
        # Walls are invalid states and the MDP will refuse to give an action for
        # them. However, the VIN's architecture requires it to provide an action
        # distribution for walls too, so hardcode it to always be STAY.
        x, y = state
        if mdp.walls[y][x]:
            return dist_to_numpy(Distribution({Direction.STAY: 1}))
        return dist_to_numpy(agent.get_action_distribution(state))

    imsize = len(grid)

    action_dists = [[action((x, y)) for y in range(imsize)]
                    for x in range(imsize)]
    action_dists = np.array(action_dists)

    walls, rewards, start_state = mdp.convert_to_numpy_input()

    print("Start state for given mdp:", start_state)
    # inferred = _irl_wrapper(walls, action_dists, start_state, 20, 1.0)

    start = start_state
    trans = mdp.get_transition_matrix()
    initial_states = np.zeros((len(grid), len(grid)))
    initial_states[start[1]][start[0]] = 1
    initial_states = initial_states.reshape(-1)
    policy = flatten_policy(action_dists)

    gshape = (len(grid), len(grid))
    print("initial states")
    print('-' * 20)
    print(initial_states.reshape(gshape))
    next_states = np.einsum("i,ij,ijk -> k", initial_states, policy, trans)
    # next_states = (next_states.reshape(gshape).T).reshape(-1)
    print("first expected counts")
    print('-' * 20)
    print(next_states.reshape(gshape))
    next_states = np.einsum("i,ij,ijk -> k", next_states, policy, trans)
    print("second expected counts")
    print('-' * 20)
    print(next_states.reshape(gshape))

    next_states = np.einsum("i,ij,ijk -> k", next_states, policy, trans)
    # next_states = (next_states.reshape(gshape).T).reshape(-1)
    print("third expected counts")
    print('-' * 20)
    print(next_states.reshape(gshape))

    # for i in range(5):
    #     next_states = np.einsum("i,ij,ijk -> k", next_states, policy, trans)
    #     # next_states = (next_states.reshape(gshape).T).reshape(-1)
    #     print("{}th expected counts".format(4+i))
    #     print('-'*20)
    #     print(next_states.reshape(gshape))
    return next_states.reshape((len(grid), len(grid)))