Beispiel #1
0
def plot_dist_to_neighbours(test_env_name, sim_ind,state_reps, geodesic_dist=True, single_pos=True):
    # make new env to run test in
    env = gym.make(test_env_name)
    plt.close()

    # make graph of env states
    G = make_env_graph(env)
    gd= compute_graph_distance_matrix(G,env)
    dist_in_state_space = np.delete(gd[sim_ind],sim_ind) #distance from sim ind to all other states

    cmap = linc_coolwarm_r #cm.get_cmap('coolwarm')
    try:
        reps_as_matrix = np.zeros((400,state_reps[0].shape[1]))
    except:
        reps_as_matrix = np.zeros((400,state_reps[0].shape[0]))
    reps_as_matrix[:]  = np.nan

    for ind, (k,v) in enumerate(state_reps.items()):
        reps_as_matrix[k] = v
        if k in env.obstacle:
            reps_as_matrix[k,:] = np.nan


    RS = squareform(pdist(reps_as_matrix,metric='chebyshev'))
    for state2d in env.obstacle:
        RS[state2d,:] = np.nan
        RS[:,state2d] = np.nan

    dist_in_rep_space = np.delete(RS[sim_ind],sim_ind)
    print(type(dist_in_rep_space))
    rgba = [cmap(x) for x in dist_in_rep_space/np.nanmax(dist_in_rep_space)]
    if geodesic_dist:
        plt.scatter(dist_in_state_space,dist_in_rep_space,color=LINCLAB_COLS['red'],alpha=0.4,linewidths=0.5 )
        plt.xlabel("D(s,s')")
        plt.ylabel("D(R(s), R(s'))")
        plt.ylim([0.,1.1])
        plt.savefig(f'../figures/CH1/latent_distance{sim_ind}.svg')
    else:
        if single_pos:
            a = plt.imshow(RS[sim_ind].reshape(env.shape)/np.nanmax(RS[sim_ind]), cmap=cmap, vmin=0, vmax=1)
            r,c = env.oneD2twoD(sim_ind)
            plt.gca().add_patch(plt.Rectangle(np.add((c,r),(-0.5,-0.5)), .99, .99, edgecolor='k', fill=False, alpha=1))
            plt.colorbar(a)
            plt.gca().get_xaxis().set_visible(False)
            plt.gca().get_yaxis().set_visible(False)
            plt.savefig(f'../figures/CH1/representation_similarity/latent_sim{sim_ind}.svg')
        else:
            sliced = RS.copy()
            for state1d in reversed(env.obstacle):
                sliced = np.delete(sliced,state1d,0)
                sliced = np.delete(sliced,state1d,1)

            a = plt.imshow(sliced/np.nanmax(sliced), vmin=0, vmax=1,  cmap=linc_coolwarm)
            plt.colorbar(a)

    #plt.show()
    plt.close()
Beispiel #2
0
def get_graph_dist_from_state(envs_to_plot, sim_ind):
    graph_distances = []
    for i, test_env_name in enumerate(envs_to_plot):
        env = gym.make(test_env_name)
        plt.close()
        G = make_env_graph(env)
        gd = compute_graph_distance_matrix(G, env)
        dist_in_state_space = gd[sim_ind]
        graph_distances.append(dist_in_state_space)
    return graph_distances
Beispiel #3
0
def attempt_opt_pol(env):
    rwd_loc = env.twoD2oneD(list(env.rewards.keys())[0])
    G = make_env_graph(env)
    sp = nx.shortest_path(G)
    gd = compute_graph_distance_matrix(G, env)

    num_steps_to_rwd = np.zeros((20, 20))
    num_steps_to_rwd[:] = np.nan
    for (i, j) in env.useable:
        # i = row, j = col
        oneD = env.twoD2oneD((i, j))
        num_steps_to_rwd[i, j] = gd[oneD, rwd_loc]

    opt_pol_matrix = np.zeros((20, 20, 4))
    opt_pol_matrix[:] = np.nan
    for (r, c) in env.useable:
        index = env.twoD2oneD((r, c))
        steps_from_index = num_steps_to_rwd[r, c]

        if r + 1 < 20:
            state_down = (r + 1, c)
            steps_down = num_steps_to_rwd[state_down]
            if steps_down < steps_from_index:
                opt_pol_matrix[r, c, 0] = 1

        if r - 1 >= 0:
            state_up = (r - 1, c)
            index_up = env.twoD2oneD(state_up)
            steps_up = num_steps_to_rwd[state_up]
            if steps_up < steps_from_index:
                opt_pol_matrix[r, c, 1] = 1

        if c + 1 < 20:
            state_right = (r, c + 1)
            steps_right = num_steps_to_rwd[state_right]
            if steps_right < steps_from_index:
                opt_pol_matrix[r, c, 2] = 1

        if (c - 1) >= 0:
            state_left = (r, c - 1)
            steps_left = num_steps_to_rwd[state_left]
            if steps_left < steps_from_index:
                opt_pol_matrix[r, c, 3] = 1

        opt_pol_matrix[r,
                       c, :] = np.nan_to_num(softmax(opt_pol_matrix[r, c, :]))

    return opt_pol_matrix


#opt_pol = attempt_opt_pol(env)

#op = opt_pol_map(gym.make('gridworld:gridworld-v1'))
#op[9:,:] = [0., 0.,0,0]
#plot_pref_pol(env, opt_pol)
Beispiel #4
0
def plot_grid_of_shit(test_env_name,state_reps, sim_ind, metric):
    # make new env to run test in
    env = gym.make(test_env_name)
    plt.close()

    # make graph of env states
    G = make_env_graph(env)
    gd= compute_graph_distance_matrix(G,env)
    dist_in_state_space = np.delete(gd[sim_ind],sim_ind) #distance from sim ind to all other states

    cmap = linc_coolwarm_r #cmx.get_cmap('Spectral_r') #
    try:
        reps_as_matrix = np.zeros((400,state_reps[0].shape[1]))
    except:
        reps_as_matrix = np.zeros((400,state_reps[0].shape[0]))
    reps_as_matrix[:]  = np.nan

    for ind, (k,v) in enumerate(state_reps.items()):
        reps_as_matrix[k] = v
        if k in env.obstacle:
            reps_as_matrix[k,:] = np.nan


    RS = squareform(pdist(reps_as_matrix,metric=metric))
    for state2d in env.obstacle:
        RS[state2d,:] = np.nan
        RS[:,state2d] = np.nan

    dist_in_rep_space = np.delete(RS[sim_ind],sim_ind)

    a = plt.imshow(RS[sim_ind].reshape(env.shape)/np.nanmax(RS[sim_ind]), cmap=cmap)
    r,c = env.oneD2twoD(sim_ind)
    plt.gca().add_patch(plt.Rectangle(np.add((c,r),(-0.5,-0.5)), .99, .99, edgecolor='k', fill=False, alpha=1))
    plt.colorbar(a)
    plt.gca().get_xaxis().set_visible(False)
    plt.gca().get_yaxis().set_visible(False)
    plt.savefig(f'../figures/CH1/representation_similarity/distance_metrics/{test_env_name[-2:]}{rep_name}{metric}.svg')
    plt.show()
Beispiel #5
0
for id_num in id_list:
    with open(f'../../Data/results/{id_num}_data.p', 'rb') as f:
        data = pickle.load(f)

    state_occ_vec = data['occupancy']
    all_occ += state_occ_vec
    all_visits = np.sum(state_occ_vec)
    all_v.append(all_visits)
    print(all_occ)

state_occ_map[:] = np.nan
for r,c in env.useable:
    ind = env.twoD2oneD((r,c))
    state_occ_map[r,c] = all_occ[ind]

G = make_env_graph(env)
gd = compute_graph_distance_matrix(G, env)

dist_from_reward = gd[:,105]

#a = plt.imshow(state_occ_map/(all_visits/len(env.useable)),vmax=4)
#plt.colorbar(a)
#plt.show()
print(np.nanmax(dist_from_reward))
janky_histo = np.zeros(int(np.nanmax(dist_from_reward))+1)
num_times_for_dist = np.zeros(int(np.nanmax(dist_from_reward))+1)

for ind in range(len(state_occ_vec)):
    print(dist_from_reward[ind])
    if np.isnan(dist_from_reward[ind]):
        pass
Beispiel #6
0
def plot_world(world, **kwargs):
    scale = kwargs.get('scale', 0.35)
    title = kwargs.get('title', 'Grid World')
    ax_labels = kwargs.get('ax_labels', False)
    state_labels = kwargs.get('states', False)
    invert_ = kwargs.get('invert', False)
    if invert_:
        cmap = 'bone'
    else:
        cmap = 'bone_r'
    r, c = world.shape

    G = make_env_graph(env)
    sp = nx.shortest_path(G)
    gd = compute_graph_distance_matrix(G, world)

    fig = plt.figure(figsize=(c * scale, r * scale))
    ax = fig.add_subplot(1, 1, 1)

    gridMat = np.zeros(world.shape)
    for i, j in world.obstacle2D:
        gridMat[i, j] = 1.0
    for i, j in world.terminal2D:
        gridMat[i, j] = 0.2
    ax.pcolor(world.grid,
              edgecolors='k',
              linewidths=0.75,
              cmap=cmap,
              vmin=0,
              vmax=1)

    U = np.zeros((r, c))
    V = np.zeros((r, c))
    U[:] = np.nan
    V[:] = np.nan

    if len(world.action_list) > 4:
        if world.jump is not None:
            for (a, b) in world.jump.keys():
                (a2, b2) = world.jump[(a, b)]
                U[a, b] = (b2 - b)
                V[a, b] = (a - a2)

    C, R = np.meshgrid(np.arange(0, c) + 0.5, np.arange(0, r) + 0.5)
    ax.quiver(C, R, U, V, scale=1, units='xy')

    for rwd_loc in world.rewards.keys():
        rwd_r, rwd_c = rwd_loc
        if world.rewards[rwd_loc] < 0:
            colorcode = 'red'
        else:
            colorcode = 'darkgreen'
        ax.add_patch(
            plt.Rectangle((rwd_c, rwd_r),
                          width=1,
                          height=1,
                          linewidth=2,
                          facecolor=colorcode,
                          alpha=0.5))

    if state_labels:
        for (i, j) in world.useable:
            # i = row, j = col
            oneD = world.twoD2oneD((i, j))
            #ax.text(j+0.5,i+0.7, s=f'{oneD}', ha='center')
            ax.text(j + 0.5, i + 0.7, s=f'{gd[oneD,105]}', ha='center')

    #ax.set_xticks([np.arange(c) + 0.5, np.arange(c)])
    #ax.set_yticks([np.arange(r) + 0.5, np.arange(r)])
    ax.invert_yaxis()
    ax.set_aspect('equal')
    if not ax_labels:
        ax.get_xaxis().set_ticks([])
        ax.get_yaxis().set_ticks([])
    ax.set_title(title)

    return fig, ax
Beispiel #7
0
def plot_dist_to_neighbours():
    for test_env_name in [envs_to_plot[1]]:
        sim_ind = 169
        envno = envs_to_plot.index(test_env_name)

        # make new env to run test in
        env = gym.make(test_env_name)
        plt.close()

        # make graph of env states
        G = make_env_graph(env)
        gd = compute_graph_distance_matrix(G, env)
        dist_in_state_space = np.delete(
            gd[sim_ind], sim_ind)  #distance from sim ind to all other states

        rep_types = {
            'random': random,
            'onehot': onehot,
            'place_cell': place_cell,
            'sr': sr
        }  #, 'latents':latents}
        fig, ax = plt.subplots(1,
                               len(list(rep_types.items())) + 1,
                               figsize=(14, 2))

        if test_env_name[-2:] == '51':
            rwd_colrow = (16, 9)
        else:
            rwd_colrow = (14, 14)

        rect = plt.Rectangle(rwd_colrow, 1, 1, color='g', alpha=0.3)
        agt_colrow = (env.oneD2twoD(sim_ind)[1] + 0.5,
                      env.oneD2twoD(sim_ind)[0] + 0.5)
        circ = plt.Circle(agt_colrow, radius=0.3, color='blue')
        ax[0].pcolor(grids[envno],
                     cmap='bone_r',
                     edgecolors='k',
                     linewidths=0.1)
        ax[0].axis(xmin=0, xmax=20, ymin=0, ymax=20)
        ax[0].set_aspect('equal')
        ax[0].add_patch(rect)
        ax[0].add_patch(circ)
        ax[0].get_xaxis().set_visible(False)
        ax[0].get_yaxis().set_visible(False)
        ax[0].invert_yaxis()

        cmap = linc_coolwarm_r  #cm.get_cmap('coolwarm')

        for j, rep_type in enumerate(rep_types.keys()):
            relative_path_to_data = '../../Data/'  # from within Tests/CH1
            if rep_type == 'latents':
                conv_ids = {
                    'gridworld:gridworld-v1':
                    'c34544ac-45ed-492c-b2eb-4431b403a3a8',
                    'gridworld:gridworld-v3':
                    '32301262-cd74-4116-b776-57354831c484',
                    'gridworld:gridworld-v4':
                    'b50926a2-0186-4bb9-81ec-77063cac6861',
                    'gridworld:gridworld-v5':
                    '15b5e27b-444f-4fc8-bf25-5b7807df4c7f'
                }
                run_id = conv_ids[test_env_name[:-1]]
                agent_path = relative_path_to_data + f'agents/{run_id}.pt'
                state_reps, representation_name, input_dims, _ = latents(
                    env, agent_path)
            else:
                state_reps, representation_name, input_dims, _ = rep_types[
                    rep_type](env)
            reps_as_matrix = np.zeros((400, 400))
            reps_as_matrix[:] = np.nan

            for ind, (k, v) in enumerate(state_reps.items()):
                reps_as_matrix[k] = v
                if k in env.obstacle:
                    reps_as_matrix[k, :] = np.nan

            RS = squareform(pdist(reps_as_matrix, metric='chebyshev'))
            for state2d in env.obstacle:
                RS[state2d, :] = np.nan
                RS[:, state2d] = np.nan

            dist_in_rep_space = np.delete(RS[sim_ind], sim_ind)
            print(type(dist_in_rep_space))
            rgba = [
                cmap(x)
                for x in dist_in_rep_space / np.nanmax(dist_in_rep_space)
            ]
            ax[j + 1].scatter(dist_in_state_space,
                              dist_in_rep_space,
                              color='#b40426',
                              alpha=0.2,
                              linewidths=0.5)
            ax[j + 1].set_xlabel("D(s,s')")
            ax[j + 1].set_ylim([0.4, 1.1])
            if j != 0:
                ax[j + 1].set_yticklabels([])
            else:
                ax[j + 1].set_ylabel("D(R(s), R(s'))")

        #a = ax[j+2].imshow((dist_in_state_space/np.nanmax(dist_in_state_space)).reshape(env.shape), cmap=cmap)
        #plt.colorbar(a, ax=ax[j+2])
        plt.savefig('../figures/CH2/distance.svg')
        plt.show()