def simulate_LIF_neuron(input_current, simulation_time=5. * b2.ms, dt=0.01, v_rest=-70 * b2.mV, v_reset=-65 * b2.mV, firing_threshold=-50 * b2.mV, membrane_resistance=10. * b2.Mohm, membrane_time_scale=8. * b2.ms, abs_refractory_period=2.0 * b2.ms): b2.defaultclock.dt = dt * b2.ms # differential equation of Leaky Integrate-and-Fire model # eqs = """ # dv/dt = # ( -(v-v_rest) + membrane_resistance * input_current(t,i) ) / membrane_time_scale : volt (unless refractory) # """ eqs = """ dv/dt = ( -(v-v_rest) + membrane_resistance * input_current ) / membrane_time_scale : volt (unless refractory) """ neuron = b2.NeuronGroup(1, model=eqs, reset="v=v_reset", threshold="v>firing_threshold", refractory=abs_refractory_period, method="exact") # "euler" / "exact" neuron.v = v_rest # set initial value network = b2.core.network.Network(neuron) # run before for compiling (JIT compile time out of timing) #network.run(simulation_time, profile=True) spike_monitor = b2.SpikeMonitor(neuron) network.add(spike_monitor) neuron.v = v_rest #start_wallclock = time.time() #start_cpu = time.clock() # timer() network.run(simulation_time, profile=True) #end_cpu = time.clock() # timer() #end_wallclock = time.time() #time_elapsed_wallclock = end_wallclock - start_wallclock #time_elapsed_cpu = end_cpu - start_cpu b2.device.build(directory='output', clean=True, compile=True, run=True, debug=False) print("\n") print("brian2 profiling summary (listed by time consumption):\n") print(b2.profiling_summary()) return spike_monitor, network.get_profiling_info( ) # time_elapsed_wallclock, time_elapsed_cpu,
def test_profile_ipython_html(): G = NeuronGroup(10, 'dv/dt = -v / (10*ms) : 1', threshold='v>1', reset='v=0', name='profile_test') G.v = 1.1 net = Network(G) net.run(1*ms, profile=True) summary = profiling_summary(net) assert len(summary._repr_html_())
def run_net(tr): # prefs.codegen.target = 'numpy' # prefs.codegen.target = 'cython' if tr.n_threads > 1: prefs.devices.cpp_standalone.openmp_threads = tr.n_threads set_device('cpp_standalone', directory='./builds/%.4d' % (tr.v_idx), build_on_run=False) print("Started process with id ", str(tr.v_idx)) T = tr.T1 + tr.T2 + tr.T3 + tr.T4 + tr.T5 namespace = tr.netw.f_to_dict(short_names=True, fast_access=True) namespace['idx'] = tr.v_idx defaultclock.dt = tr.netw.sim.dt # collect all network components dependent on configuration # (e.g. poisson vs. memnoise) and add them to the Brian 2 # network object later netw_objects = [] if tr.external_mode == 'memnoise': neuron_model = tr.condlif_memnoise elif tr.external_mode == 'poisson': neuron_model = tr.condlif_poisson GExc = NeuronGroup( N=tr.N_e, model=neuron_model, threshold=tr.nrnEE_thrshld, reset=tr.nrnEE_reset, #method=tr.neuron_method, namespace=namespace) GInh = NeuronGroup( N=tr.N_i, model=neuron_model, threshold='V > Vt', reset='V=Vr_i', #method=tr.neuron_method, namespace=namespace) if tr.external_mode == 'memnoise': GExc.mu, GInh.mu = tr.mu_e, tr.mu_i GExc.sigma, GInh.sigma = tr.sigma_e, tr.sigma_i GExc.Vt, GInh.Vt = tr.Vt_e, tr.Vt_i GExc.V , GInh.V = np.random.uniform(tr.Vr_e/mV, tr.Vt_e/mV, size=tr.N_e)*mV, \ np.random.uniform(tr.Vr_i/mV, tr.Vt_i/mV, size=tr.N_i)*mV netw_objects.extend([GExc, GInh]) synEE_pre_mod = mod.synEE_pre synEE_post_mod = mod.synEE_post if tr.external_mode == 'poisson': if tr.PInp_mode == 'pool': PInp = PoissonGroup(tr.NPInp, rates=tr.PInp_rate, namespace=namespace, name='poissongroup_exc') sPN = Synapses(target=GExc, source=PInp, model=tr.poisson_mod, on_pre='gfwd_post += a_EPoi', namespace=namespace, name='synPInpExc') sPN_src, sPN_tar = generate_N_connections(N_tar=tr.N_e, N_src=tr.NPInp, N=tr.NPInp_1n) elif tr.PInp_mode == 'indep': PInp = PoissonGroup(tr.N_e, rates=tr.PInp_rate, namespace=namespace) sPN = Synapses(target=GExc, source=PInp, model=tr.poisson_mod, on_pre='gfwd_post += a_EPoi', namespace=namespace, name='synPInp_inhInh') sPN_src, sPN_tar = range(tr.N_e), range(tr.N_e) sPN.connect(i=sPN_src, j=sPN_tar) if tr.PInp_mode == 'pool': PInp_inh = PoissonGroup(tr.NPInp_inh, rates=tr.PInp_inh_rate, namespace=namespace, name='poissongroup_inh') sPNInh = Synapses(target=GInh, source=PInp_inh, model=tr.poisson_mod, on_pre='gfwd_post += a_EPoi', namespace=namespace) sPNInh_src, sPNInh_tar = generate_N_connections(N_tar=tr.N_i, N_src=tr.NPInp_inh, N=tr.NPInp_inh_1n) elif tr.PInp_mode == 'indep': PInp_inh = PoissonGroup(tr.N_i, rates=tr.PInp_inh_rate, namespace=namespace) sPNInh = Synapses(target=GInh, source=PInp_inh, model=tr.poisson_mod, on_pre='gfwd_post += a_EPoi', namespace=namespace) sPNInh_src, sPNInh_tar = range(tr.N_i), range(tr.N_i) sPNInh.connect(i=sPNInh_src, j=sPNInh_tar) netw_objects.extend([PInp, sPN, PInp_inh, sPNInh]) if tr.syn_noise: synEE_mod = '''%s %s''' % (tr.synEE_noise, tr.synEE_mod) else: synEE_mod = '''%s %s''' % (tr.synEE_static, tr.synEE_mod) if tr.stdp_active: synEE_pre_mod = '''%s %s''' % (synEE_pre_mod, mod.synEE_pre_STDP) synEE_post_mod = '''%s %s''' % (synEE_post_mod, mod.synEE_post_STDP) if tr.synEE_rec: synEE_pre_mod = '''%s %s''' % (synEE_pre_mod, mod.synEE_pre_rec) synEE_post_mod = '''%s %s''' % (synEE_post_mod, mod.synEE_post_rec) # E<-E advanced synapse model, rest simple SynEE = Synapses(target=GExc, source=GExc, model=synEE_mod, on_pre=synEE_pre_mod, on_post=synEE_post_mod, namespace=namespace, dt=tr.synEE_mod_dt) SynIE = Synapses(target=GInh, source=GExc, on_pre='ge_post += a_ie', namespace=namespace) SynEI = Synapses(target=GExc, source=GInh, on_pre='gi_post += a_ei', namespace=namespace) SynII = Synapses(target=GInh, source=GInh, on_pre='gi_post += a_ii', namespace=namespace) if tr.strct_active: sEE_src, sEE_tar = generate_full_connectivity(tr.N_e, same=True) SynEE.connect(i=sEE_src, j=sEE_tar) SynEE.syn_active = 0 else: srcs_full, tars_full = generate_full_connectivity(tr.N_e, same=True) SynEE.connect(i=srcs_full, j=tars_full) SynEE.syn_active = 0 sIE_src, sIE_tar = generate_connections(tr.N_i, tr.N_e, tr.p_ie) sEI_src, sEI_tar = generate_connections(tr.N_e, tr.N_i, tr.p_ei) sII_src, sII_tar = generate_connections(tr.N_i, tr.N_i, tr.p_ii, same=True) SynIE.connect(i=sIE_src, j=sIE_tar) SynEI.connect(i=sEI_src, j=sEI_tar) SynII.connect(i=sII_src, j=sII_tar) tr.f_add_result('sIE_src', sIE_src) tr.f_add_result('sIE_tar', sIE_tar) tr.f_add_result('sEI_src', sEI_src) tr.f_add_result('sEI_tar', sEI_tar) tr.f_add_result('sII_src', sII_src) tr.f_add_result('sII_tar', sII_tar) if tr.syn_noise: SynEE.syn_sigma = tr.syn_sigma SynEE.insert_P = tr.insert_P SynEE.p_inactivate = tr.p_inactivate SynEE.stdp_active = 1 ATM_vals = np.random.normal(loc=tr.ATotalMax, scale=tr.ATotalMax_sd, size=tr.N_e * (tr.N_e - 1)) assert np.min(ATM_vals) > 0. SynEE.ATotalMax = ATM_vals # make randomly chosen synapses active at beginning rs = np.random.uniform(size=tr.N_e * (tr.N_e - 1)) initial_active = (rs < tr.p_ee).astype('int') initial_a = initial_active * tr.a_ee SynEE.syn_active = initial_active SynEE.a = initial_a # recording of stdp in T4 SynEE.stdp_rec_start = tr.T1 + tr.T2 + tr.T3 SynEE.stdp_rec_max = tr.T1 + tr.T2 + tr.T3 + tr.stdp_rec_T # synaptic scaling if tr.netw.config.scl_active: if tr.syn_scl_rec: SynEE.scl_rec_start = tr.T1 + tr.T2 + tr.T3 SynEE.scl_rec_max = tr.T1 + tr.T2 + tr.T3 + tr.scl_rec_T else: SynEE.scl_rec_start = T + 10 * second SynEE.scl_rec_max = T SynEE.summed_updaters['Asum_post']._clock = Clock( dt=tr.dt_synEE_scaling) synscaling = SynEE.run_regularly(tr.synEE_scaling, dt=tr.dt_synEE_scaling, when='end', name='syn_scaling') # # intrinsic plasticity # if tr.netw.config.it_active: # GExc.h_ip = tr.h_ip # GExc.run_regularly(tr.intrinsic_mod, dt = tr.it_dt, when='end') # structural plasticity if tr.netw.config.strct_active: if tr.strct_mode == 'zero': if tr.turnover_rec: strct_mod = '''%s %s''' % (tr.strct_mod, tr.turnover_rec_mod) else: strct_mod = tr.strct_mod strctplst = SynEE.run_regularly(strct_mod, dt=tr.strct_dt, when='end', name='strct_plst_zero') elif tr.strct_mode == 'thrs': if tr.turnover_rec: strct_mod_thrs = '''%s %s''' % (tr.strct_mod_thrs, tr.turnover_rec_mod) else: strct_mod_thrs = tr.strct_mod_thrs strctplst = SynEE.run_regularly(strct_mod_thrs, dt=tr.strct_dt, when='end', name='strct_plst_thrs') netw_objects.extend([SynEE, SynEI, SynIE, SynII]) # keep track of the number of active synapses sum_target = NeuronGroup(1, 'c : 1 (shared)', dt=tr.csample_dt) sum_model = '''NSyn : 1 (constant) c_post = (1.0*syn_active_pre)/NSyn : 1 (summed)''' sum_connection = Synapses(target=sum_target, source=SynEE, model=sum_model, dt=tr.csample_dt, name='get_active_synapse_count') sum_connection.connect() sum_connection.NSyn = tr.N_e * (tr.N_e - 1) if tr.adjust_insertP: # homeostatically adjust growth rate growth_updater = Synapses(sum_target, SynEE) growth_updater.run_regularly('insert_P_post *= 0.1/c_pre', when='after_groups', dt=tr.csample_dt, name='update_insP') growth_updater.connect(j='0') netw_objects.extend([sum_target, sum_connection, growth_updater]) # -------------- recording ------------------ GExc_recvars = [] if tr.memtraces_rec: GExc_recvars.append('V') if tr.vttraces_rec: GExc_recvars.append('Vt') if tr.getraces_rec: GExc_recvars.append('ge') if tr.gitraces_rec: GExc_recvars.append('gi') if tr.gfwdtraces_rec and tr.external_mode == 'poisson': GExc_recvars.append('gfwd') GInh_recvars = GExc_recvars GExc_stat = StateMonitor(GExc, GExc_recvars, record=[0, 1, 2], dt=tr.GExc_stat_dt) GInh_stat = StateMonitor(GInh, GInh_recvars, record=[0, 1, 2], dt=tr.GInh_stat_dt) SynEE_recvars = [] if tr.synee_atraces_rec: SynEE_recvars.append('a') if tr.synee_activetraces_rec: SynEE_recvars.append('syn_active') if tr.synee_Apretraces_rec: SynEE_recvars.append('Apre') if tr.synee_Aposttraces_rec: SynEE_recvars.append('Apost') SynEE_stat = StateMonitor(SynEE, SynEE_recvars, record=range(tr.n_synee_traces_rec), when='end', dt=tr.synEE_stat_dt) if tr.adjust_insertP: C_stat = StateMonitor(sum_target, 'c', dt=tr.csample_dt, record=[0], when='end') insP_stat = StateMonitor(SynEE, 'insert_P', dt=tr.csample_dt, record=[0], when='end') netw_objects.extend([C_stat, insP_stat]) GExc_spks = SpikeMonitor(GExc) GInh_spks = SpikeMonitor(GInh) GExc_rate = PopulationRateMonitor(GExc) GInh_rate = PopulationRateMonitor(GInh) if tr.external_mode == 'poisson': PInp_spks = SpikeMonitor(PInp) PInp_rate = PopulationRateMonitor(PInp) netw_objects.extend([PInp_spks, PInp_rate]) if tr.synee_a_nrecpoints == 0: SynEE_a_dt = 10 * tr.sim.T2 else: SynEE_a_dt = tr.sim.T2 / tr.synee_a_nrecpoints SynEE_a = StateMonitor(SynEE, ['a', 'syn_active'], record=range(tr.N_e * (tr.N_e - 1)), dt=SynEE_a_dt, when='end', order=100) netw_objects.extend([ GExc_stat, GInh_stat, SynEE_stat, SynEE_a, GExc_spks, GInh_spks, GExc_rate, GInh_rate ]) net = Network(*netw_objects) def set_active(*argv): for net_object in argv: net_object.active = True def set_inactive(*argv): for net_object in argv: net_object.active = False ### Simulation periods # --------- T1 --------- # initial recording period, # all recorders active set_active(GExc_spks, GInh_spks, SynEE_stat, GExc_stat, GInh_stat, GExc_rate, GInh_rate) if tr.external_mode == 'poisson': set_active(PInp_spks, PInp_rate) net.run(tr.sim.T1, report='text', report_period=300 * second, profile=True) # --------- T2 --------- # main simulation period # only active recordings are: # 1) turnover 2) C_stat 3) SynEE_a set_inactive(GExc_spks, GInh_spks, SynEE_stat, GExc_stat, GInh_stat, GExc_rate, GInh_rate) if tr.external_mode == 'poisson': set_inactive(PInp_spks, PInp_rate) net.run(tr.sim.T2, report='text', report_period=300 * second, profile=True) # --------- T3 --------- # second recording period, # all recorders active set_active(GExc_spks, GInh_spks, SynEE_stat, GExc_stat, GInh_stat, GExc_rate, GInh_rate) if tr.external_mode == 'poisson': set_active(PInp_spks, PInp_rate) net.run(tr.sim.T3, report='text', report_period=300 * second, profile=True) # --------- T4 --------- # record STDP and scaling weight changes to file # through the cpp models set_inactive(GExc_spks, GInh_spks, SynEE_stat, GExc_stat, GInh_stat, GExc_rate, GInh_rate) if tr.external_mode == 'poisson': set_inactive(PInp_spks, PInp_rate) net.run(tr.sim.T4, report='text', report_period=300 * second, profile=True) # --------- T5 --------- # freeze network and record Exc spikes # for cross correlations synscaling.active = False strctplst.active = False SynEE.stdp_active = 0 set_active(GExc_spks) net.run(tr.sim.T5, report='text', report_period=300 * second, profile=True) SynEE_a.record_single_timestep() device.build(directory='builds/%.4d' % (tr.v_idx), clean=True, compile=True, run=True, debug=False) # ----------------------------------------- # save monitors as raws in build directory raw_dir = 'builds/%.4d/raw/' % (tr.v_idx) if not os.path.exists(raw_dir): os.makedirs(raw_dir) with open(raw_dir + 'namespace.p', 'wb') as pfile: pickle.dump(namespace, pfile) with open(raw_dir + 'gexc_stat.p', 'wb') as pfile: pickle.dump(GExc_stat.get_states(), pfile) with open(raw_dir + 'ginh_stat.p', 'wb') as pfile: pickle.dump(GInh_stat.get_states(), pfile) with open(raw_dir + 'synee_stat.p', 'wb') as pfile: pickle.dump(SynEE_stat.get_states(), pfile) with open(raw_dir + 'synee_a.p', 'wb') as pfile: SynEE_a_states = SynEE_a.get_states() if tr.crs_crrs_rec: SynEE_a_states['i'] = list(SynEE.i) SynEE_a_states['j'] = list(SynEE.j) pickle.dump(SynEE_a_states, pfile) if tr.adjust_insertP: with open(raw_dir + 'c_stat.p', 'wb') as pfile: pickle.dump(C_stat.get_states(), pfile) with open(raw_dir + 'insP_stat.p', 'wb') as pfile: pickle.dump(insP_stat.get_states(), pfile) with open(raw_dir + 'gexc_spks.p', 'wb') as pfile: pickle.dump(GExc_spks.get_states(), pfile) with open(raw_dir + 'ginh_spks.p', 'wb') as pfile: pickle.dump(GInh_spks.get_states(), pfile) if tr.external_mode == 'poisson': with open(raw_dir + 'pinp_spks.p', 'wb') as pfile: pickle.dump(PInp_spks.get_states(), pfile) with open(raw_dir + 'gexc_rate.p', 'wb') as pfile: pickle.dump(GExc_rate.get_states(), pfile) if tr.rates_rec: pickle.dump(GExc_rate.smooth_rate(width=25 * ms), pfile) with open(raw_dir + 'ginh_rate.p', 'wb') as pfile: pickle.dump(GInh_rate.get_states(), pfile) if tr.rates_rec: pickle.dump(GInh_rate.smooth_rate(width=25 * ms), pfile) if tr.external_mode == 'poisson': with open(raw_dir + 'pinp_rate.p', 'wb') as pfile: pickle.dump(PInp_rate.get_states(), pfile) if tr.rates_rec: pickle.dump(PInp_rate.smooth_rate(width=25 * ms), pfile) # ----------------- add raw data ------------------------ fpath = 'builds/%.4d/' % (tr.v_idx) from pathlib import Path Path(fpath + 'turnover').touch() turnover_data = np.genfromtxt(fpath + 'turnover', delimiter=',') os.remove(fpath + 'turnover') with open(raw_dir + 'turnover.p', 'wb') as pfile: pickle.dump(turnover_data, pfile) Path(fpath + 'spk_register').touch() spk_register_data = np.genfromtxt(fpath + 'spk_register', delimiter=',') os.remove(fpath + 'spk_register') with open(raw_dir + 'spk_register.p', 'wb') as pfile: pickle.dump(spk_register_data, pfile) Path(fpath + 'scaling_deltas').touch() scaling_deltas_data = np.genfromtxt(fpath + 'scaling_deltas', delimiter=',') os.remove(fpath + 'scaling_deltas') with open(raw_dir + 'scaling_deltas.p', 'wb') as pfile: pickle.dump(scaling_deltas_data, pfile) with open(raw_dir + 'profiling_summary.txt', 'w+') as tfile: tfile.write(str(profiling_summary(net))) # --------------- cross-correlations --------------------- if tr.crs_crrs_rec: GExc_spks = GExc_spks.get_states() synee_a = SynEE_a_states wsize = 100 * pq.ms for binsize in [1 * pq.ms, 2 * pq.ms, 5 * pq.ms]: wlen = int(wsize / binsize) ts, idxs = GExc_spks['t'], GExc_spks['i'] idxs = idxs[ts > tr.T1 + tr.T2 + tr.T3 + tr.T4] ts = ts[ts > tr.T1 + tr.T2 + tr.T3 + tr.T4] ts = ts - (tr.T1 + tr.T2 + tr.T3 + tr.T4) sts = [ neo.SpikeTrain(ts[idxs == i] / second * pq.s, t_stop=tr.T5 / second * pq.s) for i in range(tr.N_e) ] crs_crrs, syn_a = [], [] for f, (i, j) in enumerate(zip(synee_a['i'], synee_a['j'])): if synee_a['syn_active'][-1][f] == 1: crs_crr, cbin = cch(BinnedSpikeTrain(sts[i], binsize=binsize), BinnedSpikeTrain(sts[j], binsize=binsize), cross_corr_coef=True, border_correction=True, window=(-1 * wlen, wlen)) crs_crrs.append(list(np.array(crs_crr).T[0])) syn_a.append(synee_a['a'][-1][f]) fname = 'crs_crrs_wsize%dms_binsize%fms_full' % (wsize / pq.ms, binsize / pq.ms) df = { 'cbin': cbin, 'crs_crrs': np.array(crs_crrs), 'syn_a': np.array(syn_a), 'binsize': binsize, 'wsize': wsize, 'wlen': wlen } with open('builds/%.4d/raw/' % (tr.v_idx) + fname + '.p', 'wb') as pfile: pickle.dump(df, pfile) # ----------------- clean up --------------------------- shutil.rmtree('builds/%.4d/results/' % (tr.v_idx)) # ---------------- plot results -------------------------- #os.chdir('./analysis/file_based/') from analysis.overview_fb import overview_figure overview_figure('builds/%.4d' % (tr.v_idx), namespace) from analysis.synw_fb import synw_figure synw_figure('builds/%.4d' % (tr.v_idx), namespace) from analysis.synw_log_fb import synw_log_figure synw_log_figure('builds/%.4d' % (tr.v_idx), namespace)
def simulate_balanced_network( input_current, simulation_time=1000. * b2.ms, dt=0.01, n_of_neurons=1000, f=3 / 4, connection_probability=0.1, w0=0.2 * b2.mV, # J in julia EventBasedIntegrator g=5., membrane_resistance=10. * b2.Mohm, #poisson_input_rate=13. * b2.Hz, v_rest=0. * b2.mV, v_reset=10. * b2.mV, firing_threshold=20. * b2.mV, membrane_time_scale=20. * b2.ms, abs_refractory_period=2.0 * b2.ms, random_vm_init=False): b2.defaultclock.dt = dt * b2.ms N_Excit = int(floor(f * n_of_neurons)) N_Inhib = int(floor((1 - f) * n_of_neurons)) J_excit = w0 J_inhib = -g * w0 lif_dynamics = """ dv/dt = ( -(v-v_rest) + membrane_resistance * input_current ) / membrane_time_scale : volt (unless refractory) """ neurons = NeuronGroup(N_Excit + N_Inhib, model=lif_dynamics, threshold="v>firing_threshold", reset="v=v_reset", refractory=abs_refractory_period, method="exact") # "exact"/"linear" if random_vm_init: neurons.v = random.uniform(v_rest / b2.mV, high=firing_threshold / b2.mV, size=(N_Excit + N_Inhib)) * b2.mV else: neurons.v = v_rest excitatory_population = neurons[:N_Excit] inhibitory_population = neurons[N_Excit:] exc_synapses = Synapses(excitatory_population, target=neurons, on_pre="v += J_excit", delay=dt * b2.ms) exc_synapses.connect(p=connection_probability) inhib_synapses = Synapses(inhibitory_population, target=neurons, on_pre="v += J_inhib", delay=dt * b2.ms) inhib_synapses.connect(p=connection_probability) #monitored_subset_size = min(monitored_subset_size, (N_Excit+N_Inhib)) #idx_monitored_neurons = sample(range(N_Excit+N_Inhib), monitored_subset_size) spike_monitor = b2.SpikeMonitor(neurons) #, record=idx_monitored_neurons) network = b2.core.network.Network(b2.core.magic.collect()) #start_wallclock = time.time() #start_cpu = time.clock() # timer() network.run(simulation_time, profile=True) #end_cpu = time.clock() # timer() #end_wallclock = time.time() #time_elapsed_wallclock = end_wallclock - start_wallclock #time_elapsed_cpu = end_cpu - start_cpu b2.device.build(directory='output', clean=True, compile=True, run=True, debug=False) print("\n") print("brian2 profiling summary (listed by time consumption):\n") print(b2.profiling_summary()) return spike_monitor, network.get_profiling_info( ) # time_elapsed_wallclock, time_elapsed_cpu
def simulation( test_mode=True, runname=None, num_epochs=None, progress_interval=None, progress_assignments_window=None, progress_accuracy_window=None, record_spikes=False, monitoring=False, permute_data=False, size=400, resume=False, stdp_rule="original", custom_namespace=None, timer=None, tc_theta=None, total_input_weight=None, use_premade_weights=False, supervised=False, feedback=False, profile=False, clock=None, store=None, **kwargs, ): metadata = get_metadata(store) if not resume: metadata.nseen = 0 metadata.nprogress = 0 if test_mode: random_weights = False use_premade_weights = True ee_STDP_on = False if num_epochs is None: num_epochs = 1 if progress_interval is None: progress_interval = 1000 if progress_assignments_window is None: progress_assignments_window = 0 if progress_accuracy_window is None: progress_accuracy_window = 1000000 else: random_weights = not resume ee_STDP_on = True if num_epochs is None: num_epochs = 3 if progress_interval is None: progress_interval = 1000 if progress_assignments_window is None: progress_assignments_window = 1000 if progress_accuracy_window is None: progress_accuracy_window = 1000 log.info("Brian2STDPMNIST/simulation.py") log.info("Arguments =============") metadata["args"] = record_arguments(currentframe(), locals()) log.info("=======================") # load MNIST training, testing = get_labeled_data() config.classes = np.unique(training["y"]) config.num_classes = len(config.classes) # configuration np.random.seed(0) modulefilename = getframeinfo(currentframe()).filename config.data_path = os.path.dirname(os.path.abspath(modulefilename)) config.random_weight_path = os.path.join(config.data_path, "random/") runpath = os.path.join("runs", runname) config.weight_path = os.path.join(runpath, "weights/") os.makedirs(config.weight_path, exist_ok=True) if test_mode: log.info("Testing run {}".format(runname)) elif resume: log.info("Resuming training run {}".format(runname)) else: log.info("Training run {}".format(runname)) if test_mode: config.output_path = os.path.join(runpath, "output_test/") else: config.output_path = os.path.join(runpath, "output/") os.makedirs(config.output_path, exist_ok=True) if test_mode: data = testing else: data = training if permute_data: sample = np.random.permutation(len(data["y"])) data["x"] = data["x"][sample] data["y"] = data["y"][sample] num_examples = int(len(data["y"]) * num_epochs) n_input = data["x"][0].size n_data = data["y"].size if num_epochs < 1: n_data = int(np.ceil(n_data * num_epochs)) data["x"] = data["x"][:n_data] data["y"] = data["y"][:n_data] # ------------------------------------------------------------------------- # set parameters and equations # ------------------------------------------------------------------------- # log.info('Original defaultclock.dt = {}'.format(str(b2.defaultclock.dt))) if clock is None: clock = 0.5 b2.defaultclock.dt = clock * b2.ms metadata["dt"] = b2.defaultclock.dt log.info("defaultclock.dt = {}".format(str(b2.defaultclock.dt))) n_neurons = { "Ae": size, "Ai": size, "Oe": config.num_classes, "Oi": config.num_classes, "Xe": n_input, "Ye": config.num_classes, } metadata["n_neurons"] = n_neurons single_example_time = 0.35 * b2.second resting_time = 0.15 * b2.second total_example_time = single_example_time + resting_time runtime = num_examples * total_example_time metadata["total_example_time"] = total_example_time input_population_names = ["X"] population_names = ["A"] connection_names = ["XA"] config.save_conns = ["XeAe"] config.plot_conns = ["XeAe"] forward_conntype_names = ["ee"] recurrent_conntype_names = ["ei_rec", "ie_rec"] stdp_conn_names = ["XeAe"] # TODO: add --dc15 option total_weight = {} if total_input_weight is None: total_weight[ "XeAe"] = n_neurons["Xe"] / 10.0 # standard dc15 value was 78.0 else: total_weight["XeAe"] = total_input_weight theta_init = {} if supervised: input_population_names += ["Y"] population_names += ["O"] connection_names += ["YO", "AO"] config.save_conns += ["YeOe", "AeOe"] config.plot_conns += ["AeOe"] stdp_conn_names += ["AeOe"] total_weight["AeOe"] = n_neurons["Ae"] / 5.0 # TODO: refine? theta_init["O"] = 15.0 * b2.mV if feedback: connection_names += ["OA"] config.save_conns += ["OeAe"] config.plot_conns += ["OeAe"] stdp_conn_names += ["OeAe"] total_weight["OeAe"] = n_neurons["Oe"] / 5.0 # TODO: refine? delay = {} # TODO: potentially specify by connName? delay["ee"] = (0 * b2.ms, 10 * b2.ms) delay["ei"] = (0 * b2.ms, 5 * b2.ms) delay["ei_rec"] = (0 * b2.ms, 0 * b2.ms) delay["ie_rec"] = (0 * b2.ms, 0 * b2.ms) input_intensity = 2.0 if test_mode: input_label_intensity = 0.0 else: input_label_intensity = 10.0 initial_weight_matrices = get_initial_weights(n_neurons) # TODO: put all configuration/setup variables in config object # and save to the store for future reference # metadata["config"] = config neuron_groups = {} connections = {} spike_monitors = {} state_monitors = {} network_operations = [] # ------------------------------------------------------------------------- # create network population and recurrent connections # ------------------------------------------------------------------------- for subgroup_n, name in enumerate(population_names): log.info(f"Creating neuron group {name}") subpop_e = name + "e" subpop_i = name + "i" const_theta = False neuron_namespace = {} if name == "A" and tc_theta is not None: neuron_namespace["tc_theta"] = tc_theta * b2.ms if name == "O": neuron_namespace["tc_theta"] = 1e6 * b2.ms if test_mode: const_theta = True if name == "O": # TODO: move to a config variable neuron_namespace["tc_theta"] = 1e5 * b2.ms const_theta = False nge = neuron_groups[subpop_e] = DiehlAndCookExcitatoryNeuronGroup( n_neurons[subpop_e], const_theta=const_theta, timer=timer, custom_namespace=neuron_namespace, ) ngi = neuron_groups[subpop_i] = DiehlAndCookInhibitoryNeuronGroup( n_neurons[subpop_i]) if not random_weights: theta_saved = load_theta(name) if len(theta_saved) != n_neurons[subpop_e]: raise ValueError( f"Requested size of neuron population {subpop_e} " f"({n_neurons[subpop_e]}) does not match size of " f"saved data ({len(theta_saved)})") neuron_groups[subpop_e].theta = theta_saved elif name in theta_init: neuron_groups[subpop_e].theta = theta_init[name] for connType in recurrent_conntype_names: log.info(f"Creating recurrent connections for {connType}") preName = name + connType[0] postName = name + connType[1] connName = preName + postName conn = connections[connName] = DiehlAndCookSynapses( neuron_groups[preName], neuron_groups[postName], conn_type=connType) conn.connect() # all-to-all connection minDelay, maxDelay = delay[connType] if maxDelay > 0: deltaDelay = maxDelay - minDelay conn.delay = "minDelay + rand() * deltaDelay" # TODO: the use of connections with fixed zero weights is inefficient # "random" connections for AeAi is matrix with zero everywhere # except the diagonal, which contains 10.4 # "random" connections for AiAe is matrix with 17.0 everywhere # except the diagonal, which contains zero # TODO: these weights appear to have been tuned, # we may need different values for the O layer weightMatrix = None if use_premade_weights: try: weightMatrix = load_connections(connName, random=random_weights) except FileNotFoundError: log.info( f"Requested premade {'random' if random_weights else ''} " f"weights, but none found for {connName}") if weightMatrix is None: log.info("Using generated initial weight matrices") weightMatrix = initial_weight_matrices[connName] conn.w = weightMatrix.flatten() log.debug(f"Creating spike monitors for {name}") spike_monitors[subpop_e] = b2.SpikeMonitor(nge, record=record_spikes) spike_monitors[subpop_i] = b2.SpikeMonitor(ngi, record=record_spikes) if monitoring: log.debug(f"Creating state monitors for {name}") state_monitors[subpop_e] = b2.StateMonitor( nge, variables=True, record=range(0, n_neurons[subpop_e], 10), dt=0.5 * b2.ms, ) if test_mode and supervised: # make output neurons more sensitive neuron_groups["Oe"].theta = 5.0 * b2.mV # TODO: refine # ------------------------------------------------------------------------- # create TimedArray of rates for input examples # ------------------------------------------------------------------------- input_dt = 50 * b2.ms n_dt_example = int(round(single_example_time / input_dt)) n_dt_rest = int(round(resting_time / input_dt)) n_dt_total = int(n_dt_example + n_dt_rest) input_rates = np.zeros((n_data * n_dt_total, n_neurons["Xe"]), dtype=np.float16) log.info("Preparing input rate stream {}".format(input_rates.shape)) for j in range(n_data): spike_rates = data["x"][j].reshape(n_neurons["Xe"]) / 8 spike_rates *= input_intensity start = j * n_dt_total input_rates[start:start + n_dt_example] = spike_rates input_rates = input_rates * b2.Hz stimulus_X = b2.TimedArray(input_rates, dt=input_dt) total_data_time = n_data * n_dt_total * input_dt # ------------------------------------------------------------------------- # create TimedArray of rates for input labels # ------------------------------------------------------------------------- if "Y" in input_population_names: input_label_rates = np.zeros((n_data * n_dt_total, n_neurons["Ye"]), dtype=np.float16) log.info("Preparing input label rate stream {}".format( input_label_rates.shape)) if not test_mode: label_spike_rates = to_categorical(data["y"], dtype=np.float16) else: label_spike_rates = np.ones(n_data) label_spike_rates *= input_label_intensity for j in range(n_data): start = j * n_dt_total input_label_rates[start:start + n_dt_example] = label_spike_rates[j] input_label_rates = input_label_rates * b2.Hz stimulus_Y = b2.TimedArray(input_label_rates, dt=input_dt) # ------------------------------------------------------------------------- # create input population and connections from input populations # ------------------------------------------------------------------------- for k, name in enumerate(input_population_names): subpop_e = name + "e" # stimulus is repeated for duration of simulation # (i.e. if there are multiple epochs) neuron_groups[subpop_e] = b2.PoissonGroup( n_neurons[subpop_e], rates=f"stimulus_{name}(t % total_data_time, i)") log.debug(f"Creating spike monitors for {name}") spike_monitors[subpop_e] = b2.SpikeMonitor(neuron_groups[subpop_e], record=record_spikes) for name in connection_names: log.info(f"Creating connections between {name[0]} and {name[1]}") for connType in forward_conntype_names: log.debug(f"connType {connType}") preName = name[0] + connType[0] postName = name[1] + connType[1] connName = preName + postName stdp_on = ee_STDP_on and connName in stdp_conn_names nu_factor = 10.0 if name in ["AO"] else None conn = connections[connName] = DiehlAndCookSynapses( neuron_groups[preName], neuron_groups[postName], conn_type=connType, stdp_on=stdp_on, stdp_rule=stdp_rule, custom_namespace=custom_namespace, nu_factor=nu_factor, ) conn.connect() # all-to-all connection minDelay, maxDelay = delay[connType] if maxDelay > 0: deltaDelay = maxDelay - minDelay conn.delay = "minDelay + rand() * deltaDelay" weightMatrix = None if use_premade_weights: try: weightMatrix = load_connections(connName, random=random_weights) except FileNotFoundError: log.info( f"Requested premade {'random' if random_weights else ''} " f"weights, but none found for {connName}") if weightMatrix is None: log.info("Using generated initial weight matrices") weightMatrix = initial_weight_matrices[connName] conn.w = weightMatrix.flatten() if monitoring: log.debug(f"Creating state monitors for {connName}") state_monitors[connName] = b2.StateMonitor( conn, variables=True, record=range(0, n_neurons[preName] * n_neurons[postName], 1000), dt=5 * b2.ms, ) if ee_STDP_on: @b2.network_operation(dt=total_example_time, order=1) def normalize_weights(t): for connName in connections: if connName in stdp_conn_names: # log.debug( # "Normalizing weights for {} " "at time {}".format(connName, t) # ) conn = connections[connName] connweights = np.reshape( conn.w, (len(conn.source), len(conn.target))) colSums = connweights.sum(axis=0) ok = colSums > 0 colFactors = np.ones_like(colSums) colFactors[ok] = total_weight[connName] / colSums[ok] connweights *= colFactors conn.w = connweights.flatten() network_operations.append(normalize_weights) def record_cumulative_spike_counts(t=None): if t is None or t > 0: metadata.nseen += 1 for name in population_names + input_population_names: subpop_e = name + "e" count = pd.DataFrame(spike_monitors[subpop_e].count[:][None, :], index=[metadata.nseen]) count = count.rename_axis("tbin") count = count.rename_axis("neuron", axis="columns") store.append(f"cumulative_spike_counts/{subpop_e}", count) @b2.network_operation(dt=total_example_time, order=0) def record_cumulative_spike_counts_net_op(t): record_cumulative_spike_counts(t) network_operations.append(record_cumulative_spike_counts_net_op) def progress(): log.debug("Starting progress") starttime = time.process_time() labels = get_labels(data) log.info("So far seen {} examples".format(metadata.nseen)) store.append( f"nseen", pd.Series(data=[metadata.nseen], index=[metadata.nprogress])) metadata.nprogress += 1 assignments_window, accuracy_window = get_windows( metadata.nseen, progress_assignments_window, progress_accuracy_window) for name in population_names + input_population_names: log.debug(f"Progress for population {name}") subpop_e = name + "e" csc = store.select(f"cumulative_spike_counts/{subpop_e}") spikecounts_present = spike_counts_from_cumulative( csc, n_data, metadata.nseen, n_neurons[subpop_e], start=-accuracy_window) n_spikes_present = spikecounts_present["count"].sum() if n_spikes_present > 0: spikerates = ( spikecounts_present.groupby("i")["count"].mean().astype( np.float32)) # this reindex no longer necessary? spikerates = spikerates.reindex(np.arange(n_neurons[subpop_e]), fill_value=0) spikerates = add_nseen_index(spikerates, metadata.nseen) store.append(f"rates/{subpop_e}", spikerates) store.flush() fn = os.path.join(config.output_path, "spikerates-summary-{}.pdf".format(subpop_e)) plot_rates_summary(store.select(f"rates/{subpop_e}"), filename=fn, label=subpop_e) if name in population_names: if not test_mode: spikecounts_past = spike_counts_from_cumulative( csc, n_data, metadata.nseen, n_neurons[subpop_e], end=-accuracy_window, atmost=assignments_window, ) n_spikes_past = spikecounts_past["count"].sum() log.debug( "Assignments based on {} spikes".format(n_spikes_past)) if name == "O": assignments = pd.DataFrame({ "label": np.arange(n_neurons[subpop_e], dtype=np.int32) }) else: assignments = get_assignments(spikecounts_past, labels) assignments = add_nseen_index(assignments, metadata.nseen) store.append(f"assignments/{subpop_e}", assignments) else: assignments = store.select(f"assignments/{subpop_e}") if n_spikes_present == 0: log.debug( "No spikes in present interval - skipping accuracy estimate" ) else: log.debug( "Accuracy based on {} spikes".format(n_spikes_present)) predictions = get_predictions(spikecounts_present, assignments, labels) accuracy = get_accuracy(predictions, metadata.nseen) store.append(f"accuracy/{subpop_e}", accuracy) store.flush() accuracy_msg = ( "Accuracy [{}]: {:.1f}% ({:.1f}–{:.1f}% 1σ conf. int.)\n" "{:.1f}% of examples have no prediction\n" "Accuracy excluding non-predictions: " "{:.1f}% ({:.1f}–{:.1f}% 1σ conf. int.)") log.info( accuracy_msg.format(subpop_e, *accuracy.values.flat)) fn = os.path.join(config.output_path, "accuracy-{}.pdf".format(subpop_e)) plot_accuracy(store.select(f"accuracy/{subpop_e}"), filename=fn) fn = os.path.join(config.output_path, "spikerates-{}.pdf".format(subpop_e)) plot_quantity( spikerates, filename=fn, label=f"spike rate {subpop_e}", nseen=metadata.nseen, ) theta = theta_to_pandas(subpop_e, neuron_groups, metadata.nseen) store.append(f"theta/{subpop_e}", theta) fn = os.path.join(config.output_path, "theta-{}.pdf".format(subpop_e)) plot_quantity( theta, filename=fn, label=f"theta {subpop_e} (mV)", nseen=metadata.nseen, ) fn = os.path.join(config.output_path, "theta-summary-{}.pdf".format(subpop_e)) plot_theta_summary(store.select(f"theta/{subpop_e}"), filename=fn, label=subpop_e) if not test_mode or metadata.nseen == 0: for conn in config.save_conns: log.info(f"Saving connection {conn}") conn_df = connections_to_pandas(connections[conn], metadata.nseen) store.append(f"connections/{conn}", conn_df) for conn in config.plot_conns: log.info(f"Plotting connection {conn}") subpop = conn[-2:] if "O" in conn: assignments = None else: try: assignments = store.select( f"assignments/{subpop}", where="nseen == metadata.nseen") assignments = assignments.reset_index("nseen", drop=True) except KeyError: assignments = None fn = os.path.join(config.output_path, "weights-{}.pdf".format(conn)) plot_weights( connections[conn], assignments, theta=None, filename=fn, max_weight=None, nseen=metadata.nseen, output=("O" in conn), feedback=("O" in conn[:2]), label=conn, ) if monitoring: for km, vm in spike_monitors.items(): states = vm.get_states() with open( os.path.join(config.output_path, f"saved-spikemonitor-{km}.pickle"), "wb", ) as f: pickle.dump(states, f) for km, vm in state_monitors.items(): states = vm.get_states() with open( os.path.join(config.output_path, f"saved-statemonitor-{km}.pickle"), "wb", ) as f: pickle.dump(states, f) log.debug("progress took {:.3f} seconds".format(time.process_time() - starttime)) if progress_interval > 0: @b2.network_operation(dt=total_example_time * progress_interval, order=2) def progress_net_op(t): # if t < total_example_time: # return None progress() network_operations.append(progress_net_op) # ------------------------------------------------------------------------- # run the simulation and set inputs # ------------------------------------------------------------------------- log.info("Constructing the network") net = b2.Network() for obj_list in [ neuron_groups, connections, spike_monitors, state_monitors ]: for key in obj_list: net.add(obj_list[key]) for obj in network_operations: net.add(obj) log.info("Starting simulations") net.run(runtime, report="text", report_period=(60 * b2.second), profile=profile) b2.device.build(directory=os.path.join("build", runname), compile=True, run=True, debug=False) if profile: log.debug(b2.profiling_summary(net, 10)) # ------------------------------------------------------------------------- # save results # ------------------------------------------------------------------------- log.info("Saving results") progress() if not test_mode: record_cumulative_spike_counts() save_theta(population_names, neuron_groups) save_connections(connections)
def run_net(tr): # prefs.codegen.target = 'numpy' # prefs.codegen.target = 'cython' if tr.n_threads > 1: prefs.devices.cpp_standalone.openmp_threads = tr.n_threads set_device('cpp_standalone', directory='./builds/%.4d' % (tr.v_idx), build_on_run=False) # set brian 2 and numpy random seeds seed(tr.random_seed) np.random.seed(tr.random_seed + 11) print("Started process with id ", str(tr.v_idx)) T = tr.T1 + tr.T2 + tr.T3 + tr.T4 + tr.T5 namespace = tr.netw.f_to_dict(short_names=True, fast_access=True) namespace['idx'] = tr.v_idx defaultclock.dt = tr.netw.sim.dt # collect all network components dependent on configuration # (e.g. poisson vs. memnoise) and add them to the Brian 2 # network object later netw_objects = [] if tr.external_mode == 'memnoise': neuron_model = tr.condlif_memnoise elif tr.external_mode == 'poisson': raise NotImplementedError #neuron_model = tr.condlif_poisson if tr.syn_cond_mode == 'exp': neuron_model += tr.syn_cond_EE_exp print("Using EE exp mode") elif tr.syn_cond_mode == 'alpha': neuron_model += tr.syn_cond_EE_alpha print("Using EE alpha mode") elif tr.syn_cond_mode == 'biexp': neuron_model += tr.syn_cond_EE_biexp namespace['invpeakEE'] = (tr.tau_e / tr.tau_e_rise) ** \ (tr.tau_e_rise / (tr.tau_e - tr.tau_e_rise)) print("Using EE biexp mode") if tr.syn_cond_mode_EI == 'exp': neuron_model += tr.syn_cond_EI_exp print("Using EI exp mode") elif tr.syn_cond_mode_EI == 'alpha': neuron_model += tr.syn_cond_EI_alpha print("Using EI alpha mode") elif tr.syn_cond_mode_EI == 'biexp': neuron_model += tr.syn_cond_EI_biexp namespace['invpeakEI'] = (tr.tau_i / tr.tau_i_rise) ** \ (tr.tau_i_rise / (tr.tau_i - tr.tau_i_rise)) print("Using EI biexp mode") GExc = NeuronGroup( N=tr.N_e, model=neuron_model, threshold=tr.nrnEE_thrshld, reset=tr.nrnEE_reset, #method=tr.neuron_method, name='GExc', namespace=namespace) GInh = NeuronGroup( N=tr.N_i, model=neuron_model, threshold='V > Vt', reset='V=Vr_i', #method=tr.neuron_method, name='GInh', namespace=namespace) if tr.external_mode == 'memnoise': # GExc.mu, GInh.mu = [0.*mV] + (tr.N_e-1)*[tr.mu_e], tr.mu_i # GExc.sigma, GInh.sigma = [0.*mV] + (tr.N_e-1)*[tr.sigma_e], tr.sigma_i GExc.mu, GInh.mu = tr.mu_e, tr.mu_i GExc.sigma, GInh.sigma = tr.sigma_e, tr.sigma_i GExc.Vt, GInh.Vt = tr.Vt_e, tr.Vt_i GExc.V , GInh.V = np.random.uniform(tr.Vr_e/mV, tr.Vt_e/mV, size=tr.N_e)*mV, \ np.random.uniform(tr.Vr_i/mV, tr.Vt_i/mV, size=tr.N_i)*mV netw_objects.extend([GExc, GInh]) if tr.external_mode == 'poisson': if tr.PInp_mode == 'pool': PInp = PoissonGroup(tr.NPInp, rates=tr.PInp_rate, namespace=namespace, name='poissongroup_exc') sPN = Synapses(target=GExc, source=PInp, model=tr.poisson_mod, on_pre='gfwd_post += a_EPoi', namespace=namespace, name='synPInpExc') sPN_src, sPN_tar = generate_N_connections(N_tar=tr.N_e, N_src=tr.NPInp, N=tr.NPInp_1n) elif tr.PInp_mode == 'indep': PInp = PoissonGroup(tr.N_e, rates=tr.PInp_rate, namespace=namespace) sPN = Synapses(target=GExc, source=PInp, model=tr.poisson_mod, on_pre='gfwd_post += a_EPoi', namespace=namespace, name='synPInp_inhInh') sPN_src, sPN_tar = range(tr.N_e), range(tr.N_e) sPN.connect(i=sPN_src, j=sPN_tar) if tr.PInp_mode == 'pool': PInp_inh = PoissonGroup(tr.NPInp_inh, rates=tr.PInp_inh_rate, namespace=namespace, name='poissongroup_inh') sPNInh = Synapses(target=GInh, source=PInp_inh, model=tr.poisson_mod, on_pre='gfwd_post += a_EPoi', namespace=namespace) sPNInh_src, sPNInh_tar = generate_N_connections(N_tar=tr.N_i, N_src=tr.NPInp_inh, N=tr.NPInp_inh_1n) elif tr.PInp_mode == 'indep': PInp_inh = PoissonGroup(tr.N_i, rates=tr.PInp_inh_rate, namespace=namespace) sPNInh = Synapses(target=GInh, source=PInp_inh, model=tr.poisson_mod, on_pre='gfwd_post += a_EPoi', namespace=namespace) sPNInh_src, sPNInh_tar = range(tr.N_i), range(tr.N_i) sPNInh.connect(i=sPNInh_src, j=sPNInh_tar) netw_objects.extend([PInp, sPN, PInp_inh, sPNInh]) if tr.syn_noise: if tr.syn_noise_type == 'additive': synEE_mod = '''%s %s''' % (tr.synEE_noise_add, tr.synEE_mod) synEI_mod = '''%s %s''' % (tr.synEE_noise_add, tr.synEE_mod) elif tr.syn_noise_type == 'multiplicative': synEE_mod = '''%s %s''' % (tr.synEE_noise_mult, tr.synEE_mod) synEI_mod = '''%s %s''' % (tr.synEE_noise_mult, tr.synEE_mod) else: synEE_mod = '''%s %s''' % (tr.synEE_static, tr.synEE_mod) synEI_mod = '''%s %s''' % (tr.synEE_static, tr.synEE_mod) if tr.scl_active: synEE_mod = '''%s %s''' % (synEE_mod, tr.synEE_scl_mod) synEI_mod = '''%s %s''' % (synEI_mod, tr.synEI_scl_mod) if tr.syn_cond_mode == 'exp': synEE_pre_mod = mod.synEE_pre_exp elif tr.syn_cond_mode == 'alpha': synEE_pre_mod = mod.synEE_pre_alpha elif tr.syn_cond_mode == 'biexp': synEE_pre_mod = mod.synEE_pre_biexp synEE_post_mod = mod.syn_post if tr.stdp_active: synEE_pre_mod = '''%s %s''' % (synEE_pre_mod, mod.syn_pre_STDP) synEE_post_mod = '''%s %s''' % (synEE_post_mod, mod.syn_post_STDP) if tr.synEE_rec: synEE_pre_mod = '''%s %s''' % (synEE_pre_mod, mod.synEE_pre_rec) synEE_post_mod = '''%s %s''' % (synEE_post_mod, mod.synEE_post_rec) # E<-E advanced synapse model SynEE = Synapses(target=GExc, source=GExc, model=synEE_mod, on_pre=synEE_pre_mod, on_post=synEE_post_mod, namespace=namespace, dt=tr.synEE_mod_dt) if tr.istdp_active and tr.istdp_type == 'dbexp': if tr.syn_cond_mode_EI == 'exp': EI_pre_mod = mod.synEI_pre_exp elif tr.syn_cond_mode_EI == 'alpha': EI_pre_mod = mod.synEI_pre_alpha elif tr.syn_cond_mode_EI == 'biexp': EI_pre_mod = mod.synEI_pre_biexp synEI_pre_mod = '''%s %s''' % (EI_pre_mod, mod.syn_pre_STDP) synEI_post_mod = '''%s %s''' % (mod.syn_post, mod.syn_post_STDP) elif tr.istdp_active and tr.istdp_type == 'sym': if tr.syn_cond_mode_EI == 'exp': EI_pre_mod = mod.synEI_pre_sym_exp elif tr.syn_cond_mode_EI == 'alpha': EI_pre_mod = mod.synEI_pre_sym_alpha elif tr.syn_cond_mode_EI == 'biexp': EI_pre_mod = mod.synEI_pre_sym_biexp synEI_pre_mod = '''%s %s''' % (EI_pre_mod, mod.syn_pre_STDP) synEI_post_mod = '''%s %s''' % (mod.synEI_post_sym, mod.syn_post_STDP) if tr.istdp_active and tr.synEI_rec: synEI_pre_mod = '''%s %s''' % (synEI_pre_mod, mod.synEI_pre_rec) synEI_post_mod = '''%s %s''' % (synEI_post_mod, mod.synEI_post_rec) if tr.istdp_active: SynEI = Synapses(target=GExc, source=GInh, model=synEI_mod, on_pre=synEI_pre_mod, on_post=synEI_post_mod, namespace=namespace, dt=tr.synEE_mod_dt) else: model = '''a : 1 syn_active : 1''' SynEI = Synapses(target=GExc, source=GInh, model=model, on_pre='gi_post += a', namespace=namespace) #other simple SynIE = Synapses(target=GInh, source=GExc, on_pre='ge_post += a_ie', namespace=namespace) SynII = Synapses(target=GInh, source=GInh, on_pre='gi_post += a_ii', namespace=namespace) sEE_src, sEE_tar = generate_full_connectivity(tr.N_e, same=True) SynEE.connect(i=sEE_src, j=sEE_tar) SynEE.syn_active = 0 SynEE.taupre, SynEE.taupost = tr.taupre, tr.taupost if tr.istdp_active and tr.istrct_active: print('istrct active') sEI_src, sEI_tar = generate_full_connectivity(Nsrc=tr.N_i, Ntar=tr.N_e, same=False) SynEI.connect(i=sEI_src, j=sEI_tar) SynEI.syn_active = 0 else: print('istrct not active') if tr.weight_mode == 'init': sEI_src, sEI_tar = generate_connections(tr.N_e, tr.N_i, tr.p_ei) # print('Index Zero will not get inhibition') # sEI_src, sEI_tar = np.array(sEI_src), np.array(sEI_tar) # sEI_src, sEI_tar = sEI_src[sEI_tar > 0],sEI_tar[sEI_tar > 0] elif tr.weight_mode == 'load': fpath = os.path.join(tr.basepath, tr.weight_path) with open(fpath + 'synei_a.p', 'rb') as pfile: synei_a_init = pickle.load(pfile) sEI_src, sEI_tar = synei_a_init['i'], synei_a_init['j'] SynEI.connect(i=sEI_src, j=sEI_tar) if tr.istdp_active: SynEI.taupre, SynEI.taupost = tr.taupre_EI, tr.taupost_EI sIE_src, sIE_tar = generate_connections(tr.N_i, tr.N_e, tr.p_ie) sII_src, sII_tar = generate_connections(tr.N_i, tr.N_i, tr.p_ii, same=True) SynIE.connect(i=sIE_src, j=sIE_tar) SynII.connect(i=sII_src, j=sII_tar) tr.f_add_result('sEE_src', sEE_src) tr.f_add_result('sEE_tar', sEE_tar) tr.f_add_result('sIE_src', sIE_src) tr.f_add_result('sIE_tar', sIE_tar) tr.f_add_result('sEI_src', sEI_src) tr.f_add_result('sEI_tar', sEI_tar) tr.f_add_result('sII_src', sII_src) tr.f_add_result('sII_tar', sII_tar) if tr.syn_noise: SynEE.syn_sigma = tr.syn_sigma SynEE.run_regularly('a = clip(a,0,amax)', when='after_groups', name='SynEE_noise_clipper') if tr.syn_noise and tr.istdp_active: SynEI.syn_sigma = tr.syn_sigma SynEI.run_regularly('a = clip(a,0,amax)', when='after_groups', name='SynEI_noise_clipper') SynEE.insert_P = tr.insert_P SynEE.p_inactivate = tr.p_inactivate SynEE.stdp_active = 1 print('Setting maximum EE weight threshold to ', tr.amax) SynEE.amax = tr.amax if tr.istdp_active: SynEI.insert_P = tr.insert_P_ei SynEI.p_inactivate = tr.p_inactivate_ei SynEI.stdp_active = 1 SynEI.amax = tr.amax SynEE.syn_active, SynEE.a = init_synapses('EE', tr) SynEI.syn_active, SynEI.a = init_synapses('EI', tr) # recording of stdp in T4 SynEE.stdp_rec_start = tr.T1 + tr.T2 + tr.T3 SynEE.stdp_rec_max = tr.T1 + tr.T2 + tr.T3 + tr.stdp_rec_T if tr.istdp_active: SynEI.stdp_rec_start = tr.T1 + tr.T2 + tr.T3 SynEI.stdp_rec_max = tr.T1 + tr.T2 + tr.T3 + tr.stdp_rec_T # synaptic scaling if tr.netw.config.scl_active: if tr.syn_scl_rec: SynEE.scl_rec_start = tr.T1 + tr.T2 + tr.T3 SynEE.scl_rec_max = tr.T1 + tr.T2 + tr.T3 + tr.scl_rec_T else: SynEE.scl_rec_start = T + 10 * second SynEE.scl_rec_max = T if tr.sig_ATotalMax == 0.: GExc.ANormTar = tr.ATotalMax else: GExc.ANormTar = np.random.normal(loc=tr.ATotalMax, scale=tr.sig_ATotalMax, size=tr.N_e) SynEE.summed_updaters['AsumEE_post']._clock = Clock( dt=tr.dt_synEE_scaling) synee_scaling = SynEE.run_regularly(tr.synEE_scaling, dt=tr.dt_synEE_scaling, when='end', name='synEE_scaling') if tr.istdp_active and tr.netw.config.iscl_active: if tr.syn_iscl_rec: SynEI.scl_rec_start = tr.T1 + tr.T2 + tr.T3 SynEI.scl_rec_max = tr.T1 + tr.T2 + tr.T3 + tr.scl_rec_T else: SynEI.scl_rec_start = T + 10 * second SynEI.scl_rec_max = T if tr.sig_iATotalMax == 0.: GExc.iANormTar = tr.iATotalMax else: GExc.iANormTar = np.random.normal(loc=tr.iATotalMax, scale=tr.sig_iATotalMax, size=tr.N_e) SynEI.summed_updaters['AsumEI_post']._clock = Clock( dt=tr.dt_synEE_scaling) synei_scaling = SynEI.run_regularly(tr.synEI_scaling, dt=tr.dt_synEE_scaling, when='end', name='synEI_scaling') # # intrinsic plasticity # if tr.netw.config.it_active: # GExc.h_ip = tr.h_ip # GExc.run_regularly(tr.intrinsic_mod, dt = tr.it_dt, when='end') # structural plasticity if tr.netw.config.strct_active: if tr.strct_mode == 'zero': if tr.turnover_rec: strct_mod = '''%s %s''' % (tr.strct_mod, tr.turnover_rec_mod) else: strct_mod = tr.strct_mod strctplst = SynEE.run_regularly(strct_mod, dt=tr.strct_dt, when='end', name='strct_plst_zero') elif tr.strct_mode == 'thrs': if tr.turnover_rec: strct_mod_thrs = '''%s %s''' % (tr.strct_mod_thrs, tr.turnover_rec_mod) else: strct_mod_thrs = tr.strct_mod_thrs strctplst = SynEE.run_regularly(strct_mod_thrs, dt=tr.strct_dt, when='end', name='strct_plst_thrs') if tr.istdp_active and tr.netw.config.istrct_active: if tr.strct_mode == 'zero': if tr.turnover_rec: strct_mod_EI = '''%s %s''' % (tr.strct_mod, tr.turnoverEI_rec_mod) else: strct_mod_EI = tr.strct_mod strctplst_EI = SynEI.run_regularly(strct_mod_EI, dt=tr.strct_dt, when='end', name='strct_plst_EI') elif tr.strct_mode == 'thrs': raise NotImplementedError netw_objects.extend([SynEE, SynEI, SynIE, SynII]) # keep track of the number of active synapses sum_target = NeuronGroup(1, 'c : 1 (shared)', dt=tr.csample_dt) sum_model = '''NSyn : 1 (constant) c_post = (1.0*syn_active_pre)/NSyn : 1 (summed)''' sum_connection = Synapses(target=sum_target, source=SynEE, model=sum_model, dt=tr.csample_dt, name='get_active_synapse_count') sum_connection.connect() sum_connection.NSyn = tr.N_e * (tr.N_e - 1) if tr.adjust_insertP: # homeostatically adjust growth rate growth_updater = Synapses(sum_target, SynEE) growth_updater.run_regularly('insert_P_post *= 0.1/c_pre', when='after_groups', dt=tr.csample_dt, name='update_insP') growth_updater.connect(j='0') netw_objects.extend([sum_target, sum_connection, growth_updater]) if tr.istdp_active and tr.istrct_active: # keep track of the number of active synapses sum_target_EI = NeuronGroup(1, 'c : 1 (shared)', dt=tr.csample_dt) sum_model_EI = '''NSyn : 1 (constant) c_post = (1.0*syn_active_pre)/NSyn : 1 (summed)''' sum_connection_EI = Synapses(target=sum_target_EI, source=SynEI, model=sum_model_EI, dt=tr.csample_dt, name='get_active_synapse_count_EI') sum_connection_EI.connect() sum_connection_EI.NSyn = tr.N_e * tr.N_i if tr.adjust_EI_insertP: # homeostatically adjust growth rate growth_updater_EI = Synapses(sum_target_EI, SynEI) growth_updater_EI.run_regularly('insert_P_post *= 0.1/c_pre', when='after_groups', dt=tr.csample_dt, name='update_insP_EI') growth_updater_EI.connect(j='0') netw_objects.extend( [sum_target_EI, sum_connection_EI, growth_updater_EI]) # -------------- recording ------------------ GExc_recvars = [] if tr.memtraces_rec: GExc_recvars.append('V') if tr.vttraces_rec: GExc_recvars.append('Vt') if tr.getraces_rec: GExc_recvars.append('ge') if tr.gitraces_rec: GExc_recvars.append('gi') if tr.gfwdtraces_rec and tr.external_mode == 'poisson': GExc_recvars.append('gfwd') GInh_recvars = GExc_recvars GExc_stat = StateMonitor(GExc, GExc_recvars, record=list(range(tr.nrec_GExc_stat)), dt=tr.GExc_stat_dt) GInh_stat = StateMonitor(GInh, GInh_recvars, record=list(range(tr.nrec_GInh_stat)), dt=tr.GInh_stat_dt) # SynEE stat SynEE_recvars = [] if tr.synee_atraces_rec: SynEE_recvars.append('a') if tr.synee_activetraces_rec: SynEE_recvars.append('syn_active') if tr.synee_Apretraces_rec: SynEE_recvars.append('Apre') if tr.synee_Aposttraces_rec: SynEE_recvars.append('Apost') SynEE_stat = StateMonitor(SynEE, SynEE_recvars, record=range(tr.n_synee_traces_rec), when='end', dt=tr.synEE_stat_dt) if tr.istdp_active: # SynEI stat SynEI_recvars = [] if tr.synei_atraces_rec: SynEI_recvars.append('a') if tr.synei_activetraces_rec: SynEI_recvars.append('syn_active') if tr.synei_Apretraces_rec: SynEI_recvars.append('Apre') if tr.synei_Aposttraces_rec: SynEI_recvars.append('Apost') SynEI_stat = StateMonitor(SynEI, SynEI_recvars, record=range(tr.n_synei_traces_rec), when='end', dt=tr.synEI_stat_dt) netw_objects.append(SynEI_stat) if tr.adjust_insertP: C_stat = StateMonitor(sum_target, 'c', dt=tr.csample_dt, record=[0], when='end') insP_stat = StateMonitor(SynEE, 'insert_P', dt=tr.csample_dt, record=[0], when='end') netw_objects.extend([C_stat, insP_stat]) if tr.istdp_active and tr.adjust_EI_insertP: C_EI_stat = StateMonitor(sum_target_EI, 'c', dt=tr.csample_dt, record=[0], when='end') insP_EI_stat = StateMonitor(SynEI, 'insert_P', dt=tr.csample_dt, record=[0], when='end') netw_objects.extend([C_EI_stat, insP_EI_stat]) GExc_spks = SpikeMonitor(GExc) GInh_spks = SpikeMonitor(GInh) GExc_rate = PopulationRateMonitor(GExc) GInh_rate = PopulationRateMonitor(GInh) if tr.external_mode == 'poisson': PInp_spks = SpikeMonitor(PInp) PInp_rate = PopulationRateMonitor(PInp) netw_objects.extend([PInp_spks, PInp_rate]) if tr.synee_a_nrecpoints == 0 or tr.sim.T2 == 0 * second: SynEE_a_dt = 2 * (tr.T1 + tr.T2 + tr.T3 + tr.T4 + tr.T5) else: SynEE_a_dt = tr.sim.T2 / tr.synee_a_nrecpoints # make sure that choice of SynEE_a_dt does lead # to execessively many recordings - this can # happen if t1 >> t2. estm_nrecs = int(T / SynEE_a_dt) if estm_nrecs > 3 * tr.synee_a_nrecpoints: print('''Estimated number of EE weight recordings (%d) exceeds desired number (%d), increasing SynEE_a_dt''' % (estm_nrecs, tr.synee_a_nrecpoints)) SynEE_a_dt = T / tr.synee_a_nrecpoints SynEE_a = StateMonitor(SynEE, ['a', 'syn_active'], record=range(tr.N_e * (tr.N_e - 1)), dt=SynEE_a_dt, when='end', order=100) if tr.istrct_active: record_range = range(tr.N_e * tr.N_i) else: record_range = range(len(sEI_src)) if tr.synei_a_nrecpoints > 0 and tr.sim.T2 > 0 * second: SynEI_a_dt = tr.sim.T2 / tr.synei_a_nrecpoints estm_nrecs = int(T / SynEI_a_dt) if estm_nrecs > 3 * tr.synei_a_nrecpoints: print('''Estimated number of EI weight recordings (%d) exceeds desired number (%d), increasing SynEI_a_dt''' % (estm_nrecs, tr.synei_a_nrecpoints)) SynEI_a_dt = T / tr.synei_a_nrecpoints SynEI_a = StateMonitor(SynEI, ['a', 'syn_active'], record=record_range, dt=SynEI_a_dt, when='end', order=100) netw_objects.append(SynEI_a) netw_objects.extend([ GExc_stat, GInh_stat, SynEE_stat, SynEE_a, GExc_spks, GInh_spks, GExc_rate, GInh_rate ]) if (tr.synEEdynrec and (2 * tr.syndynrec_npts * tr.syndynrec_dt < tr.sim.T2)): SynEE_dynrec = StateMonitor(SynEE, ['a'], record=range(tr.N_e * (tr.N_e - 1)), dt=tr.syndynrec_dt, name='SynEE_dynrec', when='end', order=100) SynEE_dynrec.active = False netw_objects.extend([SynEE_dynrec]) if (tr.synEIdynrec and (2 * tr.syndynrec_npts * tr.syndynrec_dt < tr.sim.T2)): SynEI_dynrec = StateMonitor(SynEI, ['a'], record=record_range, dt=tr.syndynrec_dt, name='SynEI_dynrec', when='end', order=100) SynEI_dynrec.active = False netw_objects.extend([SynEI_dynrec]) net = Network(*netw_objects) def set_active(*argv): for net_object in argv: net_object.active = True def set_inactive(*argv): for net_object in argv: net_object.active = False ### Simulation periods # --------- T1 --------- # initial recording period, # all recorders active T1T3_recorders = [ GExc_spks, GInh_spks, SynEE_stat, GExc_stat, GInh_stat, GExc_rate, GInh_rate ] if tr.istdp_active: T1T3_recorders.append(SynEI_stat) set_active(*T1T3_recorders) if tr.external_mode == 'poisson': set_active(PInp_spks, PInp_rate) net.run(tr.sim.T1, report='text', report_period=300 * second, profile=True) # --------- T2 --------- # main simulation period # only active recordings are: # 1) turnover 2) C_stat 3) SynEE_a set_inactive(*T1T3_recorders) if tr.T2_spks_rec: set_active(GExc_spks, GInh_spks) if tr.external_mode == 'poisson': set_inactive(PInp_spks, PInp_rate) run_T2_syndynrec(net, tr, netw_objects) # --------- T3 --------- # second recording period, # all recorders active set_active(*T1T3_recorders) if tr.external_mode == 'poisson': set_active(PInp_spks, PInp_rate) run_T3_split(net, tr) # --------- T4 --------- # record STDP and scaling weight changes to file # through the cpp models set_inactive(*T1T3_recorders) if tr.external_mode == 'poisson': set_inactive(PInp_spks, PInp_rate) run_T4(net, tr) # --------- T5 --------- # freeze network and record Exc spikes # for cross correlations if tr.scl_active: synee_scaling.active = False if tr.istdp_active and tr.netw.config.iscl_active: synei_scaling.active = False if tr.strct_active: strctplst.active = False if tr.istdp_active and tr.istrct_active: strctplst_EI.active = False SynEE.stdp_active = 0 if tr.istdp_active: SynEI.stdp_active = 0 set_active(GExc_rate, GInh_rate) set_active(GExc_spks, GInh_spks) run_T5(net, tr) SynEE_a.record_single_timestep() if tr.synei_a_nrecpoints > 0 and tr.sim.T2 > 0. * second: SynEI_a.record_single_timestep() device.build(directory='builds/%.4d' % (tr.v_idx), clean=True, compile=True, run=True, debug=False) # ----------------------------------------- # save monitors as raws in build directory raw_dir = 'builds/%.4d/raw/' % (tr.v_idx) if not os.path.exists(raw_dir): os.makedirs(raw_dir) with open(raw_dir + 'namespace.p', 'wb') as pfile: pickle.dump(namespace, pfile) with open(raw_dir + 'gexc_stat.p', 'wb') as pfile: pickle.dump(GExc_stat.get_states(), pfile) with open(raw_dir + 'ginh_stat.p', 'wb') as pfile: pickle.dump(GInh_stat.get_states(), pfile) with open(raw_dir + 'synee_stat.p', 'wb') as pfile: pickle.dump(SynEE_stat.get_states(), pfile) if tr.istdp_active: with open(raw_dir + 'synei_stat.p', 'wb') as pfile: pickle.dump(SynEI_stat.get_states(), pfile) if ((tr.synEEdynrec or tr.synEIdynrec) and (2 * tr.syndynrec_npts * tr.syndynrec_dt < tr.sim.T2)): if tr.synEEdynrec: with open(raw_dir + 'syneedynrec.p', 'wb') as pfile: pickle.dump(SynEE_dynrec.get_states(), pfile) if tr.synEIdynrec: with open(raw_dir + 'syneidynrec.p', 'wb') as pfile: pickle.dump(SynEI_dynrec.get_states(), pfile) with open(raw_dir + 'synee_a.p', 'wb') as pfile: SynEE_a_states = SynEE_a.get_states() if tr.crs_crrs_rec: SynEE_a_states['i'] = list(SynEE.i) SynEE_a_states['j'] = list(SynEE.j) pickle.dump(SynEE_a_states, pfile) if tr.synei_a_nrecpoints > 0 and tr.sim.T2 > 0. * second: with open(raw_dir + 'synei_a.p', 'wb') as pfile: SynEI_a_states = SynEI_a.get_states() if tr.crs_crrs_rec: SynEI_a_states['i'] = list(SynEI.i) SynEI_a_states['j'] = list(SynEI.j) pickle.dump(SynEI_a_states, pfile) if tr.adjust_insertP: with open(raw_dir + 'c_stat.p', 'wb') as pfile: pickle.dump(C_stat.get_states(), pfile) with open(raw_dir + 'insP_stat.p', 'wb') as pfile: pickle.dump(insP_stat.get_states(), pfile) if tr.istdp_active and tr.adjust_EI_insertP: with open(raw_dir + 'c_EI_stat.p', 'wb') as pfile: pickle.dump(C_EI_stat.get_states(), pfile) with open(raw_dir + 'insP_EI_stat.p', 'wb') as pfile: pickle.dump(insP_EI_stat.get_states(), pfile) with open(raw_dir + 'gexc_spks.p', 'wb') as pfile: pickle.dump(GExc_spks.get_states(), pfile) with open(raw_dir + 'ginh_spks.p', 'wb') as pfile: pickle.dump(GInh_spks.get_states(), pfile) if tr.external_mode == 'poisson': with open(raw_dir + 'pinp_spks.p', 'wb') as pfile: pickle.dump(PInp_spks.get_states(), pfile) with open(raw_dir + 'gexc_rate.p', 'wb') as pfile: pickle.dump(GExc_rate.get_states(), pfile) if tr.rates_rec: pickle.dump(GExc_rate.smooth_rate(width=25 * ms), pfile) with open(raw_dir + 'ginh_rate.p', 'wb') as pfile: pickle.dump(GInh_rate.get_states(), pfile) if tr.rates_rec: pickle.dump(GInh_rate.smooth_rate(width=25 * ms), pfile) if tr.external_mode == 'poisson': with open(raw_dir + 'pinp_rate.p', 'wb') as pfile: pickle.dump(PInp_rate.get_states(), pfile) if tr.rates_rec: pickle.dump(PInp_rate.smooth_rate(width=25 * ms), pfile) # ----------------- add raw data ------------------------ fpath = 'builds/%.4d/' % (tr.v_idx) from pathlib import Path Path(fpath + 'turnover').touch() turnover_data = np.genfromtxt(fpath + 'turnover', delimiter=',') os.remove(fpath + 'turnover') with open(raw_dir + 'turnover.p', 'wb') as pfile: pickle.dump(turnover_data, pfile) Path(fpath + 'turnover_EI').touch() turnover_EI_data = np.genfromtxt(fpath + 'turnover_EI', delimiter=',') os.remove(fpath + 'turnover_EI') with open(raw_dir + 'turnover_EI.p', 'wb') as pfile: pickle.dump(turnover_EI_data, pfile) Path(fpath + 'spk_register').touch() spk_register_data = np.genfromtxt(fpath + 'spk_register', delimiter=',') os.remove(fpath + 'spk_register') with open(raw_dir + 'spk_register.p', 'wb') as pfile: pickle.dump(spk_register_data, pfile) Path(fpath + 'spk_register_EI').touch() spk_register_EI_data = np.genfromtxt(fpath + 'spk_register_EI', delimiter=',') os.remove(fpath + 'spk_register_EI') with open(raw_dir + 'spk_register_EI.p', 'wb') as pfile: pickle.dump(spk_register_EI_data, pfile) Path(fpath + 'scaling_deltas').touch() scaling_deltas_data = np.genfromtxt(fpath + 'scaling_deltas', delimiter=',') os.remove(fpath + 'scaling_deltas') with open(raw_dir + 'scaling_deltas.p', 'wb') as pfile: pickle.dump(scaling_deltas_data, pfile) Path(fpath + 'scaling_deltas_EI').touch() scaling_deltas_data = np.genfromtxt(fpath + 'scaling_deltas_EI', delimiter=',') os.remove(fpath + 'scaling_deltas_EI') with open(raw_dir + 'scaling_deltas_EI.p', 'wb') as pfile: pickle.dump(scaling_deltas_data, pfile) with open(raw_dir + 'profiling_summary.txt', 'w+') as tfile: tfile.write(str(profiling_summary(net))) # --------------- cross-correlations --------------------- if tr.crs_crrs_rec: GExc_spks = GExc_spks.get_states() synee_a = SynEE_a_states wsize = 100 * pq.ms for binsize in [1 * pq.ms, 2 * pq.ms, 5 * pq.ms]: wlen = int(wsize / binsize) ts, idxs = GExc_spks['t'], GExc_spks['i'] idxs = idxs[ts > tr.T1 + tr.T2 + tr.T3 + tr.T4] ts = ts[ts > tr.T1 + tr.T2 + tr.T3 + tr.T4] ts = ts - (tr.T1 + tr.T2 + tr.T3 + tr.T4) sts = [ neo.SpikeTrain(ts[idxs == i] / second * pq.s, t_stop=tr.T5 / second * pq.s) for i in range(tr.N_e) ] crs_crrs, syn_a = [], [] for f, (i, j) in enumerate(zip(synee_a['i'], synee_a['j'])): if synee_a['syn_active'][-1][f] == 1: crs_crr, cbin = cch(BinnedSpikeTrain(sts[i], binsize=binsize), BinnedSpikeTrain(sts[j], binsize=binsize), cross_corr_coef=True, border_correction=True, window=(-1 * wlen, wlen)) crs_crrs.append(list(np.array(crs_crr).T[0])) syn_a.append(synee_a['a'][-1][f]) fname = 'crs_crrs_wsize%dms_binsize%fms_full' % (wsize / pq.ms, binsize / pq.ms) df = { 'cbin': cbin, 'crs_crrs': np.array(crs_crrs), 'syn_a': np.array(syn_a), 'binsize': binsize, 'wsize': wsize, 'wlen': wlen } with open('builds/%.4d/raw/' % (tr.v_idx) + fname + '.p', 'wb') as pfile: pickle.dump(df, pfile) GInh_spks = GInh_spks.get_states() synei_a = SynEI_a_states wsize = 100 * pq.ms for binsize in [1 * pq.ms, 2 * pq.ms, 5 * pq.ms]: wlen = int(wsize / binsize) ts_E, idxs_E = GExc_spks['t'], GExc_spks['i'] idxs_E = idxs_E[ts_E > tr.T1 + tr.T2 + tr.T3 + tr.T4] ts_E = ts_E[ts_E > tr.T1 + tr.T2 + tr.T3 + tr.T4] ts_E = ts_E - (tr.T1 + tr.T2 + tr.T3 + tr.T4) ts_I, idxs_I = GInh_spks['t'], GInh_spks['i'] idxs_I = idxs_I[ts_I > tr.T1 + tr.T2 + tr.T3 + tr.T4] ts_I = ts_I[ts_I > tr.T1 + tr.T2 + tr.T3 + tr.T4] ts_I = ts_I - (tr.T1 + tr.T2 + tr.T3 + tr.T4) sts_E = [ neo.SpikeTrain(ts_E[idxs_E == i] / second * pq.s, t_stop=tr.T5 / second * pq.s) for i in range(tr.N_e) ] sts_I = [ neo.SpikeTrain(ts_I[idxs_I == i] / second * pq.s, t_stop=tr.T5 / second * pq.s) for i in range(tr.N_i) ] crs_crrs, syn_a = [], [] for f, (i, j) in enumerate(zip(synei_a['i'], synei_a['j'])): if synei_a['syn_active'][-1][f] == 1: crs_crr, cbin = cch(BinnedSpikeTrain(sts_I[i], binsize=binsize), BinnedSpikeTrain(sts_E[j], binsize=binsize), cross_corr_coef=True, border_correction=True, window=(-1 * wlen, wlen)) crs_crrs.append(list(np.array(crs_crr).T[0])) syn_a.append(synei_a['a'][-1][f]) fname = 'EI_crrs_wsize%dms_binsize%fms_full' % (wsize / pq.ms, binsize / pq.ms) df = { 'cbin': cbin, 'crs_crrs': np.array(crs_crrs), 'syn_a': np.array(syn_a), 'binsize': binsize, 'wsize': wsize, 'wlen': wlen } with open('builds/%.4d/raw/' % (tr.v_idx) + fname + '.p', 'wb') as pfile: pickle.dump(df, pfile) # ----------------- clean up --------------------------- shutil.rmtree('builds/%.4d/results/' % (tr.v_idx)) shutil.rmtree('builds/%.4d/static_arrays/' % (tr.v_idx)) shutil.rmtree('builds/%.4d/brianlib/' % (tr.v_idx)) shutil.rmtree('builds/%.4d/code_objects/' % (tr.v_idx)) # ---------------- plot results -------------------------- #os.chdir('./analysis/file_based/') if tr.istdp_active: from src.analysis.overview_winh import overview_figure overview_figure('builds/%.4d' % (tr.v_idx), namespace) else: from src.analysis.overview import overview_figure overview_figure('builds/%.4d' % (tr.v_idx), namespace) from src.analysis.synw_fb import synw_figure synw_figure('builds/%.4d' % (tr.v_idx), namespace) if tr.istdp_active: synw_figure('builds/%.4d' % (tr.v_idx), namespace, connections='EI') from src.analysis.synw_log_fb import synw_log_figure synw_log_figure('builds/%.4d' % (tr.v_idx), namespace) if tr.istdp_active: synw_log_figure('builds/%.4d' % (tr.v_idx), namespace, connections='EI')
def run_task_hierarchical(task_info, taskdir, tempdir): # imports from brian2 import defaultclock, set_device, seed, TimedArray, Network, profiling_summary from brian2.monitors import SpikeMonitor, PopulationRateMonitor, StateMonitor from brian2.synapses import Synapses from brian2.core.magic import start_scope from brian2.units import second, ms, amp from integration_circuit import mk_intcircuit from sensory_circuit import mk_sencircuit, mk_sencircuit_2c, mk_sencircuit_2cplastic from burstQuant import spks2neurometric from scipy import interpolate # if you want to put something in the taskdir, you must create it first os.mkdir(taskdir) print(taskdir) # parallel code and flag to start set_device('cpp_standalone', directory=tempdir) #prefs.devices.cpp_standalone.openmp_threads = max_tasks start_scope() # simulation parameters seedcon = task_info['simulation']['seedcon'] runtime = task_info['simulation']['runtime'] runtime_ = runtime / second settletime = task_info['simulation']['settletime'] settletime_ = settletime / second stimon = task_info['simulation']['stimon'] stimoff = task_info['simulation']['stimoff'] stimoff_ = stimoff / second stimdur = stimoff - stimon smoothwin = task_info['simulation']['smoothwin'] nummethod = task_info['simulation']['nummethod'] # ------------------------------------- # Construct hierarchical network # ------------------------------------- # set connection seed seed( seedcon ) # set specific seed to test the same network, this way we also have the same synapses! # decision circuit Dgroups, Dsynapses, Dsubgroups = mk_intcircuit(task_info) decE = Dgroups['DE'] decI = Dgroups['DI'] decE1 = Dsubgroups['DE1'] decE2 = Dsubgroups['DE2'] # sensory circuit, ff and fb connections eps = 0.2 # connection probability d = 1 * ms # transmission delays of E synapses if task_info['simulation']['2cmodel']: if task_info['simulation']['plasticdend']: # plasticity rule in dendrites --> FB synapses will be removed from the network! Sgroups, Ssynapses, Ssubgroups = mk_sencircuit_2cplastic(task_info) else: # 2c model (Naud) Sgroups, Ssynapses, Ssubgroups = mk_sencircuit_2c(task_info) senE = Sgroups['soma'] dend = Sgroups['dend'] senI = Sgroups['SI'] senE1 = Ssubgroups['soma1'] senE2 = Ssubgroups['soma2'] dend1 = Ssubgroups['dend1'] dend2 = Ssubgroups['dend2'] # FB wDS = 0.003 # synaptic weight of FB synapses, 0.0668 nS when scaled by gleakE of sencircuit_2c synDE1SE1 = Synapses(decE1, dend1, model='w : 1', method=nummethod, on_pre='x_ea += w', delay=d) synDE2SE2 = Synapses(decE2, dend2, model='w : 1', method=nummethod, on_pre='x_ea += w', delay=d) else: # normal sensory circuit (Wimmer) Sgroups, Ssynapses, Ssubgroups = mk_sencircuit(task_info) senE = Sgroups['SE'] senI = Sgroups['SI'] senE1 = Ssubgroups['SE1'] senE2 = Ssubgroups['SE2'] # FB wDS = 0.004 # synaptic weight of FB synapses, 0.0668 nS when scaled by gleakE of sencircuit synDE1SE1 = Synapses(decE1, senE1, model='w : 1', method=nummethod, on_pre='x_ea += w', delay=d) synDE2SE2 = Synapses(decE2, senE2, model='w : 1', method=nummethod, on_pre='x_ea += w', delay=d) # feedforward synapses from sensory to integration wSD = 0.0036 # synaptic weight of FF synapses, 0.09 nS when scaled by gleakE of intcircuit synSE1DE1 = Synapses(senE1, decE1, model='w : 1', method=nummethod, on_pre='g_ea += w', delay=d) synSE1DE1.connect(p='eps') synSE1DE1.w = 'wSD' synSE2DE2 = Synapses(senE2, decE2, model='w : 1', method=nummethod, on_pre='g_ea += w', delay=d) synSE2DE2.connect(p='eps') synSE2DE2.w = 'wSD' # feedback synapses from integration to sensory b_fb = task_info['bfb'] # feedback strength, between 0 and 6 wDS *= b_fb # synaptic weight of FB synapses, 0.0668 nS when scaled by gleakE of sencircuit synDE1SE1.connect(p='eps') synDE1SE1.w = 'wDS' synDE2SE2.connect(p='eps') synDE2SE2.w = 'wDS' # ------------------------------------- # Create stimuli # ------------------------------------- if task_info['stimulus']['replicate']: # replicated stimuli across iters() np.random.seed(task_info['seed']) # numpy seed for OU process else: # every trials has different stimuli np.random.seed() # Note that in standalone we need to specify np seed because it's not taken care with Brian's seed() function! if task_info['simulation']['2cmodel']: I0 = task_info['stimulus']['I0s'] last_muOUd = np.loadtxt("last_muOUd.csv") # save the mean else: I0 = task_info['stimulus'][ 'I0'] # mean input current for zero-coherence stim c = task_info['c'] # stim coherence (between 0 and 1) mu1 = task_info['stimulus'][ 'mu1'] # av. additional input current to senE1 at highest coherence (c=1) mu2 = task_info['stimulus'][ 'mu2'] # av. additional input current to senE2 at highest coherence (c=1) sigma = task_info['stimulus'][ 'sigma'] # amplitude of temporal modulations of stim sigmastim = 0.212 * sigma # std of modulation of stim inputs sigmaind = 0.212 * sigma # std of modulations in individual inputs taustim = task_info['stimulus'][ 'taustim'] # correlation time constant of Ornstein-Uhlenbeck process # generate stim from OU process N_stim = int(senE1.__len__()) z1, z2, zk1, zk2 = generate_stim(N_stim, stimdur, taustim) # stim2exc i1 = I0 * (1 + c * mu1 + sigmastim * z1 + sigmaind * zk1) i2 = I0 * (1 + c * mu2 + sigmastim * z2 + sigmaind * zk2) stim_dt = 1 * ms i1t = np.concatenate((np.zeros((int(stimon / ms), N_stim)), i1.T, np.zeros((int( (runtime - stimoff) / stim_dt), N_stim))), axis=0) i2t = np.concatenate((np.zeros((int(stimon / ms), N_stim)), i2.T, np.zeros((int( (runtime - stimoff) / stim_dt), N_stim))), axis=0) Irec = TimedArray(np.concatenate((i1t, i2t), axis=1) * amp, dt=stim_dt) # ------------------------------------- # Simulation # ------------------------------------- # set initial conditions (different for evert trial) seed() decE.g_ea = '0.2 * rand()' decI.g_ea = '0.2 * rand()' decE.V = '-52*mV + 2*mV * rand()' decI.V = '-52*mV + 2*mV * rand()' # random initialization near 0, prevent an early decision! senE.g_ea = '0.05 * (1 + 0.2*rand())' senI.g_ea = '0.05 * (1 + 0.2*rand())' senE.V = '-52*mV + 2*mV*rand()' # random initialization near Vt, avoid initial bump! senI.V = '-52*mV + 2*mV*rand()' if task_info['simulation']['2cmodel']: dend.g_ea = '0.05 * (1 + 0.2*rand())' dend.V_d = '-72*mV + 2*mV*rand()' dend.muOUd = np.tile(last_muOUd, 2) * amp # create monitors rateDE1 = PopulationRateMonitor(decE1) rateDE2 = PopulationRateMonitor(decE2) rateSE1 = PopulationRateMonitor(senE1) rateSE2 = PopulationRateMonitor(senE2) subSE = int(senE1.__len__()) spksSE = SpikeMonitor(senE[subSE - 100:subSE + 100]) # last 100 of SE1 and first 100 of SE2 # construct network net = Network(Dgroups.values(), Dsynapses.values(), Sgroups.values(), Ssynapses.values(), synSE1DE1, synSE2DE2, synDE1SE1, synDE2SE2, rateDE1, rateDE2, rateSE1, rateSE2, spksSE, name='hierarchicalnet') # create more monitors for plot if task_info['simulation']['pltfig1']: # inh rateDI = PopulationRateMonitor(decI) rateSI = PopulationRateMonitor(senI) # spk monitors subDE = int(decE1.__len__() * 2) spksDE = SpikeMonitor(decE[:subDE]) spksSE = SpikeMonitor(senE) # state mons no more, just the arrays stim1 = i1t.T stim2 = i2t.T stimtime = np.linspace(0, runtime_, stim1.shape[1]) # construct network net = Network(Dgroups.values(), Dsynapses.values(), Sgroups.values(), Ssynapses.values(), synSE1DE1, synSE2DE2, synDE1SE1, synDE2SE2, spksDE, rateDE1, rateDE2, rateDI, spksSE, rateSE1, rateSE2, rateSI, name='hierarchicalnet') if task_info['simulation']['plasticdend']: # create state monitor to follow muOUd and add it to the networks dend_mon = StateMonitor(dend1, variables=['muOUd', 'Ibg', 'g_ea', 'B'], record=True, dt=1 * ms) net.add(dend_mon) # remove FB synapses! net.remove([synDE1SE1, synDE2SE2, Dsynapses.values()]) print( " FB synapses and synapses of decision circuit are ignored in this simulation!" ) # run hierarchical net net.run(runtime, report='stdout', profile=True) print(profiling_summary(net=net, show=10)) # nice plots on cluster if task_info['simulation']['pltfig1']: plot_fig1b([ rateDE1, rateDE2, rateDI, spksDE, rateSE1, rateSE2, rateSI, spksSE, stim1, stim2, stimtime ], smoothwin, taskdir) # ------------------------------------- # Burst quantification # ------------------------------------- events = np.zeros(1) bursts = np.zeros(1) singles = np.zeros(1) spikes = np.zeros(1) last_muOUd = np.zeros(1) # neurometric params dt = spksSE.clock.dt validburst = task_info['sen']['2c']['validburst'] smoothwin_ = smoothwin / second if task_info['simulation']['burstanalysis']: if task_info['simulation']['2cmodel']: last_muOUd = np.array(dend_mon.muOUd[:, -int(1e3):].mean(axis=1)) if task_info['simulation']['plasticdend']: # calculate neurometric info per population events, bursts, singles, spikes, isis = spks2neurometric( spksSE, runtime, settletime, validburst, smoothwin=smoothwin_, raster=False) # plot & save weigths after convergence eta0 = task_info['sen']['2c']['eta0'] tauB = task_info['sen']['2c']['tauB'] targetB = task_info['targetB'] B0 = tauB * targetB tau_update = task_info['sen']['2c']['tau_update'] eta = eta0 * tau_update / tauB plot_weights(dend_mon, events, bursts, spikes, [targetB, B0, eta, tauB, tau_update, smoothwin_], taskdir) plot_rasters(spksSE, bursts, targetB, isis, runtime_, taskdir) else: # calculate neurometric per neuron events, bursts, singles, spikes, isis = spks2neurometric( spksSE, runtime, settletime, validburst, smoothwin=smoothwin_, raster=True) plot_neurometric(events, bursts, spikes, stim1, stim2, stimtime, (settletime_, runtime_), taskdir, smoothwin_) plot_isis(isis, bursts, events, (settletime_, runtime_), taskdir) # ------------------------------------- # Choice selection # ------------------------------------- # population rates and downsample originaltime = rateDE1.t / second interptime = np.linspace(0, originaltime[-1], originaltime[-1] * 100) # every 10 ms fDE1 = interpolate.interp1d( originaltime, rateDE1.smooth_rate(window='flat', width=smoothwin)) fDE2 = interpolate.interp1d( originaltime, rateDE2.smooth_rate(window='flat', width=smoothwin)) fSE1 = interpolate.interp1d( originaltime, rateSE1.smooth_rate(window='flat', width=smoothwin)) fSE2 = interpolate.interp1d( originaltime, rateSE2.smooth_rate(window='flat', width=smoothwin)) rateDE = np.array([f(interptime) for f in [fDE1, fDE2]]) rateSE = np.array([f(interptime) for f in [fSE1, fSE2]]) # select the last half second of the stimulus newdt = runtime_ / rateDE.shape[1] settletimeidx = int(settletime_ / newdt) dec_ival = np.array([(stimoff_ - 0.5) / newdt, stimoff_ / newdt], dtype=int) who_wins = rateDE[:, dec_ival[0]:dec_ival[1]].mean(axis=1) # divide trls into preferred and non-preferred pref_msk = np.argmax(who_wins) poprates_dec = np.array([rateDE[pref_msk], rateDE[~pref_msk]]) # 0: pref, 1: npref poprates_sen = np.array([rateSE[pref_msk], rateSE[~pref_msk]]) results = { 'raw_data': { 'poprates_dec': poprates_dec[:, settletimeidx:], 'poprates_sen': poprates_sen[:, settletimeidx:], 'pref_msk': np.array([pref_msk]), 'last_muOUd': last_muOUd }, 'sim_state': np.zeros(1), 'computed': { 'events': events, 'bursts': bursts, 'singles': singles, 'spikes': spikes, 'isis': np.array(isis) } } return results