def generate_saved_latents(): for latent_type in ['conv','rwd_conv']: for version in [1,2,3,4,5]: # get agent ids to load/save/generate latents example_run_ids = ids[latent_type] # get environment env_id = f'gridworld:gridworld-v{version}' run_id = example_run_ids[env_id] # make sure saved agent is in the form of a state_dict of weights instead of the agent object try: convert_agent_to_weight_dict(f'../../Data/network_objs/{run_id}.pt',destination_path=f'./../../Data/agents/{run_id}.pt') except: pass # make gym environment env = gym.make(env_id) plt.close() # save latents by loading network, passing appropriate tensor, getting top fc layer activity reps, name, dim, _ = latents(env, f'./../../Data/agents/{run_id}.pt', type=latent_type ) latent_array = np.zeros((env.nstates,env.nstates)) for i in reps.keys(): latent_array[i] = reps[i] if save: with open(f'../../modules/Agents/RepresentationLearning/Learned_Rep_pickles/{latent_type}{run_id[0:8]}_{env_id[-12:]}.p', 'wb') as f: pickle.dump(file=f, obj=latent_array.copy())
def get_mem_map(gb, env_name, rep, pct, ind=0): example_env = gym.make(env_name) plt.close() rep_types = { 'onehot': onehot, 'random': random, 'place_cell': place_cell, 'analytic successor': sr } if rep == 'latents': conv_ids = { 'gridworld:gridworld-v1': 'c34544ac-45ed-492c-b2eb-4431b403a3a8', 'gridworld:gridworld-v3': '32301262-cd74-4116-b776-57354831c484', 'gridworld:gridworld-v4': 'b50926a2-0186-4bb9-81ec-77063cac6861', 'gridworld:gridworld-v5': '15b5e27b-444f-4fc8-bf25-5b7807df4c7f' } load_id = conv_ids[f'{env_name[:-1]}'] agent_path = parent_path + f'agents/{load_id}.pt' state_reps, representation_name, input_dims, _ = latents( example_env, agent_path) else: state_reps, representation_name, input_dims, _ = rep_types[rep]( example_env) run_id = list(gb.get_group( (env_name, rep, cache_limits[env_name][pct])))[ind] policy_map = np.zeros(example_env.shape, dtype=[(x, 'f8') for x in example_env.action_list]) with open(parent_path + f'ec_dicts/{run_id}_EC.p', 'rb') as f: cache_list = pickle.load(f) mem = Memory( entry_size=example_env.action_space.n, cache_limit=400 ) # cache limit doesn't matter since we are only using for recall mem.cache_list = cache_list for state2d in example_env.useable: state1d = example_env.twoD2oneD(state2d) state_rep = tuple(state_reps[state1d]) #print(state_rep in cache_list.keys()) policy_map[state2d] = tuple(mem.recall_mem(state_rep)) return policy_map
def plot_inputs_and_latents(env_version, latent_type, test_index): training_env_name = f'gridworld:gridworld-v{env_version}' testing_env_name = training_env_name + '1' env_name = testing_env_name env = gym.make(env_name) plt.close() # get inputs states if latent_type == 'conv': inputs, _, __, ___ = convs(env) elif latent_type == 'rwd_conv': inputs, _, __, ___ = reward_convs(env) tensor_slices = inputs[test_index][0].shape[0] fig, ax = plt.subplots(1, tensor_slices + 1) for item in range(tensor_slices): ax[item].imshow(inputs[test_index][0][item], cmap='bone_r') ax[item].set_aspect('equal') plt.show() example_ids = ids[latent_type] run_id = example_ids[training_env_name] # get corresponding latent states path_to_agent = f'./../../../Data/agents/{run_id}.pt' empty = head_AC(400, 4, lr=0.005) full = load_saved_head_weights(empty, path_to_agent) state_reps, name, dim, _ = latents(env, path_to_agent, type=latent_type) policy_map = np.zeros(env.shape, dtype=[(x, 'f8') for x in env.action_list]) for state2d in env.useable: latent_state = state_reps[env.twoD2oneD(state2d)] pol, val = full(latent_state) policy_map[state2d] = tuple(pol) plot_polmap(env, policy_map)
rep_types = { 'onehot': onehot, 'random': random, 'place_cell': place_cell, 'sr': sr } if rep_type == 'latents': conv_ids = { 'gridworld:gridworld-v1': 'c34544ac-45ed-492c-b2eb-4431b403a3a8', 'gridworld:gridworld-v3': '32301262-cd74-4116-b776-57354831c484', 'gridworld:gridworld-v4': 'b50926a2-0186-4bb9-81ec-77063cac6861', 'gridworld:gridworld-v5': '15b5e27b-444f-4fc8-bf25-5b7807df4c7f' } run_id = conv_ids[f'gridworld:gridworld-v{version}'] agent_path = relative_path_to_data + f'agents/saved_agents/{run_id}.pt' state_reps, representation_name, input_dims, _ = latents(env, agent_path) else: state_reps, representation_name, input_dims, _ = rep_types[rep_type](env) AC_head_agent = head_AC(input_dims, env.action_space.n, lr=learning_rate) memory = Memory(entry_size=env.action_space.n, cache_limit=cache_size_for_env, distance=distance_metric) agent = Agent(AC_head_agent, memory=memory, state_representations=state_reps) ex = flat_expt(agent, env) print( f"Experiment running {env.unwrapped.spec.id} \nRepresentation: {representation_name} \nCache Limit:{cache_size_for_env} \nDistance: {distance_metric}" )
plt.close() # make new env to run test in test_env = gym.make(test_env_name) plt.close() if latent_type == 'conv' or latent_type == 'rwd_conv': ids = {'conv': conv_ids, 'rwd_conv': rwd_conv_ids} id_dict = ids[latent_type] run_id = id_dict[training_env_name] # load latent states to use as state representations to actor-critic heads agent_path = relative_path_to_data + f'agents/{run_id}.pt' # save latents by loading network, passing appropriate tensor, getting top fc layer activity state_reps, representation_name, input_dims, _ = latents(train_env, agent_path, type=latent_type) elif latent_type in ['sr', 'onehot']: rep_Type = {'sr': sr, 'onehot': onehot} state_reps, representation_name, input_dims, _ = rep_Type[latent_type]( test_env) if load_weights: # load weights to head_ac network from previously learned agent empty_net = head_AC(input_dims, test_env.action_space.n, lr=learning_rate) AC_head_agent = load_saved_head_weights(empty_net, agent_path) loaded_from = run_id else: AC_head_agent = head_AC(input_dims, test_env.action_space.n,
def plot_squares(what_to_plot='all_state_rep', current_cmap='viridis_r'): for test_env_name in [envs_to_plot[1]]: sim_ind = 169 envno = envs_to_plot.index(test_env_name) #test_env_name = envs_to_plot[env_id] # make new env to run test in env = gym.make(test_env_name) plt.close() #plot_world(env,plotNow=True,scale=0.4,states=True) rep_types = { 'random': random, 'onehot': onehot, 'place_cell': place_cell, 'sr': sr } #, 'latents':latents} fig, ax = plt.subplots(1, len(list(rep_types.items())) + 1, figsize=(14, 2)) if test_env_name[-2:] == '51': rwd_colrow = (16, 9) else: rwd_colrow = (14, 14) rect = plt.Rectangle(rwd_colrow, 1, 1, color='g', alpha=0.3) agt_colrow = (env.oneD2twoD(sim_ind)[1] + 0.5, env.oneD2twoD(sim_ind)[0] + 0.5) circ = plt.Circle(agt_colrow, radius=0.3, color='blue') ax[0].pcolor(grids[envno], cmap='bone_r', edgecolors='k', linewidths=0.1) ax[0].axis(xmin=0, xmax=20, ymin=0, ymax=20) ax[0].set_aspect('equal') ax[0].add_patch(rect) ax[0].add_patch(circ) ax[0].get_xaxis().set_visible(False) ax[0].get_yaxis().set_visible(False) ax[0].invert_yaxis() for j, rep_type in enumerate(rep_types.keys()): relative_path_to_data = '../../Data/' # from within Tests/CH1 if rep_type == 'latents': conv_ids = { 'gridworld:gridworld-v1': 'c34544ac-45ed-492c-b2eb-4431b403a3a8', 'gridworld:gridworld-v3': '32301262-cd74-4116-b776-57354831c484', 'gridworld:gridworld-v4': 'b50926a2-0186-4bb9-81ec-77063cac6861', 'gridworld:gridworld-v5': '15b5e27b-444f-4fc8-bf25-5b7807df4c7f' } run_id = conv_ids[test_env_name[:-1]] agent_path = relative_path_to_data + f'agents/{run_id}.pt' state_reps, representation_name, input_dims, _ = latents( env, agent_path) else: state_reps, representation_name, input_dims, _ = rep_types[ rep_type](env) reps_as_matrix = np.zeros((400, 400)) reps_as_matrix[:] = np.nan for ind, (k, v) in enumerate(state_reps.items()): reps_as_matrix[k] = v if k in env.obstacle: reps_as_matrix[k, :] = np.nan RS = squareform(pdist(reps_as_matrix, metric='chebyshev')) for state2d in env.obstacle: RS[state2d, :] = np.nan RS[:, state2d] = np.nan #plot representation of a single state if what_to_plot == 'single_state_rep': #current_cmap = cmaps_mappings_r[rep_type] a = ax[j + 1].imshow( (reps_as_matrix[sim_ind] / np.nanmax(reps_as_matrix)).reshape(env.shape), vmin=0, vmax=1, cmap=current_cmap) elif what_to_plot == 'all_state_rep': #current_cmap = cmaps_mappings_r[rep_type] a = ax[j + 1].imshow(reps_as_matrix / np.nanmax(reps_as_matrix), vmin=0, vmax=1, cmap=current_cmap) elif what_to_plot == 'all_state_sim_sliced': #current_cmap = cmaps_mappings_r[rep_type] sliced = RS.copy() print(np.argwhere(np.isnan(sliced))) for state1d in reversed(env.obstacle): sliced = np.delete(sliced, state1d, 0) sliced = np.delete(sliced, state1d, 1) a = ax[j + 1].imshow(sliced / np.nanmax(sliced), vmin=0, vmax=1, cmap=current_cmap) elif what_to_plot == 'single_state_sim': #current_cmap = cmaps_mappings[rep_type] a = ax[j + 1].imshow(RS[sim_ind].reshape(env.shape) / np.nanmax(RS), vmin=0, vmax=1, cmap=current_cmap) elif what_to_plot == 'all_state_sim': #current_cmap = cmaps_mappings[rep_type] a = ax[j + 1].imshow(RS / np.nanmax(RS), vmin=0, vmax=1, cmap=current_cmap) if j == len(list(rep_types.keys())) - 1: divider = make_axes_locatable(ax[j + 1]) cax = divider.append_axes('right', size='5%', pad=0.05) plt.colorbar(a, cax=cax) ax[j + 1].get_xaxis().set_visible(False) ax[j + 1].get_yaxis().set_visible(False) plt.savefig(f'../figures/CH2/{test_env_name[-3:]}_{what_to_plot}.svg') plt.show()
def plot_dist_to_neighbours(): for test_env_name in [envs_to_plot[1]]: sim_ind = 169 envno = envs_to_plot.index(test_env_name) # make new env to run test in env = gym.make(test_env_name) plt.close() # make graph of env states G = make_env_graph(env) gd = compute_graph_distance_matrix(G, env) dist_in_state_space = np.delete( gd[sim_ind], sim_ind) #distance from sim ind to all other states rep_types = { 'random': random, 'onehot': onehot, 'place_cell': place_cell, 'sr': sr } #, 'latents':latents} fig, ax = plt.subplots(1, len(list(rep_types.items())) + 1, figsize=(14, 2)) if test_env_name[-2:] == '51': rwd_colrow = (16, 9) else: rwd_colrow = (14, 14) rect = plt.Rectangle(rwd_colrow, 1, 1, color='g', alpha=0.3) agt_colrow = (env.oneD2twoD(sim_ind)[1] + 0.5, env.oneD2twoD(sim_ind)[0] + 0.5) circ = plt.Circle(agt_colrow, radius=0.3, color='blue') ax[0].pcolor(grids[envno], cmap='bone_r', edgecolors='k', linewidths=0.1) ax[0].axis(xmin=0, xmax=20, ymin=0, ymax=20) ax[0].set_aspect('equal') ax[0].add_patch(rect) ax[0].add_patch(circ) ax[0].get_xaxis().set_visible(False) ax[0].get_yaxis().set_visible(False) ax[0].invert_yaxis() cmap = linc_coolwarm_r #cm.get_cmap('coolwarm') for j, rep_type in enumerate(rep_types.keys()): relative_path_to_data = '../../Data/' # from within Tests/CH1 if rep_type == 'latents': conv_ids = { 'gridworld:gridworld-v1': 'c34544ac-45ed-492c-b2eb-4431b403a3a8', 'gridworld:gridworld-v3': '32301262-cd74-4116-b776-57354831c484', 'gridworld:gridworld-v4': 'b50926a2-0186-4bb9-81ec-77063cac6861', 'gridworld:gridworld-v5': '15b5e27b-444f-4fc8-bf25-5b7807df4c7f' } run_id = conv_ids[test_env_name[:-1]] agent_path = relative_path_to_data + f'agents/{run_id}.pt' state_reps, representation_name, input_dims, _ = latents( env, agent_path) else: state_reps, representation_name, input_dims, _ = rep_types[ rep_type](env) reps_as_matrix = np.zeros((400, 400)) reps_as_matrix[:] = np.nan for ind, (k, v) in enumerate(state_reps.items()): reps_as_matrix[k] = v if k in env.obstacle: reps_as_matrix[k, :] = np.nan RS = squareform(pdist(reps_as_matrix, metric='chebyshev')) for state2d in env.obstacle: RS[state2d, :] = np.nan RS[:, state2d] = np.nan dist_in_rep_space = np.delete(RS[sim_ind], sim_ind) print(type(dist_in_rep_space)) rgba = [ cmap(x) for x in dist_in_rep_space / np.nanmax(dist_in_rep_space) ] ax[j + 1].scatter(dist_in_state_space, dist_in_rep_space, color='#b40426', alpha=0.2, linewidths=0.5) ax[j + 1].set_xlabel("D(s,s')") ax[j + 1].set_ylim([0.4, 1.1]) if j != 0: ax[j + 1].set_yticklabels([]) else: ax[j + 1].set_ylabel("D(R(s), R(s'))") #a = ax[j+2].imshow((dist_in_state_space/np.nanmax(dist_in_state_space)).reshape(env.shape), cmap=cmap) #plt.colorbar(a, ax=ax[j+2]) plt.savefig('../figures/CH2/distance.svg') plt.show()
ls = {} for index, inp in state_reps.items(): # do a forward pass network(inp) # get hidden_layer activity ls[index] = network.h_act.detach().numpy()[0] test_index = 100 tensor_slices = state_reps[test_index][0].shape[0] fig, ax = plt.subplots(1, tensor_slices) for item in range(tensor_slices): ax[item].imshow(state_reps[test_index][0][item], cmap='bone_r') ax[item].set_aspect('equal') latent_reps1, _, __, ___ = latents(train_env, path_to_agent, type=latent_type) latent_reps2, _, __, ___ = latents(test_env, path_to_agent, type=latent_type) fig, ax = plt.subplots(1, 3) ax[0].imshow(np.asarray([latent_reps1[0]]).T, aspect='auto') ax[1].imshow(np.asarray([latent_reps2[0]]).T, aspect='auto') ax[2].imshow(np.asarray([latent_reps2[0] - latent_reps1[0]]).T, aspect='auto') plt.show() def plot_inputs_and_latents(env_version, latent_type, test_index): training_env_name = f'gridworld:gridworld-v{env_version}' testing_env_name = training_env_name + '1' env_name = testing_env_name env = gym.make(env_name) plt.close()