def plot_inputs_and_latents(env_version, latent_type, test_index): training_env_name = f'gridworld:gridworld-v{env_version}' testing_env_name = training_env_name + '1' env_name = testing_env_name env = gym.make(env_name) plt.close() # get inputs states if latent_type == 'conv': inputs, _, __, ___ = convs(env) elif latent_type == 'rwd_conv': inputs, _, __, ___ = reward_convs(env) tensor_slices = inputs[test_index][0].shape[0] fig, ax = plt.subplots(1, tensor_slices + 1) for item in range(tensor_slices): ax[item].imshow(inputs[test_index][0][item], cmap='bone_r') ax[item].set_aspect('equal') plt.show() example_ids = ids[latent_type] run_id = example_ids[training_env_name] # get corresponding latent states path_to_agent = f'./../../../Data/agents/{run_id}.pt' empty = head_AC(400, 4, lr=0.005) full = load_saved_head_weights(empty, path_to_agent) state_reps, name, dim, _ = latents(env, path_to_agent, type=latent_type) policy_map = np.zeros(env.shape, dtype=[(x, 'f8') for x in env.action_list]) for state2d in env.useable: latent_state = state_reps[env.twoD2oneD(state2d)] pol, val = full(latent_state) policy_map[state2d] = tuple(pol) plot_polmap(env, policy_map)
'sr': sr } if rep_type == 'latents': conv_ids = { 'gridworld:gridworld-v1': 'c34544ac-45ed-492c-b2eb-4431b403a3a8', 'gridworld:gridworld-v3': '32301262-cd74-4116-b776-57354831c484', 'gridworld:gridworld-v4': 'b50926a2-0186-4bb9-81ec-77063cac6861', 'gridworld:gridworld-v5': '15b5e27b-444f-4fc8-bf25-5b7807df4c7f' } run_id = conv_ids[f'gridworld:gridworld-v{version}'] agent_path = relative_path_to_data + f'agents/saved_agents/{run_id}.pt' state_reps, representation_name, input_dims, _ = latents(env, agent_path) else: state_reps, representation_name, input_dims, _ = rep_types[rep_type](env) AC_head_agent = head_AC(input_dims, env.action_space.n, lr=learning_rate) memory = Memory(entry_size=env.action_space.n, cache_limit=cache_size_for_env, distance=distance_metric) agent = Agent(AC_head_agent, memory=memory, state_representations=state_reps) ex = flat_expt(agent, env) print( f"Experiment running {env.unwrapped.spec.id} \nRepresentation: {representation_name} \nCache Limit:{cache_size_for_env} \nDistance: {distance_metric}" ) ex.run(num_trials, num_events, snapshot_logging=False) ex.record_log(env_name=test_env_name, representation_type=representation_name, n_trials=num_trials,
# load latent states to use as state representations to actor-critic heads agent_path = relative_path_to_data + f'agents/{run_id}.pt' # save latents by loading network, passing appropriate tensor, getting top fc layer activity state_reps, representation_name, input_dims, _ = latents(train_env, agent_path, type=latent_type) elif latent_type in ['sr', 'onehot']: rep_Type = {'sr': sr, 'onehot': onehot} state_reps, representation_name, input_dims, _ = rep_Type[latent_type]( test_env) if load_weights: # load weights to head_ac network from previously learned agent empty_net = head_AC(input_dims, test_env.action_space.n, lr=learning_rate) AC_head_agent = load_saved_head_weights(empty_net, agent_path) loaded_from = run_id else: AC_head_agent = head_AC(input_dims, test_env.action_space.n, lr=learning_rate) loaded_from = ' ' cache_limits = { 'gridworld:gridworld-v11': { 100: 400, 75: 300, 50: 200, 25: 100 }, 'gridworld:gridworld-v31': {
num_trials = 25000 num_events = 250 relative_path_to_data = '../../Data' # ../../Data if you are in Tests/CH2 # make gym environment env = gym.make(env_name) plt.close() rep_types = { 'onehot': onehot, 'random': random, 'place_cell': place_cell, 'sr': sr, 'latent': latents } state_reps, representation_name, input_dims, _ = rep_types[ representation_type](env) # load weights to head_ac network from previously learned agent AC_head_agent = head_AC(input_dims, test_env.action_space.n, lr=0.0005) agent = Agent(AC_head_agent, state_representations=state_reps) ex = flat_expt(agent, test_env) ex.run(num_trials, num_events, snapshot_logging=False) ex.record_log(env_name=test_env_name, representation_type=representation_name, n_trials=num_trials, n_steps=num_events, dir=relative_path_to_data, file=write_to_file)
id_dict = ids[latent_type] run_id = id_dict[training_env_name] # create environment object -- automatically creates a plot so close it right away test_env = gym.make(test_env_name) plt.close() # load latent states to use as state representations to actor-critic heads agent_path = relative_path_to_data + f'agents/{run_id}.pt' # save latents by loading network, passing appropriate tensor, getting top fc layer activity state_reps, representation_name, input_dims, _ = latents(test_env, agent_path, type=latent_type) # load weights to head_ac network from previously learned agent empty_net = head_AC(input_dims, test_env.action_space.n, lr=0.0005) AC_head_agent = load_saved_head_weights(empty_net, agent_path) agent = Agent(AC_head_agent, state_representations=state_reps) ex = flat_expt(agent, test_env) ex.run(num_trials, num_events, snapshot_logging=False) ex.record_log(env_name=test_env_name, representation_type=representation_name, n_trials=num_trials, n_steps=num_events, load_from=run_id, dir=relative_path_to_data, file=write_to_file)