def update_env(village, new_paths=None, del_paths=None, new_rwds=None): """Update environment with new paths and/or reward values.""" # Add shortcuts. if new_paths is not None: for s1, s2 in new_paths: village.add_path(s1, s2) # Add obstacles. if del_paths is not None: for s1, s2 in del_paths: village.remove_path(s1, s2) # Revalue some states. r_orig = utils.get_copy(village.R) if new_rwds is not None: for loc, rwd in new_rwds.items(): village.change_reward(loc, rwd) # Get file prefix for saving results. f_pref = ('learned_env' if new_paths is None and del_paths is None else 'shortcuts' if new_paths is not None and del_paths is None else 'obstacles' if new_paths is None and del_paths is not None else 'shortcuts_obstacles') f_pref += '__' + ('learned_rwd' if new_rwds is None else 'revalued') return r_orig, f_pref
def get_rec_stats(village, u, s, r, stim, GS, HC, hc_ro, vS, dlS, data, co_occ_pars, idx=None): """Record trial data and model state after a step during navigation.""" res = {} if 'anim_data' in data: # Externally observable variables: action, state, reward, x-y position. x, y = village.animal_coords() res['anim_data'] = {'u': u, 's': s, 'r': r, 'x': x, 'y': y} if 'stim_data' in data: # Stimuli: motor feedback and visual input, both noise-free and noisy. res['stim_data'] = stim if 'gs_state' in data: # Instantaneous GS activity. res['gs_state'] = GS.P.copy() if 'hc_state' in data: # Instantaneous HC state estimate (full). res['hc_state'] = utils.get_copy(HC.s) if 'hc_ro' in data: # Instantaneous HC state estimate (read-out only). res['hc_ro'] = hc_ro.copy() if 'vs_state' in data: # Full vS reward estimate across all states. res['vs_state'] = vS.r.copy() if 'dls_state' in data: # Full dlS habit matrix across all states and actions. res['dls_state'] = dlS.Q.copy() if 'co_occs' in data: # Real location - HC state co-occurance frequency in last n steps. co_occ_pars['idx'] = update_co_occ_mat(hc_ro, village.S.index(s), **co_occ_pars) res['co_occs'] = co_occ_pars['co_occs'].copy() if 'gs_hc_pos' in data: # Mean position of GS - HC connectivity, per HC unit. C, xv, yv, circ = HC.GS_HC, GS.xvec, GS.yvec, GS.circular res['gs_hc_pos'] = analysis.GS_HC_conn_mean(C, xv, yv, circ) if 'gs_hc_max' in data: # Max value of GS - HC connectivity, per HC unit (proxy of spread). # Entropy calculation takes a lot of time --> approximated by max. # gs_hc_h = utils.entropy(HC.GS_HC.reshape((len(HC.s_names), -1)).T) res['gs_hc_max'] = HC.GS_HC.max(axis=(1, 2)) # Put collected results into data object. for k, v in res.items(): if idx is not None: data[k][idx] = v else: data[k] = v return
def store_activity(VC, GS, HC): """Store current VC, GS and HC activity.""" vc_v = VC.v.copy() gs_p = GS.P.copy() hc_s = utils.get_copy(HC.s) a_model = dict(vc_v=vc_v, gs_p=gs_p, hc_s=hc_s) return a_model
def restore_activity(VC, GS, HC, vc_v, gs_p, hc_s): """Restore VC, GS and HC activity.""" VC.v = vc_v.copy() GS.P = gs_p.copy() HC.s = utils.get_copy(hc_s)
def IBP_vs_HAS(village, mod_par_kws, ibp_pars, lrn_pars, step_stats, trl_stats, new_paths=None, del_paths=None, new_rwds=None, goals=None, starts=None, lrn_conns=None, lrn_rwds=None, n_reset=100, max_steps=100, nav_types=None, fdir_dm=None, restore_vill=True): """ Run a bunch of trials from given starting locations using either IBP or HAS to navigate, collect and plot results. Alteration of environment (adding shortcuts or obstacles) is possible. """ # Create model using original (learned) environment. mod_kws = setup.init_model(village, **mod_par_kws) # Update environment with new paths and rewards. r_orig, fpref = setup.update_env(village, new_paths, del_paths, new_rwds) # Set up learning params. lrn_kws = setup.init_lrn_kws(mod_par_kws, new_paths, del_paths, new_rwds, lrn_conns, lrn_rwds, lrn_pars) # Params and objects for world - HC state co-occurance calculation. co_pars = None if 'co_occs' in step_stats or 'co_occs' in trl_stats: co_pars = setup.init_co_occ_pars(len(mod_kws['HC'].s_names), len(village.S), nsteps=500) # Types of navigation to use. if nav_types is None: nav_types = ['optimal', 'HAS', 'IBP', 'random'] # Set up goal states and start states. goals, start_locs = setup.init_locs(village, n_reset, goals, starts, goal_type='max') # Report config. process.report_sim_setup(fpref, lrn_kws, goals) dres, dmodels, dcoocs = {}, {}, {} for nav in nav_types: print(nav + '\n') model_kws = utils.get_copy(mod_kws) co_occ_pars = utils.get_copy(co_pars) nav_kws = dict(village=village, stim_pars=mod_par_kws['stim_pars'], conn_pars=mod_par_kws['conn_pars'], nav=nav, max_steps=max_steps, goals=goals) trial_kws = dict(step_stats=step_stats, trl_stats=trl_stats, co_occ_pars=co_occ_pars, ibp_pars=ibp_pars) trial_kws.update(utils.merge([model_kws, nav_kws, lrn_kws])) dres[nav] = run_trial_batch(start_locs, **trial_kws) dmodels[nav] = model_kws dcoocs[nav] = co_occ_pars['co_occs'] # Format results. res = pd.concat({nav: pd.DataFrame(dres[nav]).T for nav in dres}) res = res.apply(pd.to_numeric, errors='ignore') res = utils.create_long_DF(res, ['nav type', 'trial']) # Plot results. fdir_dm += fpref + '/' ttl = '' # fpref.replace('__', ', ',).replace('_', ' ') plotting.plot_trajs_cooc_habits_by_type(village, res, dmodels, dcoocs, ttl=ttl, fdir=fdir_dm) plotting.plot_rolling_stats_by_type(res, ttl=ttl, fdir=fdir_dm) for per_loc in [True, False]: plotting.plot_learned_navig_by_type(res, village, per_loc=per_loc, ttl=ttl, fdir=fdir_dm) # Save model and recorded results. sim_data = { 'village': village, 'res': res, 'dmodels': dmodels, 'dcoocs': dcoocs, 'mod_kws': mod_kws, 'lrn_kws': lrn_kws, 'ibp_pars': ibp_pars, 'co_pars': co_pars, 'goals': goals, 'start_locs': start_locs, 'new_paths': new_paths, 'del_paths': del_paths, 'new_rwds': new_rwds, 'n_reset': n_reset, 'max_steps': max_steps } utils.write_objects(sim_data, fdir_dm + 'sim_data.pickle') # Restore environment. if restore_vill: setup.restore_env(village, new_paths, del_paths, r_orig) return sim_data
def format_rec_data(tr_data, village, HC, gs_pars, vfeatures, idx_pars=[]): """Format recorded simulation data.""" print('\nFormatting recorded data...') ret_list = [] gs_xvec, gs_yvec = [utils.get_copy(gs_pars[v]) for v in ['xvec', 'yvec']] if 'anim_data' in tr_data: anim_data = pd.DataFrame(tr_data['anim_data']).T ret_list.append(anim_data) if 'stim_data' in tr_data: # Motor input. mot_keys = ['umot', 'vmot'] mot_data = {i: pd.DataFrame({k: d[k] for k in mot_keys}).unstack() for i, d in tr_data['stim_data'].items() if d is not None} mot_data = pd.DataFrame(mot_data).T mot_data.columns.set_levels(['x', 'y'], level=1, inplace=True) ret_list.append(mot_data) # Visual input. vis_keys = ['ovis', 'vvis'] vis_data = {i: pd.DataFrame({k: d[k] for k in vis_keys}).unstack() for i, d in tr_data['stim_data'].items() if d is not None} vis_data = pd.DataFrame(vis_data).T vis_data.columns.set_levels(vfeatures, level=1, inplace=True) ret_list.append(vis_data) if 'gs_state' in tr_data: gs_state = {k: pd.DataFrame(gs, columns=gs_xvec, index=gs_yvec) for k, gs in tr_data['gs_state'].items()} gs_state = pd.concat(gs_state) gs_state.columns.rename('x', inplace=True) gs_state.index.rename('y', level=-1, inplace=True) ret_list.append(gs_state) if 'hc_state' in tr_data: hc_state = {k: pd.DataFrame(hc, index=HC.s_names) for k, hc in tr_data['hc_state'].items()} hc_state = pd.concat(hc_state) hc_state.index.rename('loc', level=-1, inplace=True) hc_state = hc_state[HC.s_types] # reorder columns ret_list.append(hc_state) if 'vs_state' in tr_data: vs_state = {k: pd.Series(vc, index=HC.s_names) for k, vc in tr_data['vs_state'].items()} vs_state = pd.concat(vs_state) ret_list.append(vs_state) if 'dls_state' in tr_data: dls_state = {k: pd.DataFrame(dls, columns=HC.s_names, index=village.U) for k, dls in tr_data['dls_state'].items()} dls_state = pd.concat(dls_state) dls_state.columns.rename('s', inplace=True) dls_state.index.rename('u', level=-1, inplace=True) ret_list.append(dls_state) # Set index level names. idx_lvl_names = idx_pars + ['step'] for df in ret_list: levels = (None if not isinstance(df.index, pd.core.index.MultiIndex) else list(range(len(idx_lvl_names)))) df.index.set_names(idx_lvl_names, level=levels, inplace=True) return ret_list