return -s / tau_decay def update(ST, _t, pre): s = ints(ST['s'], _t) s += pre['spike'] ST['s'] = s ST['g'] = ST['w'] * s def output(ST, post): post['input'] += ST['g'] syn = bp.SynType(name='exponential_synapse', ST=bp.types.SynState(['s', 'g', 'w']), steps=(update, output), mode='scalar') # ------- # network # ------- group = bp.NeuGroup(neu, geometry=num_exc + num_inh, monitors=['spike']) group.ST['V'] = np.random.random(num_exc + num_inh) * (V_threshld - V_rest) + V_rest exc_conn = bp.SynConn(syn, pre_group=group[:num_exc], post_group=group, conn=bp.connect.FixedProb(prob=prob)) exc_conn.ST['w'] = JE
def get_voltage_jump(post_has_refractory=False, mode='vector'): """Voltage jump synapses without post-synaptic neuron refractory. .. math:: I_{syn} = \sum J \delta(t-t_j) ST refers to synapse state, members of ST are listed below: =============== ================= ========================================================= **Member name** **Initial Value** **Explanation** --------------- ----------------- --------------------------------------------------------- g 0. Synapse conductance on post-synaptic neuron. =============== ================= ========================================================= Note that all ST members are saved as floating point type in BrainPy, though some of them represent other data types (such as boolean). Args: post_has_refractory (bool): whether the post-synaptic neuron have refractory. Returns: bp.SynType. """ requires = dict( pre=bp.types.NeuState(['spike']) ) if post_has_refractory: requires['post'] = bp.types.NeuState(['V', 'refractory']) @bp.delayed def output(ST, post): post['V'] += ST['s'] * (1. - post['refractory']) else: requires['post'] = bp.types.NeuState(['V']) @bp.delayed def output(ST, post): post['V'] += ST['s'] if mode=='vector': requires['pre2post']=bp.types.ListConn() def update(ST, pre, post, pre2post): num_post = post['V'].shape[0] s = np.zeros_like(num_post, dtype=np.float_) spike_idx = np.where(pre['spike'] > 0.)[0] for i in spike_idx: post_ids = pre2post[i] s[post_ids] = 1. ST['s'] = s elif mode=='scalar': def update(ST, pre): ST['s'] = 0. if pre['spike'] > 0.: ST['s'] = 1. elif mode=='matrix': def update(ST, pre): ST['s'] += pre['spike'].reshape((-1, 1)) else: raise ValueError("BrainPy does not support mode '%s'." % (mode)) return bp.SynType(name='voltage_jump_synapse', ST=bp.types.SynState(['s']), requires=requires, steps=(update, output), mode = mode)
def get_NMDA(g_max=0.15, E=0, alpha=0.062, beta=3.57, cc_Mg=1.2, tau_decay=100., a=0.5, tau_rise=2., mode='vector'): """NMDA conductance-based synapse. .. math:: & I(t) = \\bar{g} s(t) (V-E) \\cdot g_{\\infty} & g_{\\infty}(V,[{Mg}^{2+}]) = (1+{e}^{-\\alpha V} \\frac{[{Mg}^{2+}] {\\beta})^{-1} & \\frac{d s_{j}(t)}{dt} = -\\frac{s_{j}(t)} {\\tau_{decay}}+a x_{j}(t)(1-s_{j}(t)) & \\frac{d x_{j}(t)}{dt} = -\\frac{x_{j}(t)}{\\tau_{rise}}+ \\sum_{k} \\delta(t-t_{j}^{k}) where the decay time of NMDA currents is taken to be :math:`\\tau_{decay}` =100 ms, :math:`a= 0.5 ms^{-1}`, and :math:`\\tau_{rise}` =2 ms ST refers to the synapse state, items in ST are listed below: =============== ================== ========================================================= **Member name** **Initial values** **Explanation** --------------- ------------------ --------------------------------------------------------- s 0 Gating variable. g 0 Synapse conductance. x 0 Gating variable. =============== ================== ========================================================= Note that all ST members are saved as floating point type in BrainPy, though some of them represent other data types (such as boolean). Args: g_max (float) : The maximum conductance. E (float) : The reversal potential. alpha (float) : Binding constant. beta (float) : Unbinding constant. cc_Mg (float) : concentration of Magnesium ion. tau_decay (float) : The time constant of decay. tau_rise (float) : The time constant of rise. a (float) References: .. [1] Brunel N, Wang X J. Effects of neuromodulation in a cortical network model of object working memory dominated by recurrent inhibition[J]. Journal of computational neuroscience, 2001, 11(1): 63-85. """ @bp.integrate def int_x(x, _t): return -x / tau_rise @bp.integrate def int_s(s, _t, x): return -s / tau_decay + a * x * (1 - s) ST = bp.types.SynState({'s': 0., 'x': 0., 'g': 0.}) requires = dict(pre=bp.types.NeuState(['spike']), post=bp.types.NeuState(['V', 'input'])) if mode == 'scalar': def update(ST, _t, pre): x = int_x(ST['x'], _t) x += pre['spike'] s = int_s(ST['s'], _t, x) ST['x'] = x ST['s'] = s ST['g'] = g_max * s @bp.delayed def output(ST, post): I_syn = ST['g'] * (post['V'] - E) g_inf = 1 + cc_Mg / beta * np.exp(-alpha * post['V']) post['input'] -= I_syn * g_inf elif mode == 'vector': requires['pre2syn'] = bp.types.ListConn( help='Pre-synaptic neuron index -> synapse index') requires['post2syn'] = bp.types.ListConn( help='Post-synaptic neuron index -> synapse index') def update(ST, _t, pre, pre2syn): for pre_id in range(len(pre2syn)): if pre['spike'][pre_id] > 0.: syn_ids = pre2syn[pre_id] ST['x'][syn_ids] += 1. x = int_x(ST['x'], _t) s = int_s(ST['s'], _t, x) ST['x'] = x ST['s'] = s ST['g'] = g_max * s @bp.delayed def output(ST, post, post2syn): g = np.zeros(len(post2syn), dtype=np.float_) for post_id, syn_ids in enumerate(post2syn): g[post_id] = np.sum(ST['g'][syn_ids]) I_syn = g * (post['V'] - E) g_inf = 1 + cc_Mg / beta * np.exp(-alpha * post['V']) post['input'] -= I_syn * g_inf elif mode == 'matrix': requires['conn_mat'] = bp.types.MatConn() def update(ST, _t, pre, conn_mat): x = int_x(ST['x'], _t) x += pre['spike'].reshape((-1, 1)) * conn_mat s = int_s(ST['s'], _t, x) ST['x'] = x ST['s'] = s ST['g'] = g_max * s @bp.delayed def output(ST, post): g = np.sum(ST['g'], axis=0) I_syn = g * (post['V'] - E) g_inf = 1 + cc_Mg / beta * np.exp(-alpha * post['V']) post['input'] -= I_syn * g_inf else: raise ValueError("BrainPy does not support mode '%s'." % (mode)) return bp.SynType(name='NMDA_synapse', ST=ST, requires=requires, steps=(update, output), mode=mode)
def get_alpha2(g_max=.2, E=0., tau_decay = 2.): """ Alpha conductance-based synapse. .. math:: I_{syn}(t) &= g_{syn} (t) (V(t)-E_{syn}) g_{syn} (t) &= w s \\frac{d s}{d t}&=-\\frac{s}{\\tau_{decay}}+\\sum_{k} \\delta(t-t_{j}^{k}) ST refers to the synapse state, items in ST are listed below: ================ ================== ========================================================= **Member name** **Initial values** **Explanation** ---------------- ------------------ --------------------------------------------------------- g 0 Synapse conductance on the post-synaptic neuron. s 0 Synapse conductance on the post-synaptic neuron. w 1 Synapse conductance on the post-synaptic neuron. ================ ================== ========================================================= Note that all ST members are saved as floating point type in BrainPy, though some of them represent other data types (such as boolean). Args: g_max (float): The peak conductance change in µmho (µS). E (float): The reversal potential for the synaptic current. tau_decay (float): The time constant of decay. Returns: bp.Neutype """ ST=bp.types.SynState({'g': 0., 's': 0., 'w':1.}, help='The conductance defined by exponential function.') requires = { 'pre': bp.types.NeuState(['spike'], help='pre-synaptic neuron state must have "V"'), 'post': bp.types.NeuState(['input', 'V'], help='post-synaptic neuron state must include "input" and "V"'), 'pre2syn': bp.types.ListConn(help='Pre-synaptic neuron index -> synapse index'), 'post2syn': bp.types.ListConn(help='Post-synaptic neuron index -> synapse index'), } @bp.integrate def ints(s, t): return - s / tau_decay def update(ST, _t, pre, pre2syn): s = ints(ST['s'], _t) for i in np.where(pre['spike'] > 0.)[0]: syn_ids = pre2syn[i] s[syn_ids] += 1. ST['s'] = s ST['g'] = ST['w'] * s def output(ST, post, post2syn): for post_id, syn_id in enumerate(post2syn): post['input'][post_id] += np.sum(ST['g'][syn_id]) return bp.SynType(name='alpha_synapse', requires=requires, ST=ST, steps=(update, output), mode = 'vector')
def get_GABAb1(g_max=0.02, E=-95., k1=0.18, k2=0.034, k3=0.09, k4=0.0012, kd=100., T=0.5, T_duration=0.3, mode='vector'): """GABAb conductance-based synapse model(type 1). .. math:: &\\frac{d[R]}{dt} = k_3 [T](1-[R])- k_4 [R] &\\frac{d[G]}{dt} = k_1 [R]- k_2 [G] I_{GABA_{B}} &=\\overline{g}_{GABA_{B}} (\\frac{[G]^{4}} {[G]^{4}+K_{d}}) (V-E_{GABA_{B}}) - [G] is the concentration of activated G protein. - [R] is the fraction of activated receptor. - [T] is the transmitter concentration. ST refers to synapse state, members of ST are listed below: ================ ================= ========================================================= **Member name** **Initial Value** **Explanation** ---------------- ----------------- --------------------------------------------------------- R 0. The fraction of activated receptor. G 0. The concentration of activated G protein. g 0. Synapse conductance on post-synaptic neuron. t_last_pre_spike -1e7 Last spike time stamp of pre-synaptic neuron. ================ ================= ========================================================= Note that all ST members are saved as floating point type in BrainPy, though some of them represent other data types (such as boolean). Args: g_max (float): Maximum synapse conductance. E (float): Reversal potential of synapse. k1 (float): Activating rate constant of G protein catalyzed by activated GABAb receptor. k2 (float): De-activating rate constant of G protein. k3 (float): Activating rate constant of GABAb receptor. k4 (float): De-activating rate constant of GABAb receptor. T (float): Transmitter concentration when synapse is triggered by a pre-synaptic spike. T_duration (float): Transmitter concentration duration time after being triggered. mode (str): Data structure of ST members. Returns: bp.SynType: return description of GABAb synapse model. References: .. [1] Gerstner, Wulfram, et al. Neuronal dynamics: From single neurons to networks and models of cognition. Cambridge University Press, 2014. """ ST_scalar = bp.types.SynState( { 'R': 0., 'G': 0., 'g': 0., 't_last_spike': -1e7, }, help="GABAb synapse state") ST_vector = bp.types.SynState( { 'R': 0., 'G': 0., 'g': 0., 't_last_pre_spike': -1e7 }, help="GABAb synapse state") requires_scalar = { 'pre': bp.types.NeuState( ['spike'], help="Pre-synaptic neuron state must have 'spike' item"), 'post': bp.types.NeuState( ['V', 'input'], help="Post-synaptic neuron state must have 'V' and 'input' item"), } requires_vector = dict( pre=bp.types.NeuState( ['spike'], help="Pre-synaptic neuron state must have 'spike' item"), post=bp.types.NeuState( ['V', 'input'], help="Post-synaptic neuron state must have 'V' and 'input' item"), pre2syn=bp.types.ListConn( help="Pre-synaptic neuron index -> synapse index"), post2syn=bp.types.ListConn( help="Post-synaptic neuron index -> synapse index"), ) @bp.integrate def int_R(R, t, TT): return k3 * TT * (1 - R) - k4 * R @bp.integrate def int_G(G, t, R): return k1 * R - k2 * G if mode == 'scalar': def update(ST, _t, pre): if pre['spike'] > 0.: ST['t_last_spike'] = _t TT = ((_t - ST['t_last_spike']) < T_duration) * T R = int_R(ST['R'], _t, TT) G = int_G(ST['G'], _t, R) ST['R'] = R ST['G'] = G ST['g'] = g_max * G**4 / (G**4 + kd) elif mode == 'vector': def update(ST, _t, pre, pre2syn): for pre_id in np.where(pre['spike'] > 0.)[0]: syn_ids = pre2syn[pre_id] ST['t_last_pre_spike'][syn_ids] = _t TT = ((_t - ST['t_last_pre_spike']) < T_duration) * T R = int_R(ST['R'], _t, TT) G = int_G(ST['G'], _t, R) ST['R'] = R ST['G'] = G ST['g'] = g_max * G**4 / (G**4 + kd) if mode == 'scalar': @bp.delayed def output(ST, _t, post): I_syn = ST['g'] * (post['V'] - E) post['input'] -= I_syn elif mode == 'vector': @bp.delayed def output(ST, post, post2syn): post_cond = np.zeros(len(post2syn), dtype=np.float_) for post_id, syn_ids in enumerate(post2syn): post_cond[post_id] = np.sum(ST['g'][syn_ids]) post['input'] -= post_cond * (post['V'] - E) if mode == 'scalar': return bp.SynType(name='GABAb1_synapse', ST=ST_scalar, requires=requires_scalar, steps=(update, output), mode=mode) elif mode == 'vector': return bp.SynType(name='GABAb1_synapse', ST=ST_vector, requires=requires_vector, steps=(update, output), mode=mode) elif mode == 'matrix': raise ValueError("mode of function '%s' can not be '%s'." % (sys._getframe().f_code.co_name, mode)) else: raise ValueError("BrainPy does not support mode '%s'." % (mode))
ST['V'] = V neuron = bp.NeuType(name='COBA', ST=neu_ST, steps=neu_update, mode='scalar') def update1(pre, post, pre2post): for pre_id in range(len(pre2post)): if pre['spike'][pre_id] > 0.: post_ids = pre2post[pre_id] for i in post_ids: post['ge'][i] += we exc_syn = bp.SynType('exc_syn', steps=update1, ST=bp.types.SynState([]), mode='vector') def update2(pre, post, pre2post): for pre_id in range(len(pre2post)): if pre['spike'][pre_id] > 0.: post_ids = pre2post[pre_id] for i in post_ids: post['gi'][i] += wi inh_syn = bp.SynType('inh_syn', steps=update2, ST=bp.types.SynState([]), mode='vector')
def run_brianpy(num_neu, duration, device='cpu'): num_inh = int(num_neu / 5) num_exc = num_neu - num_inh bp.profile.set(jit=True, device=device, dt=dt) # Parameters taum = 20 taue = 5 taui = 10 Vt = -50 Vr = -60 El = -60 Erev_exc = 0. Erev_inh = -80. I = 20. we = 0.6 # excitatory synaptic weight (voltage) wi = 6.7 # inhibitory synaptic weight ref = 5.0 neu_ST = bp.types.NeuState({ 'sp_t': -1e7, 'V': Vr, 'spike': 0., 'ge': 0., 'gi': 0. }) @bp.integrate def int_ge(ge, t): return -ge / taue @bp.integrate def int_gi(gi, t): return -gi / taui @bp.integrate def int_V(V, t, ge, gi): return (ge * (Erev_exc - V) + gi * (Erev_inh - V) + (El - V) + I) / taum def neu_update(ST, _t): ST['ge'] = int_ge(ST['ge'], _t) ST['gi'] = int_gi(ST['gi'], _t) ST['spike'] = 0. if (_t - ST['sp_t']) > ref: V = int_V(ST['V'], _t, ST['ge'], ST['gi']) ST['spike'] = 0. if V >= Vt: ST['V'] = Vr ST['spike'] = 1. ST['sp_t'] = _t else: ST['V'] = V neuron = bp.NeuType(name='COBA', ST=neu_ST, steps=neu_update, mode='scalar') def syn_update1(pre, post, pre2post): for pre_id in range(len(pre2post)): if pre['spike'][pre_id] > 0.: post_ids = pre2post[pre_id] for i in post_ids: post['ge'][i] += we exc_syn = bp.SynType('exc_syn', steps=syn_update1, ST=bp.types.SynState([]), mode='vector') def syn_update2(pre, post, pre2post): for pre_id in range(len(pre2post)): if pre['spike'][pre_id] > 0.: post_ids = pre2post[pre_id] for i in post_ids: post['gi'][i] += wi inh_syn = bp.SynType('inh_syn', steps=syn_update2, ST=bp.types.SynState([]), mode='vector') group = bp.NeuGroup(neuron, geometry=num_exc + num_inh) group.ST['V'] = np.random.randn(num_exc + num_inh) * 5. - 55. exc_conn = bp.SynConn(exc_syn, pre_group=group[:num_exc], post_group=group, conn=bp.connect.FixedProb(prob=0.02)) inh_conn = bp.SynConn(inh_syn, pre_group=group[num_exc:], post_group=group, conn=bp.connect.FixedProb(prob=0.02)) net = bp.Network(group, exc_conn, inh_conn) t0 = time.time() net.run(duration) t = time.time() - t0 print(f'BrainPy ({device}) used time {t} s.') return t
def get_GABAa2(g_max=0.04, E=-80., alpha=0.53, beta=0.18, T=1., T_duration=1., mode='vector'): """ GABAa conductance-based synapse model (markov form). .. math:: I_{syn}&= - \\bar{g}_{max} s (V - E) \\frac{d r}{d t}&=\\alpha[T]^2(1-s) - \\beta s ST refers to synapse state, members of ST are listed below: ================ ================= ========================================================= **Member name** **Initial Value** **Explanation** ---------------- ----------------- --------------------------------------------------------- s 0. Gating variable. g 0. Synapse conductance on post-synaptic neuron. t_last_pre_spike -1e7 Last spike time stamp of pre-synaptic neuron. ================ ================= ========================================================= Note that all ST members are saved as floating point type in BrainPy, though some of them represent other data types (such as boolean). Args: g_max (float): Maximum synapse conductance. E (float): Reversal potential of synapse. alpha (float): Opening rate constant of ion channel. beta (float): Closing rate constant of ion channel. T (float): Transmitter concentration when synapse is triggered by a pre-synaptic spike. T_duration (float): Transmitter concentration duration time after being triggered. mode (str): Data structure of ST members. Returns: bp.SynType: return description of GABAa synapse model. References: .. [1] Destexhe, Alain, and Denis Paré. "Impact of network activity on the integrative properties of neocortical pyramidal neurons in vivo." Journal of neurophysiology 81.4 (1999): 1531-1547. """ ST=bp.types.SynState({'s': 0., 'g': 0., 't_last_pre_spike': -1e7}, help = "GABAa synapse state") requires = dict( pre=bp.types.NeuState(['spike'], help = "Pre-synaptic neuron state must have 'spike' item"), post=bp.types.NeuState(['V', 'input'], help = "Post-synaptic neuron state must have 'V' and 'input' item"), pre2syn=bp.types.ListConn(help = "Pre-synaptic neuron index -> synapse index"), post2syn=bp.types.ListConn(help = "Post-synaptic neuron index -> synapse index") ) @bp.integrate def int_s(s, t, TT): return alpha * TT * (1 - s) - beta * s def update(ST, pre, pre2syn, _t): for pre_id in np.where(pre['spike'] > 0.)[0]: syn_ids = pre2syn[pre_id] ST['t_last_pre_spike'][syn_ids] = _t TT = ((_t - ST['t_last_pre_spike']) < T_duration) * T s = int_s(ST['s'], _t, TT) ST['s'] = s ST['g'] = g_max * s @bp.delayed def output(ST, post, post2syn): post_cond = np.zeros(len(post2syn), dtype=np.float_) for post_id, syn_ids in enumerate(post2syn): post_cond[post_id] = np.sum(ST['g'][syn_ids]) post['input'] -= post_cond * (post['V'] - E) if mode == 'scalar': raise ValueError("mode of function '%s' can not be '%s'." % (sys._getframe().f_code.co_name, mode)) elif mode == 'vector': return bp.SynType(name='GABAa_synapse', ST=ST, requires=requires, steps=[update, output], mode='vector') elif mode == 'matrix': raise ValueError("mode of function '%s' can not be '%s'." % (sys._getframe().f_code.co_name, mode)) else: raise ValueError("BrainPy does not support mode '%s'." % (mode))
def get_GABAa1(g_max=0.4, E=-80., tau_decay=6., mode='vector'): """ GABAa conductance-based synapse model (differential form). .. math:: I_{syn}&= - \\bar{g}_{max} s (V - E) \\frac{d s}{d t}&=-\\frac{s}{\\tau_{decay}}+\\sum_{k}\\delta(t-t-{j}^{k}) ST refers to synapse state, members of ST are listed below: =============== ================= ========================================================= **Member name** **Initial Value** **Explanation** --------------- ----------------- --------------------------------------------------------- s 0. Gating variable. g 0. Synapse conductance on post-synaptic neuron. =============== ================= ========================================================= Note that all ST members are saved as floating point type in BrainPy, though some of them represent other data types (such as boolean). Args: g_max (float): Maximum synapse conductance. E (float): Reversal potential of synapse. tau_decay (float): Time constant of gating variable decay. mode (str): Data structure of ST members. Returns: bp.SynType: return description of GABAa synapse model. References: .. [1] Gerstner, Wulfram, et al. Neuronal dynamics: From single neurons to networks and models of cognition. Cambridge University Press, 2014. """ ST_vector = bp.types.SynState({'s': 0., 'g': 0.}, help = "GABAa synapse state") ST_scalar = bp.types.SynState(['s'], help = 'GABAa synapse state.') requires_vector = dict( pre=bp.types.NeuState(['spike'], help = "Pre-synaptic neuron state must have 'spike' item"), post=bp.types.NeuState(['V', 'input'], help = "Post-synaptic neuron state must have 'V' and 'input' item"), pre2syn=bp.types.ListConn(help="Pre-synaptic neuron index -> synapse index"), post2syn=bp.types.ListConn(help="Post-synaptic neuron index -> synapse index") ) requires_scalar = { 'pre': bp.types.NeuState(['spike'], help = 'Pre-synaptic neuron state must have "isFire"'), 'post': bp.types.NeuState(['V', 'input'], help = 'Post-synaptic neuron state must include "input" and "Vr"') } @bp.integrate def int_s(s, t): return - s / tau_decay if mode=='scalar': def update(ST, _t, pre): s = int_s(ST['s'], _t) s += pre['spike'] ST['s'] = s elif mode=='vector': def update(ST, pre, pre2syn): s = int_s(ST['s'], 0.) for pre_id in np.where(pre['spike'] > 0.)[0]: syn_ids = pre2syn[pre_id] s[syn_ids] += 1 ST['s'] = s ST['g'] = g_max * s if mode=='scalar': @bp.delayed def output(ST, _t, post): I_syn = - g_max * ST['s'] * (post['V'] - E) post['input'] += I_syn elif mode=='vector': @bp.delayed def output(ST, post, post2syn): post_cond = np.zeros(len(post2syn), dtype=np.float_) for post_id, syn_ids in enumerate(post2syn): post_cond[post_id] = np.sum(ST['g'][syn_ids]) post['input'] -= post_cond * (post['V'] - E) if mode == 'scalar': return bp.SynType(name='GABAa_synapse', ST=ST_scalar, requires=requires_scalar, steps=(update, output), mode=mode) elif mode == 'vector': return bp.SynType(name='GABAa_synapse', ST=ST_vector, requires=requires_vector, steps=(update, output), mode=mode) elif mode == 'matrix': raise ValueError("mode of function '%s' can not be '%s'." % (sys._getframe().f_code.co_name, mode)) else: raise ValueError("BrainPy does not support mode '%s'." % (mode))
def get_NMDA(g_max=0.15, E=0, alpha=0.062, beta=3.57, cc_Mg=1.2, tau_decay=100., a=0.5, tau_rise=2., mode='vector'): """NMDA conductance-based synapse. .. math:: & I_{syn} = \\bar{g}_{syn} s (V-E_{syn}) & g(t) = \\bar{g} \\cdot g_{\\infty} \\cdot \\sum_j s_j(t) & g_{\\infty}(V,[{Mg}^{2+}]_{o}) = (1+{e}^{-\\alpha V} \\frac{[{Mg}^{2+}]_{o}} {\\beta})^{-1} & \\frac{d s_{j}(t)}{dt} = -\\frac{s_{j}(t)} {\\tau_{decay}}+a x_{j}(t)(1-s_{j}(t)) & \\frac{d x_{j}(t)}{dt} = -\\frac{x_{j}(t)}{\\tau_{rise}}+ \\sum_{k} \\delta(t-t_{j}^{k}) where the decay time of NMDA currents is taken to be :math:`\\tau_{decay}` =100 ms, :math:`a= 0.5 ms^{-1}`, and :math:`\\tau_{rise}` =2 ms (Hestrin et al., 1990 [1]_; Spruston et al., 1995 [2]_). ST refers to the synapse state, items in ST are listed below: =============== ================== ========================================================= **Member name** **Initial values** **Explanation** --------------- ------------------ --------------------------------------------------------- s 0 Gating variable. g 0 Synapse conductance. x 0 Gating variable. =============== ================== ========================================================= Note that all ST members are saved as floating point type in BrainPy, though some of them represent other data types (such as boolean). Args: g_max (float) : The maximum conductance. E (float) : The reversal potential. alpha (float) : Binding constant. beta (float) : Unbinding constant. cc_Mg (float) : concentration of Magnesium ion. tau_decay (float) : The time constant of decay. tau_rise (float) : The time constant of rise. a (float) References: .. [1] Hestrin, S., et al. "Analysis of excitatory synaptic action in pyramidal cells using whole‐cell recording from rat hippocampal slices." The Journal of Physiology 422.1 (1990): 203-225. .. [2] Spruston, Nelson, Peter Jonas, and Bert Sakmann. "Dendritic glutamate receptor channels in rat hippocampal CA3 and CA1 pyramidal neurons." The Journal of physiology 482.2 (1995): 325-352. """ @bp.integrate def int_x(x, _t): return -x / tau_rise @bp.integrate def int_s(s, _t, x): return -s / tau_decay + a * x * (1 - s) ST = bp.types.SynState({'s': 0., 'x': 0., 'g': 0.}) requires = dict(pre=bp.types.NeuState(['spike']), post=bp.types.NeuState(['V', 'input'])) if mode == 'scalar': def update(ST, _t, pre): x = int_x(ST['x'], _t) x += pre['spike'] s = int_s(ST['s'], _t, x) ST['x'] = x ST['s'] = s ST['g'] = g_max * s @bp.delayed def output(ST, post): I_syn = ST['g'] * (post['V'] - E) g_inf = 1 + cc_Mg / beta * np.exp(-alpha * post['V']) post['input'] -= I_syn * g_inf elif mode == 'vector': requires['pre2syn'] = bp.types.ListConn( help='Pre-synaptic neuron index -> synapse index') requires['post2syn'] = bp.types.ListConn( help='Post-synaptic neuron index -> synapse index') def update(ST, _t, pre, pre2syn): for pre_id in range(len(pre2syn)): if pre['spike'][pre_id] > 0.: syn_ids = pre2syn[pre_id] ST['x'][syn_ids] += 1. x = int_x(ST['x'], _t) s = int_s(ST['s'], _t, x) ST['x'] = x ST['s'] = s ST['g'] = g_max * s @bp.delayed def output(ST, post, post2syn): g = np.zeros(len(post2syn), dtype=np.float_) for post_id, syn_ids in enumerate(post2syn): g[post_id] = np.sum(ST['g'][syn_ids]) I_syn = g * (post['V'] - E) g_inf = 1 + cc_Mg / beta * np.exp(-alpha * post['V']) post['input'] -= I_syn * g_inf elif mode == 'matrix': requires['conn_mat'] = bp.types.MatConn() def update(ST, _t, pre, conn_mat): x = int_x(ST['x'], _t) x += pre['spike'].reshape((-1, 1)) * conn_mat s = int_s(ST['s'], _t, x) ST['x'] = x ST['s'] = s ST['g'] = g_max * s @bp.delayed def output(ST, post): g = np.sum(ST['g'], axis=0) I_syn = g * (post['V'] - E) g_inf = 1 + cc_Mg / beta * np.exp(-alpha * post['V']) post['input'] -= I_syn * g_inf else: raise ValueError("BrainPy does not support mode '%s'." % (mode)) return bp.SynType(name='NMDA_synapse', ST=ST, requires=requires, steps=(update, output), mode=mode)
def get_Oja(gamma = 0.005, w_max = 1., w_min = 0., mode = 'vector'): """ Oja's learning rule. .. math:: \\frac{d w_{ij}}{dt} = \\gamma(\\upsilon_i \\upsilon_j - w_{ij}\\upsilon_i ^ 2) ST refers to synapse state (note that Oja learning rule can be implemented as synapses), members of ST are listed below: ================ ================= ========================================================= **Member name** **Initial Value** **Explanation** ---------------- ----------------- --------------------------------------------------------- w 0.05 Synapse weight. output_save 0. Temporary save synapse output value until post-synaptic neuron get the value after delay time. ================ ================= ========================================================= Note that all ST members are saved as floating point type in BrainPy, though some of them represent other data types (such as boolean). Args: gamma(float): Learning rate. w_max (float): Maximal possible synapse weight. w_min (float): Minimal possible synapse weight. mode (str): Data structure of ST members. Returns: bp.Syntype: return description of synapse with Oja's rule. References: .. [1] Gerstner, Wulfram, et al. Neuronal dynamics: From single neurons to networks and models of cognition. Cambridge University Press, 2014. """ ST = bp.types.SynState({'w': 0.05, 'output_save': 0.}) requires = dict( pre = bp.types.NeuState(['r']), post = bp.types.NeuState(['r']), post2syn=bp.types.ListConn(), post2pre=bp.types.ListConn(), ) @bp.integrate def int_w(w, _t, r_pre, r_post): dw = gamma * (r_post * r_pre - np.square(r_post) * w) return dw def update(ST, _t, pre, post, post2syn, post2pre): for post_id, post_r in enumerate(post['r']): syn_ids = post2syn[post_id] pre_ids = post2pre[post_id] pre_r = pre['r'][pre_ids] w = ST['w'][syn_ids] output = np.dot(w, pre_r) output += post_r w = int_w(w, _t, pre_r, output) ST['w'][syn_ids] = w ST['output_save'][syn_ids] = output @bp.delayed def output(ST, pre, post, post2syn): for post_id, _ in enumerate(post['r']): syn_ids = post2syn[post_id] post['r'][post_id] += ST['output_save'][syn_ids[0]] if mode == 'scalar': raise ValueError("mode of function '%s' can not be '%s'." % (sys._getframe().f_code.co_name, mode)) elif mode == 'vector': return bp.SynType(name='Oja_synapse', ST=ST, requires=requires, steps=(update, output), mode=mode) elif mode == 'matrix': raise ValueError("mode of function '%s' can not be '%s'." % (sys._getframe().f_code.co_name, mode)) else: raise ValueError("BrainPy does not support mode '%s'." % (mode))
def get_gap_junction_lif(weight, k_spikelet=0.1, post_has_refractory=False, mode='scalar'): """ synapse with gap junction. .. math:: I_{syn} = w (V_{pre} - V_{post}) ST refers to synapse state, members of ST are listed below: =============== ================= ========================================================= **Member name** **Initial Value** **Explanation** --------------- ----------------- --------------------------------------------------------- w 0. Synapse weights. spikelet 0. conductance for post-synaptic neuron =============== ================= ========================================================= Note that all ST members are saved as floating point type in BrainPy, though some of them represent other data types (such as boolean). Args: weight (float): Synapse weights. Returns: bp.SynType References: .. [1] Chow, Carson C., and Nancy Kopell. "Dynamics of spiking neurons with electrical coupling." Neural computation 12.7 (2000): 1643-1678. """ ST=bp.types.SynState('w', 'spikelet') requires = dict( pre=bp.types.NeuState(['V', 'spike']), post=bp.types.NeuState(['V', 'input']) ) if mode == 'scalar': def update(ST, pre, post): # gap junction sub-threshold post['input'] += ST['w'] * (pre['V'] - post['V']) # gap junction supra-threshold ST['spikelet'] = ST['w'] * k_spikelet * pre['spike'] @bp.delayed def output(ST, post): post['V'] += ST['spikelet'] steps=(update, output) elif mode == 'vector': requires['pre2post']=bp.types.ListConn(help='post-to-synapse connection.'), requires['pre_ids']=bp.types.Array(dim=1, help='Pre-synaptic neuron indices.'), if post_has_refractory: requires['post'] = bp.types.NeuState(['V', 'input', 'refractory']) def update(ST, pre, post, pre2post): num_pre = len(pre2post) g_post = np.zeros_like(post['V'], dtype=np.float_) spikelet = np.zeros_like(post['V'], dtype=np.float_) for pre_id in range(num_pre): post_ids = pre2post[pre_id] pre_V = pre['V'][pre_id] g_post[post_ids] = weight * np.sum(pre_V - post['V'][post_ids]) if pre['spike'][pre_id] > 0.: spikelet[post_ids] += weight * k_spikelet * pre_V post['V'] += spikelet * (1. - post['refractory']) post['input'] += g_post else: requires['post'] = bp.types.NeuState(['V', 'input']) def update(ST, pre, post, pre2post): num_pre = len(pre2post) g_post = np.zeros_like(post['V'], dtype=np.float_) spikelet = np.zeros_like(post['V'], dtype=np.float_) for pre_id in range(num_pre): post_ids = pre2post[pre_id] pre_V = pre['V'][pre_id] g_post[post_ids] = weight * np.sum(pre_V - post['V'][post_ids]) if pre['spike'][pre_id] > 0.: spikelet[post_ids] += weight * k_spikelet * pre_V post['V'] += spikelet post['input'] += g_post steps=update else: raise ValueError("BrainPy does not support mode '%s'." % (mode)) return bp.SynType(name='gap_junctin_synapse_for_LIF', ST=ST, requires=requires, steps=steps, mode=mode)
def get_gap_junction(mode='scalar'): """ synapse with gap junction. .. math:: I_{syn} = w (V_{pre} - V_{post}) ST refers to synapse state, members of ST are listed below: =============== ================= ========================================================= **Member name** **Initial Value** **Explanation** --------------- ----------------- --------------------------------------------------------- w 0. Synapse weights. =============== ================= ========================================================= Note that all ST members are saved as floating point type in BrainPy, though some of them represent other data types (such as boolean). Args: mode (string): data structure of ST members. Returns: bp.SynType Reference: .. [1] Chow, Carson C., and Nancy Kopell. "Dynamics of spiking neurons with electrical coupling." Neural computation 12.7 (2000): 1643-1678. """ ST=bp.types.SynState(['w']) requires = dict( pre=bp.types.NeuState(['V']), post=bp.types.NeuState(['V', 'input']) ) if mode=='scalar': def update(ST, pre, post): post['input'] += ST['w'] * (pre['V'] - post['V']) elif mode == 'vector': requires['post2pre']=bp.types.ListConn(help='post-to-pre connection.') requires['pre_ids']=bp.types.Array(dim=1, help='Pre-synaptic neuron indices.') def update(ST, pre, post, post2pre, pre_ids): num_post = len(post2pre) for post_id in range(num_post): pre_id = pre_ids[post_id] post['input'][post_id] += ST['w'] * np.sum(pre['V'][pre_id] - post['V'][post_id]) elif mode == 'matrix': requires['conn_mat']=bp.types.MatConn() def update(ST, pre, post, conn_mat): # reshape dim = np.shape(ST['w']) v_post = np.vstack((post['V'],)*dim[0]) v_pre = np.vstack((pre['V'],)*dim[1]).T # update post['input'] += ST['w'] * (v_pre - v_post) * conn_mat else: raise ValueError("BrainPy does not support mode '%s'." % (mode)) return bp.SynType(name='gap_junction_synapse', ST=ST, requires=requires, steps=update, mode=mode)
def get_BCM(learning_rate=0.01, w_max=2., w_min = 0., r_0 = 0., mode='matrix'): """ Bienenstock-Cooper-Munro (BCM) rule. .. math:: r_i = \\sum_j w_{ij} r_j \\frac d{dt} w_{ij} = \\eta \\cdot r_i (r_i - r_{\\theta}) r_j where :math:`\\eta` is some learning rate, and :math:`r_{\\theta}` is the plasticity threshold, which is a function of the averaged postsynaptic rate, we take: .. math:: r_{\\theta} = < r_i > ST refers to synapse state (note that BCM learning rule can be implemented as synapses), members of ST are listed below: ================ ================= ========================================================= **Member name** **Initial Value** **Explanation** ---------------- ----------------- --------------------------------------------------------- w 1. Synapse weights. ================ ================= ========================================================= Note that all ST members are saved as floating point type in BrainPy, though some of them represent other data types (such as boolean). Args: learning_rate (float): learning rate of the synapse weights. w_max (float): Maximum of the synapse weights. w_min (float): Minimum of the synapse weights. r_0 (float): Minimal plasticity threshold. Returns: bp.Syntype: return description of BCM rule. References: .. [1] Gerstner, Wulfram, et al. Neuronal dynamics: From single neurons to networks and models of cognition. Cambridge University Press, 2014. """ ST=bp.types.SynState( {'w': 1., 'dwdt': 0.}, help='BCM synapse state.') requires = dict( pre=bp.types.NeuState( ['r'], help='Pre-synaptic neuron state must have "spike" item.'), post=bp.types.NeuState( ['r'], help='Post-synaptic neuron state must have "spike" item.'), r_th = bp.types.Array(dim=1), post_r = bp.types.Array(dim=1), sum_post_r = bp.types.Array(dim=1) ) @bp.integrate def int_w(w, t, r_pre, r_post, r_th): dwdt = learning_rate * r_post * (r_post - r_th) * r_pre return (dwdt,),dwdt if mode == 'scalar': raise ValueError("mode of function '%s' can not be '%s'." % (sys._getframe().f_code.co_name, mode)) elif mode == 'vector': requires['post2syn']=bp.types.ListConn() requires['post2pre']=bp.types.ListConn() def learn(ST, _t, pre, post, post2syn, post2pre, r_th, sum_post_r, post_r): for post_id , post_r_i in enumerate(post['r']): if post_r_i < r_0: post_r[post_id] = post_r_i elif post2syn[post_id].size > 0 and post2pre[post_id].size > 0: # mapping syn_ids = post2syn[post_id] pre_ids = post2pre[post_id] pre_r = pre['r'][pre_ids] w = ST['w'][syn_ids] # threshold sum_post_r[post_id] += post_r_i r_threshold = sum_post_r[post_id] / (_t / bp.profile._dt + 1) r_th[post_id] = r_threshold # BCM rule w, dw = int_w(w, _t, pre_r, post_r_i, r_th[post_id]) w = np.clip(w, w_min, w_max) ST['w'][syn_ids] = w ST['dwdt'][syn_ids] = dw # output next_post_r = np.dot(w, pre_r) post_r[post_id] = next_post_r @bp.delayed def output(post, post_r): post['r'] = post_r elif mode == 'matrix': requires['conn_mat']=bp.types.MatConn() def learn(ST, _t, pre, post, conn_mat, r_th, sum_post_r, post_r): post_r_i = post['r'] w = ST['w'] * conn_mat pre_r = pre['r'] # threshold sum_post_r += post_r_i r_th = sum_post_r / (_t / bp.profile._dt + 1) # BCM rule dim = np.shape(w) reshape_th = np.vstack((r_th,)*dim[0]) reshape_post = np.vstack((post_r_i,)*dim[0]) reshape_pre = np.vstack((pre_r,)*dim[1]).T w, dw = int_w(w, _t, reshape_pre, reshape_post, reshape_th) w = np.clip(w, w_min, w_max) ST['w'] = w ST['dwdt'] = dw # output next_post_r = np.dot(w.T, pre_r) post_r = next_post_r @bp.delayed def output(post, post_r): post['r'] = post_r else: raise ValueError("BrainPy does not support mode '%s'." % (mode)) return bp.SynType(name='BCM_synapse', ST=ST, requires=requires, steps=[learn, output], mode=mode)
def get_AMPA2(g_max=0.42, E=0., alpha=0.98, beta=0.18, T=0.5, T_duration=0.5, mode='vector'): """AMPA conductance-based synapse (type 2). .. math:: I_{syn}&=\\bar{g}_{syn} s (V-E_{syn}) \\frac{ds}{dt} &=\\alpha[T](1-s)-\\beta s ST refers to the synapse state, items in ST are listed below: ================ ================== ========================================================= **Member name** **Initial values** **Explanation** ---------------- ------------------ --------------------------------------------------------- s 0 Gating variable. g 0 Synapse conductance. t_last_pre_spike -1e7 Last spike time stamp of the pre-synaptic neuron. ================ ================== ========================================================= Note that all ST members are saved as floating point type in BrainPy, though some of them represent other data types (such as boolean). Args: g_max (float): Maximum conductance in µmho (µS). E (float): The reversal potential for the synaptic current. alpha (float): Binding constant. beta (float): Unbinding constant. T (float): Neurotransmitter binding coefficient. T_duration (float): Duration of the binding of neurotransmitter. Returns: bp.Neutype References: .. [1] Vijayan S, Kopell N J. Thalamic model of awake alpha oscillations and implications for stimulus processing[J]. Proceedings of the National Academy of Sciences, 2012, 109(45): 18553-18558. """ @bp.integrate def int_s(s, _t, TT): return alpha * TT * (1 - s) - beta * s ST = bp.types.SynState( { 's': 0., 't_last_pre_spike': -1e7, 'g': 0. }, help='AMPA synapse state.\n' '"s": Synaptic state.\n' '"t_last_pre_spike": Pre-synaptic neuron spike time.') requires = dict( pre=bp.types.NeuState( ['spike'], help='Pre-synaptic neuron state must have "spike" item.'), post=bp.types.NeuState( ['V', 'input'], help='Post-synaptic neuron state must have "V" and "input" item.')) if mode == 'scalar': def update(ST, _t, pre): if pre['spike'] > 0.: ST['t_last_pre_spike'] = _t TT = ((_t - ST['t_last_pre_spike']) < T_duration) * T s = np.clip(int_s(ST['s'], _t, TT), 0., 1.) ST['s'] = s ST['g'] = g_max * s @bp.delayed def output(ST, post): post_val = -ST['g'] * (post['V'] - E) post['input'] += post_val elif mode == 'vector': requires['pre2syn'] = bp.types.ListConn( help='Pre-synaptic neuron index -> synapse index') requires['post2syn'] = bp.types.ListConn( help='Post-synaptic neuron index -> synapse index') def update(ST, _t, pre, pre2syn): for i in np.where(pre['spike'] > 0.)[0]: syn_idx = pre2syn[i] ST['t_last_pre_spike'][syn_idx] = _t TT = ((_t - ST['t_last_pre_spike']) < T_duration) * T s = np.clip(int_s(ST['s'], _t, TT), 0., 1.) ST['s'] = s ST['g'] = g_max * s @bp.delayed def output(ST, post, post2syn): g = np.zeros(len(post2syn), dtype=np.float_) for post_id, syn_ids in enumerate(post2syn): g[post_id] = np.sum(ST['g'][syn_ids]) post['input'] -= g * (post['V'] - E) elif mode == 'matrix': requires['conn_mat'] = bp.types.MatConn() def update(ST, _t, pre, conn_mat): spike_idxs = np.where(pre['spike'] > 0.)[0] ST['t_last_pre_spike'][spike_idxs] = _t TT = ((_t - ST['t_last_pre_spike']) < T_duration) * T s = np.clip(int_s(ST['s'], _t, TT), 0., 1.) ST['s'] = s ST['g'] = g_max * s @bp.delayed def output(ST, post): g = np.sum(ST['g'], axis=0) post['input'] -= g * (post['V'] - E) else: raise ValueError("BrainPy does not support mode '%s'." % (mode)) return bp.SynType(name='AMPA_synapse', ST=ST, requires=requires, steps=(update, output), mode=mode)
def get_STDP1(g_max=0.10, E=0., tau_decay=10., tau_s=10., tau_t=10., w_min=0., w_max=20., delta_A_s=0.5, delta_A_t=0.5, mode='vector'): """ Spike-time dependent plasticity (in differential form). .. math:: \\frac{d A_{source}}{d t}&=-\\frac{A_{source}}{\\tau_{source}} \\frac{d A_{target}}{d t}&=-\\frac{A_{target}}{\\tau_{target}} After a pre-synaptic spike: .. math:: g_{post}&= g_{post}+w A_{source}&= A_{source} + \\delta A_{source} w&= min([w-A_{target}]^+, w_{max}) After a post-synaptic spike: .. math:: A_{target}&= A_{target} + \\delta A_{target} w&= min([w+A_{source}]^+, w_{max}) ST refers to synapse state (note that STDP learning rule can be implemented as synapses), members of ST are listed below: ================ ================= ========================================================= **Member name** **Initial Value** **Explanation** ---------------- ----------------- --------------------------------------------------------- A_s 0. Source neuron trace. A_t 0. Target neuron trace. g 0. Synapse conductance on post-synaptic neuron. w 0. Synapse weight. ================ ================= ========================================================= Note that all ST members are saved as floating point type in BrainPy, though some of them represent other data types (such as boolean). Args: g_max (float): Maximum conductance. E (float): Reversal potential. tau_decay (float): Time constant of decay. tau_s (float): Time constant of source neuron (i.e. pre-synaptic neuron) tau_t (float): Time constant of target neuron (i.e. post-synaptic neuron) w_min (float): Minimal possible synapse weight. w_max (float): Maximal possible synapse weight. delta_A_s (float): Change on source neuron traces elicited by a source neuron spike. delta_A_t (float): Change on target neuron traces elicited by a target neuron spike. mode (str): Data structure of ST members. Returns: bp.Syntype: return description of STDP. References: .. [1] Stimberg, Marcel, et al. "Equation-oriented specification of neural models for simulations." Frontiers in neuroinformatics 8 (2014): 6. """ ST = bp.types.SynState({ 'A_s': 0., 'A_t': 0., 'g': 0., 'w': 0. }, help='STDP synapse state.') requires_scalar = dict( pre=bp.types.NeuState(['spike'], help='Pre-synaptic neuron state \ must have "spike" item.'), post=bp.types.NeuState(['V', 'input', 'spike'], help='Pre-synaptic neuron state must \ have "V", "input" and "spike" item.'), ) requires_vector = dict( pre=bp.types.NeuState(['spike'], help='Pre-synaptic neuron state \ must have "spike" item.'), post=bp.types.NeuState(['V', 'input', 'spike'], help='Post-synaptic neuron state must \ have "V", "input" and "spike" item.'), pre2syn=bp.types.ListConn( help='Pre-synaptic neuron index -> synapse index'), post2syn=bp.types.ListConn( help='Post-synaptic neuron index -> synapse index'), ) @bp.integrate def int_A_s(A_s, _t): return -A_s / tau_s @bp.integrate def int_A_t(A_t, _t): return -A_t / tau_t @bp.integrate def int_g(g, _t): return -g / tau_decay if mode == 'scalar': def my_relu(w): return w if w > 0 else 0 if mode == 'scalar': def update(ST, _t, pre, post): A_s = int_A_s(ST['A_s'], _t) A_t = int_A_t(ST['A_t'], _t) g = int_g(ST['g'], _t) w = ST['w'] if pre['spike']: g += ST['w'] A_s = A_s + delta_A_s w = np.clip(my_relu(ST['w'] - A_t), w_min, w_max) if post['spike']: A_t = A_t + delta_A_t w = np.clip(my_relu(ST['w'] + A_s), w_min, w_max) ST['A_s'] = A_s ST['A_t'] = A_t ST['g'] = g ST['w'] = w elif mode == 'vector': def update(ST, _t, pre, post, pre2syn, post2syn): A_s = int_A_s(ST['A_s'], _t) A_t = int_A_t(ST['A_t'], _t) g = int_g(ST['g'], _t) w = ST['w'] for i in np.where(pre['spike'] > 0.)[0]: syn_ids = pre2syn[i] g[syn_ids] += ST['w'][syn_ids] A_s[syn_ids] = A_s[syn_ids] + delta_A_s w[syn_ids] = np.clip(ST['w'][syn_ids] - ST['A_t'][syn_ids], w_min, w_max) for i in np.where(post['spike'] > 0.)[0]: syn_ids = post2syn[i] A_t[syn_ids] = A_t[syn_ids] + delta_A_t w[syn_ids] = np.clip(ST['w'][syn_ids] + ST['A_s'][syn_ids], w_min, w_max) ST['A_s'] = A_s ST['A_t'] = A_t ST['g'] = g ST['w'] = w if mode == 'scalar': @bp.delayed def output(ST, post): I_syn = -g_max * ST['g'] * (post['V'] - E) post['input'] += I_syn elif mode == 'vector': @bp.delayed def output(ST, post, post2syn): post_cond = np.zeros(len(post2syn), dtype=np.float_) for post_id, syn_ids in enumerate(post2syn): post_cond[post_id] = np.sum(-g_max * ST['g'][syn_ids] * (post['V'][post_id] - E)) post['input'] += post_cond if mode == 'scalar': return bp.SynType(name='STDP_synapse', ST=ST, requires=requires_scalar, steps=(update, output), mode=mode) elif mode == 'vector': return bp.SynType(name='STDP_synapse', ST=ST, requires=requires_vector, steps=(update, output), mode=mode) elif mode == 'matrix': raise ValueError("mode of function '%s' can not be '%s'." % (sys._getframe().f_code.co_name, mode)) else: raise ValueError("BrainPy does not support mode '%s'." % (mode))
def get_AMPA1(g_max=0.10, E=0., tau_decay=2.0, mode='vector'): """AMPA conductance-based synapse (type 1). .. math:: I(t)&=\\bar{g} s(t) (V-E_{syn}) \\frac{d s}{d t}&=-\\frac{s}{\\tau_{decay}}+\\sum_{k} \\delta(t-t_{j}^{k}) ST refers to the synapse state, items in ST are listed below: =============== ================== ========================================================= **Member name** **Initial values** **Explanation** --------------- ------------------ --------------------------------------------------------- s 0 Gating variable. g 0 Synapse conductance. =============== ================== ========================================================= Note that all ST members are saved as floating point type in BrainPy, though some of them represent other data types (such as boolean). Args: g_max (float): Maximum conductance in µmho (µS). E (float): The reversal potential for the synaptic current. tau_decay (float): The time constant of decay. Returns: bp.Neutype References: .. [1] Brunel N, Wang X J. Effects of neuromodulation in a cortical network model of object working memory dominated by recurrent inhibition[J]. Journal of computational neuroscience, 2001, 11(1): 63-85. """ @bp.integrate def ints(s, _t): return -s / tau_decay ST = bp.types.SynState(['s', 'g'], help='AMPA synapse state.') requires = { 'pre': bp.types.NeuState( ['spike'], help='Pre-synaptic neuron state must have "spike" item.'), 'post': bp.types.NeuState( ['V', 'input'], help='Post-synaptic neuron state must have "V" and "input" item.') } if mode == 'scalar': def update(ST, _t, pre): s = ints(ST['s'], _t) s += pre['spike'] ST['s'] = s ST['g'] = g_max * s @bp.delayed def output(ST, post): post_val = -ST['g'] * (post['V'] - E) post['input'] += post_val elif mode == 'vector': requires['pre2syn'] = bp.types.ListConn( help='Pre-synaptic neuron index -> synapse index') requires['post2syn'] = bp.types.ListConn( help='Post-synaptic neuron index -> synapse index') def update(ST, _t, pre, pre2syn): s = ints(ST['s'], _t) spike_idx = np.where(pre['spike'] > 0.)[0] for i in spike_idx: syn_idx = pre2syn[i] s[syn_idx] += 1. ST['s'] = s ST['g'] = g_max * s @bp.delayed def output(ST, post, post2syn): g = np.zeros(len(post2syn), dtype=np.float_) for post_id, syn_ids in enumerate(post2syn): g[post_id] = np.sum(ST['g'][syn_ids]) post['input'] -= g * (post['V'] - E) elif mode == 'matrix': requires['conn_mat'] = bp.types.MatConn() def update(ST, _t, pre, conn_mat): s = ints(ST['s'], _t) s += pre['spike'].reshape((-1, 1)) * conn_mat ST['s'] = s ST['g'] = g_max * s @bp.delayed def output(ST, post): g = np.sum(ST['g'], axis=0) post['input'] -= g * (post['V'] - E) else: raise ValueError("BrainPy does not support mode '%s'." % (mode)) return bp.SynType(name='AMPA_synapse', ST=ST, requires=requires, steps=(update, output), mode=mode)
def get_exponential(tau_decay=8., mode='scalar'): ''' Exponential decay synapse model. .. math:: I_{syn}(t) &= w s (t) \\frac{d s}{d t}&=-\\frac{s}{\\tau_{decay}}+\\sum_{k} \\delta(t-t_{j}^{k}) ST refers to synapse state, members of ST are listed below: ================ ================== ========================================================= **Member name** **Initial values** **Explanation** ---------------- ------------------ --------------------------------------------------------- s 0 Gating variable. w .1 Synapse weights. g 0 Synapse conductance on the post-synaptic neuron. ================ ================== ========================================================= Note that all ST members are saved as floating point type in BrainPy, though some of them represent other data types (such as boolean). Args: tau_decay (float): The time constant of decay. mode (string): data structure of ST members. Returns: bp.Neutype: return description of exponential synapse model. References: .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw. "The Synapse." Principles of Computational Modelling in Neuroscience. Cambridge: Cambridge UP, 2011. 172-95. Print. ''' @bp.integrate def ints(s, _t): return -s / tau_decay ST = bp.types.SynState({'s': 0., 'w': .1, 'g': 0.}, help='synapse state.') requires = { 'pre': bp.types.NeuState( ['spike'], help='Pre-synaptic neuron state must have "spike" item.'), 'post': bp.types.NeuState( ['V', 'input'], help='Post-synaptic neuron state must have "V" and "input" item.') } if mode == 'scalar': def update(ST, _t, pre): s = ints(ST['s'], _t) s += pre['spike'] ST['s'] = s ST['g'] = ST['w'] * s @bp.delayed def output(ST, post): post['input'] += ST['g'] elif mode == 'vector': requires['pre2syn'] = bp.types.ListConn( help='Pre-synaptic neuron index -> synapse index') requires['post2syn'] = bp.types.ListConn( help='Post-synaptic neuron index -> synapse index') def update(ST, _t, pre, pre2syn): s = ints(ST['s'], _t) spike_idx = np.where(pre['spike'] > 0.)[0] for i in spike_idx: syn_idx = pre2syn[i] s[syn_idx] += 1. ST['s'] = s ST['g'] = ST['w'] * s @bp.delayed def output(ST, post, post2syn): g = np.zeros(len(post2syn), dtype=np.float_) for post_id, syn_ids in enumerate(post2syn): g[post_id] = np.sum(ST['g'][syn_ids]) post['input'] += g elif mode == 'matrix': requires['conn_mat'] = bp.types.MatConn() def update(ST, _t, pre, conn_mat): s = ints(ST['s'], _t) s += pre['spike'].reshape((-1, 1)) * conn_mat ST['s'] = s ST['g'] = ST['w'] * s @bp.delayed def output(ST, post): g = np.sum(ST['g'], axis=0) post['input'] += g else: raise ValueError("BrainPy does not support mode '%s'." % (mode)) return bp.SynType(name='exponential_synapse', ST=ST, requires=requires, steps=(update, output), mode=mode)
def get_STP(U=0.15, tau_f=1500., tau_d=200., mode='vector'): """Short-term plasticity proposed by Tsodyks and Markram (Tsodyks 98) [1]_. The model is given by .. math:: \\frac{du}{dt} = -\\frac{u}{\\tau_f}+U(1-u^-)\\delta(t-t_{spike}) \\frac{dx}{dt} = \\frac{1-x}{\\tau_d}-u^+x^-\\delta(t-t_{spike}) where :math:`t_{spike}` denotes the spike time and :math:`U` is the increment of :math:`u` produced by a spike. The synaptic current generated at the synapse by the spike arriving at :math:`t_{spike}` is then given by .. math:: \\Delta I(t_{spike}) = Au^+x^- where :math:`A` denotes the response amplitude that would be produced by total release of all the neurotransmitter (:math:`u=x=1`), called absolute synaptic efficacy of the connections. ST refers to the synapse state, items in ST are listed below: =============== ================== ===================================================================== **Member name** **Initial values** **Explanation** --------------- ------------------ --------------------------------------------------------------------- u 0 Release probability of the neurotransmitters. x 1 A Normalized variable denoting the fraction of remain neurotransmitters. w 1 Synapse weight. g 0 Synapse conductance. =============== ================== ===================================================================== Note that all ST members are saved as floating point type in BrainPy, though some of them represent other data types (such as boolean). Parameters ---------- tau_d : float Time constant of short-term depression. tau_f : float Time constant of short-term facilitation . U : float The increment of :math:`u` produced by a spike. x0 : float Initial value of :math:`x`. u0 : float Initial value of :math:`u`. References ---------- .. [1] Tsodyks, Misha, Klaus Pawelzik, and Henry Markram. "Neural networks with dynamic synapses." Neural computation 10.4 (1998): 821-835. """ @bp.integrate def int_u(u, _t): return -u / tau_f @bp.integrate def int_x(x, _t): return (1 - x) / tau_d ST = bp.types.SynState({'u': 0., 'x': 1., 'w': 1., 'g': 0.}) requires = dict(pre=bp.types.NeuState(['spike']), post=bp.types.NeuState(['V', 'input'])) if mode == 'scalar': def update(ST, pre): u = int_u(ST['u'], 0) x = int_x(ST['x'], 0) if pre['spike'] > 0.: u += U * (1 - ST['u']) x -= u * ST['x'] ST['u'] = np.clip(u, 0., 1.) ST['x'] = np.clip(x, 0., 1.) ST['g'] = ST['w'] * ST['u'] * ST['x'] @bp.delayed def output(ST, post): post['input'] += ST['g'] elif mode == 'vector': requires['pre2syn'] = bp.types.ListConn( help='Pre-synaptic neuron index -> synapse index') requires['post2syn'] = bp.types.ListConn( help='Post-synaptic neuron index -> synapse index') def update(ST, pre, pre2syn): u = int_u(ST['u'], 0) x = int_x(ST['x'], 0) for pre_id in np.where(pre['spike'] > 0.)[0]: syn_ids = pre2syn[pre_id] u_syn = u[syn_ids] + U * (1 - ST['u'][syn_ids]) u[syn_ids] = u_syn x[syn_ids] -= u_syn * ST['x'][syn_ids] ST['u'] = np.clip(u, 0., 1.) ST['x'] = np.clip(x, 0., 1.) ST['g'] = ST['w'] * ST['u'] * ST['x'] @bp.delayed def output(ST, post, post2syn): g = np.zeros(len(post2syn), dtype=np.float_) for post_id, syn_ids in enumerate(post2syn): g[post_id] = np.sum(ST['g'][syn_ids]) post['input'] += g elif mode == 'matrix': requires['conn_mat'] = bp.types.MatConn() def update(ST, pre, conn_mat): u = int_u(ST['u'], 0) x = int_x(ST['x'], 0) spike_idxs = np.where(pre['spike'] > 0.)[0] # u_syn = u[spike_idxs] + U * (1 - ST['u'][spike_idxs]) u[spike_idxs] = u_syn x[spike_idxs] -= u_syn * ST['x'][spike_idxs] # ST['u'] = np.clip(u, 0., 1.) ST['x'] = np.clip(x, 0., 1.) ST['g'] = ST['w'] * ST['u'] * ST['x'] @bp.delayed def output(ST, post): g = np.sum(ST['g'], axis=0) post['input'] += g return bp.SynType(name='STP_synapse', ST=ST, requires=requires, steps=(update, output), mode=mode)
else: ST['sp'] = 0. neuron = bp.NeuType(name='CUBA', ST=neu_ST, steps=neu_update, mode='scalar') def update1(pre, post, pre2post): for pre_id in range(len(pre2post)): if pre['sp'][pre_id] > 0.: post_ids = pre2post[pre_id] for i in post_ids: post['ge'][i] += we exc_syn = bp.SynType('exc_syn', steps=update1, ST=bp.types.SynState()) def update2(pre, post, pre2post): for pre_id in range(len(pre2post)): if pre['sp'][pre_id] > 0.: post_ids = pre2post[pre_id] for i in post_ids: post['gi'][i] += wi inh_syn = bp.SynType('inh_syn', steps=update2, ST=bp.types.SynState()) group = bp.NeuGroup(neuron, geometry=num_exc + num_inh, monitors=['sp'])
def get_alpha(g_max=.2, E=0., tau_decay = 2.): """ Alpha conductance-based synapse. .. math:: I_{syn}(t) &= g_{syn} (t) (V(t)-E_{syn}) g_{syn} (t) &= \\sum \\bar{g}_{syn} \\frac{t-t_f} {\\tau} exp(- \\frac{t-t_f}{\\tau}) ST refers to the synapse state, items in ST are listed below: ================ ================== ========================================================= **Member name** **Initial values** **Explanation** ---------------- ------------------ --------------------------------------------------------- g 0 Synapse conductance on the post-synaptic neuron. t_last_pre_spike -1e7 Last spike time stamp of the pre-synaptic neuron. ================ ================== ========================================================= Note that all ST members are saved as floating point type in BrainPy, though some of them represent other data types (such as boolean). Args: g_max (float): The peak conductance change in µmho (µS). E (float): The reversal potential for the synaptic current. tau_decay (float): The time constant of decay. Returns: bp.Neutype References: .. [1] Sterratt, David, Bruce Graham, Andrew Gillies, and David Willshaw. "The Synapse." Principles of Computational Modelling in Neuroscience. Cambridge: Cambridge UP, 2011. 172-95. Print. """ ST=bp.types.SynState({'g': 0., 't_last_pre_spike': -1e7}, help='The conductance defined by exponential function.') requires = { 'pre': bp.types.NeuState(['spike'], help='pre-synaptic neuron state must have "V"'), 'post': bp.types.NeuState(['input', 'V'], help='post-synaptic neuron state must include "input" and "V"'), 'pre2syn': bp.types.ListConn(help='Pre-synaptic neuron index -> synapse index'), 'post2syn': bp.types.ListConn(help='Post-synaptic neuron index -> synapse index'), } def update(ST, _t, pre, pre2syn): for pre_idx in np.where(pre['spike'] > 0.)[0]: syn_idx = pre2syn[pre_idx] ST['t_last_pre_spike'][syn_idx] = _t c = (_t-ST['t_last_pre_spike']) / tau_decay g = g_max * np.exp(-c) * c ST['g'] = g @bp.delayed def output(ST, post, post2syn): I_syn = np.zeros(len(post2syn), dtype=np.float_) for post_id, syn_ids in enumerate(post2syn): I_syn[post_id] = np.sum(ST['g'][syn_ids]*(post['V'] - E)) post['input'] -= I_syn return bp.SynType(name='alpha_synapse', requires=requires, ST=ST, steps=(update, output), mode = 'vector')
def test_hh(num, device): print('Scale:{}, Model:HH, Device:{}, '.format(num, device), end='') st_build = time.time() bp.profile.set(jit=True, device=device, dt=0.1, numerical_method='exponential') num_exc = int(num * 0.8) num_inh = int(num * 0.2) num = num_exc + num_inh Cm = 200 # Membrane Capacitance [pF] gl = 10. # Leak Conductance [nS] El = -60. # Resting Potential [mV] g_Na = 20. * 1000 ENa = 50. # reversal potential (Sodium) [mV] g_Kd = 6. * 1000 # K Conductance [nS] EK = -90. # reversal potential (Potassium) [mV] VT = -63. Vt = -20. # Time constants taue = 5. # Excitatory synaptic time constant [ms] taui = 10. # Inhibitory synaptic time constant [ms] # Reversal potentials Ee = 0. # Excitatory reversal potential (mV) Ei = -80. # Inhibitory reversal potential (Potassium) [mV] # excitatory synaptic weight we = 6.0 * np.sqrt(3200) / np.sqrt( num_exc) # excitatory synaptic conductance [nS] # inhibitory synaptic weight wi = 67.0 * np.sqrt(800) / np.sqrt( num_inh) # inhibitory synaptic conductance [nS] inf = 0.05 neu_ST = bp.types.NeuState('V', 'm', 'n', 'h', 'sp', 'ge', 'gi') @bp.integrate def int_ge(ge, t): return -ge / taue @bp.integrate def int_gi(gi, t): return -gi / taui @bp.integrate def int_m(m, t, V): a = 13.0 - V + VT b = V - VT - 40.0 m_alpha = 0.32 * a / (exp(a / 4.) - 1.) m_beta = 0.28 * b / (exp(b / 5.) - 1.) dmdt = (m_alpha * (1. - m) - m_beta * m) return dmdt @bp.integrate def int_m_zeroa(m, t, V): b = V - VT - 40.0 m_alpha = 0.32 m_beta = 0.28 * b / (exp(b / 5.) - 1.) dmdt = (m_alpha * (1. - m) - m_beta * m) return dmdt @bp.integrate def int_m_zerob(m, t, V): a = 13.0 - V + VT m_alpha = 0.32 * a / (exp(a / 4.) - 1.) m_beta = 0.28 dmdt = (m_alpha * (1. - m) - m_beta * m) return dmdt @bp.integrate def int_h(h, t, V): h_alpha = 0.128 * exp((17. - V + VT) / 18.) h_beta = 4. / (1. + exp(-(V - VT - 40.) / 5.)) dhdt = (h_alpha * (1. - h) - h_beta * h) return dhdt @bp.integrate def int_n(n, t, V): c = 15. - V + VT n_alpha = 0.032 * c / (exp(c / 5.) - 1.) n_beta = .5 * exp((10. - V + VT) / 40.) dndt = (n_alpha * (1. - n) - n_beta * n) return dndt @bp.integrate def int_n_zero(n, t, V): n_alpha = 0.032 n_beta = .5 * exp((10. - V + VT) / 40.) dndt = (n_alpha * (1. - n) - n_beta * n) return dndt @bp.integrate def int_V(V, t, m, h, n, ge, gi): g_na_ = g_Na * (m * m * m) * h g_kd_ = g_Kd * (n * n * n * n) dvdt = (gl * (El - V) + ge * (Ee - V) + gi * (Ei - V) - g_na_ * (V - ENa) - g_kd_ * (V - EK)) / Cm return dvdt def neu_update(ST, _t): ST['ge'] = int_ge(ST['ge'], _t) ST['gi'] = int_gi(ST['gi'], _t) if abs(ST['V'] - (40.0 + VT)) < inf: ST['m'] = int_m_zerob(ST['m'], _t, ST['V']) elif abs(ST['V'] - (13.0 + VT)) < inf: ST['m'] = int_m_zeroa(ST['m'], _t, ST['V']) else: ST['m'] = int_m(ST['m'], _t, ST['V']) ST['h'] = int_h(ST['h'], _t, ST['V']) if abs(ST['V'] - (15.0 + VT)) > inf: ST['n'] = int_n(ST['n'], _t, ST['V']) else: ST['n'] = int_n_zero(ST['n'], _t, ST['V']) V = int_V(ST['V'], _t, ST['m'], ST['h'], ST['n'], ST['ge'], ST['gi']) sp = ST['V'] < Vt and V >= Vt ST['sp'] = sp ST['V'] = V neuron = bp.NeuType(name='CUBA-HH', ST=neu_ST, steps=neu_update, mode='scalar') requires_exc = { 'pre': bp.types.NeuState( ['sp'], help='Pre-synaptic neuron state must have "spike" item.'), 'post': bp.types.NeuState( ['ge'], help='Post-synaptic neuron state must have "V" and "input" item.') } def update_syn_exc(ST, pre, post): if pre['sp']: post['ge'] += we exc_syn = bp.SynType(name='exc_syn', ST=bp.types.SynState(), requires=requires_exc, steps=update_syn_exc, mode='scalar') requires_inh = { 'pre': bp.types.NeuState( ['sp'], help='Pre-synaptic neuron state must have "spike" item.'), 'post': bp.types.NeuState( ['gi'], help='Post-synaptic neuron state must have "V" and "input" item.') } def update_syn_inh(ST, pre, post): if pre['sp']: post['gi'] -= wi inh_syn = bp.SynType(name='inh_syn', ST=bp.types.SynState(), requires=requires_inh, steps=update_syn_inh, mode='scalar') group = bp.NeuGroup(neuron, geometry=num) group.ST['V'] = El + (np.random.randn(num_exc + num_inh) * 5. - 5.) group.ST['ge'] = (np.random.randn(num_exc + num_inh) * 1.5 + 4.) * 10. group.ST['gi'] = (np.random.randn(num_exc + num_inh) * 12. + 20.) * 10. exc_conn = bp.SynConn(exc_syn, pre_group=group[:num_exc], post_group=group, conn=bp.connect.FixedProb(prob=0.02)) inh_conn = bp.SynConn(inh_syn, pre_group=group[num_exc:], post_group=group, conn=bp.connect.FixedProb(prob=0.02)) net = bp.Network(group, exc_conn, inh_conn) ed_build = time.time() st_run = time.time() net.run(duration=1000.0) ed_run = time.time() build_time = float(ed_build - st_build) run_time = float(ed_run - st_run) print('BuildT:{:.2f}s, RunT:{:.2f}s'.format(build_time, run_time)) return run_time, build_time
post_ids = pre2post[pre_id] # post['ge'][post_ids] += we for p_id in post_ids: post['ge'][p_id] += we def inh_update(pre, post, pre2post): for pre_id in range(len(pre2post)): if pre['sp'][pre_id] > 0.: post_ids = pre2post[pre_id] # post['gi'][post_ids] += wi for p_id in post_ids: post['gi'][p_id] += wi exc_syn = bp.SynType('exc_syn', steps=exc_update, ST=bp.types.SynState()) inh_syn = bp.SynType('inh_syn', steps=inh_update, ST=bp.types.SynState()) group = bp.NeuGroup(neuron, size=num_exc + num_inh, monitors=['sp']) group.ST['V'] = El + (np.random.randn(num_exc + num_inh) * 5 - 5) group.ST['ge'] = (np.random.randn(num_exc + num_inh) * 1.5 + 4) * 10. group.ST['gi'] = (np.random.randn(num_exc + num_inh) * 12 + 20) * 10. exc_conn = bp.TwoEndConn(exc_syn, pre=group[:num_exc], post=group, conn=bp.connect.FixedProb(prob=0.02)) inh_conn = bp.TwoEndConn(inh_syn, pre=group[num_exc:],
def get_GABAb2(g_max=0.02, E=-95., k1=0.66, k2=0.02, k3=0.0053, k4=0.017, k5=8.3e-5, k6=7.9e-3, kd=100., T=0.5, T_duration=0.5, mode='vector'): """ GABAb conductance-based synapse model (markov form). G-protein cascade occurs in the following steps: (i) the transmitter binds to the receptor, leading to its activated form; (ii) the activated receptor catalyzes the activation of G proteins; (iii) G proteins bind to open K+ channel, with n(=4) independent binding sites. .. math:: &\\frac{d[D]}{dt}=K_{4}[R]-K_{3}[D] &\\frac{d[R]}{dt}=K_{1}[T](1-[R]-[D])-K_{2}[R]+K_{3}[D] &\\frac{d[G]}{dt}=K_{5}[R]-K_{6}[G] I_{GABA_{B}}&=\\bar{g}_{GABA_{B}} \\frac{[G]^{n}}{[G]^{n}+K_{d}}(V-E_{GABA_{B}}) - [R] is the fraction of activated receptor. - [D] is the fraction of desensitized receptor. - [G] is the concentration of activated G-protein (μM). - [T] is the transmitter concentration. ST refers to synapse state, members of ST are listed below: ================ ================= ========================================================= **Member name** **Initial Value** **Explanation** ---------------- ----------------- --------------------------------------------------------- D 0. The fraction of desensitized receptor. R 0. The fraction of activated receptor. G 0. The concentration of activated G protein. g 0. Synapse conductance on post-synaptic neuron. t_last_pre_spike -1e7 Last spike time stamp of pre-synaptic neuron. ================ ================= ========================================================= Note that all ST members are saved as floating point type in BrainPy, though some of them represent other data types (such as boolean). Args: g_max (float): Maximum synapse conductance. E (float): Reversal potential of synapse. k1 (float): Activating rate constant of GABAb receptor. k2 (float): De-activating rate constant of GABAb receptor. k3 (float): Activating rate constant of desensitized GABAb receptor. k4 (float): Desensitizing rate constant of activated GABAb receptor. k5 (float): Activating rate constant of G protein catalyzed by activated GABAb receptor. k6 (float): De-activating rate constant of activated G protein. kd (float): Dissociation constant of the binding of G protein on K+ channels. T (float): Transmitter concentration when synapse is triggered by a pre-synaptic spike. T_duration (float): Transmitter concentration duration time after being triggered. mode (str): Data structure of ST members. Returns: bp.SynType: return decription of GABAb synapse model. References: .. [1] Destexhe, Alain, et al. "G-protein activation kinetics and spillover of GABA may account for differences between inhibitory responses in the hippocampus and thalamus." Proc. Natl. Acad. Sci. USA v92 (1995): 9515-9519. """ ST = bp.types.SynState( { 'D': 0., 'R': 0., 'G': 0., 'g': 0., 't_last_pre_spike': -1e7 }, help="GABAb synapse state") requires = dict( pre=bp.types.NeuState( ['spike'], help="Pre-synaptic neuron state must have 'spike' item"), post=bp.types.NeuState( ['V', 'input'], help="Post-synaptic neuron state must have 'V' and 'input' item"), pre2syn=bp.types.ListConn( help="Pre-synaptic neuron index -> synapse index"), post2syn=bp.types.ListConn( help="Post-synaptic neuron index -> synapse index"), ) @bp.integrate def int_D(D, t, R): return k4 * R - k3 * D @bp.integrate def int_R(R, t, TT, D): return k1 * TT * (1 - R - D) - k2 * R + k3 * D @bp.integrate def int_G(G, t, R): return k5 * R - k6 * G def update(ST, _t, pre, pre2syn): # calculate synaptic state for pre_id in np.where(pre['spike'] > 0.)[0]: syn_ids = pre2syn[pre_id] ST['t_last_pre_spike'][syn_ids] = _t TT = ((_t - ST['t_last_pre_spike']) < T_duration) * T D = int_D(ST['D'], _t, ST['R']) R = int_R(ST['R'], _t, TT, D) G = int_G(ST['G'], _t, R) ST['D'] = D ST['R'] = R ST['G'] = G ST['g'] = -g_max * (G**4 / (G**4 + kd)) @bp.delayed def output(ST, post, post2syn): post_cond = np.zeros(len(post2syn), dtype=np.float_) for post_id, syn_ids in enumerate(post2syn): post_cond[post_id] = np.sum(ST['g'][syn_ids]) post['input'] += post_cond * (post['V'] - E) if mode == 'scalar': raise ValueError("mode of function '%s' can not be '%s'." % (sys._getframe().f_code.co_name, mode)) elif mode == 'vector': return bp.SynType(name='GABAb2_synapse', ST=ST, requires=requires, steps=(update, output), mode=mode) elif mode == 'matrix': raise ValueError("mode of function '%s' can not be '%s'." % (sys._getframe().f_code.co_name, mode)) else: raise ValueError("BrainPy does not support mode '%s'." % (mode))