def run_watermaze(num_agents, params, lesion_hpc=False, **kwargs): g = HexWaterMaze(6) for i_a in tqdm(range(num_agents), desc='Agent'): if lesion_hpc: filename = 'spatial_agent{}_lesion'.format(i_a) elif 'inact_hpc' in kwargs: filename = 'spatial_partial_lesion_agent{}'.format(i_a) else: filename = 'spatial_agent{}'.format(i_a) if os.path.exists(os.path.join(res_dir, filename)): tqdm.write('Already done') continue #possible_platform_states = np.array([192, 185, 181, 174, 216, 210, 203, 197]) # for the r = 10 case possible_platform_states = np.array([48, 45, 42, 39, 60, 57, 54, 51]) platform_sequence = determine_platform_seq(possible_platform_states, g) # intialise agent agent = SpatialAgent(g, init_sr='rw', A_alpha=params['A_alpha'][i_a], alpha1=params['alpha1'][i_a], A_beta=params['A_beta'][i_a], beta1=params['beta1'][i_a], lesion_hpc=lesion_hpc, inact_hpc=kwargs['inact_hpc'][i_a]) agent_results = [] total_trial_count = 0 for ses in tqdm(range(11), desc='Session', leave=False): for trial in tqdm(range(4), leave=False, desc='Trial'): # every first trial of a session, change the platform location if trial == 0: g.set_platform_state(platform_sequence[ses]) res = agent.one_episode(random_policy=False) res['trial'] = trial res['escape time'] = res.time.max() res['session'] = ses res['total trial'] = total_trial_count agent_results.append(res) total_trial_count += 1 agent_df = pd.concat(agent_results) agent_df['total time'] = np.arange(len(agent_df)) agent_df.to_csv(os.path.join(res_dir, filename))
import numpy as np import matplotlib.pyplot as plt from hippocampus.agents import LandmarkCells from hippocampus.utils import angle_to_landmark from hippocampus.environments import HexWaterMaze g = HexWaterMaze(6) location = g.grid.cart_coords[g.platform_state] my_location = (0, 0) my_orientation = 60 angle = angle_to_landmark(my_location, location, my_orientation) LC = LandmarkCells() xs = np.linspace(-180, 180, 1000) responses = np.zeros((len(xs), LC.n_cells)) for i, x in enumerate(xs): responses[i, :] = LC.compute_response(i) for col in responses.T: plt.plot(xs, col) plt.show()
'Control - trial 4']) if not op.exists(figure_location): os.makedirs(figure_location) plt.savefig(os.path.join(figure_location, 'pearce_escapetime_firstlast'), format='pdf') plt.show() plt.close() if __name__ == '__main__': from hippocampus.environments import HexWaterMaze # create_summary_file(data_directories['lesion']) plot_escape_time() maze = HexWaterMaze(10) # pick example agent for group in ['control', 'lesion']: agents = {'control': 89, 'lesion': 12} session = 6 df = pd.read_csv(op.join(data_directories[group], 'agent{}.csv'.format(agents[group]))) s6t0 = df[np.logical_and(df.trial == 0, df.session == session)] previous_platform = df[df.session == session - 1].platform.iloc[0] s6t0['previous platform'] = previous_platform current_platform = s6t0.platform.iloc[0] maze.plot_occupancy_on_grid(s6t0, alpha=1., show_state_idx=False)
figure_folder = os.path.join(results_folder, 'figures') if not os.path.exists(results_folder): os.makedirs(results_folder) os.makedirs(figure_folder) params = pd.DataFrame({ 'n_agents': [n_agents], 'inv_temp': [inv_temp], 'gamma': [gamma], 'lesion HPC': [lesion_hippocampus], 'lesion DLS': [lesion_striatum] }) params.to_csv(os.path.join(results_folder, 'params.csv')) # initialise environment g = HexWaterMaze(10) # determine platform sequence possible_platform_states = np.array([192, 185, 181, 174, 216, 210, 203, 197]) # for the r = 10 case def determine_platform_seq(platform_states): indices = np.arange(len(platform_states)) usage = np.zeros(len(platform_states)) plat_seq = [np.random.choice(platform_states)] for sess in range(1, 11): distances = np.array( [g.grid.distance(plat_seq[sess - 1], s) for s in platform_states]) candidates = indices[np.logical_and(usage < 2,
import matplotlib.pyplot as plt import pandas as pd import numpy as np from hippocampus.agents import LandmarkLearningAgent, CombinedAgent from hippocampus.environments import HexWaterMaze if __name__ == '__main__': from tqdm import tqdm from hippocampus.plotting import tsplot_boot g = HexWaterMaze(5) g.set_platform_state(30) all_ets = [] for ag in tqdm(range(5)): agent = LandmarkLearningAgent(g) agent_results = [] agent_ets = [] for ep in range(60): res = agent.one_episode() res['trial'] = ep res['escape time'] = res.time.max() agent_results.append(res) agent_ets.append(res.time.max()) all_ets.append(agent_ets) fig, ax = plt.subplots() tsplot_boot(ax, np.array(all_ets)) plt.show()