def derivative(self, p, t, V): phi = self.T_base**((self.T - 36) / 10) alpha_p = 0.032 * (V - self.V_sh - 15.) / (1. - bm.exp(-(V - self.V_sh - 15.) / 5.)) beta_p = 0.5 * bm.exp(-(V - self.V_sh - 10.) / 40.) dpdt = phi * (alpha_p * (1. - p) - beta_p * p) return dpdt
def dp(self, p, t, V): phi_p = self.T_base_p**((self.T - 24) / 10) p_inf = 1. / (1. + bm.exp(-(V + 52. - self.V_sh) / 7.4)) p_tau = 3. + 1. / (bm.exp((V + 27. - self.V_sh) / 10.) + bm.exp(-(V + 102. - self.V_sh) / 15.)) dpdt = phi_p * (p_inf - p) / p_tau return dpdt
def dq(self, q, t, V): phi_q = self.T_base_q**((self.T - 24) / 10) q_inf = 1. / (1. + bm.exp((V + 80. - self.V_sh) / 5.)) q_tau = 85. + 1. / (bm.exp((V + 48. - self.V_sh) / 4.) + bm.exp(-(V + 407. - self.V_sh) / 50.)) dqdt = phi_q * (q_inf - q) / q_tau return dqdt
def dp(self, p, t, V): phi_p = self.T_base_p**((self.T - 24) / 10) p_inf = 1. / (1 + bm.exp(-(V + 10. - self.V_sh) / 4.)) p_tau = 0.4 + .7 / (bm.exp(-(V + 5. - self.V_sh) / 15.) + bm.exp( (V + 5. - self.V_sh) / 15.)) dpdt = phi_p * (p_inf - p) / p_tau return dpdt
def test_grad_ob_aux_return(self): class Test(bp.Base): def __init__(self): super(Test, self).__init__() self.a = bm.TrainVar(bm.ones(10)) self.b = bm.TrainVar(bm.random.randn(10)) self.c = bm.TrainVar(bm.random.uniform(size=10)) def __call__(self): return bm.sum(self.a + self.b + self.c), (bm.sin(100), bm.exp(0.1)) bm.random.seed(0) t = Test() f_grad = bm.grad(t, grad_vars=[t.a, t.b], dyn_vars=t.vars(), has_aux=True, return_value=True) grads, returns, aux = f_grad() for g in grads: assert (g == 1.).all() assert returns == bm.sum(t.a + t.b + t.c) assert aux[0] == bm.sin(100) assert aux[1] == bm.exp(0.1) t = Test() f_grad = bm.grad(t, grad_vars=t.a, dyn_vars=t.vars(), has_aux=True, return_value=True) grads, returns, aux = f_grad() assert (grads == 1.).all() assert returns == bm.sum(t.a + t.b + t.c) assert aux[0] == bm.sin(100) assert aux[1] == bm.exp(0.1)
def dq(self, q, t, V): phi_q = self.T_base_q**((self.T - 24) / 10) q_inf = 1. / (1. + bm.exp((V + 25. - self.V_sh) / 2.)) q_tau = 300. + 100. / (bm.exp( (V + 40 - self.V_sh) / 9.5) + bm.exp(-(V + 40 - self.V_sh) / 9.5)) dqdt = phi_q * (q_inf - q) / q_tau return dqdt
def dm( m, t, V, ): m_alpha = 0.32 * (13 - V + VT) / (bm.exp((13 - V + VT) / 4) - 1.) m_beta = 0.28 * (V - VT - 40) / (bm.exp((V - VT - 40) / 5) - 1) dmdt = (m_alpha * (1 - m) - m_beta * m) return dmdt
def fz(self, z, t, V): q_inf_by_V = 1. / (1. + bm.exp( (self.q_half - V + self.IT_th) / self.q_k)) # q_{\infty}(V) t8 = bm.exp((self.q_half - z + self.IT_th) / self.q_k) q_inf_by_z = 1. / (1. + t8) # q_{\infty}(z) q_inf_diff_z = t8 / self.q_k / (1. + t8)**2 # q_{\infty}'(z) q_tau_by_V = (85. + 1 / (bm.exp((V + 48. - self.IT_th) / 4.) + bm.exp( -(V + 407. - self.IT_th) / 50.))) / self.phi_q # \tau_q(V) dzdt = (q_inf_by_V - q_inf_by_z) / q_tau_by_V / q_inf_diff_z return dzdt
def derivative(self, p, q, t, V): phi = 3 ** ((self.T - 36) / 10) alpha_p = 0.32 * (V - self.V_sh - 13.) / (1. - bm.exp(-(V - self.V_sh - 13.) / 4.)) beta_p = -0.28 * (V - self.V_sh - 40.) / (1. - bm.exp((V - self.V_sh - 40.) / 5.)) dpdt = phi * (alpha_p * (1. - p) - beta_p * p) alpha_q = 0.128 * bm.exp(-(V - self.V_sh - 17.) / 18.) beta_q = 4. / (1. + bm.exp(-(V - self.V_sh - 40.) / 5.)) dqdt = phi * (alpha_q * (1. - q) - beta_q * q) return dpdt, dqdt
def dV(self, V, t, h, n, Iext): m_alpha = -0.1 * (V + 35) / (bm.exp(-0.1 * (V + 35)) - 1) m_beta = 4 * bm.exp(-(V + 60) / 18) m = m_alpha / (m_alpha + m_beta) INa = self.gNa * m**3 * h * (V - self.ENa) IK = self.gK * n**4 * (V - self.EK) IL = self.gL * (V - self.EL) dVdt = (-INa - IK - IL + Iext) / self.C return dVdt
def test_grad_ob_argnums_aux_return(self): class Test(bp.Base): def __init__(self): super(Test, self).__init__() self.a = bm.TrainVar(bm.ones(10)) self.b = bm.TrainVar(bm.random.randn(10)) self.c = bm.TrainVar(bm.random.uniform(size=10)) def __call__(self, d): return bm.sum(self.a + self.b + self.c + 2 * d), (bm.sin(100), bm.exp(0.1)) bm.random.seed(0) t = Test() f_grad = bm.grad(t, grad_vars=t.vars(), argnums=0, has_aux=True, return_value=True) d = bm.random.random(10) (var_grads, arg_grads), loss, aux = f_grad(d) for g in var_grads.values(): assert (g == 1.).all() assert (arg_grads == 2.).all() assert aux[0] == bm.sin(100) assert aux[1] == bm.exp(0.1) assert loss == t(d)[0] t = Test() f_grad = bm.grad(t, grad_vars=t.vars(), argnums=[0], has_aux=True, return_value=True) d = bm.random.random(10) (var_grads, arg_grads), loss, aux = f_grad(d) for g in var_grads.values(): assert (g == 1.).all() assert (arg_grads[0] == 2.).all() assert aux[0] == bm.sin(100) assert aux[1] == bm.exp(0.1) assert loss == t(d)[0] t = Test() f_grad = bm.grad(t, dyn_vars=t.vars(), argnums=0, has_aux=True, return_value=True) d = bm.random.random(10) arg_grads, loss, aux = f_grad(d) assert (arg_grads == 2.).all() assert aux[0] == bm.sin(100) assert aux[1] == bm.exp(0.1) assert loss == t(d)[0] t = Test() f_grad = bm.grad(t, dyn_vars=t.vars(), argnums=[0], has_aux=True, return_value=True) d = bm.random.random(10) arg_grads, loss, aux = f_grad(d) assert (arg_grads[0] == 2.).all() assert aux[0] == bm.sin(100) assert aux[1] == bm.exp(0.1) assert loss == t(d)[0]
def integral(*args, **kwargs): assert len(args) > 0 dt = kwargs.pop('dt', math.get_dt()) linear, derivative = value_and_grad(*args, **kwargs) phi = math.where(linear == 0., math.ones_like(linear), (math.exp(dt * linear) - 1) / (dt * linear)) return args[0] + dt * phi * derivative
def get_stimulus_by_pos(self, pos): assert bm.size(pos) == 2 x1, x2 = bm.meshgrid(self.x, self.x) value = bm.stack([x1.flatten(), x2.flatten()]).T d = self.dist(bm.abs(bm.asarray(pos) - value)) d = bm.linalg.norm(d, axis=1) d = d.reshape((self.length, self.length)) return self.A * bm.exp(-0.25 * bm.square(d / self.a))
def update(self, _t, _dt): self.pre_spike.push(self.pre.spike) pre_spike = self.pre_spike.pull() self.s.value, self.x.value = self.integral(self.s, self.x, _t) self.x += pre_spike.reshape((-1, 1)) g_inf = 1 / (1 + self.cc_Mg * bm.exp(-0.062 * self.post.V) / 3.57) Iext = bm.dot(self.pre_one, self.s) * (self.post.V - self.E) * g_inf self.post.input += Iext * self.g_max
def make_conn(self, x): assert bm.ndim(x) == 1 x_left = bm.reshape(x, (-1, 1)) x_right = bm.repeat(x.reshape((1, -1)), len(x), axis=0) d = self.dist(x_left - x_right) Jxx = self.J0 * bm.exp( -0.5 * bm.square(d / self.a)) / (bm.sqrt(2 * bm.pi) * self.a) return Jxx
def update(self, _t, _dt): self.pre_spike.push(self.pre.spike) delayed_pre_spike = self.pre_spike.pull() self.g.value, self.x.value = self.integral(self.g, self.x, _t, dt=_dt) self.x += bm.pre2syn(delayed_pre_spike, self.pre_ids) post_g = bm.syn2post(self.g, self.post_ids, self.post.num) g_inf = 1 + self.cc_Mg / self.beta * bm.exp(-self.alpha * self.post.V) self.post.input -= self.g_max * post_g * (self.post.V - self.E) / g_inf
def make_conn(self): x1, x2 = bm.meshgrid(self.x, self.x) value = bm.stack([x1.flatten(), x2.flatten()]).T d = self.dist(bm.abs(value[0] - value)) d = bm.linalg.norm(d, axis=1) d = d.reshape((self.length, self.length)) Jxx = self.J0 * bm.exp( -0.5 * bm.square(d / self.a)) / (bm.sqrt(2 * bm.pi) * self.a) return Jxx
def test_grad_pure_func_aux2(self): def call(a, b, c): return bm.sum(a + b + c), (bm.sin(100), bm.exp(0.1)) bm.random.seed(1) f_grad = bm.grad(call, argnums=[0, 1, 2], has_aux=True) grads, aux = f_grad(bm.ones(10), bm.random.randn(10), bm.random.uniform(size=10)) for g in grads: assert (g == 1.).all() assert aux[0] == bm.sin(100) assert aux[1] == bm.exp(0.1)
def test_grad_func_return_aux1(self): def call(a, b, c): return bm.sum(a + b + c), (bm.sin(100), bm.exp(0.1)) bm.random.seed(1) a = bm.ones(10) b = bm.random.randn(10) c = bm.random.uniform(size=10) f_grad = bm.grad(call, return_value=True, has_aux=True) grads, returns, aux = f_grad(a, b, c) assert (grads == 1.).all() assert returns == bm.sum(a + b + c) assert aux[0] == bm.sin(100) assert aux[1] == bm.exp(0.1)
def derivative(self, p, q, t, V): phi_p = self.T_base_p**((self.T - 24) / 10) p_inf = 1. / (1. + bm.exp(-(V + 59. - self.V_sh) / 6.2)) p_tau = 1. / (bm.exp(-(V + 132. - self.V_sh) / 16.7) + bm.exp( (V + 16.8 - self.V_sh) / 18.2)) + 0.612 dpdt = phi_p * (p_inf - p) / p_tau phi_q = self.T_base_q**((self.T - 24) / 10) q_inf = 1. / (1. + bm.exp((V + 83. - self.V_sh) / 4.)) q_tau = bm.where(V >= (-80. + self.V_sh), bm.exp(-(V + 22. - self.V_sh) / 10.5) + 28., bm.exp((V + 467. - self.V_sh) / 66.6)) dqdt = phi_q * (q_inf - q) / q_tau return dpdt, dqdt
def drivative(V, m, h, n, t, Iext, gNa, ENa, gK, EK, gL, EL, C): alpha = 0.1 * (V + 40) / (1 - bm.exp(-(V + 40) / 10)) beta = 4.0 * bm.exp(-(V + 65) / 18) dmdt = alpha * (1 - m) - beta * m alpha = 0.07 * bm.exp(-(V + 65) / 20.) beta = 1 / (1 + bm.exp(-(V + 35) / 10)) dhdt = alpha * (1 - h) - beta * h alpha = 0.01 * (V + 55) / (1 - bm.exp(-(V + 55) / 10)) beta = 0.125 * bm.exp(-(V + 65) / 80) dndt = alpha * (1 - n) - beta * n I_Na = (gNa * m**3.0 * h) * (V - ENa) I_K = (gK * n**4.0) * (V - EK) I_leak = gL * (V - EL) dVdt = (-I_Na - I_K - I_leak + Iext) / C return dVdt, dmdt, dhdt, dndt
def dV(self, V, t, w, Iext): _tmp = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) dVdt = (- V + self.V_rest + _tmp - self.R * w + self.R * Iext) / self.tau return dVdt
def fV(self, V, t, y, z, Isyn): # m channel t1 = 13. - V + self.NaK_th t1_exp = bm.exp(t1 / 4.) m_alpha_by_V = 0.32 * t1 / (t1_exp - 1.) # \alpha_m(V) m_alpha_by_V_diff = (-0.32 * (t1_exp - 1.) + 0.08 * t1 * t1_exp) / ( t1_exp - 1.)**2 # \alpha_m'(V) t2 = V - 40. - self.NaK_th t2_exp = bm.exp(t2 / 5.) m_beta_by_V = 0.28 * t2 / (t2_exp - 1.) # \beta_m(V) m_beta_by_V_diff = (0.28 * (t2_exp - 1) - 0.056 * t2 * t2_exp) / ( t2_exp - 1)**2 # \beta_m'(V) m_tau_by_V = 1. / self.phi_m / (m_alpha_by_V + m_beta_by_V ) # \tau_m(V) m_inf_by_V = m_alpha_by_V / (m_alpha_by_V + m_beta_by_V ) # \m_{\infty}(V) m_inf_by_V_diff = (m_alpha_by_V_diff * m_beta_by_V - m_alpha_by_V * m_beta_by_V_diff) / \ (m_alpha_by_V + m_beta_by_V) ** 2 # \m_{\infty}'(V) # h channel h_alpha_by_y = 0.128 * bm.exp( (17. - y + self.NaK_th) / 18.) # \alpha_h(y) t3 = bm.exp((40. - y + self.NaK_th) / 5.) h_beta_by_y = 4. / (t3 + 1.) # \beta_h(y) h_inf_by_y = h_alpha_by_y / (h_alpha_by_y + h_beta_by_y ) # h_{\infty}(y) # n channel t5 = (15. - y + self.NaK_th) t5_exp = bm.exp(t5 / 5.) n_alpha_by_y = 0.032 * t5 / (t5_exp - 1.) # \alpha_n(y) t6 = bm.exp((10. - y + self.NaK_th) / 40.) n_beta_y = self.b * t6 # \beta_n(y) n_inf_by_y = n_alpha_by_y / (n_alpha_by_y + n_beta_y) # n_{\infty}(y) # p channel t7 = bm.exp((self.p_half - y + self.IT_th) / self.p_k) p_inf_by_y = 1. / (1. + t7) # p_{\infty}(y) t8 = bm.exp((self.q_half - z + self.IT_th) / self.q_k) q_inf_by_z = 1. / (1. + t8) # q_{\infty}(z) # x gNa = self.g_Na * m_inf_by_V**3 * h_inf_by_y # gNa gK = self.g_K * n_inf_by_y**4 # gK gT = self.g_T * p_inf_by_y * p_inf_by_y * q_inf_by_z # gT FV = gNa + gK + gT + self.g_L + self.g_KL # dF/dV Fm = 3 * self.g_Na * h_inf_by_y * ( V - self.E_Na) * m_inf_by_V * m_inf_by_V * m_inf_by_V_diff # dF/dvm t9 = self.C / m_tau_by_V t10 = FV + Fm t11 = t9 + FV rho_V = (t11 - bm.sqrt(bm.maximum(t11**2 - 4 * t9 * t10, 0.))) / 2 / t10 # rho_V INa = gNa * (V - self.E_Na) IK = gK * (V - self.E_KL) IT = gT * (V - self.E_T) IL = self.g_L * (V - self.E_L) IKL = self.g_KL * (V - self.E_KL) Iext = self.V_factor * Isyn dVdt = rho_V * (-INa - IK - IT - IL - IKL + Iext) / self.C return dVdt
def derivative(self, p, t, V): phi_p = 1.0 / (1 + bm.exp(-(V + 43.) / 5.2)) p_inf = 2.7 / (bm.exp(-(V + 55.) / 15.) + bm.exp( (V + 55.) / 15.)) + 1.6 dpdt = self.phi * (phi_p - p) / p_inf return dpdt
def dh(h, t, V): h_alpha = 0.128 * bm.exp((17 - V + VT) / 18) h_beta = 4. / (1 + bm.exp(-(V - VT - 40) / 5)) dhdt = (h_alpha * (1 - h) - h_beta * h) return dhdt
def derivative(self, V, t, Iext): exp_v = self.delta_T * bm.exp((V - self.V_T) / self.delta_T) dvdt = (-(V - self.V_rest) + exp_v + self.R * Iext) / self.tau return dvdt
def sigmoid(self, x): return self.v_max / (1. + bm.exp(self.r * (self.v0 - x)))
def dn(n, t, V): c = 15 - V + VT n_alpha = 0.032 * c / (bm.exp(c / 5) - 1.) n_beta = .5 * bm.exp((10 - V + VT) / 40) dndt = (n_alpha * (1 - n) - n_beta * n) return dndt
def get_stimulus_by_pos(self, pos): return self.A * bm.exp(-0.25 * bm.square(self.dist(self.x - pos) / self.a))
def make_conn(self): x_left = bm.reshape(self.x, (-1, 1)) x_right = bm.repeat(self.x.reshape((1, -1)), len(self.x), axis=0) d = self.dist(x_left - x_right) conn = self.J0 * bm.exp(-0.5 * bm.square(d / self.a)) / (bm.sqrt(2 * bm.pi) * self.a) return conn