def get_mem_maps(): ec_maps = {} for key in master_dict.keys(): env = gym.make(key[:-1]) print(env.rewards.keys(), "reward at ") plt.close() print(key, len(env.useable)) ec_maps[key] = {} latents, _, __, ___ = load_saved_latents(env) for j, cache_size in enumerate(master_dict[key].keys()): print(j, cache_size) v_list = master_dict[key][cache_size] policy_map = np.zeros(env.shape, dtype=[(x, 'f8') for x in env.action_list]) # load_ec dict with open(data_dir + f'ec_dicts/{v_list[0]}_EC.p', 'rb') as f: cache_list = pickle.load(f) mem = Memory(entry_size=env.action_space.n, cache_limit=400) mem.cache_list = cache_list for state2d in env.useable: state1d = env.twoD2oneD(state2d) state_rep = latents[state1d] policy_map[state2d] = tuple(mem.recall_mem(tuple(state_rep))) ec_maps[key][cache_size] = policy_map return ec_maps
def get_xy_maps(env, state_reps, data,trial_num=-1): blank_mem = Memory(cache_limit=400, entry_size=4) blank_mem.cache_list = data['ec_dicts'][trial_num] xy_pol_grid = np.zeros((2,*env.shape)) polar_grid = np.zeros(env.shape) polar_grid[:]=np.nan dxdy = np.array([(0,1),(0,-1),(1,0),(-1,0)]) #D U R L for key, value in state_reps.items(): twoD = env.oneD2twoD(key) sr_rep = value pol = blank_mem.recall_mem(sr_rep) xy = np.dot(pol,dxdy) xy_pol_grid[0,twoD] = xy[0] xy_pol_grid[1,twoD] = xy[1] '''rads = np.arctan(xy[1]/xy[0]) degs = rads*(180/np.pi) if xy[0]>=0 and xy[1]>=0: #Q1 theta = degs elif xy[0]<0: #Q2 and Q3 theta = degs+180 elif xy[0]>=0 and xy[1]<=0: theta = degs+360 else: theta = -1 polar_grid[twoD] = theta''' return xy_pol_grid,polar_grid
def reconstruct_policy_map(ec_dict): blank_mem = Memory(cache_limit=400, entry_size=4) blank_mem.cache_list = ec_dict policy_array = np.zeros((*env.shape,4)) policy_array[:] = np.nan for key, value in state_reps.items(): twoD = env.oneD2twoD(key) sr_rep = value pol = blank_mem.recall_mem(sr_rep) policy_array[twoD] = pol return policy_array
def sample_from_ec_pol(state_reps, ec_dict,**kwargs): blank_mem = Memory(cache_limit=400, entry_size=4) blank_mem.cache_list = ec_dict start_state = kwargs.get('start_state',np.random.choice(list(state_reps.keys()))) trajectory = [] env.set_state(start_state) state = start_state trajectory.append(env.oneD2twoD(state)) for i in range(250): policy = blank_mem.recall_mem(state_reps[state]) action = np.random.choice(np.arange(4),p=policy) next_state, reward, done, info = env.step(action) state = next_state trajectory.append(env.oneD2twoD(state)) if reward ==10.: break return trajectory
def reconstruct_xy_map(ec_dict): blank_mem = Memory(cache_limit=400, entry_size=4) blank_mem.cache_list = ec_dict xy_array = np.zeros((*env.shape,2)) xy_array[:] = np.nan dxdy = np.array([(0,1),(0,-1),(1,0),(-1,0)]) #D U R L for key, value in state_reps.items(): twoD = env.oneD2twoD(key) sr_rep = value pol = blank_mem.recall_mem(sr_rep) xycoords = np.dot(pol,dxdy) xy_array[twoD] = xycoords return xy_array
def get_mem_map(gb, env_name, rep, pct, ind=0): example_env = gym.make(env_name) plt.close() rep_types = { 'onehot': onehot, 'random': random, 'place_cell': place_cell, 'analytic successor': sr } if rep == 'latents': conv_ids = { 'gridworld:gridworld-v1': 'c34544ac-45ed-492c-b2eb-4431b403a3a8', 'gridworld:gridworld-v3': '32301262-cd74-4116-b776-57354831c484', 'gridworld:gridworld-v4': 'b50926a2-0186-4bb9-81ec-77063cac6861', 'gridworld:gridworld-v5': '15b5e27b-444f-4fc8-bf25-5b7807df4c7f' } load_id = conv_ids[f'{env_name[:-1]}'] agent_path = parent_path + f'agents/{load_id}.pt' state_reps, representation_name, input_dims, _ = latents( example_env, agent_path) else: state_reps, representation_name, input_dims, _ = rep_types[rep]( example_env) run_id = list(gb.get_group( (env_name, rep, cache_limits[env_name][pct])))[ind] policy_map = np.zeros(example_env.shape, dtype=[(x, 'f8') for x in example_env.action_list]) with open(parent_path + f'ec_dicts/{run_id}_EC.p', 'rb') as f: cache_list = pickle.load(f) mem = Memory( entry_size=example_env.action_space.n, cache_limit=400 ) # cache limit doesn't matter since we are only using for recall mem.cache_list = cache_list for state2d in example_env.useable: state1d = example_env.twoD2oneD(state2d) state_rep = tuple(state_reps[state1d]) #print(state_rep in cache_list.keys()) policy_map[state2d] = tuple(mem.recall_mem(state_rep)) return policy_map
def get_avg_incidence_of_memories(data): n_visits = np.zeros(env.shape) ec_pol_grid = np.zeros(env.shape, dtype=[(x, 'f8') for x in env.action_list]) for i in range(len(data['ec_dicts'])): print(i) blank_mem = Memory(cache_limit=400, entry_size=4) blank_mem.cache_list = data['ec_dicts'][i] states = [] for k, key in enumerate(blank_mem.cache_list.keys()): twoD = env.oneD2twoD(blank_mem.cache_list[key][2]) old_policy = ec_pol_grid[twoD] current_policy = blank_mem.recall_mem(key) average = [] for x,y in zip(old_policy, current_policy): z = x + (y-x)/(k+1) average.append(z) ec_pol_grid[twoD] = tuple(average) n_visits[twoD]+=1
def get_mem_maps(data,trial_num=-1,full_mem=True): blank_mem = Memory(cache_limit=400, entry_size=4) blank_mem.cache_list = data['ec_dicts'][trial_num] ec_pol_grid = np.zeros(env.shape, dtype=[(x, 'f8') for x in env.action_list]) if full_mem: for key, value in state_reps.items(): twoD = env.oneD2twoD(key) sr_rep = value pol = blank_mem.recall_mem(sr_rep) ec_pol_grid[twoD] = tuple(pol) else: for ec_key in blank_mem.cache_list.keys(): twoD = env.oneD2twoD(blank_mem.cache_list[ec_key][2]) pol = blank_mem.recall_mem(ec_key) ec_pol_grid[twoD] = tuple(pol) return ec_pol_grid
def reconstruct_mem(env, ec_dicts,state_reps, index): ec_dict = ec_dicts[index] blank_mem = Memory(cache_limit=400, entry_size=4) blank_mem.cache_list = ec_dict ec_pols = np.zeros((20, 20), dtype=[(x,'f8') for x in env.action_list]) #ec_pols[:] = np.nan ec_vals = np.zeros((20,20)) ec_vals[:] = np.nan for ind in env.useable: #for dict_key in list(ec_dict.keys()): index_1d = env.twoD2oneD(ind) #ec_dict[dict_key][-1] index_2d = ind#env.oneD2twoD(index_1d) dict_key = state_reps[index_1d] nearest_act, d = blank_mem.similarity_measure(dict_key) val = np.nanmean(ec_dict[nearest_act][0][:,0]) pol = blank_mem.recall_mem(dict_key) ec_vals[index_2d] = val ec_pols[index_2d] = tuple(pol) return ec_vals, ec_pols
def get_ec_policy_map(env, state_reps, data, trial_num, full=True): ec_pol_grid = np.zeros(env.shape, dtype=[(x, 'f8') for x in env.action_list]) blank_mem = Memory(cache_limit=400, entry_size=4) blank_mem.cache_list = data['ec_dicts'][trial_num] if full: for k in state_reps.keys(): twoD = env.oneD2twoD(k) sr_rep = state_reps[k] pol = blank_mem.recall_mem(sr_rep) ec_pol_grid[twoD] = tuple(pol) else: for k in blank_mem.cache_list.keys(): oneDstate = blank_mem.cache_list[k][2] twoD = env.oneD2twoD(oneDstate) sr_rep = state_reps[oneDstate] pol = blank_mem.recall_mem(sr_rep) ec_pol_grid[twoD] = tuple(pol) return ec_pol_grid
def get_KLD(data,probe_state,trial_num): probe_rep = state_reps[probe_state] KLD_array = np.zeros(env.shape) KLD_array[:] = np.nan entropy_array = np.zeros(env.shape) entropy_array[:] = np.nan ec_pol_grid = np.zeros((*env.shape,4))#np.zeros(env.shape, dtype=[(x, 'f8') for x in env.action_list]) blank_mem = Memory(cache_limit=400, entry_size=4) blank_mem.cache_list = data['ec_dicts'][trial_num] probe_pol = blank_mem.recall_mem(probe_rep) #for k in state_reps.keys(): #sr_rep = state_reps[k] for sr_rep in blank_mem.cache_list.keys(): k = blank_mem.cache_list[sr_rep][2] pol = blank_mem.recall_mem(sr_rep) twoD = env.oneD2twoD(k) KLD_array[twoD] = sum(rel_entr(list(probe_pol),list(pol))) ec_pol_grid[twoD][:] = pol entropy_array[twoD] = entropy(pol,base=2) return KLD_array,ec_pol_grid,entropy_array
conv_ids = { 'gridworld:gridworld-v1': 'c34544ac-45ed-492c-b2eb-4431b403a3a8', 'gridworld:gridworld-v3': '32301262-cd74-4116-b776-57354831c484', 'gridworld:gridworld-v4': 'b50926a2-0186-4bb9-81ec-77063cac6861', 'gridworld:gridworld-v5': '15b5e27b-444f-4fc8-bf25-5b7807df4c7f' } run_id = conv_ids[f'gridworld:gridworld-v{version}'] agent_path = relative_path_to_data + f'agents/saved_agents/{run_id}.pt' state_reps, representation_name, input_dims, _ = latents(env, agent_path) else: state_reps, representation_name, input_dims, _ = rep_types[rep_type](env) AC_head_agent = head_AC(input_dims, env.action_space.n, lr=learning_rate) memory = Memory(entry_size=env.action_space.n, cache_limit=cache_size_for_env, distance=distance_metric) agent = Agent(AC_head_agent, memory=memory, state_representations=state_reps) ex = flat_expt(agent, env) print( f"Experiment running {env.unwrapped.spec.id} \nRepresentation: {representation_name} \nCache Limit:{cache_size_for_env} \nDistance: {distance_metric}" ) ex.run(num_trials, num_events, snapshot_logging=False) ex.record_log(env_name=test_env_name, representation_type=representation_name, n_trials=num_trials, n_steps=num_events, dir=relative_path_to_data, file=write_to_file)
for _ in range(1): ## generate the environment object env = gym.make(env_name) plt.close() ## get state representations to be used state_reps, representation_name, input_dims, _ = rep_types[ representation_type](env) ## create an actor-critic network and associated agent network = Network(input_dims=[input_dims], fc1_dims=200, fc2_dims=200, output_dims=env.action_space.n, lr=0.0005) memory = Memory(entry_size=env.action_space.n, cache_limit=400, mem_temp=1) agent = Agent(network, state_representations=state_reps, memory=memory) # create an experiment class instance ex = expt(agent, env) ex.run(num_trials, num_events) ex.record_log(env_name=env_name, representation_type=representation_name, n_trials=num_trials, n_steps=num_events, dir='../../Data/', file=write_to_file) ''' # print results of training
from modules.Utils import one_hot_state # get environment env_name = 'gridworld:gridworld-v4' env = gym.make(env_name) plt.close() # make collection of one-hot state representations oh_state_reps = {} for state in env.useable: oh_state_reps[env.twoD2oneD(state)] = one_hot_state(env,env.twoD2oneD(state)) input_dims = len(oh_state_reps[list(oh_state_reps.keys())[0]]) network = Network(input_dims=[input_dims],fc1_dims=200,fc2_dims=200,output_dims=env.action_space.n, lr=0.0005) memory = Memory(entry_size=env.action_space.n, cache_limit=env.nstates) agent = Agent(network, memory, state_representations=oh_state_reps) ex = expt(agent,env) num_trials = 2000 num_events = 250 ex.run(num_trials, num_events) # print results of training fig, ax = plt.subplots(2,1, sharex=True) ax[0].plot(ex.data['total_reward']) ax[1].plot(ex.data['loss'][0], label='P_loss') ax[1].plot(ex.data['loss'][1], label='V_loss')
ref_key = list(e_cache.keys())[0] ref_state = e_cache[ref_key][2] print("reference state: ", ref_state, env.oneD2twoD(ref_state)) entry = np.asarray(ref_key) mem_cache = np.asarray(list(e_cache.keys())) sts_ids = [e_cache[x][2] for x in list(e_cache.keys())] distance_cos = cdist([entry], mem_cache, metric='chebyshev')[0] print(min(distance_cos), "closest state is", e_cache[tuple(mem_cache[np.argmin(distance_cos)])][2]) dist = np.zeros(env.nstates) dist[:] = np.nan for ind, item in enumerate(sts_ids): dist[item] = distance_cos[ind] fig, ax = plt.subplots(1, 1) ax.imshow(dist.reshape(env.shape)) plt.show() memory = Memory(cache_limit=env.nstates, entry_size=env.action_space.n, distance='euclidean') memory.cache_list = e_cache closest_key, min_dist = memory.similarity_measure(ref_key) print( f'ep recall closest dist = {min_dist}, state = {memory.cache_list[closest_key][2]}' )
'gridworld:gridworld-v41': { 100: 384, 75: 288, 50: 192, 25: 96 }, 'gridworld:gridworld-v51': { 100: 286, 75: 214, 50: 143, 25: 71 } } cache_size_for_env = cache_limits[test_env_name][cache_size] print(cache_size_for_env) memory = Memory(entry_size=test_env.action_space.n, cache_limit=cache_size_for_env) agent = Agent(AC_head_agent, memory=memory, state_representations=state_reps) ex = expt(agent, test_env) ex.run(100, 250, printfreq=1) for i in ex.trajectories: print(i) fig, ax = plt.subplots(1, 1) ax.pcolor(test_env.grid, cmap='bone_r', edgecolors='k', linewidths=0.1) ax.set_aspect('equal') ax.invert_yaxis() ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False)
from modules.Agents import Agent from modules.Experiments import flat_expt sys.path.append('../../../') write_to_file = 'flat_ac_training.csv' version = 1 env_name = f'gridworld:gridworld-v{version}' representation_type = 'latent' num_trials = 5000 num_events = 250 # make gym environment env = gym.make(env_name) plt.close() state_reps, representation_name, input_dims, _ = rep_types[ representation_type](env) for _ in range(1): empty_net = head_AC(input_dims, env.action_space.n, lr=0.0005) memory = Memory(entry_size=4, cache_limit=400) agent = Agent(empty_net, memory, state_representations=state_reps) ex = flat_expt(agent, env) ex.run(num_trials, num_events, snapshot_logging=False) ex.record_log(env_name=env_name, representation_type=representation_name, n_trials=num_trials, n_steps=num_events, dir='./Data/', file=write_to_file)