def __init__(self, size, ENa=50., gNa=120., EK=-77., gK=36., EL=-54.387, gL=0.03, V_th=20., C=1.0, method='exp_auto', name=None): # initialization super(HH, self).__init__(size=size, method=method, name=name) # parameters self.ENa = ENa self.EK = EK self.EL = EL self.gNa = gNa self.gK = gK self.gL = gL self.C = C self.V_th = V_th # variables self.m = bm.Variable(0.5 * bm.ones(self.num)) self.h = bm.Variable(0.6 * bm.ones(self.num)) self.n = bm.Variable(0.32 * bm.ones(self.num))
def run_integrator(method, show=False): f_integral = bm.jit(method(f_lorenz, dt=dt), auto_infer=False) x, y, z = bm.ones(1), bm.ones(1), bm.ones(1) def f(t): x.value, y.value, z.value = f_integral(x, y, z, t) f_scan = bm.make_loop(f, dyn_vars=[x, y, z], out_vars=[x, y, z]) times = np.arange(0, duration, dt) mon_x, mon_y, mon_z = f_scan(times) mon_x = np.array(mon_x).flatten() mon_y = np.array(mon_y).flatten() mon_z = np.array(mon_z).flatten() if show: fig = plt.figure() ax = fig.gca(projection='3d') plt.plot(mon_x, mon_y, mon_z) ax.set_xlabel('x') ax.set_xlabel('y') ax.set_xlabel('z') plt.show() return mon_x, mon_y, mon_z
def __init__(self, size, ENa=50., gNa=120., EK=-77., gK=36., EL=-54.387, gL=0.03, V_th=20., C=1.0, name=None): super(HH, self).__init__(size=size, name=name) # parameters self.ENa = ENa self.EK = EK self.EL = EL self.C = C self.gNa = gNa self.gK = gK self.gL = gL self.V_th = V_th # variables self.V = bm.Variable(bm.ones(self.num) * -65.) self.m = bm.Variable(0.5 * bm.ones(self.num)) self.h = bm.Variable(0.6 * bm.ones(self.num)) self.n = bm.Variable(0.32 * bm.ones(self.num)) self.spike = bm.Variable(bm.zeros(size, dtype=bool)) self.input = bm.Variable(bm.zeros(size)) # integral functions self.int_h = bp.ode.ExpEulerAuto(self.dh) self.int_n = bp.ode.ExpEulerAuto(self.dn) self.int_m = bp.ode.ExpEulerAuto(self.dm) self.int_V = bp.ode.ExpEulerAuto(self.dV)
def run_integrator(method, show=False, tol=0.001, adaptive=True): f_integral = method(f_lorenz, adaptive=adaptive, tol=tol, show_code=True) x, y, z = bm.ones(1), bm.ones(1), bm.ones(1) dt = bm.ones(1) * 0.01 def f(t): x.value, y.value, z.value, dt[:] = f_integral(x, y, z, t, dt=dt.value) f_scan = bm.make_loop(f, dyn_vars=[x, y, z, dt], out_vars=[x, y, z, dt]) times = bm.arange(0, duration, _dt) mon_x, mon_y, mon_z, mon_dt = f_scan(times.value) mon_x = np.array(mon_x).flatten() mon_y = np.array(mon_y).flatten() mon_z = np.array(mon_z).flatten() mon_dt = np.array(mon_dt).flatten() if show: fig = plt.figure() ax = fig.gca(projection='3d') plt.plot(mon_x, mon_y, mon_z) ax.set_xlabel('x') ax.set_xlabel('y') ax.set_xlabel('z') plt.show() plt.plot(mon_dt) plt.show() return mon_x, mon_y, mon_z, mon_dt
def __init__(self, size, method, d=1., F=96.489, C_rest=0.05, tau=5., C_0=2., T=36., R=8.31441, name=None): super(CaDyn, self).__init__(size, method, name=name) self.R = R # gas constant, J*mol-1*K-1 self.T = T self.d = d self.F = F self.tau = tau self.C_rest = C_rest self.C_0 = C_0 # Concentration of the Calcium self.C = bm.Variable(bm.ones(self.num, dtype=bm.float_) * self.C_rest) # The dynamical reversal potential self.E = bm.Variable(bm.ones(self.num, dtype=bm.float_) * 120.) # Used to receive all Calcium currents self.I_Ca = bm.Variable(bm.zeros(self.num, dtype=bm.float_))
def test_return1(self): def f(x, y): dx = x ** 2 + y ** 2 + 10 return dx _x = bm.ones(5) _y = bm.ones(5) g, value = bm.vector_grad(f, return_value=True)(_x, _y) pprint(g, ) pprint(value) self.assertTrue(bm.array_equal(g, 2 * _x)) self.assertTrue(bm.array_equal(value, _x ** 2 + _y ** 2 + 10))
def test_aux1(self): def f(x, y): dx = x ** 2 + y ** 2 + 10 dy = x ** 3 + y ** 3 - 10 return dx, dy _x = bm.ones(5) _y = bm.ones(5) g, aux = bm.vector_grad(f, has_aux=True)(_x, _y) pprint(g, ) pprint(aux) self.assertTrue(bm.array_equal(g, 2 * _x)) self.assertTrue(bm.array_equal(aux, _x ** 3 + _y ** 3 - 10))
def __init__(self, size, a=1., b=3., c=1., d=5., r=0.01, s=4., V_rest=-1.6, V_th=1.0, method='exp_auto', name=None): # initialization super(HindmarshRose, self).__init__(size=size, method=method, name=name) # parameters self.a = a self.b = b self.c = c self.d = d self.r = r self.s = s self.V_th = V_th self.V_rest = V_rest # variables self.z = bm.Variable(bm.zeros(self.num)) self.y = bm.Variable(bm.ones(self.num) * -10.)
def test_return_aux1(self): def f(x, y): dx = x ** 2 + y ** 2 + 10 dy = x ** 3 + y ** 3 - 10 return dx, dy _x = bm.ones(5) _y = bm.ones(5) g, value, aux = bm.vector_grad(f, has_aux=True, return_value=True)(_x, _y) print('grad', g) print('value', value) print('aux', aux) self.assertTrue(bm.array_equal(g, 2 * _x)) self.assertTrue(bm.array_equal(value, _x ** 2 + _y ** 2 + 10)) self.assertTrue(bm.array_equal(aux, _x ** 3 + _y ** 3 - 10))
def __init__( self, size, tau_neu=10., tau_syn=0.5, tau_ref=2., V_reset=-65., V_th=-50., Cm=0.25, ): super(LIF, self).__init__(size=size) # parameters self.tau_neu = tau_neu # membrane time constant [ms] self.tau_syn = tau_syn # Post-synaptic current time constant [ms] self.tau_ref = tau_ref # absolute refractory period [ms] self.Cm = Cm # membrane capacity [nF] self.V_reset = V_reset # reset potential [mV] self.V_th = V_th # fixed firing threshold [mV] self.Iext = 0. # constant external current [nA] # variables self.V = bm.Variable(-65. + 5.0 * bm.random.randn(self.num)) # [mV] self.I = bm.Variable(bm.zeros(self.num)) # synaptic currents [nA] self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7) # function self.integral = bp.odeint(bp.JointEq([self.dV, self.dI]), method='exp_auto')
def __init__(self, size, V_rest=-65., V_reset=-68., V_th=-30., V_T=-59.9, delta_T=3.48, a=1., b=1., tau=10., tau_w=30., R=1., method='exp_auto', name=None): super(AdExIF, self).__init__(size=size, name=name) # parameters self.V_rest = V_rest self.V_reset = V_reset self.V_th = V_th self.V_T = V_T self.delta_T = delta_T self.a = a self.b = b self.tau = tau self.tau_w = tau_w self.R = R # variables self.w = bm.Variable(bm.zeros(self.num)) self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool)) self.V = bm.Variable(bm.zeros(self.num)) self.input = bm.Variable(bm.zeros(self.num)) self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7) # functions self.integral = odeint(method=method, f=JointEq([self.dV, self.dw]))
def test_grad_pure_func_aux1(self): def call(a, b, c): return bm.sum(a + b + c), (bm.sin(100), bm.exp(0.1)) bm.random.seed(1) f_grad = bm.grad(call, argnums=[0, 1, 2]) with pytest.raises(TypeError): f_grad(bm.ones(10), bm.random.randn(10), bm.random.uniform(size=10))
def __init__(self, size, freqs, seed=None, name=None): super(PoissonInput, self).__init__(size=size, name=name) self.freqs = freqs self.dt = bm.get_dt() / 1000. self.size = (size,) if isinstance(size, int) else tuple(size) self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7) self.rng = bm.random.RandomState(seed=seed)
def test_grad_pure_func_2(self): def call(a, b, c): return bm.sum(a + b + c) bm.random.seed(1) a = bm.ones(10) b = bm.random.randn(10) c = bm.random.uniform(size=10) f_grad = bm.grad(call) assert (f_grad(a, b, c) == 1.).all()
def __init__(self, size, V_L=-70., V_reset=-55., V_th=-50., Cm=0.5, gL=0.025, t_refractory=2., **kwargs): super(LIF, self).__init__(size=size, **kwargs) self.V_L = V_L self.V_reset = V_reset self.V_th = V_th self.Cm = Cm self.gL = gL self.t_refractory = t_refractory self.V = bm.Variable(bm.ones(self.num) * V_L) self.input = bm.Variable(bm.zeros(self.num)) self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool)) self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7) self.integral = bp.odeint(self.derivative)
def test2(self): def f(x, y): dx = x ** 2 + y ** 2 + 10 return dx _x = bm.ones(5) _y = bm.ones(5) g = bm.vector_grad(f, argnums=0)(_x, _y) pprint(g) self.assertTrue(bm.array_equal(g, 2 * _x)) g = bm.vector_grad(f, argnums=(0,))(_x, _y) self.assertTrue(bm.array_equal(g[0], 2 * _x)) g = bm.vector_grad(f, argnums=(0, 1))(_x, _y) pprint(g) self.assertTrue(bm.array_equal(g[0], 2 * _x)) self.assertTrue(bm.array_equal(g[1], 2 * _y))
def test_grad_pure_func_aux2(self): def call(a, b, c): return bm.sum(a + b + c), (bm.sin(100), bm.exp(0.1)) bm.random.seed(1) f_grad = bm.grad(call, argnums=[0, 1, 2], has_aux=True) grads, aux = f_grad(bm.ones(10), bm.random.randn(10), bm.random.uniform(size=10)) for g in grads: assert (g == 1.).all() assert aux[0] == bm.sin(100) assert aux[1] == bm.exp(0.1)
def test_grad_pure_func_1(self): def call(a, b, c): return bm.sum(a + b + c) bm.random.seed(1) a = bm.ones(10) b = bm.random.randn(10) c = bm.random.uniform(size=10) f_grad = bm.grad(call, argnums=[0, 1, 2]) grads = f_grad(a, b, c) for g in grads: assert (g == 1.).all()
def test3(self): def f(x, y): dx = x ** 2 + y ** 2 + 10 dy = x ** 3 + y ** 3 - 10 return dx, dy _x = bm.ones(5) _y = bm.ones(5) g = bm.vector_grad(f, argnums=0)(_x, _y) # pprint(g) self.assertTrue(bm.array_equal(g, 2 * _x + 3 * _x ** 2)) g = bm.vector_grad(f, argnums=(0,))(_x, _y) self.assertTrue(bm.array_equal(g[0], 2 * _x + 3 * _x ** 2)) g = bm.vector_grad(f, argnums=(0, 1))(_x, _y) # pprint(g) self.assertTrue(bm.array_equal(g[0], 2 * _x + 3 * _x ** 2)) self.assertTrue(bm.array_equal(g[1], 2 * _y + 3 * _y ** 2))
def test_grad_pure_func_return1(self): def call(a, b, c): return bm.sum(a + b + c) bm.random.seed(1) a = bm.ones(10) b = bm.random.randn(10) c = bm.random.uniform(size=10) f_grad = bm.grad(call, return_value=True) grads, returns = f_grad(a, b, c) assert (grads == 1.).all() assert returns == bm.sum(a + b + c)
def __init__(self, size, method='exp_euler_auto', name=None): super(Neuron, self).__init__(size=size, name=name) # variables self.V = bm.Variable(bm.zeros(self.num)) self.input = bm.Variable(bm.zeros(self.num)) self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7) # integral self.integral = bp.odeint(method=method, f=self.derivative)
def __init__(self, num, method='exp_auto'): super(WilsonCowanModel, self).__init__() # Connection weights self.wEE = 12 self.wEI = 4 self.wIE = 13 self.wII = 11 # Refractory parameter self.r = 1 # Excitatory parameters self.E_tau = 1 # Timescale of excitatory population self.E_a = 1.2 # Gain of excitatory population self.E_theta = 2.8 # Threshold of excitatory population # Inhibitory parameters self.I_tau = 1 # Timescale of inhibitory population self.I_a = 1 # Gain of inhibitory population self.I_theta = 4 # Threshold of inhibitory population # variables self.i = bm.Variable(bm.ones(num)) self.e = bm.Variable(bm.ones(num)) self.Iext = bm.Variable(bm.zeros(num)) # functions def F(x, a, theta): return 1 / (1 + bm.exp(-a * (x - theta))) - 1 / (1 + bm.exp(a * theta)) def de(e, t, i, Iext=0.): x = self.wEE * e - self.wEI * i + Iext return (-e + (1 - self.r * e) * F(x, self.E_a, self.E_theta)) / self.E_tau def di(i, t, e): x = self.wIE * e - self.wII * i return (-i + (1 - self.r * i) * F(x, self.I_a, self.I_theta)) / self.I_tau self.int_e = bp.odeint(de, method=method) self.int_i = bp.odeint(di, method=method)
def test_iter_type_array(self): duration = 10. dt = 0.1 for jit in [True, False]: for run_method in [bp.ReportRunner, bp.StructRunner]: ds = ExampleDS() length = int(duration / dt) runner = run_method(ds, inputs=('o', bm.ones(length), 'iter'), monitors=['o'], dyn_vars=ds.vars(), jit=jit, dt=dt) runner(duration) assert bm.array_equal(runner.mon.o, bm.repeat(bm.arange(length) + 1, 2).reshape((length, 2)))
def __init__(self, size, freq_mean, freq_var, t_interval, **kwargs): super(PoissonStim, self).__init__(size=size, **kwargs) self.freq_mean = freq_mean self.freq_var = freq_var self.t_interval = t_interval self.dt = bm.get_dt() / 1000. self.freq = bm.Variable(bm.zeros(1)) self.freq_t_last_change = bm.Variable(bm.ones(1) * -1e7) self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) self.rng = bm.random.RandomState()
def test_grad_func_return_aux1(self): def call(a, b, c): return bm.sum(a + b + c), (bm.sin(100), bm.exp(0.1)) bm.random.seed(1) a = bm.ones(10) b = bm.random.randn(10) c = bm.random.uniform(size=10) f_grad = bm.grad(call, return_value=True, has_aux=True) grads, returns, aux = f_grad(a, b, c) assert (grads == 1.).all() assert returns == bm.sum(a + b + c) assert aux[0] == bm.sin(100) assert aux[1] == bm.exp(0.1)
def __init__(self, size, a=0.02, b=0.20, c=-65., d=8., tau_ref=0., V_th=30., method='exp_auto', name=None): # initialization super(Izhikevich, self).__init__(size=size, name=name) # params self.a = a self.b = b self.c = c self.d = d self.V_th = V_th self.tau_ref = tau_ref # variables self.u = bm.Variable(bm.ones(self.num)) self.refractory = bm.Variable(bm.zeros(self.num, dtype=bool)) self.V = bm.Variable(bm.zeros(self.num)) self.input = bm.Variable(bm.zeros(self.num)) self.spike = bm.Variable(bm.zeros(self.num, dtype=bool)) self.t_last_spike = bm.Variable(bm.ones(self.num) * -1e7) # functions self.integral = odeint(method=method, f=JointEq([self.dV, self.du]))
def __init__(self, size, ENa=55., EK=-90., EL=-65, C=1.0, gNa=35., gK=9., gL=0.1, V_th=20., phi=5.0, name=None, method='exponential_euler'): super(HH, self).__init__(size=size, name=name) # parameters self.ENa = ENa self.EK = EK self.EL = EL self.C = C self.gNa = gNa self.gK = gK self.gL = gL self.V_th = V_th self.phi = phi # variables self.V = bm.Variable(bm.ones(size) * -65.) self.h = bm.Variable(bm.ones(size) * 0.6) self.n = bm.Variable(bm.ones(size) * 0.32) self.spike = bm.Variable(bm.zeros(size, dtype=bool)) self.input = bm.Variable(bm.zeros(size)) self.int_h = bp.odeint(self.dh, method=method, show_code=True) self.int_n = bp.odeint(self.dn, method=method, show_code=True) self.int_V = bp.odeint(self.dV, method=method, show_code=True)
def __init__(self, method='exp_auto'): super(MeanFieldQIF, self).__init__() # parameters self.tau = 1. # the population time constant self.eta = -5.0 # the mean of a Lorenzian distribution over the neural excitability in the population self.delta = 1.0 # the half-width at half maximum of the Lorenzian distribution over the neural excitability self.J = 15. # the strength of the recurrent coupling inside the population # variables self.r = bm.Variable(bm.ones(1)) self.v = bm.Variable(bm.ones(1)) self.Iext = bm.Variable(bm.zeros(1)) # functions def dr(r, t, v, delta=1.0): return (delta / (bm.pi * self.tau) + 2. * r * v) / self.tau def dv(v, t, r, Iext=0., eta=-5.0): return (v ** 2 + eta + Iext + self.J * r * self.tau - (bm.pi * r * self.tau) ** 2) / self.tau self.int_r = bp.odeint(dr, method=method) self.int_v = bp.odeint(dv, method=method)
def test_constant_delay_uniform_no_batch1(): print() for bk in ['jax', 'numpy']: bm.use_backend(bk) cd = ConstantDelay(size=10, delay=2, dt=0.1) for i in range(cd.num_step): cd.push(bm.ones(cd.shape) * i) cd.update(0, 0) print(cd.pull()) cd.update(0, 0) print(cd.pull()) cd.update(0, 0) a = cd.pull() print(a) print(type(a))
def test_constant_delay_nonuniform_batch1(): print() rng = np.random.RandomState(1234) delays = rng.random(10) * 3 + 0.2 for bk in ['jax', 'numpy']: bm.use_backend(bk) cd = ConstantDelay(size=10, delay=delays, dt=0.1, num_batch=2) for i in range(cd.num_step.max()): cd.push(bm.ones(cd.shape) * i) cd.update(0, 0) print(cd.pull()) cd.update(0, 0) print(cd.pull()) cd.update(0, 0) a = cd.pull() print(a) print(type(a))