コード例 #1
0
ファイル: random_inits.py プロジェクト: PKU-NIP-Lab/BrainPy
 def __call__(self, shape, dtype=None):
     shape = [size2len(d) for d in shape]
     fan_in, fan_out = _compute_fans(shape,
                                     in_axis=self.in_axis,
                                     out_axis=self.out_axis)
     if self.mode == "fan_in":
         denominator = fan_in
     elif self.mode == "fan_out":
         denominator = fan_out
     elif self.mode == "fan_avg":
         denominator = (fan_in + fan_out) / 2
     else:
         raise ValueError(
             "invalid mode for variance scaling initializer: {}".format(
                 self.mode))
     variance = math.array(self.scale / denominator, dtype=dtype)
     if self.distribution == "truncated_normal":
         # constant is stddev of standard normal truncated to (-2, 2)
         stddev = math.sqrt(variance) / math.array(.87962566103423978,
                                                   dtype)
         res = self.rng.truncated_normal(-2, 2, shape) * stddev
         return math.asarray(res, dtype=dtype)
     elif self.distribution == "normal":
         res = self.rng.normal(size=shape) * math.sqrt(variance)
         return math.asarray(res, dtype=dtype)
     elif self.distribution == "uniform":
         res = self.rng.uniform(low=-1, high=1, size=shape) * math.sqrt(
             3 * variance)
         return math.asarray(res, dtype=dtype)
     else:
         raise ValueError(
             "invalid distribution for variance scaling initializer")
コード例 #2
0
 def make_conn(self, x):
     assert bm.ndim(x) == 1
     x_left = bm.reshape(x, (-1, 1))
     x_right = bm.repeat(x.reshape((1, -1)), len(x), axis=0)
     d = self.dist(x_left - x_right)
     Jxx = self.J0 * bm.exp(
         -0.5 * bm.square(d / self.a)) / (bm.sqrt(2 * bm.pi) * self.a)
     return Jxx
コード例 #3
0
 def make_conn(self):
     x1, x2 = bm.meshgrid(self.x, self.x)
     value = bm.stack([x1.flatten(), x2.flatten()]).T
     d = self.dist(bm.abs(value[0] - value))
     d = bm.linalg.norm(d, axis=1)
     d = d.reshape((self.length, self.length))
     Jxx = self.J0 * bm.exp(
         -0.5 * bm.square(d / self.a)) / (bm.sqrt(2 * bm.pi) * self.a)
     return Jxx
コード例 #4
0
    def __init__(self,
                 num_input,
                 num_hidden,
                 num_output,
                 tau=1.0,
                 dt=0.1,
                 g=1.8,
                 alpha=1.0,
                 **kwargs):
        super(EchoStateNet, self).__init__(**kwargs)

        # parameters
        self.num_input = num_input
        self.num_hidden = num_hidden
        self.num_output = num_output
        self.tau = tau
        self.dt = dt
        self.g = g
        self.alpha = alpha

        # weights
        self.w_ir = bm.random.normal(size=(num_input,
                                           num_hidden)) / bm.sqrt(num_input)
        self.w_rr = g * bm.random.normal(
            size=(num_hidden, num_hidden)) / bm.sqrt(num_hidden)
        self.w_or = bm.random.normal(size=(num_output, num_hidden))
        w_ro = bm.random.normal(size=(num_hidden,
                                      num_output)) / bm.sqrt(num_hidden)
        self.w_ro = bm.Variable(w_ro)

        # variables
        self.h = bm.Variable(bm.random.normal(size=num_hidden) * 0.5)  # hidden
        self.r = bm.Variable(bm.tanh(self.h))  # firing rate
        self.o = bm.Variable(bm.dot(self.r, w_ro))  # output unit
        self.P = bm.Variable(bm.eye(num_hidden) *
                             self.alpha)  # inverse correlation matrix
コード例 #5
0
 def make_conn(self):
   x_left = bm.reshape(self.x, (-1, 1))
   x_right = bm.repeat(self.x.reshape((1, -1)), len(self.x), axis=0)
   d = self.dist(x_left - x_right)
   conn = self.J0 * bm.exp(-0.5 * bm.square(d / self.a)) / (bm.sqrt(2 * bm.pi) * self.a)
   return conn
コード例 #6
0
    def fV(self, V, t, y, z, Isyn):
        # m channel
        t1 = 13. - V + self.NaK_th
        t1_exp = bm.exp(t1 / 4.)
        m_alpha_by_V = 0.32 * t1 / (t1_exp - 1.)  # \alpha_m(V)
        m_alpha_by_V_diff = (-0.32 * (t1_exp - 1.) + 0.08 * t1 * t1_exp) / (
            t1_exp - 1.)**2  # \alpha_m'(V)
        t2 = V - 40. - self.NaK_th
        t2_exp = bm.exp(t2 / 5.)
        m_beta_by_V = 0.28 * t2 / (t2_exp - 1.)  # \beta_m(V)
        m_beta_by_V_diff = (0.28 * (t2_exp - 1) - 0.056 * t2 * t2_exp) / (
            t2_exp - 1)**2  # \beta_m'(V)
        m_tau_by_V = 1. / self.phi_m / (m_alpha_by_V + m_beta_by_V
                                        )  # \tau_m(V)
        m_inf_by_V = m_alpha_by_V / (m_alpha_by_V + m_beta_by_V
                                     )  # \m_{\infty}(V)
        m_inf_by_V_diff = (m_alpha_by_V_diff * m_beta_by_V - m_alpha_by_V * m_beta_by_V_diff) / \
                          (m_alpha_by_V + m_beta_by_V) ** 2  # \m_{\infty}'(V)

        # h channel
        h_alpha_by_y = 0.128 * bm.exp(
            (17. - y + self.NaK_th) / 18.)  # \alpha_h(y)
        t3 = bm.exp((40. - y + self.NaK_th) / 5.)
        h_beta_by_y = 4. / (t3 + 1.)  # \beta_h(y)
        h_inf_by_y = h_alpha_by_y / (h_alpha_by_y + h_beta_by_y
                                     )  # h_{\infty}(y)

        # n channel
        t5 = (15. - y + self.NaK_th)
        t5_exp = bm.exp(t5 / 5.)
        n_alpha_by_y = 0.032 * t5 / (t5_exp - 1.)  # \alpha_n(y)
        t6 = bm.exp((10. - y + self.NaK_th) / 40.)
        n_beta_y = self.b * t6  # \beta_n(y)
        n_inf_by_y = n_alpha_by_y / (n_alpha_by_y + n_beta_y)  # n_{\infty}(y)

        # p channel
        t7 = bm.exp((self.p_half - y + self.IT_th) / self.p_k)
        p_inf_by_y = 1. / (1. + t7)  # p_{\infty}(y)
        t8 = bm.exp((self.q_half - z + self.IT_th) / self.q_k)
        q_inf_by_z = 1. / (1. + t8)  # q_{\infty}(z)

        # x
        gNa = self.g_Na * m_inf_by_V**3 * h_inf_by_y  # gNa
        gK = self.g_K * n_inf_by_y**4  # gK
        gT = self.g_T * p_inf_by_y * p_inf_by_y * q_inf_by_z  # gT
        FV = gNa + gK + gT + self.g_L + self.g_KL  # dF/dV
        Fm = 3 * self.g_Na * h_inf_by_y * (
            V -
            self.E_Na) * m_inf_by_V * m_inf_by_V * m_inf_by_V_diff  # dF/dvm
        t9 = self.C / m_tau_by_V
        t10 = FV + Fm
        t11 = t9 + FV
        rho_V = (t11 - bm.sqrt(bm.maximum(t11**2 - 4 * t9 * t10,
                                          0.))) / 2 / t10  # rho_V
        INa = gNa * (V - self.E_Na)
        IK = gK * (V - self.E_KL)
        IT = gT * (V - self.E_T)
        IL = self.g_L * (V - self.E_L)
        IKL = self.g_KL * (V - self.E_KL)
        Iext = self.V_factor * Isyn
        dVdt = rho_V * (-INa - IK - IT - IL - IKL + Iext) / self.C

        return dVdt
コード例 #7
0
        def reduced_trn_derivative(V, y, z, t, Isyn, b, rho_p, g_T, g_L, g_KL,
                                   E_L, E_KL, IT_th, NaK_th):
            # m channel
            t1 = 13. - V + NaK_th
            t1_exp = bm.exp(t1 / 4.)
            m_alpha_by_V = 0.32 * t1 / (t1_exp - 1.)  # \alpha_m(V)
            m_alpha_by_V_diff = (-0.32 *
                                 (t1_exp - 1.) + 0.08 * t1 * t1_exp) / (
                                     t1_exp - 1.)**2  # \alpha_m'(V)
            t2 = V - 40. - NaK_th
            t2_exp = bm.exp(t2 / 5.)
            m_beta_by_V = 0.28 * t2 / (t2_exp - 1.)  # \beta_m(V)
            m_beta_by_V_diff = (0.28 * (t2_exp - 1) - 0.056 * t2 * t2_exp) / (
                t2_exp - 1)**2  # \beta_m'(V)
            m_tau_by_V = 1. / phi_m / (m_alpha_by_V + m_beta_by_V)  # \tau_m(V)
            m_inf_by_V = m_alpha_by_V / (m_alpha_by_V + m_beta_by_V
                                         )  # \m_{\infty}(V)
            m_inf_by_V_diff = (m_alpha_by_V_diff * m_beta_by_V - m_alpha_by_V * m_beta_by_V_diff) / \
                              (m_alpha_by_V + m_beta_by_V) ** 2  # \m_{\infty}'(V)

            # h channel
            h_alpha_by_V = 0.128 * bm.exp(
                (17. - V + NaK_th) / 18.)  # \alpha_h(V)
            h_beta_by_V = 4. / (bm.exp(
                (40. - V + NaK_th) / 5.) + 1.)  # \beta_h(V)
            h_inf_by_V = h_alpha_by_V / (h_alpha_by_V + h_beta_by_V
                                         )  # h_{\infty}(V)
            h_tau_by_V = 1. / phi_h / (h_alpha_by_V + h_beta_by_V)  # \tau_h(V)
            h_alpha_by_y = 0.128 * bm.exp(
                (17. - y + NaK_th) / 18.)  # \alpha_h(y)
            t3 = bm.exp((40. - y + NaK_th) / 5.)
            h_beta_by_y = 4. / (t3 + 1.)  # \beta_h(y)
            h_beta_by_y_diff = 0.8 * t3 / (1 + t3)**2  # \beta_h'(y)
            h_inf_by_y = h_alpha_by_y / (h_alpha_by_y + h_beta_by_y
                                         )  # h_{\infty}(y)
            h_alpha_by_y_diff = -h_alpha_by_y / 18.  # \alpha_h'(y)
            h_inf_by_y_diff = (h_alpha_by_y_diff * h_beta_by_y - h_alpha_by_y * h_beta_by_y_diff) / \
                              (h_beta_by_y + h_alpha_by_y) ** 2  # h_{\infty}'(y)

            # n channel
            t4 = (15. - V + NaK_th)
            n_alpha_by_V = 0.032 * t4 / (bm.exp(t4 / 5.) - 1.)  # \alpha_n(V)
            n_beta_by_V = b * bm.exp((10. - V + NaK_th) / 40.)  # \beta_n(V)
            n_tau_by_V = 1. / (n_alpha_by_V + n_beta_by_V) / phi_n  # \tau_n(V)
            n_inf_by_V = n_alpha_by_V / (n_alpha_by_V + n_beta_by_V
                                         )  # n_{\infty}(V)
            t5 = (15. - y + NaK_th)
            t5_exp = bm.exp(t5 / 5.)
            n_alpha_by_y = 0.032 * t5 / (t5_exp - 1.)  # \alpha_n(y)
            t6 = bm.exp((10. - y + NaK_th) / 40.)
            n_beta_y = b * t6  # \beta_n(y)
            n_inf_by_y = n_alpha_by_y / (n_alpha_by_y + n_beta_y
                                         )  # n_{\infty}(y)
            n_alpha_by_y_diff = (0.0064 * t5 * t5_exp - 0.032 *
                                 (t5_exp - 1.)) / (t5_exp -
                                                   1.)**2  # \alpha_n'(y)
            n_beta_by_y_diff = -n_beta_y / 40  # \beta_n'(y)
            n_inf_by_y_diff = (n_alpha_by_y_diff * n_beta_y - n_alpha_by_y * n_beta_by_y_diff) / \
                              (n_alpha_by_y + n_beta_y) ** 2  # n_{\infty}'(y)

            # p channel
            p_inf_by_V = 1. / (1. + bm.exp(
                (p_half - V + IT_th) / p_k))  # p_{\infty}(V)
            p_tau_by_V = (3 + 1. / (bm.exp(
                (V + 27. - IT_th) / 10.) + bm.exp(-(V + 102. - IT_th) / 15.))
                          ) / phi_p  # \tau_p(V)
            t7 = bm.exp((p_half - y + IT_th) / p_k)
            p_inf_by_y = 1. / (1. + t7)  # p_{\infty}(y)
            p_inf_by_y_diff = t7 / p_k / (1. + t7)**2  # p_{\infty}'(y)

            # p channel
            q_inf_by_V = 1. / (1. + bm.exp(
                (q_half - V + IT_th) / q_k))  # q_{\infty}(V)
            t8 = bm.exp((q_half - z + IT_th) / q_k)
            q_inf_by_z = 1. / (1. + t8)  # q_{\infty}(z)
            q_inf_diff_z = t8 / q_k / (1. + t8)**2  # q_{\infty}'(z)
            q_tau_by_V = (85. + 1 / (bm.exp(
                (V + 48. - IT_th) / 4.) + bm.exp(-(V + 407. - IT_th) / 50.))
                          ) / phi_q  # \tau_q(V)

            # ----
            #  x
            # ----

            gNa = g_Na * m_inf_by_V**3 * h_inf_by_y  # gNa
            gK = g_K * n_inf_by_y**4  # gK
            gT = g_T * p_inf_by_y * p_inf_by_y * q_inf_by_z  # gT
            FV = gNa + gK + gT + g_L + g_KL  # dF/dV
            Fm = 3 * g_Na * h_inf_by_y * (
                V - E_Na) * m_inf_by_V * m_inf_by_V * m_inf_by_V_diff  # dF/dvm
            t9 = C / m_tau_by_V
            t10 = FV + Fm
            t11 = t9 + FV
            rho_V = (t11 - bm.sqrt(bm.maximum(t11**2 - 4 * t9 * t10,
                                              0.))) / 2 / t10  # rho_V
            INa = gNa * (V - E_Na)
            IK = gK * (V - E_KL)
            IT = gT * (V - E_T)
            IL = g_L * (V - E_L)
            IKL = g_KL * (V - E_KL)
            Iext = V_factor * Isyn
            dVdt = rho_V * (-INa - IK - IT - IL - IKL + Iext) / C

            # ----
            #  y
            # ----

            Fvh = g_Na * m_inf_by_V**3 * (V - E_Na) * h_inf_by_y_diff  # dF/dvh
            Fvn = 4 * g_K * (V -
                             E_KL) * n_inf_by_y**3 * n_inf_by_y_diff  # dF/dvn
            f4 = Fvh + Fvn
            rho_h = (1 - rho_p) * Fvh / f4
            rho_n = (1 - rho_p) * Fvn / f4
            fh = (h_inf_by_V - h_inf_by_y) / h_tau_by_V / h_inf_by_y_diff
            fn = (n_inf_by_V - n_inf_by_y) / n_tau_by_V / n_inf_by_y_diff
            fp = (p_inf_by_V - p_inf_by_y) / p_tau_by_V / p_inf_by_y_diff
            dydt = rho_h * fh + rho_n * fn + rho_p * fp

            # ----
            #  z
            # ----

            dzdt = (q_inf_by_V - q_inf_by_z) / q_tau_by_V / q_inf_diff_z

            return dVdt, dydt, dzdt