def run_watermaze(num_agents, params, lesion_hpc=False, **kwargs):
    g = HexWaterMaze(6)

    for i_a in tqdm(range(num_agents), desc='Agent'):

        if lesion_hpc:
            filename = 'spatial_agent{}_lesion'.format(i_a)
        elif 'inact_hpc' in kwargs:
            filename = 'spatial_partial_lesion_agent{}'.format(i_a)
        else:
            filename = 'spatial_agent{}'.format(i_a)

        if os.path.exists(os.path.join(res_dir, filename)):
            tqdm.write('Already done')
            continue

        #possible_platform_states = np.array([192, 185, 181, 174, 216, 210, 203, 197])  # for the r = 10 case
        possible_platform_states = np.array([48, 45, 42, 39, 60, 57, 54, 51])

        platform_sequence = determine_platform_seq(possible_platform_states, g)

        # intialise agent
        agent = SpatialAgent(g,
                             init_sr='rw',
                             A_alpha=params['A_alpha'][i_a],
                             alpha1=params['alpha1'][i_a],
                             A_beta=params['A_beta'][i_a],
                             beta1=params['beta1'][i_a],
                             lesion_hpc=lesion_hpc,
                             inact_hpc=kwargs['inact_hpc'][i_a])
        agent_results = []

        total_trial_count = 0

        for ses in tqdm(range(11), desc='Session', leave=False):
            for trial in tqdm(range(4), leave=False, desc='Trial'):
                # every first trial of a session, change the platform location
                if trial == 0:
                    g.set_platform_state(platform_sequence[ses])

                res = agent.one_episode(random_policy=False)
                res['trial'] = trial
                res['escape time'] = res.time.max()
                res['session'] = ses
                res['total trial'] = total_trial_count
                agent_results.append(res)
                total_trial_count += 1

        agent_df = pd.concat(agent_results)
        agent_df['total time'] = np.arange(len(agent_df))

        agent_df.to_csv(os.path.join(res_dir, filename))
Beispiel #2
0
                          lesion_dls=lesion_striatum,
                          lesion_hpc=lesion_hippocampus,
                          inv_temp=inv_temp,
                          gamma=gamma,
                          learning_rate=learning_rate)
    agent_results = []
    agent_ets = []
    session = 0

    total_trial_count = 0

    for ses in tqdm(range(11)):
        for trial in tqdm(range(4), leave=False):
            # every first trial of a session, change the platform location
            if trial == 0:
                g.set_platform_state(platform_sequence[ses])

            res = agent.one_episode(random_policy=False)
            res['trial'] = trial
            res['escape time'] = res.time.max()
            res['session'] = ses
            res['total trial'] = total_trial_count
            agent_results.append(res)
            agent_ets.append(res.time.max())

            total_trial_count += 1

    agent_df = pd.concat(agent_results)
    agent_df['total time'] = np.arange(len(agent_df))
    agent_df['agent'] = n_agent
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

from hippocampus.agents import LandmarkLearningAgent, CombinedAgent
from hippocampus.environments import HexWaterMaze

if __name__ == '__main__':
    from tqdm import tqdm
    from hippocampus.plotting import tsplot_boot
    g = HexWaterMaze(5)
    g.set_platform_state(30)

    all_ets = []
    for ag in tqdm(range(5)):

        agent = LandmarkLearningAgent(g)

        agent_results = []
        agent_ets = []
        for ep in range(60):
            res = agent.one_episode()
            res['trial'] = ep
            res['escape time'] = res.time.max()
            agent_results.append(res)
            agent_ets.append(res.time.max())
        all_ets.append(agent_ets)

    fig, ax = plt.subplots()
    tsplot_boot(ax, np.array(all_ets))
    plt.show()