Exemplo n.º 1
0
def run_net(tr):

    # prefs.codegen.target = 'numpy'
    # prefs.codegen.target = 'cython'
    set_device('cpp_standalone',
               directory='./builds/%.4d' % (tr.v_idx),
               build_on_run=False)

    print("Started process with id ", str(tr.v_idx))

    T = tr.T1 + tr.T2 + tr.T3

    namespace = tr.netw.f_to_dict(short_names=True, fast_access=True)
    namespace['idx'] = tr.v_idx

    defaultclock.dt = tr.netw.sim.dt

    GExc = NeuronGroup(
        N=tr.N_e,
        model=tr.condlif_sig,
        threshold=tr.nrnEE_thrshld,
        reset=tr.nrnEE_reset,  #method=tr.neuron_method,
        namespace=namespace)
    GInh = NeuronGroup(
        N=tr.N_i,
        model=tr.condlif_sig,
        threshold='V > Vt',
        reset='V=Vr_i',  #method=tr.neuron_method,
        namespace=namespace)

    # set initial thresholds fixed, init. potentials uniformly distrib.
    GExc.sigma, GInh.sigma = tr.sigma_e, tr.sigma_i
    GExc.Vt, GInh.Vt = tr.Vt_e, tr.Vt_i
    GExc.V , GInh.V  = np.random.uniform(tr.Vr_e/mV, tr.Vt_e/mV,
                                         size=tr.N_e)*mV, \
                       np.random.uniform(tr.Vr_i/mV, tr.Vt_i/mV,
                                         size=tr.N_i)*mV

    print("need to fix?")
    synEE_pre_mod = mod.synEE_pre
    synEE_post_mod = mod.synEE_post

    if tr.PInp_mode == 'pool':
        PInp = PoissonGroup(tr.NPInp, rates=tr.PInp_rate, namespace=namespace)
        sPN = Synapses(target=GExc,
                       source=PInp,
                       model=tr.poisson_mod,
                       on_pre='ge_post += a_EPoi',
                       namespace=namespace)

        sPN_src, sPN_tar = generate_connections(N_tar=tr.N_e,
                                                N_src=tr.NPInp,
                                                p=tr.p_EPoi)

    elif tr.PInp_mode == 'indep':
        PInp = PoissonGroup(tr.N_e, rates=tr.PInp_rate, namespace=namespace)
        sPN = Synapses(target=GExc,
                       source=PInp,
                       model=tr.poisson_mod,
                       on_pre='ge_post += a_EPoi',
                       namespace=namespace)
        sPN_src, sPN_tar = range(tr.N_e), range(tr.N_e)

    sPN.connect(i=sPN_src, j=sPN_tar)

    if tr.PInp_mode == 'pool':
        sPNInh = Synapses(target=GInh,
                          source=PInp,
                          model=tr.poisson_mod,
                          on_pre='ge_post += a_EPoi',
                          namespace=namespace)
        sPNInh_src, sPNInh_tar = generate_connections(N_tar=tr.N_i,
                                                      N_src=tr.NPInp,
                                                      p=tr.p_EPoi)

    elif tr.PInp_mode == 'indep':

        PInp_inh = PoissonGroup(tr.N_i,
                                rates=tr.PInp_rate,
                                namespace=namespace)
        sPNInh = Synapses(target=GInh,
                          source=PInp_inh,
                          model=tr.poisson_mod,
                          on_pre='ge_post += a_EPoi',
                          namespace=namespace)
        sPNInh_src, sPNInh_tar = range(tr.N_i), range(tr.N_i)

    sPNInh.connect(i=sPNInh_src, j=sPNInh_tar)

    if tr.stdp_active:
        synEE_pre_mod = '''%s 
                            %s''' % (synEE_pre_mod, mod.synEE_pre_STDP)
        synEE_post_mod = '''%s 
                            %s''' % (synEE_post_mod, mod.synEE_post_STDP)

    if tr.synEE_rec:
        synEE_pre_mod = '''%s 
                            %s''' % (synEE_pre_mod, mod.synEE_pre_rec)
        synEE_post_mod = '''%s 
                            %s''' % (synEE_post_mod, mod.synEE_post_rec)

    # E<-E advanced synapse model, rest simple
    SynEE = Synapses(
        target=GExc,
        source=GExc,
        model=tr.synEE_mod,
        on_pre=synEE_pre_mod,
        on_post=synEE_post_mod,
        #method=tr.synEE_method,
        namespace=namespace)
    SynIE = Synapses(target=GInh,
                     source=GExc,
                     on_pre='ge_post += a_ie',
                     namespace=namespace)
    SynEI = Synapses(target=GExc,
                     source=GInh,
                     on_pre='gi_post += a_ei',
                     namespace=namespace)
    SynII = Synapses(target=GInh,
                     source=GInh,
                     on_pre='gi_post += a_ii',
                     namespace=namespace)

    if tr.strct_active:
        sEE_src, sEE_tar = generate_full_connectivity(tr.N_e, same=True)
        SynEE.connect(i=sEE_src, j=sEE_tar)
        SynEE.syn_active = 0

    else:
        srcs_full, tars_full = generate_full_connectivity(tr.N_e, same=True)
        SynEE.connect(i=srcs_full, j=tars_full)
        SynEE.syn_active = 0

    sIE_src, sIE_tar = generate_connections(tr.N_i, tr.N_e, tr.p_ie)
    sEI_src, sEI_tar = generate_connections(tr.N_e, tr.N_i, tr.p_ei)
    sII_src, sII_tar = generate_connections(tr.N_i, tr.N_i, tr.p_ii, same=True)

    SynIE.connect(i=sIE_src, j=sIE_tar)
    SynEI.connect(i=sEI_src, j=sEI_tar)
    SynII.connect(i=sII_src, j=sII_tar)

    tr.f_add_result('sIE_src', sIE_src)
    tr.f_add_result('sIE_tar', sIE_tar)
    tr.f_add_result('sEI_src', sEI_src)
    tr.f_add_result('sEI_tar', sEI_tar)
    tr.f_add_result('sII_src', sII_src)
    tr.f_add_result('sII_tar', sII_tar)

    SynEE.a = tr.a_ee

    SynEE.insert_P = tr.insert_P
    SynEE.p_inactivate = tr.p_inactivate

    # make synapse active at beginning
    SynEE.run_regularly(tr.synEE_p_activate, dt=T, when='start', order=-100)

    # synaptic scaling
    if tr.netw.config.scl_active:
        SynEE.summed_updaters['Asum_post']._clock = Clock(
            dt=tr.dt_synEE_scaling)
        SynEE.run_regularly(tr.synEE_scaling,
                            dt=tr.dt_synEE_scaling,
                            when='end')

    # intrinsic plasticity
    if tr.netw.config.it_active:
        GExc.h_ip = tr.h_ip
        GExc.run_regularly(tr.intrinsic_mod, dt=tr.it_dt, when='end')

    # structural plasticity
    if tr.netw.config.strct_active:
        if tr.strct_mode == 'zero':
            if tr.turnover_rec:
                strct_mod = '''%s 
                                %s''' % (tr.strct_mod, tr.turnover_rec_mod)
            else:
                strct_mod = tr.strct_mod

            SynEE.run_regularly(strct_mod, dt=tr.strct_dt, when='end')

        elif tr.strct_mode == 'thrs':
            if tr.turnover_rec:
                strct_mod_thrs = '''%s 
                                %s''' % (tr.strct_mod_thrs,
                                         tr.turnover_rec_mod)
            else:
                strct_mod_thrs = tr.strct_mod_thrs

            SynEE.run_regularly(strct_mod_thrs, dt=tr.strct_dt, when='end')

    # -------------- recording ------------------

    #run(tr.sim.preT)

    GExc_recvars = []
    if tr.memtraces_rec:
        GExc_recvars.append('V')
    if tr.vttraces_rec:
        GExc_recvars.append('Vt')
    if tr.getraces_rec:
        GExc_recvars.append('ge')
    if tr.gitraces_rec:
        GExc_recvars.append('gi')

    GInh_recvars = GExc_recvars

    GExc_stat = StateMonitor(GExc,
                             GExc_recvars,
                             record=[0, 1, 2],
                             dt=tr.GExc_stat_dt)
    GInh_stat = StateMonitor(GInh,
                             GInh_recvars,
                             record=[0, 1, 2],
                             dt=tr.GInh_stat_dt)

    SynEE_recvars = []
    if tr.synee_atraces_rec:
        SynEE_recvars.append('a')
    if tr.synee_Apretraces_rec:
        SynEE_recvars.append('Apre')
    if tr.synee_Aposttraces_rec:
        SynEE_recvars.append('Apost')

    SynEE_stat = StateMonitor(SynEE,
                              SynEE_recvars,
                              record=range(tr.n_synee_traces_rec),
                              when='end',
                              dt=tr.synEE_stat_dt)

    GExc_spks = SpikeMonitor(GExc)
    GInh_spks = SpikeMonitor(GInh)
    PInp_spks = SpikeMonitor(PInp)

    GExc_rate = PopulationRateMonitor(GExc)
    GInh_rate = PopulationRateMonitor(GInh)
    PInp_rate = PopulationRateMonitor(PInp)

    SynEE_a = StateMonitor(SynEE, ['a', 'syn_active'],
                           record=range(tr.N_e * (tr.N_e - 1)),
                           dt=T / tr.synee_a_nrecpoints,
                           when='end',
                           order=100)

    if tr.PInp_mode == 'indep':
        net = Network(GExc, GInh, PInp, sPN, sPNInh, SynEE, SynEI, SynIE,
                      SynII, GExc_stat, GInh_stat, SynEE_stat, SynEE_a,
                      GExc_spks, GInh_spks, PInp_spks, GExc_rate, GInh_rate,
                      PInp_rate, PInp_inh)
    else:
        net = Network(GExc, GInh, PInp, sPN, sPNInh, SynEE, SynEI, SynIE,
                      SynII, GExc_stat, GInh_stat, SynEE_stat, SynEE_a,
                      GExc_spks, GInh_spks, PInp_spks, GExc_rate, GInh_rate,
                      PInp_rate)

    net.run(tr.sim.T1, report='text')
    # SynEE_a.record_single_timestep()

    recorders = [
        GExc_spks, GInh_spks, PInp_spks, SynEE_stat, GExc_stat, GInh_stat
    ]
    rate_recorders = [GExc_rate, GInh_rate, PInp_rate]

    for rcc in recorders:
        rcc.active = False

    net.run(tr.sim.T2, report='text')

    recorders = [
        SynEE_stat, GExc_stat, GInh_stat, GExc_rate, GInh_rate, PInp_rate
    ]
    for rcc in recorders:
        rcc.active = True

    if tr.spks_rec:
        GExc_spks.active = True
        GInh_spks.active = True
        # PInp_spks.active=True

    net.run(tr.sim.T3, report='text')

    device.build(directory='../builds/%.4d' % (tr.v_idx), clean=True)

    # save monitors as raws in build directory
    raw_dir = '../builds/%.4d/raw/' % (tr.v_idx)

    if not os.path.exists(raw_dir):
        os.makedirs(raw_dir)

    with open(raw_dir + 'namespace.p', 'wb') as pfile:
        pickle.dump(namespace, pfile)

    with open(raw_dir + 'gexc_stat.p', 'wb') as pfile:
        pickle.dump(GExc_stat.get_states(), pfile)
    with open(raw_dir + 'ginh_stat.p', 'wb') as pfile:
        pickle.dump(GInh_stat.get_states(), pfile)

    with open(raw_dir + 'synee_stat.p', 'wb') as pfile:
        pickle.dump(SynEE_stat.get_states(), pfile)
    with open(raw_dir + 'synee_a.p', 'wb') as pfile:
        pickle.dump(SynEE_a.get_states(), pfile)

    with open(raw_dir + 'gexc_spks.p', 'wb') as pfile:
        pickle.dump(GExc_spks.get_states(), pfile)
    with open(raw_dir + 'ginh_spks.p', 'wb') as pfile:
        pickle.dump(GInh_spks.get_states(), pfile)
    with open(raw_dir + 'pinp_spks.p', 'wb') as pfile:
        pickle.dump(PInp_spks.get_states(), pfile)

    with open(raw_dir + 'gexc_rate.p', 'wb') as pfile:
        pickle.dump(GExc_rate.get_states(), pfile)
        pickle.dump(GExc_rate.smooth_rate(width=25 * ms), pfile)
    with open(raw_dir + 'ginh_rate.p', 'wb') as pfile:
        pickle.dump(GInh_rate.get_states(), pfile)
        pickle.dump(GInh_rate.smooth_rate(width=25 * ms), pfile)
    with open(raw_dir + 'pinp_rate.p', 'wb') as pfile:
        pickle.dump(PInp_rate.get_states(), pfile)
        pickle.dump(PInp_rate.smooth_rate(width=25 * ms), pfile)

    # ----------------- add raw data ------------------------
    fpath = '../builds/%.4d/' % (tr.v_idx)

    from pathlib import Path

    Path(fpath + 'turnover').touch()
    turnover_data = np.genfromtxt(fpath + 'turnover', delimiter=',')
    os.remove(fpath + 'turnover')

    with open(raw_dir + 'turnover.p', 'wb') as pfile:
        pickle.dump(turnover_data, pfile)

    Path(fpath + 'spk_register').touch()
    spk_register_data = np.genfromtxt(fpath + 'spk_register', delimiter=',')
    os.remove(fpath + 'spk_register')

    with open(raw_dir + 'spk_register.p', 'wb') as pfile:
        pickle.dump(spk_register_data, pfile)
Exemplo n.º 2
0
def test_ExportDevice_basic():
    """
    Test the components and structure of the dictionary exported
    by ExportDevice
    """
    start_scope()
    set_device('exporter')

    grp = NeuronGroup(10,
                      'dv/dt = (1-v)/tau :1',
                      method='exact',
                      threshold='v > 0.5',
                      reset='v = 0',
                      refractory=2 * ms)
    tau = 10 * ms
    rate = '1/tau'
    grp.v['i > 2 and i < 5'] = -0.2
    pgrp = PoissonGroup(10, rates=rate)
    smon = SpikeMonitor(pgrp)
    smon.active = False
    netobj = Network(grp, pgrp, smon)
    netobj.run(100 * ms)
    dev_dict = device.runs
    # check the structure and components in dev_dict
    assert dev_dict[0]['duration'] == 100 * ms
    assert dev_dict[0]['inactive'][0] == smon.name
    components = dev_dict[0]['components']
    assert components['spikemonitor'][0]
    assert components['poissongroup'][0]
    assert components['neurongroup'][0]
    initializers = dev_dict[0]['initializers_connectors']
    assert initializers[0]['source'] == grp.name
    assert initializers[0]['variable'] == 'v'
    assert initializers[0]['index'] == 'i > 2 and i < 5'
    # TODO: why not a Quantity type?
    assert initializers[0]['value'] == '-0.2'
    with pytest.raises(KeyError):
        initializers[0]['identifiers']
    device.reinit()

    start_scope()
    set_device('exporter', build_on_run=False)
    tau = 10 * ms
    v0 = -70 * mV
    vth = 800 * mV
    grp = NeuronGroup(10,
                      'dv/dt = (v0-v)/tau :volt',
                      method='exact',
                      threshold='v > vth',
                      reset='v = v0',
                      refractory=2 * ms)
    v0 = -80 * mV
    grp.v[:] = 'v0 + 2 * mV'
    smon = StateMonitor(grp, 'v', record=True)
    smon.active = False
    net = Network(grp, smon)
    net.run(10 * ms)  # first run
    v0 = -75 * mV
    grp.v[3:8] = list(range(3, 8)) * mV
    smon.active = True
    net.run(20 * ms)  # second run
    v_new = -5 * mV
    grp.v['i >= 5'] = 'v0 + v_new'
    v_new = -10 * mV
    grp.v['i < 5'] = 'v0 - v_new'
    spikemon = SpikeMonitor(grp)
    net.add(spikemon)
    net.run(5 * ms)  # third run
    dev_dict = device.runs
    # check run1
    assert dev_dict[0]['duration'] == 10 * ms
    assert dev_dict[0]['inactive'][0] == smon.name
    components = dev_dict[0]['components']
    assert components['statemonitor'][0]
    assert components['neurongroup'][0]
    initializers = dev_dict[0]['initializers_connectors']
    assert initializers[0]['source'] == grp.name
    assert initializers[0]['variable'] == 'v'
    assert initializers[0]['index']
    assert initializers[0]['value'] == 'v0 + 2 * mV'
    assert initializers[0]['identifiers']['v0'] == -80 * mV
    with pytest.raises(KeyError):
        initializers[0]['identifiers']['mV']
    # check run2
    assert dev_dict[1]['duration'] == 20 * ms
    initializers = dev_dict[1]['initializers_connectors']
    assert initializers[0]['source'] == grp.name
    assert initializers[0]['variable'] == 'v'
    assert (initializers[0]['index'] == grp.indices[slice(3, 8, None)]).all()
    assert (initializers[0]['value'] == list(range(3, 8)) * mV).all()
    with pytest.raises(KeyError):
        dev_dict[1]['inactive']
        initializers[1]['identifiers']
    # check run3
    assert dev_dict[2]['duration'] == 5 * ms
    with pytest.raises(KeyError):
        dev_dict[2]['inactive']
    assert dev_dict[2]['components']['spikemonitor']
    initializers = dev_dict[2]['initializers_connectors']
    assert initializers[0]['source'] == grp.name
    assert initializers[0]['variable'] == 'v'
    assert initializers[0]['index'] == 'i >= 5'
    assert initializers[0]['value'] == 'v0 + v_new'
    assert initializers[0]['identifiers']['v0'] == -75 * mV
    assert initializers[0]['identifiers']['v_new'] == -5 * mV
    assert initializers[1]['index'] == 'i < 5'
    assert initializers[1]['value'] == 'v0 - v_new'
    assert initializers[1]['identifiers']['v_new'] == -10 * mV
    with pytest.raises(IndexError):
        initializers[2]
        dev_dict[3]
    device.reinit()