def run_watermaze(num_agents, params, lesion_hpc=False, **kwargs):
    g = HexWaterMaze(6)

    for i_a in tqdm(range(num_agents), desc='Agent'):

        if lesion_hpc:
            filename = 'spatial_agent{}_lesion'.format(i_a)
        elif 'inact_hpc' in kwargs:
            filename = 'spatial_partial_lesion_agent{}'.format(i_a)
        else:
            filename = 'spatial_agent{}'.format(i_a)

        if os.path.exists(os.path.join(res_dir, filename)):
            tqdm.write('Already done')
            continue

        #possible_platform_states = np.array([192, 185, 181, 174, 216, 210, 203, 197])  # for the r = 10 case
        possible_platform_states = np.array([48, 45, 42, 39, 60, 57, 54, 51])

        platform_sequence = determine_platform_seq(possible_platform_states, g)

        # intialise agent
        agent = SpatialAgent(g,
                             init_sr='rw',
                             A_alpha=params['A_alpha'][i_a],
                             alpha1=params['alpha1'][i_a],
                             A_beta=params['A_beta'][i_a],
                             beta1=params['beta1'][i_a],
                             lesion_hpc=lesion_hpc,
                             inact_hpc=kwargs['inact_hpc'][i_a])
        agent_results = []

        total_trial_count = 0

        for ses in tqdm(range(11), desc='Session', leave=False):
            for trial in tqdm(range(4), leave=False, desc='Trial'):
                # every first trial of a session, change the platform location
                if trial == 0:
                    g.set_platform_state(platform_sequence[ses])

                res = agent.one_episode(random_policy=False)
                res['trial'] = trial
                res['escape time'] = res.time.max()
                res['session'] = ses
                res['total trial'] = total_trial_count
                agent_results.append(res)
                total_trial_count += 1

        agent_df = pd.concat(agent_results)
        agent_df['total time'] = np.arange(len(agent_df))

        agent_df.to_csv(os.path.join(res_dir, filename))
Beispiel #2
0
import numpy as np
import matplotlib.pyplot as plt

from hippocampus.agents import LandmarkCells
from hippocampus.utils import angle_to_landmark
from hippocampus.environments import HexWaterMaze

g = HexWaterMaze(6)
location = g.grid.cart_coords[g.platform_state]
my_location = (0, 0)
my_orientation = 60
angle = angle_to_landmark(my_location, location, my_orientation)
LC = LandmarkCells()
xs = np.linspace(-180, 180, 1000)
responses = np.zeros((len(xs), LC.n_cells))

for i, x in enumerate(xs):
    responses[i, :] = LC.compute_response(i)

for col in responses.T:
    plt.plot(xs, col)
plt.show()
Beispiel #3
0
figure_folder = os.path.join(results_folder, 'figures')
if not os.path.exists(results_folder):
    os.makedirs(results_folder)
    os.makedirs(figure_folder)

params = pd.DataFrame({
    'n_agents': [n_agents],
    'inv_temp': [inv_temp],
    'gamma': [gamma],
    'lesion HPC': [lesion_hippocampus],
    'lesion DLS': [lesion_striatum]
})
params.to_csv(os.path.join(results_folder, 'params.csv'))

# initialise environment
g = HexWaterMaze(10)

# determine platform sequence
possible_platform_states = np.array([192, 185, 181, 174, 216, 210, 203,
                                     197])  # for the r = 10 case


def determine_platform_seq(platform_states):
    indices = np.arange(len(platform_states))
    usage = np.zeros(len(platform_states))

    plat_seq = [np.random.choice(platform_states)]
    for sess in range(1, 11):
        distances = np.array(
            [g.grid.distance(plat_seq[sess - 1], s) for s in platform_states])
        candidates = indices[np.logical_and(usage < 2,
Beispiel #4
0
                'Control - trial 4'])

    if not op.exists(figure_location):
        os.makedirs(figure_location)
    plt.savefig(os.path.join(figure_location, 'pearce_escapetime_firstlast'), format='pdf')
    plt.show()
    plt.close()


if __name__ == '__main__':
    from hippocampus.environments import HexWaterMaze

    # create_summary_file(data_directories['lesion'])
    plot_escape_time()

    maze = HexWaterMaze(10)
    # pick example agent

    for group in ['control', 'lesion']:

        agents = {'control': 89, 'lesion': 12}

        session = 6

        df = pd.read_csv(op.join(data_directories[group], 'agent{}.csv'.format(agents[group])))

        s6t0 = df[np.logical_and(df.trial == 0, df.session == session)]
        previous_platform = df[df.session == session - 1].platform.iloc[0]
        s6t0['previous platform'] = previous_platform
        current_platform = s6t0.platform.iloc[0]
        maze.plot_occupancy_on_grid(s6t0, alpha=1., show_state_idx=False)
inv_temp = 5.
gamma = .99
lesion_hippocampus = True
lesion_striatum = False

params = pd.DataFrame({
    'n_agents': [n_agents],
    'inv_temp': [inv_temp],
    'gamma': [gamma],
    'lesion HPC': [lesion_hippocampus],
    'lesion DLS': [lesion_striatum]
})
params.to_csv(os.path.join(results_folder, 'params.csv'))

# initialise environment
g = HexWaterMaze(10)

# determine platform sequence
possible_platform_states = np.array([192, 185, 181, 174, 216, 210, 203,
                                     197])  # for the r = 10 case


def determine_platform_seq(platform_states):
    indices = np.arange(len(platform_states))
    usage = np.zeros(len(platform_states))

    plat_seq = [np.random.choice(platform_states)]
    for sess in range(1, 11):
        distances = np.array(
            [g.grid.distance(plat_seq[sess - 1], s) for s in platform_states])
        candidates = indices[np.logical_and(usage < 2,
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

from hippocampus.agents import LandmarkLearningAgent, CombinedAgent
from hippocampus.environments import HexWaterMaze

if __name__ == '__main__':
    from tqdm import tqdm
    from hippocampus.plotting import tsplot_boot
    g = HexWaterMaze(5)
    g.set_platform_state(30)

    all_ets = []
    for ag in tqdm(range(5)):

        agent = LandmarkLearningAgent(g)

        agent_results = []
        agent_ets = []
        for ep in range(60):
            res = agent.one_episode()
            res['trial'] = ep
            res['escape time'] = res.time.max()
            agent_results.append(res)
            agent_ets.append(res.time.max())
        all_ets.append(agent_ets)

    fig, ax = plt.subplots()
    tsplot_boot(ax, np.array(all_ets))
    plt.show()