Ejemplo n.º 1
0
def generate_saved_latents():
    for latent_type in ['conv','rwd_conv']:
        for version in [1,2,3,4,5]:
            # get agent ids to load/save/generate latents
            example_run_ids = ids[latent_type]

            # get environment
            env_id = f'gridworld:gridworld-v{version}'
            run_id = example_run_ids[env_id]

            # make sure saved agent is in the form of a state_dict of weights instead of the agent object
            try:
                convert_agent_to_weight_dict(f'../../Data/network_objs/{run_id}.pt',destination_path=f'./../../Data/agents/{run_id}.pt')
            except:
                pass

            # make gym environment
            env = gym.make(env_id)
            plt.close()

            # save latents by loading network, passing appropriate tensor, getting top fc layer activity
            reps, name, dim, _ = latents(env, f'./../../Data/agents/{run_id}.pt', type=latent_type )

            latent_array = np.zeros((env.nstates,env.nstates))
            for i in reps.keys():
                latent_array[i] = reps[i]

            if save:
                with open(f'../../modules/Agents/RepresentationLearning/Learned_Rep_pickles/{latent_type}{run_id[0:8]}_{env_id[-12:]}.p', 'wb') as f:
                    pickle.dump(file=f, obj=latent_array.copy())
Ejemplo n.º 2
0
def get_mem_map(gb, env_name, rep, pct, ind=0):
    example_env = gym.make(env_name)
    plt.close()

    rep_types = {
        'onehot': onehot,
        'random': random,
        'place_cell': place_cell,
        'analytic successor': sr
    }
    if rep == 'latents':
        conv_ids = {
            'gridworld:gridworld-v1': 'c34544ac-45ed-492c-b2eb-4431b403a3a8',
            'gridworld:gridworld-v3': '32301262-cd74-4116-b776-57354831c484',
            'gridworld:gridworld-v4': 'b50926a2-0186-4bb9-81ec-77063cac6861',
            'gridworld:gridworld-v5': '15b5e27b-444f-4fc8-bf25-5b7807df4c7f'
        }
        load_id = conv_ids[f'{env_name[:-1]}']
        agent_path = parent_path + f'agents/{load_id}.pt'
        state_reps, representation_name, input_dims, _ = latents(
            example_env, agent_path)
    else:
        state_reps, representation_name, input_dims, _ = rep_types[rep](
            example_env)

    run_id = list(gb.get_group(
        (env_name, rep, cache_limits[env_name][pct])))[ind]
    policy_map = np.zeros(example_env.shape,
                          dtype=[(x, 'f8') for x in example_env.action_list])

    with open(parent_path + f'ec_dicts/{run_id}_EC.p', 'rb') as f:
        cache_list = pickle.load(f)

    mem = Memory(
        entry_size=example_env.action_space.n, cache_limit=400
    )  # cache limit doesn't matter since we are only using for recall
    mem.cache_list = cache_list

    for state2d in example_env.useable:
        state1d = example_env.twoD2oneD(state2d)
        state_rep = tuple(state_reps[state1d])
        #print(state_rep in cache_list.keys())

        policy_map[state2d] = tuple(mem.recall_mem(state_rep))

    return policy_map
Ejemplo n.º 3
0
def plot_inputs_and_latents(env_version, latent_type, test_index):
    training_env_name = f'gridworld:gridworld-v{env_version}'
    testing_env_name = training_env_name + '1'
    env_name = testing_env_name
    env = gym.make(env_name)
    plt.close()

    # get inputs states
    if latent_type == 'conv':
        inputs, _, __, ___ = convs(env)
    elif latent_type == 'rwd_conv':
        inputs, _, __, ___ = reward_convs(env)

    tensor_slices = inputs[test_index][0].shape[0]
    fig, ax = plt.subplots(1, tensor_slices + 1)
    for item in range(tensor_slices):
        ax[item].imshow(inputs[test_index][0][item], cmap='bone_r')
        ax[item].set_aspect('equal')
    plt.show()

    example_ids = ids[latent_type]
    run_id = example_ids[training_env_name]

    # get corresponding latent states
    path_to_agent = f'./../../../Data/agents/{run_id}.pt'

    empty = head_AC(400, 4, lr=0.005)
    full = load_saved_head_weights(empty, path_to_agent)

    state_reps, name, dim, _ = latents(env, path_to_agent, type=latent_type)

    policy_map = np.zeros(env.shape,
                          dtype=[(x, 'f8') for x in env.action_list])
    for state2d in env.useable:
        latent_state = state_reps[env.twoD2oneD(state2d)]
        pol, val = full(latent_state)
        policy_map[state2d] = tuple(pol)

    plot_polmap(env, policy_map)
Ejemplo n.º 4
0
rep_types = {
    'onehot': onehot,
    'random': random,
    'place_cell': place_cell,
    'sr': sr
}
if rep_type == 'latents':
    conv_ids = {
        'gridworld:gridworld-v1': 'c34544ac-45ed-492c-b2eb-4431b403a3a8',
        'gridworld:gridworld-v3': '32301262-cd74-4116-b776-57354831c484',
        'gridworld:gridworld-v4': 'b50926a2-0186-4bb9-81ec-77063cac6861',
        'gridworld:gridworld-v5': '15b5e27b-444f-4fc8-bf25-5b7807df4c7f'
    }
    run_id = conv_ids[f'gridworld:gridworld-v{version}']
    agent_path = relative_path_to_data + f'agents/saved_agents/{run_id}.pt'
    state_reps, representation_name, input_dims, _ = latents(env, agent_path)
else:
    state_reps, representation_name, input_dims, _ = rep_types[rep_type](env)

AC_head_agent = head_AC(input_dims, env.action_space.n, lr=learning_rate)

memory = Memory(entry_size=env.action_space.n,
                cache_limit=cache_size_for_env,
                distance=distance_metric)

agent = Agent(AC_head_agent, memory=memory, state_representations=state_reps)

ex = flat_expt(agent, env)
print(
    f"Experiment running {env.unwrapped.spec.id} \nRepresentation: {representation_name} \nCache Limit:{cache_size_for_env} \nDistance: {distance_metric}"
)
Ejemplo n.º 5
0
plt.close()
# make new env to run test in
test_env = gym.make(test_env_name)
plt.close()

if latent_type == 'conv' or latent_type == 'rwd_conv':
    ids = {'conv': conv_ids, 'rwd_conv': rwd_conv_ids}

    id_dict = ids[latent_type]
    run_id = id_dict[training_env_name]
    # load latent states to use as state representations to actor-critic heads
    agent_path = relative_path_to_data + f'agents/{run_id}.pt'

    # save latents by loading network, passing appropriate tensor, getting top fc layer activity
    state_reps, representation_name, input_dims, _ = latents(train_env,
                                                             agent_path,
                                                             type=latent_type)

elif latent_type in ['sr', 'onehot']:
    rep_Type = {'sr': sr, 'onehot': onehot}
    state_reps, representation_name, input_dims, _ = rep_Type[latent_type](
        test_env)

if load_weights:
    # load weights to head_ac network from previously learned agent
    empty_net = head_AC(input_dims, test_env.action_space.n, lr=learning_rate)
    AC_head_agent = load_saved_head_weights(empty_net, agent_path)
    loaded_from = run_id
else:
    AC_head_agent = head_AC(input_dims,
                            test_env.action_space.n,
Ejemplo n.º 6
0
def plot_squares(what_to_plot='all_state_rep', current_cmap='viridis_r'):
    for test_env_name in [envs_to_plot[1]]:
        sim_ind = 169
        envno = envs_to_plot.index(test_env_name)
        #test_env_name = envs_to_plot[env_id]
        # make new env to run test in
        env = gym.make(test_env_name)
        plt.close()
        #plot_world(env,plotNow=True,scale=0.4,states=True)

        rep_types = {
            'random': random,
            'onehot': onehot,
            'place_cell': place_cell,
            'sr': sr
        }  #, 'latents':latents}
        fig, ax = plt.subplots(1,
                               len(list(rep_types.items())) + 1,
                               figsize=(14, 2))

        if test_env_name[-2:] == '51':
            rwd_colrow = (16, 9)
        else:
            rwd_colrow = (14, 14)

        rect = plt.Rectangle(rwd_colrow, 1, 1, color='g', alpha=0.3)
        agt_colrow = (env.oneD2twoD(sim_ind)[1] + 0.5,
                      env.oneD2twoD(sim_ind)[0] + 0.5)
        circ = plt.Circle(agt_colrow, radius=0.3, color='blue')
        ax[0].pcolor(grids[envno],
                     cmap='bone_r',
                     edgecolors='k',
                     linewidths=0.1)
        ax[0].axis(xmin=0, xmax=20, ymin=0, ymax=20)
        ax[0].set_aspect('equal')
        ax[0].add_patch(rect)
        ax[0].add_patch(circ)
        ax[0].get_xaxis().set_visible(False)
        ax[0].get_yaxis().set_visible(False)
        ax[0].invert_yaxis()

        for j, rep_type in enumerate(rep_types.keys()):
            relative_path_to_data = '../../Data/'  # from within Tests/CH1
            if rep_type == 'latents':
                conv_ids = {
                    'gridworld:gridworld-v1':
                    'c34544ac-45ed-492c-b2eb-4431b403a3a8',
                    'gridworld:gridworld-v3':
                    '32301262-cd74-4116-b776-57354831c484',
                    'gridworld:gridworld-v4':
                    'b50926a2-0186-4bb9-81ec-77063cac6861',
                    'gridworld:gridworld-v5':
                    '15b5e27b-444f-4fc8-bf25-5b7807df4c7f'
                }
                run_id = conv_ids[test_env_name[:-1]]
                agent_path = relative_path_to_data + f'agents/{run_id}.pt'
                state_reps, representation_name, input_dims, _ = latents(
                    env, agent_path)
            else:
                state_reps, representation_name, input_dims, _ = rep_types[
                    rep_type](env)
            reps_as_matrix = np.zeros((400, 400))
            reps_as_matrix[:] = np.nan

            for ind, (k, v) in enumerate(state_reps.items()):
                reps_as_matrix[k] = v
                if k in env.obstacle:
                    reps_as_matrix[k, :] = np.nan

            RS = squareform(pdist(reps_as_matrix, metric='chebyshev'))
            for state2d in env.obstacle:
                RS[state2d, :] = np.nan
                RS[:, state2d] = np.nan

            #plot representation of a single state
            if what_to_plot == 'single_state_rep':
                #current_cmap = cmaps_mappings_r[rep_type]
                a = ax[j + 1].imshow(
                    (reps_as_matrix[sim_ind] /
                     np.nanmax(reps_as_matrix)).reshape(env.shape),
                    vmin=0,
                    vmax=1,
                    cmap=current_cmap)

            elif what_to_plot == 'all_state_rep':
                #current_cmap = cmaps_mappings_r[rep_type]
                a = ax[j + 1].imshow(reps_as_matrix /
                                     np.nanmax(reps_as_matrix),
                                     vmin=0,
                                     vmax=1,
                                     cmap=current_cmap)

            elif what_to_plot == 'all_state_sim_sliced':
                #current_cmap = cmaps_mappings_r[rep_type]
                sliced = RS.copy()

                print(np.argwhere(np.isnan(sliced)))
                for state1d in reversed(env.obstacle):
                    sliced = np.delete(sliced, state1d, 0)
                    sliced = np.delete(sliced, state1d, 1)

                a = ax[j + 1].imshow(sliced / np.nanmax(sliced),
                                     vmin=0,
                                     vmax=1,
                                     cmap=current_cmap)

            elif what_to_plot == 'single_state_sim':
                #current_cmap = cmaps_mappings[rep_type]
                a = ax[j + 1].imshow(RS[sim_ind].reshape(env.shape) /
                                     np.nanmax(RS),
                                     vmin=0,
                                     vmax=1,
                                     cmap=current_cmap)

            elif what_to_plot == 'all_state_sim':
                #current_cmap = cmaps_mappings[rep_type]
                a = ax[j + 1].imshow(RS / np.nanmax(RS),
                                     vmin=0,
                                     vmax=1,
                                     cmap=current_cmap)

            if j == len(list(rep_types.keys())) - 1:
                divider = make_axes_locatable(ax[j + 1])
                cax = divider.append_axes('right', size='5%', pad=0.05)
                plt.colorbar(a, cax=cax)
            ax[j + 1].get_xaxis().set_visible(False)
            ax[j + 1].get_yaxis().set_visible(False)
        plt.savefig(f'../figures/CH2/{test_env_name[-3:]}_{what_to_plot}.svg')
        plt.show()
Ejemplo n.º 7
0
def plot_dist_to_neighbours():
    for test_env_name in [envs_to_plot[1]]:
        sim_ind = 169
        envno = envs_to_plot.index(test_env_name)

        # make new env to run test in
        env = gym.make(test_env_name)
        plt.close()

        # make graph of env states
        G = make_env_graph(env)
        gd = compute_graph_distance_matrix(G, env)
        dist_in_state_space = np.delete(
            gd[sim_ind], sim_ind)  #distance from sim ind to all other states

        rep_types = {
            'random': random,
            'onehot': onehot,
            'place_cell': place_cell,
            'sr': sr
        }  #, 'latents':latents}
        fig, ax = plt.subplots(1,
                               len(list(rep_types.items())) + 1,
                               figsize=(14, 2))

        if test_env_name[-2:] == '51':
            rwd_colrow = (16, 9)
        else:
            rwd_colrow = (14, 14)

        rect = plt.Rectangle(rwd_colrow, 1, 1, color='g', alpha=0.3)
        agt_colrow = (env.oneD2twoD(sim_ind)[1] + 0.5,
                      env.oneD2twoD(sim_ind)[0] + 0.5)
        circ = plt.Circle(agt_colrow, radius=0.3, color='blue')
        ax[0].pcolor(grids[envno],
                     cmap='bone_r',
                     edgecolors='k',
                     linewidths=0.1)
        ax[0].axis(xmin=0, xmax=20, ymin=0, ymax=20)
        ax[0].set_aspect('equal')
        ax[0].add_patch(rect)
        ax[0].add_patch(circ)
        ax[0].get_xaxis().set_visible(False)
        ax[0].get_yaxis().set_visible(False)
        ax[0].invert_yaxis()

        cmap = linc_coolwarm_r  #cm.get_cmap('coolwarm')

        for j, rep_type in enumerate(rep_types.keys()):
            relative_path_to_data = '../../Data/'  # from within Tests/CH1
            if rep_type == 'latents':
                conv_ids = {
                    'gridworld:gridworld-v1':
                    'c34544ac-45ed-492c-b2eb-4431b403a3a8',
                    'gridworld:gridworld-v3':
                    '32301262-cd74-4116-b776-57354831c484',
                    'gridworld:gridworld-v4':
                    'b50926a2-0186-4bb9-81ec-77063cac6861',
                    'gridworld:gridworld-v5':
                    '15b5e27b-444f-4fc8-bf25-5b7807df4c7f'
                }
                run_id = conv_ids[test_env_name[:-1]]
                agent_path = relative_path_to_data + f'agents/{run_id}.pt'
                state_reps, representation_name, input_dims, _ = latents(
                    env, agent_path)
            else:
                state_reps, representation_name, input_dims, _ = rep_types[
                    rep_type](env)
            reps_as_matrix = np.zeros((400, 400))
            reps_as_matrix[:] = np.nan

            for ind, (k, v) in enumerate(state_reps.items()):
                reps_as_matrix[k] = v
                if k in env.obstacle:
                    reps_as_matrix[k, :] = np.nan

            RS = squareform(pdist(reps_as_matrix, metric='chebyshev'))
            for state2d in env.obstacle:
                RS[state2d, :] = np.nan
                RS[:, state2d] = np.nan

            dist_in_rep_space = np.delete(RS[sim_ind], sim_ind)
            print(type(dist_in_rep_space))
            rgba = [
                cmap(x)
                for x in dist_in_rep_space / np.nanmax(dist_in_rep_space)
            ]
            ax[j + 1].scatter(dist_in_state_space,
                              dist_in_rep_space,
                              color='#b40426',
                              alpha=0.2,
                              linewidths=0.5)
            ax[j + 1].set_xlabel("D(s,s')")
            ax[j + 1].set_ylim([0.4, 1.1])
            if j != 0:
                ax[j + 1].set_yticklabels([])
            else:
                ax[j + 1].set_ylabel("D(R(s), R(s'))")

        #a = ax[j+2].imshow((dist_in_state_space/np.nanmax(dist_in_state_space)).reshape(env.shape), cmap=cmap)
        #plt.colorbar(a, ax=ax[j+2])
        plt.savefig('../figures/CH2/distance.svg')
        plt.show()
Ejemplo n.º 8
0
ls = {}
for index, inp in state_reps.items():
    # do a forward pass
    network(inp)
    # get hidden_layer activity
    ls[index] = network.h_act.detach().numpy()[0]

test_index = 100
tensor_slices = state_reps[test_index][0].shape[0]
fig, ax = plt.subplots(1, tensor_slices)
for item in range(tensor_slices):
    ax[item].imshow(state_reps[test_index][0][item], cmap='bone_r')
    ax[item].set_aspect('equal')

latent_reps1, _, __, ___ = latents(train_env, path_to_agent, type=latent_type)
latent_reps2, _, __, ___ = latents(test_env, path_to_agent, type=latent_type)

fig, ax = plt.subplots(1, 3)
ax[0].imshow(np.asarray([latent_reps1[0]]).T, aspect='auto')
ax[1].imshow(np.asarray([latent_reps2[0]]).T, aspect='auto')
ax[2].imshow(np.asarray([latent_reps2[0] - latent_reps1[0]]).T, aspect='auto')
plt.show()


def plot_inputs_and_latents(env_version, latent_type, test_index):
    training_env_name = f'gridworld:gridworld-v{env_version}'
    testing_env_name = training_env_name + '1'
    env_name = testing_env_name
    env = gym.make(env_name)
    plt.close()