Beispiel #1
0
def get_mem_maps():
    ec_maps = {}
    for key in master_dict.keys():
        env = gym.make(key[:-1])
        print(env.rewards.keys(), "reward at ")
        plt.close()
        print(key, len(env.useable))
        ec_maps[key] = {}
        latents, _, __, ___ = load_saved_latents(env)

        for j, cache_size in enumerate(master_dict[key].keys()):
            print(j, cache_size)
            v_list = master_dict[key][cache_size]
            policy_map = np.zeros(env.shape,
                                  dtype=[(x, 'f8') for x in env.action_list])
            # load_ec dict
            with open(data_dir + f'ec_dicts/{v_list[0]}_EC.p', 'rb') as f:
                cache_list = pickle.load(f)

            mem = Memory(entry_size=env.action_space.n, cache_limit=400)
            mem.cache_list = cache_list

            for state2d in env.useable:
                state1d = env.twoD2oneD(state2d)
                state_rep = latents[state1d]

                policy_map[state2d] = tuple(mem.recall_mem(tuple(state_rep)))

            ec_maps[key][cache_size] = policy_map
    return ec_maps
Beispiel #2
0
def get_xy_maps(env, state_reps, data,trial_num=-1):
    blank_mem = Memory(cache_limit=400, entry_size=4)
    blank_mem.cache_list = data['ec_dicts'][trial_num]
    xy_pol_grid = np.zeros((2,*env.shape))
    polar_grid = np.zeros(env.shape)
    polar_grid[:]=np.nan
    dxdy = np.array([(0,1),(0,-1),(1,0),(-1,0)]) #D U R L

    for key, value in state_reps.items():
        twoD = env.oneD2twoD(key)
        sr_rep = value
        pol = blank_mem.recall_mem(sr_rep)
        xy = np.dot(pol,dxdy)
        xy_pol_grid[0,twoD] = xy[0]
        xy_pol_grid[1,twoD] = xy[1]
        '''rads = np.arctan(xy[1]/xy[0])
        degs = rads*(180/np.pi)
        if xy[0]>=0 and xy[1]>=0: #Q1
            theta = degs
        elif xy[0]<0: #Q2 and Q3
            theta = degs+180
        elif xy[0]>=0 and xy[1]<=0:
            theta = degs+360
        else:
            theta = -1

        polar_grid[twoD] = theta'''

    return xy_pol_grid,polar_grid
Beispiel #3
0
def reconstruct_policy_map(ec_dict):
    blank_mem = Memory(cache_limit=400, entry_size=4)
    blank_mem.cache_list = ec_dict

    policy_array = np.zeros((*env.shape,4))
    policy_array[:] = np.nan

    for key, value in state_reps.items():
        twoD = env.oneD2twoD(key)
        sr_rep = value
        pol = blank_mem.recall_mem(sr_rep)
        policy_array[twoD] = pol

    return policy_array
Beispiel #4
0
def sample_from_ec_pol(state_reps, ec_dict,**kwargs):
    blank_mem = Memory(cache_limit=400, entry_size=4)
    blank_mem.cache_list = ec_dict
    start_state = kwargs.get('start_state',np.random.choice(list(state_reps.keys())))
    trajectory = []
    env.set_state(start_state)
    state = start_state
    trajectory.append(env.oneD2twoD(state))
    for i in range(250):
        policy = blank_mem.recall_mem(state_reps[state])
        action = np.random.choice(np.arange(4),p=policy)
        next_state, reward, done, info = env.step(action)
        state = next_state
        trajectory.append(env.oneD2twoD(state))
        if reward ==10.:
            break
    return trajectory
Beispiel #5
0
def reconstruct_xy_map(ec_dict):
    blank_mem = Memory(cache_limit=400, entry_size=4)
    blank_mem.cache_list = ec_dict

    xy_array = np.zeros((*env.shape,2))
    xy_array[:] = np.nan

    dxdy = np.array([(0,1),(0,-1),(1,0),(-1,0)]) #D U R L

    for key, value in state_reps.items():
        twoD = env.oneD2twoD(key)
        sr_rep = value
        pol = blank_mem.recall_mem(sr_rep)
        xycoords = np.dot(pol,dxdy)
        xy_array[twoD] = xycoords

    return xy_array
Beispiel #6
0
def get_mem_map(gb, env_name, rep, pct, ind=0):
    example_env = gym.make(env_name)
    plt.close()

    rep_types = {
        'onehot': onehot,
        'random': random,
        'place_cell': place_cell,
        'analytic successor': sr
    }
    if rep == 'latents':
        conv_ids = {
            'gridworld:gridworld-v1': 'c34544ac-45ed-492c-b2eb-4431b403a3a8',
            'gridworld:gridworld-v3': '32301262-cd74-4116-b776-57354831c484',
            'gridworld:gridworld-v4': 'b50926a2-0186-4bb9-81ec-77063cac6861',
            'gridworld:gridworld-v5': '15b5e27b-444f-4fc8-bf25-5b7807df4c7f'
        }
        load_id = conv_ids[f'{env_name[:-1]}']
        agent_path = parent_path + f'agents/{load_id}.pt'
        state_reps, representation_name, input_dims, _ = latents(
            example_env, agent_path)
    else:
        state_reps, representation_name, input_dims, _ = rep_types[rep](
            example_env)

    run_id = list(gb.get_group(
        (env_name, rep, cache_limits[env_name][pct])))[ind]
    policy_map = np.zeros(example_env.shape,
                          dtype=[(x, 'f8') for x in example_env.action_list])

    with open(parent_path + f'ec_dicts/{run_id}_EC.p', 'rb') as f:
        cache_list = pickle.load(f)

    mem = Memory(
        entry_size=example_env.action_space.n, cache_limit=400
    )  # cache limit doesn't matter since we are only using for recall
    mem.cache_list = cache_list

    for state2d in example_env.useable:
        state1d = example_env.twoD2oneD(state2d)
        state_rep = tuple(state_reps[state1d])
        #print(state_rep in cache_list.keys())

        policy_map[state2d] = tuple(mem.recall_mem(state_rep))

    return policy_map
Beispiel #7
0
def get_avg_incidence_of_memories(data):
    n_visits = np.zeros(env.shape)
    ec_pol_grid = np.zeros(env.shape, dtype=[(x, 'f8') for x in env.action_list])
    for i in range(len(data['ec_dicts'])):
        print(i)
        blank_mem = Memory(cache_limit=400, entry_size=4)
        blank_mem.cache_list = data['ec_dicts'][i]
        states = []
        for k, key in enumerate(blank_mem.cache_list.keys()):
            twoD = env.oneD2twoD(blank_mem.cache_list[key][2])
            old_policy = ec_pol_grid[twoD]
            current_policy = blank_mem.recall_mem(key)

            average = []
            for x,y in zip(old_policy, current_policy):
                z = x + (y-x)/(k+1)
                average.append(z)
            ec_pol_grid[twoD] = tuple(average)

            n_visits[twoD]+=1
Beispiel #8
0
def get_mem_maps(data,trial_num=-1,full_mem=True):
    blank_mem = Memory(cache_limit=400, entry_size=4)
    blank_mem.cache_list = data['ec_dicts'][trial_num]

    ec_pol_grid = np.zeros(env.shape, dtype=[(x, 'f8') for x in env.action_list])
    if full_mem:
        for key, value in state_reps.items():
            twoD = env.oneD2twoD(key)
            sr_rep = value
            pol = blank_mem.recall_mem(sr_rep)

            ec_pol_grid[twoD] = tuple(pol)
    else:
        for ec_key in blank_mem.cache_list.keys():
            twoD = env.oneD2twoD(blank_mem.cache_list[ec_key][2])
            pol  = blank_mem.recall_mem(ec_key)

            ec_pol_grid[twoD] = tuple(pol)

    return ec_pol_grid
def reconstruct_mem(env, ec_dicts,state_reps, index):
    ec_dict = ec_dicts[index]
    blank_mem = Memory(cache_limit=400, entry_size=4)
    blank_mem.cache_list = ec_dict
    ec_pols = np.zeros((20, 20), dtype=[(x,'f8') for x in env.action_list])
    #ec_pols[:] = np.nan
    ec_vals = np.zeros((20,20))
    ec_vals[:] = np.nan
    for ind in env.useable: #for dict_key in list(ec_dict.keys()):

        index_1d = env.twoD2oneD(ind) #ec_dict[dict_key][-1]
        index_2d = ind#env.oneD2twoD(index_1d)

        dict_key = state_reps[index_1d]
        nearest_act, d = blank_mem.similarity_measure(dict_key)
        val = np.nanmean(ec_dict[nearest_act][0][:,0])
        pol = blank_mem.recall_mem(dict_key)
        ec_vals[index_2d] = val
        ec_pols[index_2d] = tuple(pol)

    return ec_vals, ec_pols
Beispiel #10
0
def get_ec_policy_map(env, state_reps, data, trial_num, full=True):
    ec_pol_grid = np.zeros(env.shape, dtype=[(x, 'f8') for x in env.action_list])

    blank_mem = Memory(cache_limit=400, entry_size=4)
    blank_mem.cache_list = data['ec_dicts'][trial_num]

    if full:
        for k in state_reps.keys():
            twoD = env.oneD2twoD(k)
            sr_rep = state_reps[k]
            pol = blank_mem.recall_mem(sr_rep)

            ec_pol_grid[twoD] = tuple(pol)
    else:
        for k in blank_mem.cache_list.keys():
            oneDstate = blank_mem.cache_list[k][2]
            twoD = env.oneD2twoD(oneDstate)
            sr_rep = state_reps[oneDstate]
            pol = blank_mem.recall_mem(sr_rep)

            ec_pol_grid[twoD] = tuple(pol)

    return ec_pol_grid
Beispiel #11
0
def get_KLD(data,probe_state,trial_num):
    probe_rep = state_reps[probe_state]

    KLD_array = np.zeros(env.shape)
    KLD_array[:] = np.nan
    entropy_array = np.zeros(env.shape)
    entropy_array[:] = np.nan
    ec_pol_grid = np.zeros((*env.shape,4))#np.zeros(env.shape, dtype=[(x, 'f8') for x in env.action_list])

    blank_mem = Memory(cache_limit=400, entry_size=4)
    blank_mem.cache_list = data['ec_dicts'][trial_num]
    probe_pol = blank_mem.recall_mem(probe_rep)

    #for k in state_reps.keys():
        #sr_rep = state_reps[k]
    for sr_rep in blank_mem.cache_list.keys():
        k = blank_mem.cache_list[sr_rep][2]
        pol = blank_mem.recall_mem(sr_rep)
        twoD = env.oneD2twoD(k)
        KLD_array[twoD] = sum(rel_entr(list(probe_pol),list(pol)))
        ec_pol_grid[twoD][:] = pol
        entropy_array[twoD] = entropy(pol,base=2)

    return KLD_array,ec_pol_grid,entropy_array
Beispiel #12
0
    conv_ids = {
        'gridworld:gridworld-v1': 'c34544ac-45ed-492c-b2eb-4431b403a3a8',
        'gridworld:gridworld-v3': '32301262-cd74-4116-b776-57354831c484',
        'gridworld:gridworld-v4': 'b50926a2-0186-4bb9-81ec-77063cac6861',
        'gridworld:gridworld-v5': '15b5e27b-444f-4fc8-bf25-5b7807df4c7f'
    }
    run_id = conv_ids[f'gridworld:gridworld-v{version}']
    agent_path = relative_path_to_data + f'agents/saved_agents/{run_id}.pt'
    state_reps, representation_name, input_dims, _ = latents(env, agent_path)
else:
    state_reps, representation_name, input_dims, _ = rep_types[rep_type](env)

AC_head_agent = head_AC(input_dims, env.action_space.n, lr=learning_rate)

memory = Memory(entry_size=env.action_space.n,
                cache_limit=cache_size_for_env,
                distance=distance_metric)

agent = Agent(AC_head_agent, memory=memory, state_representations=state_reps)

ex = flat_expt(agent, env)
print(
    f"Experiment running {env.unwrapped.spec.id} \nRepresentation: {representation_name} \nCache Limit:{cache_size_for_env} \nDistance: {distance_metric}"
)
ex.run(num_trials, num_events, snapshot_logging=False)
ex.record_log(env_name=test_env_name,
              representation_type=representation_name,
              n_trials=num_trials,
              n_steps=num_events,
              dir=relative_path_to_data,
              file=write_to_file)
Beispiel #13
0
for _ in range(1):
    ## generate the environment object
    env = gym.make(env_name)
    plt.close()

    ## get state representations to be used
    state_reps, representation_name, input_dims, _ = rep_types[
        representation_type](env)

    ## create an actor-critic network and associated agent
    network = Network(input_dims=[input_dims],
                      fc1_dims=200,
                      fc2_dims=200,
                      output_dims=env.action_space.n,
                      lr=0.0005)
    memory = Memory(entry_size=env.action_space.n, cache_limit=400, mem_temp=1)
    agent = Agent(network, state_representations=state_reps, memory=memory)

    # create an experiment class instance
    ex = expt(agent, env)

    ex.run(num_trials, num_events)

    ex.record_log(env_name=env_name,
                  representation_type=representation_name,
                  n_trials=num_trials,
                  n_steps=num_events,
                  dir='../../Data/',
                  file=write_to_file)
'''
# print results of training
Beispiel #14
0
from modules.Utils import one_hot_state

# get environment
env_name = 'gridworld:gridworld-v4'
env = gym.make(env_name)
plt.close()

# make collection of one-hot state representations
oh_state_reps = {}
for state in env.useable:
    oh_state_reps[env.twoD2oneD(state)] = one_hot_state(env,env.twoD2oneD(state))

input_dims = len(oh_state_reps[list(oh_state_reps.keys())[0]])

network = Network(input_dims=[input_dims],fc1_dims=200,fc2_dims=200,output_dims=env.action_space.n, lr=0.0005)
memory = Memory(entry_size=env.action_space.n, cache_limit=env.nstates)

agent = Agent(network, memory, state_representations=oh_state_reps)

ex = expt(agent,env)

num_trials = 2000
num_events = 250

ex.run(num_trials, num_events)

# print results of training
fig, ax = plt.subplots(2,1, sharex=True)
ax[0].plot(ex.data['total_reward'])
ax[1].plot(ex.data['loss'][0], label='P_loss')
ax[1].plot(ex.data['loss'][1], label='V_loss')
ref_key = list(e_cache.keys())[0]
ref_state = e_cache[ref_key][2]
print("reference state: ", ref_state, env.oneD2twoD(ref_state))

entry = np.asarray(ref_key)
mem_cache = np.asarray(list(e_cache.keys()))
sts_ids = [e_cache[x][2] for x in list(e_cache.keys())]

distance_cos = cdist([entry], mem_cache, metric='chebyshev')[0]
print(min(distance_cos), "closest state is",
      e_cache[tuple(mem_cache[np.argmin(distance_cos)])][2])

dist = np.zeros(env.nstates)
dist[:] = np.nan
for ind, item in enumerate(sts_ids):
    dist[item] = distance_cos[ind]

fig, ax = plt.subplots(1, 1)
ax.imshow(dist.reshape(env.shape))
plt.show()

memory = Memory(cache_limit=env.nstates,
                entry_size=env.action_space.n,
                distance='euclidean')
memory.cache_list = e_cache

closest_key, min_dist = memory.similarity_measure(ref_key)
print(
    f'ep recall closest dist = {min_dist}, state = {memory.cache_list[closest_key][2]}'
)
Beispiel #16
0
    'gridworld:gridworld-v41': {
        100: 384,
        75: 288,
        50: 192,
        25: 96
    },
    'gridworld:gridworld-v51': {
        100: 286,
        75: 214,
        50: 143,
        25: 71
    }
}
cache_size_for_env = cache_limits[test_env_name][cache_size]
print(cache_size_for_env)
memory = Memory(entry_size=test_env.action_space.n,
                cache_limit=cache_size_for_env)

agent = Agent(AC_head_agent, memory=memory, state_representations=state_reps)

ex = expt(agent, test_env)
ex.run(100, 250, printfreq=1)
for i in ex.trajectories:
    print(i)

fig, ax = plt.subplots(1, 1)
ax.pcolor(test_env.grid, cmap='bone_r', edgecolors='k', linewidths=0.1)
ax.set_aspect('equal')
ax.invert_yaxis()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
Beispiel #17
0
from modules.Agents import Agent
from modules.Experiments import flat_expt
sys.path.append('../../../')

write_to_file = 'flat_ac_training.csv'
version = 1
env_name = f'gridworld:gridworld-v{version}'
representation_type = 'latent'
num_trials = 5000
num_events = 250

# make gym environment
env = gym.make(env_name)
plt.close()

state_reps, representation_name, input_dims, _ = rep_types[
    representation_type](env)

for _ in range(1):
    empty_net = head_AC(input_dims, env.action_space.n, lr=0.0005)
    memory = Memory(entry_size=4, cache_limit=400)
    agent = Agent(empty_net, memory, state_representations=state_reps)

    ex = flat_expt(agent, env)
    ex.run(num_trials, num_events, snapshot_logging=False)
    ex.record_log(env_name=env_name,
                  representation_type=representation_name,
                  n_trials=num_trials,
                  n_steps=num_events,
                  dir='./Data/',
                  file=write_to_file)