Class for simulating input networks from Akam et al Neuron 2010. Network can be set to either asychronous, gamma oscillating, or beta oscillating state when it is initialised. The input networks were simulated in Nest simulator for the original paper and when I built them in Brian they showed substanitally lower firing rates for reasons that I do not fully understand but which may be due to the connection function working somewhat differently. I have adjusted a couple of synaptic weights (where indicated) to bring the firing rates back to aproximately where they were in the Nest implementation. The resulting networks behave very similarly to those in the paper but have slightly lower oscillation frequencies. # Copyright (c) Thomas Akam 2012. Licenced under the GNU General Public License v3. ''' import os import time import pylab as plt import numpy as np from numpy.random import randn import brian2 as b2 from brian2.equations.equations import Equations import logging logger = logging.getLogger('ftpuploader') b2.seed(1) np.random.seed(1) def simulate(): common_params = { # Parameters common to all neurons. 'C': 100 * b2.pfarad, 'tau_m': 10 * b2.ms, 'EL': -60 * b2.mV, 'DeltaT': 2 * b2.mV, 'Vreset': -65, # *b2.mV 'VTmean': -50 * b2.mV, 'VTsd': 2 * b2.mV }
import os import time import pylab as plt import numpy as np from numpy.random import randn import brian2 as b2 from brian2.equations.equations import Equations import logging # logger = logging.getLogger('ftpuploader') ''' connect three inhibitory exponential LIF cells in feed forward loop with alpha shape conductance base synapse edges : ([0,1], [0,2], [1,2]) ''' b2.seed(2) np.random.seed(2) def simulate(): common_params = { # Parameters common to all neurons. 'C': 100 * b2.pfarad, 'tau_m': 10 * b2.ms, 'EL': -60 * b2.mV, 'DeltaT': 2 * b2.mV, 'Vreset': -65, # *b2.mV 'VTmean': -50 * b2.mV, 'VTsd': 2 * b2.mV } common_params['gL'] = common_params['C'] / common_params['tau_m']
import brian2 as bs from matplotlib import pyplot as plt import numpy as np from bayesian_test.utils import visualise_connectivity w_input_to_place = 26.0 * bs.mV TAU_M = 17.0 * bs.ms TAU_S = 5.0 * bs.ms threshold = -55.0 * bs.mV v_rest = -80.0 * bs.mV v_reset = -80.0 * bs.mV e = 2.71828182846 bs.seed(seed=19971124) dt = 0.1 lif = """ dv/dt = (v_rest - v)/tau_m : volt (unless refractory) tau_m : second """ synaptic_update = """ v_post += w_input_to_place """ threshold_eq = "v > (1+1/e)*(w_input_to_place + v_rest) + w_input_to_place" reset_eq = "v = v_reset" SIM_TIME = 5
def run_net(tr): # prefs.codegen.target = 'numpy' # prefs.codegen.target = 'cython' if tr.n_threads > 1: prefs.devices.cpp_standalone.openmp_threads = tr.n_threads set_device('cpp_standalone', directory='./builds/%.4d' % (tr.v_idx), build_on_run=False) # set brian 2 and numpy random seeds seed(tr.random_seed) np.random.seed(tr.random_seed + 11) print("Started process with id ", str(tr.v_idx)) T = tr.T1 + tr.T2 + tr.T3 + tr.T4 + tr.T5 namespace = tr.netw.f_to_dict(short_names=True, fast_access=True) namespace['idx'] = tr.v_idx defaultclock.dt = tr.netw.sim.dt # collect all network components dependent on configuration # (e.g. poisson vs. memnoise) and add them to the Brian 2 # network object later netw_objects = [] if tr.external_mode == 'memnoise': neuron_model = tr.condlif_memnoise elif tr.external_mode == 'poisson': raise NotImplementedError #neuron_model = tr.condlif_poisson if tr.syn_cond_mode == 'exp': neuron_model += tr.syn_cond_EE_exp print("Using EE exp mode") elif tr.syn_cond_mode == 'alpha': neuron_model += tr.syn_cond_EE_alpha print("Using EE alpha mode") elif tr.syn_cond_mode == 'biexp': neuron_model += tr.syn_cond_EE_biexp namespace['invpeakEE'] = (tr.tau_e / tr.tau_e_rise) ** \ (tr.tau_e_rise / (tr.tau_e - tr.tau_e_rise)) print("Using EE biexp mode") if tr.syn_cond_mode_EI == 'exp': neuron_model += tr.syn_cond_EI_exp print("Using EI exp mode") elif tr.syn_cond_mode_EI == 'alpha': neuron_model += tr.syn_cond_EI_alpha print("Using EI alpha mode") elif tr.syn_cond_mode_EI == 'biexp': neuron_model += tr.syn_cond_EI_biexp namespace['invpeakEI'] = (tr.tau_i / tr.tau_i_rise) ** \ (tr.tau_i_rise / (tr.tau_i - tr.tau_i_rise)) print("Using EI biexp mode") GExc = NeuronGroup( N=tr.N_e, model=neuron_model, threshold=tr.nrnEE_thrshld, reset=tr.nrnEE_reset, #method=tr.neuron_method, name='GExc', namespace=namespace) GInh = NeuronGroup( N=tr.N_i, model=neuron_model, threshold='V > Vt', reset='V=Vr_i', #method=tr.neuron_method, name='GInh', namespace=namespace) if tr.external_mode == 'memnoise': # GExc.mu, GInh.mu = [0.*mV] + (tr.N_e-1)*[tr.mu_e], tr.mu_i # GExc.sigma, GInh.sigma = [0.*mV] + (tr.N_e-1)*[tr.sigma_e], tr.sigma_i GExc.mu, GInh.mu = tr.mu_e, tr.mu_i GExc.sigma, GInh.sigma = tr.sigma_e, tr.sigma_i GExc.Vt, GInh.Vt = tr.Vt_e, tr.Vt_i GExc.V , GInh.V = np.random.uniform(tr.Vr_e/mV, tr.Vt_e/mV, size=tr.N_e)*mV, \ np.random.uniform(tr.Vr_i/mV, tr.Vt_i/mV, size=tr.N_i)*mV netw_objects.extend([GExc, GInh]) if tr.external_mode == 'poisson': if tr.PInp_mode == 'pool': PInp = PoissonGroup(tr.NPInp, rates=tr.PInp_rate, namespace=namespace, name='poissongroup_exc') sPN = Synapses(target=GExc, source=PInp, model=tr.poisson_mod, on_pre='gfwd_post += a_EPoi', namespace=namespace, name='synPInpExc') sPN_src, sPN_tar = generate_N_connections(N_tar=tr.N_e, N_src=tr.NPInp, N=tr.NPInp_1n) elif tr.PInp_mode == 'indep': PInp = PoissonGroup(tr.N_e, rates=tr.PInp_rate, namespace=namespace) sPN = Synapses(target=GExc, source=PInp, model=tr.poisson_mod, on_pre='gfwd_post += a_EPoi', namespace=namespace, name='synPInp_inhInh') sPN_src, sPN_tar = range(tr.N_e), range(tr.N_e) sPN.connect(i=sPN_src, j=sPN_tar) if tr.PInp_mode == 'pool': PInp_inh = PoissonGroup(tr.NPInp_inh, rates=tr.PInp_inh_rate, namespace=namespace, name='poissongroup_inh') sPNInh = Synapses(target=GInh, source=PInp_inh, model=tr.poisson_mod, on_pre='gfwd_post += a_EPoi', namespace=namespace) sPNInh_src, sPNInh_tar = generate_N_connections(N_tar=tr.N_i, N_src=tr.NPInp_inh, N=tr.NPInp_inh_1n) elif tr.PInp_mode == 'indep': PInp_inh = PoissonGroup(tr.N_i, rates=tr.PInp_inh_rate, namespace=namespace) sPNInh = Synapses(target=GInh, source=PInp_inh, model=tr.poisson_mod, on_pre='gfwd_post += a_EPoi', namespace=namespace) sPNInh_src, sPNInh_tar = range(tr.N_i), range(tr.N_i) sPNInh.connect(i=sPNInh_src, j=sPNInh_tar) netw_objects.extend([PInp, sPN, PInp_inh, sPNInh]) if tr.syn_noise: if tr.syn_noise_type == 'additive': synEE_mod = '''%s %s''' % (tr.synEE_noise_add, tr.synEE_mod) synEI_mod = '''%s %s''' % (tr.synEE_noise_add, tr.synEE_mod) elif tr.syn_noise_type == 'multiplicative': synEE_mod = '''%s %s''' % (tr.synEE_noise_mult, tr.synEE_mod) synEI_mod = '''%s %s''' % (tr.synEE_noise_mult, tr.synEE_mod) else: synEE_mod = '''%s %s''' % (tr.synEE_static, tr.synEE_mod) synEI_mod = '''%s %s''' % (tr.synEE_static, tr.synEE_mod) if tr.scl_active: synEE_mod = '''%s %s''' % (synEE_mod, tr.synEE_scl_mod) synEI_mod = '''%s %s''' % (synEI_mod, tr.synEI_scl_mod) if tr.syn_cond_mode == 'exp': synEE_pre_mod = mod.synEE_pre_exp elif tr.syn_cond_mode == 'alpha': synEE_pre_mod = mod.synEE_pre_alpha elif tr.syn_cond_mode == 'biexp': synEE_pre_mod = mod.synEE_pre_biexp synEE_post_mod = mod.syn_post if tr.stdp_active: synEE_pre_mod = '''%s %s''' % (synEE_pre_mod, mod.syn_pre_STDP) synEE_post_mod = '''%s %s''' % (synEE_post_mod, mod.syn_post_STDP) if tr.synEE_rec: synEE_pre_mod = '''%s %s''' % (synEE_pre_mod, mod.synEE_pre_rec) synEE_post_mod = '''%s %s''' % (synEE_post_mod, mod.synEE_post_rec) # E<-E advanced synapse model SynEE = Synapses(target=GExc, source=GExc, model=synEE_mod, on_pre=synEE_pre_mod, on_post=synEE_post_mod, namespace=namespace, dt=tr.synEE_mod_dt) if tr.istdp_active and tr.istdp_type == 'dbexp': if tr.syn_cond_mode_EI == 'exp': EI_pre_mod = mod.synEI_pre_exp elif tr.syn_cond_mode_EI == 'alpha': EI_pre_mod = mod.synEI_pre_alpha elif tr.syn_cond_mode_EI == 'biexp': EI_pre_mod = mod.synEI_pre_biexp synEI_pre_mod = '''%s %s''' % (EI_pre_mod, mod.syn_pre_STDP) synEI_post_mod = '''%s %s''' % (mod.syn_post, mod.syn_post_STDP) elif tr.istdp_active and tr.istdp_type == 'sym': if tr.syn_cond_mode_EI == 'exp': EI_pre_mod = mod.synEI_pre_sym_exp elif tr.syn_cond_mode_EI == 'alpha': EI_pre_mod = mod.synEI_pre_sym_alpha elif tr.syn_cond_mode_EI == 'biexp': EI_pre_mod = mod.synEI_pre_sym_biexp synEI_pre_mod = '''%s %s''' % (EI_pre_mod, mod.syn_pre_STDP) synEI_post_mod = '''%s %s''' % (mod.synEI_post_sym, mod.syn_post_STDP) if tr.istdp_active and tr.synEI_rec: synEI_pre_mod = '''%s %s''' % (synEI_pre_mod, mod.synEI_pre_rec) synEI_post_mod = '''%s %s''' % (synEI_post_mod, mod.synEI_post_rec) if tr.istdp_active: SynEI = Synapses(target=GExc, source=GInh, model=synEI_mod, on_pre=synEI_pre_mod, on_post=synEI_post_mod, namespace=namespace, dt=tr.synEE_mod_dt) else: model = '''a : 1 syn_active : 1''' SynEI = Synapses(target=GExc, source=GInh, model=model, on_pre='gi_post += a', namespace=namespace) #other simple SynIE = Synapses(target=GInh, source=GExc, on_pre='ge_post += a_ie', namespace=namespace) SynII = Synapses(target=GInh, source=GInh, on_pre='gi_post += a_ii', namespace=namespace) sEE_src, sEE_tar = generate_full_connectivity(tr.N_e, same=True) SynEE.connect(i=sEE_src, j=sEE_tar) SynEE.syn_active = 0 SynEE.taupre, SynEE.taupost = tr.taupre, tr.taupost if tr.istdp_active and tr.istrct_active: print('istrct active') sEI_src, sEI_tar = generate_full_connectivity(Nsrc=tr.N_i, Ntar=tr.N_e, same=False) SynEI.connect(i=sEI_src, j=sEI_tar) SynEI.syn_active = 0 else: print('istrct not active') if tr.weight_mode == 'init': sEI_src, sEI_tar = generate_connections(tr.N_e, tr.N_i, tr.p_ei) # print('Index Zero will not get inhibition') # sEI_src, sEI_tar = np.array(sEI_src), np.array(sEI_tar) # sEI_src, sEI_tar = sEI_src[sEI_tar > 0],sEI_tar[sEI_tar > 0] elif tr.weight_mode == 'load': fpath = os.path.join(tr.basepath, tr.weight_path) with open(fpath + 'synei_a.p', 'rb') as pfile: synei_a_init = pickle.load(pfile) sEI_src, sEI_tar = synei_a_init['i'], synei_a_init['j'] SynEI.connect(i=sEI_src, j=sEI_tar) if tr.istdp_active: SynEI.taupre, SynEI.taupost = tr.taupre_EI, tr.taupost_EI sIE_src, sIE_tar = generate_connections(tr.N_i, tr.N_e, tr.p_ie) sII_src, sII_tar = generate_connections(tr.N_i, tr.N_i, tr.p_ii, same=True) SynIE.connect(i=sIE_src, j=sIE_tar) SynII.connect(i=sII_src, j=sII_tar) tr.f_add_result('sEE_src', sEE_src) tr.f_add_result('sEE_tar', sEE_tar) tr.f_add_result('sIE_src', sIE_src) tr.f_add_result('sIE_tar', sIE_tar) tr.f_add_result('sEI_src', sEI_src) tr.f_add_result('sEI_tar', sEI_tar) tr.f_add_result('sII_src', sII_src) tr.f_add_result('sII_tar', sII_tar) if tr.syn_noise: SynEE.syn_sigma = tr.syn_sigma SynEE.run_regularly('a = clip(a,0,amax)', when='after_groups', name='SynEE_noise_clipper') if tr.syn_noise and tr.istdp_active: SynEI.syn_sigma = tr.syn_sigma SynEI.run_regularly('a = clip(a,0,amax)', when='after_groups', name='SynEI_noise_clipper') SynEE.insert_P = tr.insert_P SynEE.p_inactivate = tr.p_inactivate SynEE.stdp_active = 1 print('Setting maximum EE weight threshold to ', tr.amax) SynEE.amax = tr.amax if tr.istdp_active: SynEI.insert_P = tr.insert_P_ei SynEI.p_inactivate = tr.p_inactivate_ei SynEI.stdp_active = 1 SynEI.amax = tr.amax SynEE.syn_active, SynEE.a = init_synapses('EE', tr) SynEI.syn_active, SynEI.a = init_synapses('EI', tr) # recording of stdp in T4 SynEE.stdp_rec_start = tr.T1 + tr.T2 + tr.T3 SynEE.stdp_rec_max = tr.T1 + tr.T2 + tr.T3 + tr.stdp_rec_T if tr.istdp_active: SynEI.stdp_rec_start = tr.T1 + tr.T2 + tr.T3 SynEI.stdp_rec_max = tr.T1 + tr.T2 + tr.T3 + tr.stdp_rec_T # synaptic scaling if tr.netw.config.scl_active: if tr.syn_scl_rec: SynEE.scl_rec_start = tr.T1 + tr.T2 + tr.T3 SynEE.scl_rec_max = tr.T1 + tr.T2 + tr.T3 + tr.scl_rec_T else: SynEE.scl_rec_start = T + 10 * second SynEE.scl_rec_max = T if tr.sig_ATotalMax == 0.: GExc.ANormTar = tr.ATotalMax else: GExc.ANormTar = np.random.normal(loc=tr.ATotalMax, scale=tr.sig_ATotalMax, size=tr.N_e) SynEE.summed_updaters['AsumEE_post']._clock = Clock( dt=tr.dt_synEE_scaling) synee_scaling = SynEE.run_regularly(tr.synEE_scaling, dt=tr.dt_synEE_scaling, when='end', name='synEE_scaling') if tr.istdp_active and tr.netw.config.iscl_active: if tr.syn_iscl_rec: SynEI.scl_rec_start = tr.T1 + tr.T2 + tr.T3 SynEI.scl_rec_max = tr.T1 + tr.T2 + tr.T3 + tr.scl_rec_T else: SynEI.scl_rec_start = T + 10 * second SynEI.scl_rec_max = T if tr.sig_iATotalMax == 0.: GExc.iANormTar = tr.iATotalMax else: GExc.iANormTar = np.random.normal(loc=tr.iATotalMax, scale=tr.sig_iATotalMax, size=tr.N_e) SynEI.summed_updaters['AsumEI_post']._clock = Clock( dt=tr.dt_synEE_scaling) synei_scaling = SynEI.run_regularly(tr.synEI_scaling, dt=tr.dt_synEE_scaling, when='end', name='synEI_scaling') # # intrinsic plasticity # if tr.netw.config.it_active: # GExc.h_ip = tr.h_ip # GExc.run_regularly(tr.intrinsic_mod, dt = tr.it_dt, when='end') # structural plasticity if tr.netw.config.strct_active: if tr.strct_mode == 'zero': if tr.turnover_rec: strct_mod = '''%s %s''' % (tr.strct_mod, tr.turnover_rec_mod) else: strct_mod = tr.strct_mod strctplst = SynEE.run_regularly(strct_mod, dt=tr.strct_dt, when='end', name='strct_plst_zero') elif tr.strct_mode == 'thrs': if tr.turnover_rec: strct_mod_thrs = '''%s %s''' % (tr.strct_mod_thrs, tr.turnover_rec_mod) else: strct_mod_thrs = tr.strct_mod_thrs strctplst = SynEE.run_regularly(strct_mod_thrs, dt=tr.strct_dt, when='end', name='strct_plst_thrs') if tr.istdp_active and tr.netw.config.istrct_active: if tr.strct_mode == 'zero': if tr.turnover_rec: strct_mod_EI = '''%s %s''' % (tr.strct_mod, tr.turnoverEI_rec_mod) else: strct_mod_EI = tr.strct_mod strctplst_EI = SynEI.run_regularly(strct_mod_EI, dt=tr.strct_dt, when='end', name='strct_plst_EI') elif tr.strct_mode == 'thrs': raise NotImplementedError netw_objects.extend([SynEE, SynEI, SynIE, SynII]) # keep track of the number of active synapses sum_target = NeuronGroup(1, 'c : 1 (shared)', dt=tr.csample_dt) sum_model = '''NSyn : 1 (constant) c_post = (1.0*syn_active_pre)/NSyn : 1 (summed)''' sum_connection = Synapses(target=sum_target, source=SynEE, model=sum_model, dt=tr.csample_dt, name='get_active_synapse_count') sum_connection.connect() sum_connection.NSyn = tr.N_e * (tr.N_e - 1) if tr.adjust_insertP: # homeostatically adjust growth rate growth_updater = Synapses(sum_target, SynEE) growth_updater.run_regularly('insert_P_post *= 0.1/c_pre', when='after_groups', dt=tr.csample_dt, name='update_insP') growth_updater.connect(j='0') netw_objects.extend([sum_target, sum_connection, growth_updater]) if tr.istdp_active and tr.istrct_active: # keep track of the number of active synapses sum_target_EI = NeuronGroup(1, 'c : 1 (shared)', dt=tr.csample_dt) sum_model_EI = '''NSyn : 1 (constant) c_post = (1.0*syn_active_pre)/NSyn : 1 (summed)''' sum_connection_EI = Synapses(target=sum_target_EI, source=SynEI, model=sum_model_EI, dt=tr.csample_dt, name='get_active_synapse_count_EI') sum_connection_EI.connect() sum_connection_EI.NSyn = tr.N_e * tr.N_i if tr.adjust_EI_insertP: # homeostatically adjust growth rate growth_updater_EI = Synapses(sum_target_EI, SynEI) growth_updater_EI.run_regularly('insert_P_post *= 0.1/c_pre', when='after_groups', dt=tr.csample_dt, name='update_insP_EI') growth_updater_EI.connect(j='0') netw_objects.extend( [sum_target_EI, sum_connection_EI, growth_updater_EI]) # -------------- recording ------------------ GExc_recvars = [] if tr.memtraces_rec: GExc_recvars.append('V') if tr.vttraces_rec: GExc_recvars.append('Vt') if tr.getraces_rec: GExc_recvars.append('ge') if tr.gitraces_rec: GExc_recvars.append('gi') if tr.gfwdtraces_rec and tr.external_mode == 'poisson': GExc_recvars.append('gfwd') GInh_recvars = GExc_recvars GExc_stat = StateMonitor(GExc, GExc_recvars, record=list(range(tr.nrec_GExc_stat)), dt=tr.GExc_stat_dt) GInh_stat = StateMonitor(GInh, GInh_recvars, record=list(range(tr.nrec_GInh_stat)), dt=tr.GInh_stat_dt) # SynEE stat SynEE_recvars = [] if tr.synee_atraces_rec: SynEE_recvars.append('a') if tr.synee_activetraces_rec: SynEE_recvars.append('syn_active') if tr.synee_Apretraces_rec: SynEE_recvars.append('Apre') if tr.synee_Aposttraces_rec: SynEE_recvars.append('Apost') SynEE_stat = StateMonitor(SynEE, SynEE_recvars, record=range(tr.n_synee_traces_rec), when='end', dt=tr.synEE_stat_dt) if tr.istdp_active: # SynEI stat SynEI_recvars = [] if tr.synei_atraces_rec: SynEI_recvars.append('a') if tr.synei_activetraces_rec: SynEI_recvars.append('syn_active') if tr.synei_Apretraces_rec: SynEI_recvars.append('Apre') if tr.synei_Aposttraces_rec: SynEI_recvars.append('Apost') SynEI_stat = StateMonitor(SynEI, SynEI_recvars, record=range(tr.n_synei_traces_rec), when='end', dt=tr.synEI_stat_dt) netw_objects.append(SynEI_stat) if tr.adjust_insertP: C_stat = StateMonitor(sum_target, 'c', dt=tr.csample_dt, record=[0], when='end') insP_stat = StateMonitor(SynEE, 'insert_P', dt=tr.csample_dt, record=[0], when='end') netw_objects.extend([C_stat, insP_stat]) if tr.istdp_active and tr.adjust_EI_insertP: C_EI_stat = StateMonitor(sum_target_EI, 'c', dt=tr.csample_dt, record=[0], when='end') insP_EI_stat = StateMonitor(SynEI, 'insert_P', dt=tr.csample_dt, record=[0], when='end') netw_objects.extend([C_EI_stat, insP_EI_stat]) GExc_spks = SpikeMonitor(GExc) GInh_spks = SpikeMonitor(GInh) GExc_rate = PopulationRateMonitor(GExc) GInh_rate = PopulationRateMonitor(GInh) if tr.external_mode == 'poisson': PInp_spks = SpikeMonitor(PInp) PInp_rate = PopulationRateMonitor(PInp) netw_objects.extend([PInp_spks, PInp_rate]) if tr.synee_a_nrecpoints == 0 or tr.sim.T2 == 0 * second: SynEE_a_dt = 2 * (tr.T1 + tr.T2 + tr.T3 + tr.T4 + tr.T5) else: SynEE_a_dt = tr.sim.T2 / tr.synee_a_nrecpoints # make sure that choice of SynEE_a_dt does lead # to execessively many recordings - this can # happen if t1 >> t2. estm_nrecs = int(T / SynEE_a_dt) if estm_nrecs > 3 * tr.synee_a_nrecpoints: print('''Estimated number of EE weight recordings (%d) exceeds desired number (%d), increasing SynEE_a_dt''' % (estm_nrecs, tr.synee_a_nrecpoints)) SynEE_a_dt = T / tr.synee_a_nrecpoints SynEE_a = StateMonitor(SynEE, ['a', 'syn_active'], record=range(tr.N_e * (tr.N_e - 1)), dt=SynEE_a_dt, when='end', order=100) if tr.istrct_active: record_range = range(tr.N_e * tr.N_i) else: record_range = range(len(sEI_src)) if tr.synei_a_nrecpoints > 0 and tr.sim.T2 > 0 * second: SynEI_a_dt = tr.sim.T2 / tr.synei_a_nrecpoints estm_nrecs = int(T / SynEI_a_dt) if estm_nrecs > 3 * tr.synei_a_nrecpoints: print('''Estimated number of EI weight recordings (%d) exceeds desired number (%d), increasing SynEI_a_dt''' % (estm_nrecs, tr.synei_a_nrecpoints)) SynEI_a_dt = T / tr.synei_a_nrecpoints SynEI_a = StateMonitor(SynEI, ['a', 'syn_active'], record=record_range, dt=SynEI_a_dt, when='end', order=100) netw_objects.append(SynEI_a) netw_objects.extend([ GExc_stat, GInh_stat, SynEE_stat, SynEE_a, GExc_spks, GInh_spks, GExc_rate, GInh_rate ]) if (tr.synEEdynrec and (2 * tr.syndynrec_npts * tr.syndynrec_dt < tr.sim.T2)): SynEE_dynrec = StateMonitor(SynEE, ['a'], record=range(tr.N_e * (tr.N_e - 1)), dt=tr.syndynrec_dt, name='SynEE_dynrec', when='end', order=100) SynEE_dynrec.active = False netw_objects.extend([SynEE_dynrec]) if (tr.synEIdynrec and (2 * tr.syndynrec_npts * tr.syndynrec_dt < tr.sim.T2)): SynEI_dynrec = StateMonitor(SynEI, ['a'], record=record_range, dt=tr.syndynrec_dt, name='SynEI_dynrec', when='end', order=100) SynEI_dynrec.active = False netw_objects.extend([SynEI_dynrec]) net = Network(*netw_objects) def set_active(*argv): for net_object in argv: net_object.active = True def set_inactive(*argv): for net_object in argv: net_object.active = False ### Simulation periods # --------- T1 --------- # initial recording period, # all recorders active T1T3_recorders = [ GExc_spks, GInh_spks, SynEE_stat, GExc_stat, GInh_stat, GExc_rate, GInh_rate ] if tr.istdp_active: T1T3_recorders.append(SynEI_stat) set_active(*T1T3_recorders) if tr.external_mode == 'poisson': set_active(PInp_spks, PInp_rate) net.run(tr.sim.T1, report='text', report_period=300 * second, profile=True) # --------- T2 --------- # main simulation period # only active recordings are: # 1) turnover 2) C_stat 3) SynEE_a set_inactive(*T1T3_recorders) if tr.T2_spks_rec: set_active(GExc_spks, GInh_spks) if tr.external_mode == 'poisson': set_inactive(PInp_spks, PInp_rate) run_T2_syndynrec(net, tr, netw_objects) # --------- T3 --------- # second recording period, # all recorders active set_active(*T1T3_recorders) if tr.external_mode == 'poisson': set_active(PInp_spks, PInp_rate) run_T3_split(net, tr) # --------- T4 --------- # record STDP and scaling weight changes to file # through the cpp models set_inactive(*T1T3_recorders) if tr.external_mode == 'poisson': set_inactive(PInp_spks, PInp_rate) run_T4(net, tr) # --------- T5 --------- # freeze network and record Exc spikes # for cross correlations if tr.scl_active: synee_scaling.active = False if tr.istdp_active and tr.netw.config.iscl_active: synei_scaling.active = False if tr.strct_active: strctplst.active = False if tr.istdp_active and tr.istrct_active: strctplst_EI.active = False SynEE.stdp_active = 0 if tr.istdp_active: SynEI.stdp_active = 0 set_active(GExc_rate, GInh_rate) set_active(GExc_spks, GInh_spks) run_T5(net, tr) SynEE_a.record_single_timestep() if tr.synei_a_nrecpoints > 0 and tr.sim.T2 > 0. * second: SynEI_a.record_single_timestep() device.build(directory='builds/%.4d' % (tr.v_idx), clean=True, compile=True, run=True, debug=False) # ----------------------------------------- # save monitors as raws in build directory raw_dir = 'builds/%.4d/raw/' % (tr.v_idx) if not os.path.exists(raw_dir): os.makedirs(raw_dir) with open(raw_dir + 'namespace.p', 'wb') as pfile: pickle.dump(namespace, pfile) with open(raw_dir + 'gexc_stat.p', 'wb') as pfile: pickle.dump(GExc_stat.get_states(), pfile) with open(raw_dir + 'ginh_stat.p', 'wb') as pfile: pickle.dump(GInh_stat.get_states(), pfile) with open(raw_dir + 'synee_stat.p', 'wb') as pfile: pickle.dump(SynEE_stat.get_states(), pfile) if tr.istdp_active: with open(raw_dir + 'synei_stat.p', 'wb') as pfile: pickle.dump(SynEI_stat.get_states(), pfile) if ((tr.synEEdynrec or tr.synEIdynrec) and (2 * tr.syndynrec_npts * tr.syndynrec_dt < tr.sim.T2)): if tr.synEEdynrec: with open(raw_dir + 'syneedynrec.p', 'wb') as pfile: pickle.dump(SynEE_dynrec.get_states(), pfile) if tr.synEIdynrec: with open(raw_dir + 'syneidynrec.p', 'wb') as pfile: pickle.dump(SynEI_dynrec.get_states(), pfile) with open(raw_dir + 'synee_a.p', 'wb') as pfile: SynEE_a_states = SynEE_a.get_states() if tr.crs_crrs_rec: SynEE_a_states['i'] = list(SynEE.i) SynEE_a_states['j'] = list(SynEE.j) pickle.dump(SynEE_a_states, pfile) if tr.synei_a_nrecpoints > 0 and tr.sim.T2 > 0. * second: with open(raw_dir + 'synei_a.p', 'wb') as pfile: SynEI_a_states = SynEI_a.get_states() if tr.crs_crrs_rec: SynEI_a_states['i'] = list(SynEI.i) SynEI_a_states['j'] = list(SynEI.j) pickle.dump(SynEI_a_states, pfile) if tr.adjust_insertP: with open(raw_dir + 'c_stat.p', 'wb') as pfile: pickle.dump(C_stat.get_states(), pfile) with open(raw_dir + 'insP_stat.p', 'wb') as pfile: pickle.dump(insP_stat.get_states(), pfile) if tr.istdp_active and tr.adjust_EI_insertP: with open(raw_dir + 'c_EI_stat.p', 'wb') as pfile: pickle.dump(C_EI_stat.get_states(), pfile) with open(raw_dir + 'insP_EI_stat.p', 'wb') as pfile: pickle.dump(insP_EI_stat.get_states(), pfile) with open(raw_dir + 'gexc_spks.p', 'wb') as pfile: pickle.dump(GExc_spks.get_states(), pfile) with open(raw_dir + 'ginh_spks.p', 'wb') as pfile: pickle.dump(GInh_spks.get_states(), pfile) if tr.external_mode == 'poisson': with open(raw_dir + 'pinp_spks.p', 'wb') as pfile: pickle.dump(PInp_spks.get_states(), pfile) with open(raw_dir + 'gexc_rate.p', 'wb') as pfile: pickle.dump(GExc_rate.get_states(), pfile) if tr.rates_rec: pickle.dump(GExc_rate.smooth_rate(width=25 * ms), pfile) with open(raw_dir + 'ginh_rate.p', 'wb') as pfile: pickle.dump(GInh_rate.get_states(), pfile) if tr.rates_rec: pickle.dump(GInh_rate.smooth_rate(width=25 * ms), pfile) if tr.external_mode == 'poisson': with open(raw_dir + 'pinp_rate.p', 'wb') as pfile: pickle.dump(PInp_rate.get_states(), pfile) if tr.rates_rec: pickle.dump(PInp_rate.smooth_rate(width=25 * ms), pfile) # ----------------- add raw data ------------------------ fpath = 'builds/%.4d/' % (tr.v_idx) from pathlib import Path Path(fpath + 'turnover').touch() turnover_data = np.genfromtxt(fpath + 'turnover', delimiter=',') os.remove(fpath + 'turnover') with open(raw_dir + 'turnover.p', 'wb') as pfile: pickle.dump(turnover_data, pfile) Path(fpath + 'turnover_EI').touch() turnover_EI_data = np.genfromtxt(fpath + 'turnover_EI', delimiter=',') os.remove(fpath + 'turnover_EI') with open(raw_dir + 'turnover_EI.p', 'wb') as pfile: pickle.dump(turnover_EI_data, pfile) Path(fpath + 'spk_register').touch() spk_register_data = np.genfromtxt(fpath + 'spk_register', delimiter=',') os.remove(fpath + 'spk_register') with open(raw_dir + 'spk_register.p', 'wb') as pfile: pickle.dump(spk_register_data, pfile) Path(fpath + 'spk_register_EI').touch() spk_register_EI_data = np.genfromtxt(fpath + 'spk_register_EI', delimiter=',') os.remove(fpath + 'spk_register_EI') with open(raw_dir + 'spk_register_EI.p', 'wb') as pfile: pickle.dump(spk_register_EI_data, pfile) Path(fpath + 'scaling_deltas').touch() scaling_deltas_data = np.genfromtxt(fpath + 'scaling_deltas', delimiter=',') os.remove(fpath + 'scaling_deltas') with open(raw_dir + 'scaling_deltas.p', 'wb') as pfile: pickle.dump(scaling_deltas_data, pfile) Path(fpath + 'scaling_deltas_EI').touch() scaling_deltas_data = np.genfromtxt(fpath + 'scaling_deltas_EI', delimiter=',') os.remove(fpath + 'scaling_deltas_EI') with open(raw_dir + 'scaling_deltas_EI.p', 'wb') as pfile: pickle.dump(scaling_deltas_data, pfile) with open(raw_dir + 'profiling_summary.txt', 'w+') as tfile: tfile.write(str(profiling_summary(net))) # --------------- cross-correlations --------------------- if tr.crs_crrs_rec: GExc_spks = GExc_spks.get_states() synee_a = SynEE_a_states wsize = 100 * pq.ms for binsize in [1 * pq.ms, 2 * pq.ms, 5 * pq.ms]: wlen = int(wsize / binsize) ts, idxs = GExc_spks['t'], GExc_spks['i'] idxs = idxs[ts > tr.T1 + tr.T2 + tr.T3 + tr.T4] ts = ts[ts > tr.T1 + tr.T2 + tr.T3 + tr.T4] ts = ts - (tr.T1 + tr.T2 + tr.T3 + tr.T4) sts = [ neo.SpikeTrain(ts[idxs == i] / second * pq.s, t_stop=tr.T5 / second * pq.s) for i in range(tr.N_e) ] crs_crrs, syn_a = [], [] for f, (i, j) in enumerate(zip(synee_a['i'], synee_a['j'])): if synee_a['syn_active'][-1][f] == 1: crs_crr, cbin = cch(BinnedSpikeTrain(sts[i], binsize=binsize), BinnedSpikeTrain(sts[j], binsize=binsize), cross_corr_coef=True, border_correction=True, window=(-1 * wlen, wlen)) crs_crrs.append(list(np.array(crs_crr).T[0])) syn_a.append(synee_a['a'][-1][f]) fname = 'crs_crrs_wsize%dms_binsize%fms_full' % (wsize / pq.ms, binsize / pq.ms) df = { 'cbin': cbin, 'crs_crrs': np.array(crs_crrs), 'syn_a': np.array(syn_a), 'binsize': binsize, 'wsize': wsize, 'wlen': wlen } with open('builds/%.4d/raw/' % (tr.v_idx) + fname + '.p', 'wb') as pfile: pickle.dump(df, pfile) GInh_spks = GInh_spks.get_states() synei_a = SynEI_a_states wsize = 100 * pq.ms for binsize in [1 * pq.ms, 2 * pq.ms, 5 * pq.ms]: wlen = int(wsize / binsize) ts_E, idxs_E = GExc_spks['t'], GExc_spks['i'] idxs_E = idxs_E[ts_E > tr.T1 + tr.T2 + tr.T3 + tr.T4] ts_E = ts_E[ts_E > tr.T1 + tr.T2 + tr.T3 + tr.T4] ts_E = ts_E - (tr.T1 + tr.T2 + tr.T3 + tr.T4) ts_I, idxs_I = GInh_spks['t'], GInh_spks['i'] idxs_I = idxs_I[ts_I > tr.T1 + tr.T2 + tr.T3 + tr.T4] ts_I = ts_I[ts_I > tr.T1 + tr.T2 + tr.T3 + tr.T4] ts_I = ts_I - (tr.T1 + tr.T2 + tr.T3 + tr.T4) sts_E = [ neo.SpikeTrain(ts_E[idxs_E == i] / second * pq.s, t_stop=tr.T5 / second * pq.s) for i in range(tr.N_e) ] sts_I = [ neo.SpikeTrain(ts_I[idxs_I == i] / second * pq.s, t_stop=tr.T5 / second * pq.s) for i in range(tr.N_i) ] crs_crrs, syn_a = [], [] for f, (i, j) in enumerate(zip(synei_a['i'], synei_a['j'])): if synei_a['syn_active'][-1][f] == 1: crs_crr, cbin = cch(BinnedSpikeTrain(sts_I[i], binsize=binsize), BinnedSpikeTrain(sts_E[j], binsize=binsize), cross_corr_coef=True, border_correction=True, window=(-1 * wlen, wlen)) crs_crrs.append(list(np.array(crs_crr).T[0])) syn_a.append(synei_a['a'][-1][f]) fname = 'EI_crrs_wsize%dms_binsize%fms_full' % (wsize / pq.ms, binsize / pq.ms) df = { 'cbin': cbin, 'crs_crrs': np.array(crs_crrs), 'syn_a': np.array(syn_a), 'binsize': binsize, 'wsize': wsize, 'wlen': wlen } with open('builds/%.4d/raw/' % (tr.v_idx) + fname + '.p', 'wb') as pfile: pickle.dump(df, pfile) # ----------------- clean up --------------------------- shutil.rmtree('builds/%.4d/results/' % (tr.v_idx)) shutil.rmtree('builds/%.4d/static_arrays/' % (tr.v_idx)) shutil.rmtree('builds/%.4d/brianlib/' % (tr.v_idx)) shutil.rmtree('builds/%.4d/code_objects/' % (tr.v_idx)) # ---------------- plot results -------------------------- #os.chdir('./analysis/file_based/') if tr.istdp_active: from src.analysis.overview_winh import overview_figure overview_figure('builds/%.4d' % (tr.v_idx), namespace) else: from src.analysis.overview import overview_figure overview_figure('builds/%.4d' % (tr.v_idx), namespace) from src.analysis.synw_fb import synw_figure synw_figure('builds/%.4d' % (tr.v_idx), namespace) if tr.istdp_active: synw_figure('builds/%.4d' % (tr.v_idx), namespace, connections='EI') from src.analysis.synw_log_fb import synw_log_figure synw_log_figure('builds/%.4d' % (tr.v_idx), namespace) if tr.istdp_active: synw_log_figure('builds/%.4d' % (tr.v_idx), namespace, connections='EI')
def run_task_hierarchical(task_info, taskdir, tempdir): # imports from brian2 import defaultclock, set_device, seed, TimedArray, Network, profiling_summary from brian2.monitors import SpikeMonitor, PopulationRateMonitor, StateMonitor from brian2.synapses import Synapses from brian2.core.magic import start_scope from brian2.units import second, ms, amp from integration_circuit import mk_intcircuit from sensory_circuit import mk_sencircuit, mk_sencircuit_2c, mk_sencircuit_2cplastic from burstQuant import spks2neurometric from scipy import interpolate # if you want to put something in the taskdir, you must create it first os.mkdir(taskdir) print(taskdir) # parallel code and flag to start set_device('cpp_standalone', directory=tempdir) #prefs.devices.cpp_standalone.openmp_threads = max_tasks start_scope() # simulation parameters seedcon = task_info['simulation']['seedcon'] runtime = task_info['simulation']['runtime'] runtime_ = runtime / second settletime = task_info['simulation']['settletime'] settletime_ = settletime / second stimon = task_info['simulation']['stimon'] stimoff = task_info['simulation']['stimoff'] stimoff_ = stimoff / second stimdur = stimoff - stimon smoothwin = task_info['simulation']['smoothwin'] nummethod = task_info['simulation']['nummethod'] # ------------------------------------- # Construct hierarchical network # ------------------------------------- # set connection seed seed( seedcon ) # set specific seed to test the same network, this way we also have the same synapses! # decision circuit Dgroups, Dsynapses, Dsubgroups = mk_intcircuit(task_info) decE = Dgroups['DE'] decI = Dgroups['DI'] decE1 = Dsubgroups['DE1'] decE2 = Dsubgroups['DE2'] # sensory circuit, ff and fb connections eps = 0.2 # connection probability d = 1 * ms # transmission delays of E synapses if task_info['simulation']['2cmodel']: if task_info['simulation']['plasticdend']: # plasticity rule in dendrites --> FB synapses will be removed from the network! Sgroups, Ssynapses, Ssubgroups = mk_sencircuit_2cplastic(task_info) else: # 2c model (Naud) Sgroups, Ssynapses, Ssubgroups = mk_sencircuit_2c(task_info) senE = Sgroups['soma'] dend = Sgroups['dend'] senI = Sgroups['SI'] senE1 = Ssubgroups['soma1'] senE2 = Ssubgroups['soma2'] dend1 = Ssubgroups['dend1'] dend2 = Ssubgroups['dend2'] # FB wDS = 0.003 # synaptic weight of FB synapses, 0.0668 nS when scaled by gleakE of sencircuit_2c synDE1SE1 = Synapses(decE1, dend1, model='w : 1', method=nummethod, on_pre='x_ea += w', delay=d) synDE2SE2 = Synapses(decE2, dend2, model='w : 1', method=nummethod, on_pre='x_ea += w', delay=d) else: # normal sensory circuit (Wimmer) Sgroups, Ssynapses, Ssubgroups = mk_sencircuit(task_info) senE = Sgroups['SE'] senI = Sgroups['SI'] senE1 = Ssubgroups['SE1'] senE2 = Ssubgroups['SE2'] # FB wDS = 0.004 # synaptic weight of FB synapses, 0.0668 nS when scaled by gleakE of sencircuit synDE1SE1 = Synapses(decE1, senE1, model='w : 1', method=nummethod, on_pre='x_ea += w', delay=d) synDE2SE2 = Synapses(decE2, senE2, model='w : 1', method=nummethod, on_pre='x_ea += w', delay=d) # feedforward synapses from sensory to integration wSD = 0.0036 # synaptic weight of FF synapses, 0.09 nS when scaled by gleakE of intcircuit synSE1DE1 = Synapses(senE1, decE1, model='w : 1', method=nummethod, on_pre='g_ea += w', delay=d) synSE1DE1.connect(p='eps') synSE1DE1.w = 'wSD' synSE2DE2 = Synapses(senE2, decE2, model='w : 1', method=nummethod, on_pre='g_ea += w', delay=d) synSE2DE2.connect(p='eps') synSE2DE2.w = 'wSD' # feedback synapses from integration to sensory b_fb = task_info['bfb'] # feedback strength, between 0 and 6 wDS *= b_fb # synaptic weight of FB synapses, 0.0668 nS when scaled by gleakE of sencircuit synDE1SE1.connect(p='eps') synDE1SE1.w = 'wDS' synDE2SE2.connect(p='eps') synDE2SE2.w = 'wDS' # ------------------------------------- # Create stimuli # ------------------------------------- if task_info['stimulus']['replicate']: # replicated stimuli across iters() np.random.seed(task_info['seed']) # numpy seed for OU process else: # every trials has different stimuli np.random.seed() # Note that in standalone we need to specify np seed because it's not taken care with Brian's seed() function! if task_info['simulation']['2cmodel']: I0 = task_info['stimulus']['I0s'] last_muOUd = np.loadtxt("last_muOUd.csv") # save the mean else: I0 = task_info['stimulus'][ 'I0'] # mean input current for zero-coherence stim c = task_info['c'] # stim coherence (between 0 and 1) mu1 = task_info['stimulus'][ 'mu1'] # av. additional input current to senE1 at highest coherence (c=1) mu2 = task_info['stimulus'][ 'mu2'] # av. additional input current to senE2 at highest coherence (c=1) sigma = task_info['stimulus'][ 'sigma'] # amplitude of temporal modulations of stim sigmastim = 0.212 * sigma # std of modulation of stim inputs sigmaind = 0.212 * sigma # std of modulations in individual inputs taustim = task_info['stimulus'][ 'taustim'] # correlation time constant of Ornstein-Uhlenbeck process # generate stim from OU process N_stim = int(senE1.__len__()) z1, z2, zk1, zk2 = generate_stim(N_stim, stimdur, taustim) # stim2exc i1 = I0 * (1 + c * mu1 + sigmastim * z1 + sigmaind * zk1) i2 = I0 * (1 + c * mu2 + sigmastim * z2 + sigmaind * zk2) stim_dt = 1 * ms i1t = np.concatenate((np.zeros((int(stimon / ms), N_stim)), i1.T, np.zeros((int( (runtime - stimoff) / stim_dt), N_stim))), axis=0) i2t = np.concatenate((np.zeros((int(stimon / ms), N_stim)), i2.T, np.zeros((int( (runtime - stimoff) / stim_dt), N_stim))), axis=0) Irec = TimedArray(np.concatenate((i1t, i2t), axis=1) * amp, dt=stim_dt) # ------------------------------------- # Simulation # ------------------------------------- # set initial conditions (different for evert trial) seed() decE.g_ea = '0.2 * rand()' decI.g_ea = '0.2 * rand()' decE.V = '-52*mV + 2*mV * rand()' decI.V = '-52*mV + 2*mV * rand()' # random initialization near 0, prevent an early decision! senE.g_ea = '0.05 * (1 + 0.2*rand())' senI.g_ea = '0.05 * (1 + 0.2*rand())' senE.V = '-52*mV + 2*mV*rand()' # random initialization near Vt, avoid initial bump! senI.V = '-52*mV + 2*mV*rand()' if task_info['simulation']['2cmodel']: dend.g_ea = '0.05 * (1 + 0.2*rand())' dend.V_d = '-72*mV + 2*mV*rand()' dend.muOUd = np.tile(last_muOUd, 2) * amp # create monitors rateDE1 = PopulationRateMonitor(decE1) rateDE2 = PopulationRateMonitor(decE2) rateSE1 = PopulationRateMonitor(senE1) rateSE2 = PopulationRateMonitor(senE2) subSE = int(senE1.__len__()) spksSE = SpikeMonitor(senE[subSE - 100:subSE + 100]) # last 100 of SE1 and first 100 of SE2 # construct network net = Network(Dgroups.values(), Dsynapses.values(), Sgroups.values(), Ssynapses.values(), synSE1DE1, synSE2DE2, synDE1SE1, synDE2SE2, rateDE1, rateDE2, rateSE1, rateSE2, spksSE, name='hierarchicalnet') # create more monitors for plot if task_info['simulation']['pltfig1']: # inh rateDI = PopulationRateMonitor(decI) rateSI = PopulationRateMonitor(senI) # spk monitors subDE = int(decE1.__len__() * 2) spksDE = SpikeMonitor(decE[:subDE]) spksSE = SpikeMonitor(senE) # state mons no more, just the arrays stim1 = i1t.T stim2 = i2t.T stimtime = np.linspace(0, runtime_, stim1.shape[1]) # construct network net = Network(Dgroups.values(), Dsynapses.values(), Sgroups.values(), Ssynapses.values(), synSE1DE1, synSE2DE2, synDE1SE1, synDE2SE2, spksDE, rateDE1, rateDE2, rateDI, spksSE, rateSE1, rateSE2, rateSI, name='hierarchicalnet') if task_info['simulation']['plasticdend']: # create state monitor to follow muOUd and add it to the networks dend_mon = StateMonitor(dend1, variables=['muOUd', 'Ibg', 'g_ea', 'B'], record=True, dt=1 * ms) net.add(dend_mon) # remove FB synapses! net.remove([synDE1SE1, synDE2SE2, Dsynapses.values()]) print( " FB synapses and synapses of decision circuit are ignored in this simulation!" ) # run hierarchical net net.run(runtime, report='stdout', profile=True) print(profiling_summary(net=net, show=10)) # nice plots on cluster if task_info['simulation']['pltfig1']: plot_fig1b([ rateDE1, rateDE2, rateDI, spksDE, rateSE1, rateSE2, rateSI, spksSE, stim1, stim2, stimtime ], smoothwin, taskdir) # ------------------------------------- # Burst quantification # ------------------------------------- events = np.zeros(1) bursts = np.zeros(1) singles = np.zeros(1) spikes = np.zeros(1) last_muOUd = np.zeros(1) # neurometric params dt = spksSE.clock.dt validburst = task_info['sen']['2c']['validburst'] smoothwin_ = smoothwin / second if task_info['simulation']['burstanalysis']: if task_info['simulation']['2cmodel']: last_muOUd = np.array(dend_mon.muOUd[:, -int(1e3):].mean(axis=1)) if task_info['simulation']['plasticdend']: # calculate neurometric info per population events, bursts, singles, spikes, isis = spks2neurometric( spksSE, runtime, settletime, validburst, smoothwin=smoothwin_, raster=False) # plot & save weigths after convergence eta0 = task_info['sen']['2c']['eta0'] tauB = task_info['sen']['2c']['tauB'] targetB = task_info['targetB'] B0 = tauB * targetB tau_update = task_info['sen']['2c']['tau_update'] eta = eta0 * tau_update / tauB plot_weights(dend_mon, events, bursts, spikes, [targetB, B0, eta, tauB, tau_update, smoothwin_], taskdir) plot_rasters(spksSE, bursts, targetB, isis, runtime_, taskdir) else: # calculate neurometric per neuron events, bursts, singles, spikes, isis = spks2neurometric( spksSE, runtime, settletime, validburst, smoothwin=smoothwin_, raster=True) plot_neurometric(events, bursts, spikes, stim1, stim2, stimtime, (settletime_, runtime_), taskdir, smoothwin_) plot_isis(isis, bursts, events, (settletime_, runtime_), taskdir) # ------------------------------------- # Choice selection # ------------------------------------- # population rates and downsample originaltime = rateDE1.t / second interptime = np.linspace(0, originaltime[-1], originaltime[-1] * 100) # every 10 ms fDE1 = interpolate.interp1d( originaltime, rateDE1.smooth_rate(window='flat', width=smoothwin)) fDE2 = interpolate.interp1d( originaltime, rateDE2.smooth_rate(window='flat', width=smoothwin)) fSE1 = interpolate.interp1d( originaltime, rateSE1.smooth_rate(window='flat', width=smoothwin)) fSE2 = interpolate.interp1d( originaltime, rateSE2.smooth_rate(window='flat', width=smoothwin)) rateDE = np.array([f(interptime) for f in [fDE1, fDE2]]) rateSE = np.array([f(interptime) for f in [fSE1, fSE2]]) # select the last half second of the stimulus newdt = runtime_ / rateDE.shape[1] settletimeidx = int(settletime_ / newdt) dec_ival = np.array([(stimoff_ - 0.5) / newdt, stimoff_ / newdt], dtype=int) who_wins = rateDE[:, dec_ival[0]:dec_ival[1]].mean(axis=1) # divide trls into preferred and non-preferred pref_msk = np.argmax(who_wins) poprates_dec = np.array([rateDE[pref_msk], rateDE[~pref_msk]]) # 0: pref, 1: npref poprates_sen = np.array([rateSE[pref_msk], rateSE[~pref_msk]]) results = { 'raw_data': { 'poprates_dec': poprates_dec[:, settletimeidx:], 'poprates_sen': poprates_sen[:, settletimeidx:], 'pref_msk': np.array([pref_msk]), 'last_muOUd': last_muOUd }, 'sim_state': np.zeros(1), 'computed': { 'events': events, 'bursts': bursts, 'singles': singles, 'spikes': spikes, 'isis': np.array(isis) } } return results