Ejemplo n.º 1
0
 def update(self, _t, _dt):
   in_interval = bm.logical_and(pre_stimulus_period < _t, _t < pre_stimulus_period + stimulus_period)
   prev_freq = bm.where(in_interval, self.freq[0], 0.)
   in_interval = bm.logical_and(in_interval, (_t - self.freq_t_last_change[0]) >= self.t_interval)
   self.freq[:] = bm.where(in_interval, self.rng.normal(self.freq_mean, self.freq_var), prev_freq)
   self.freq_t_last_change[:] = bm.where(in_interval, _t, self.freq_t_last_change[0])
   self.spike.value = self.rng.random(self.num) < self.freq[0] * self.dt
Ejemplo n.º 2
0
 def update(self, _t, _dt):
     V = self.int_V(self.V, _t, self.u, self.input, dt=_dt)
     u = self.int_u(self.u, _t, self.V, dt=_dt)
     spike = V >= 0.
     self.V.value = bm.where(spike, -65., V)
     self.u.value = bm.where(spike, u + 8., u)
     self.input[:] = 0.
Ejemplo n.º 3
0
 def update(self, _t, _dt):
   V, w = self.integral(self.V, self.w, _t, self.input, dt=_dt)
   spike = V >= self.V_th
   self.t_last_spike.value = bm.where(spike, _t, self.t_last_spike)
   self.V.value = bm.where(spike, self.V_reset, V)
   self.w.value = bm.where(spike, w + self.b, w)
   self.spike.value = spike
   self.input[:] = 0.
 def update(self, _t, _dt):
     ref = (_t - self.t_last_spike) <= self.tau_ref
     V, I = self.integral(self.V, self.I, _t, _dt)
     V = bm.where(ref, self.V, V)
     spike = (V >= self.V_th)
     self.t_last_spike.value = bm.where(spike, _t, self.t_last_spike)
     self.V.value = bm.where(spike, self.V_reset, V)
     self.spike.value = spike
     self.I.value = I
Ejemplo n.º 5
0
 def update(self, _t, _dt):
   ref = (_t - self.t_last_spike) <= self.t_refractory
   V = self.integral(self.V, _t, self.input)
   V = bm.where(ref, self.V, V)
   spike = (V >= self.V_th)
   self.V.value = bm.where(spike, self.V_reset, V)
   self.spike.value = spike
   self.t_last_spike.value = bm.where(spike, _t, self.t_last_spike)
   self.refractory.value = bm.logical_or(spike, ref)
   self.input[:] = 0.
Ejemplo n.º 6
0
 def update(self, _t, _dt):
     refractory = (_t - self.t_last_spike) <= self.tau_ref
     V = self.integral(self.V, _t, self.input, dt=_dt)
     V = bm.where(refractory, self.V, V)
     spike = self.V_th <= V
     self.t_last_spike.value = bm.where(spike, _t, self.t_last_spike)
     self.V.value = bm.where(spike, self.V_reset, V)
     self.refractory.value = bm.logical_or(refractory, spike)
     self.spike.value = spike
     self.input[:] = 0.
Ejemplo n.º 7
0
 def update(self, _t, _dt):
     V, w = self.integral(self.V, self.w, _t, self.input, dt=_dt)
     self.spike.value = bm.logical_and(V >= self.Vth, self.V < self.Vth)
     self.t_last_spike.value = bm.where(self.spike, _t, self.t_last_spike)
     self.V.value = V
     self.w.value = w
     self.input[:] = 0.
Ejemplo n.º 8
0
 def update(self, _t, _dt):
   V, self.W.value = self.integral(self.V, self.W, _t, self.input, dt=_dt)
   spike = bm.logical_and(self.V < self.V_th, V >= self.V_th)
   self.t_last_spike.value = bm.where(spike, _t, self.t_last_spike)
   self.V.value = V
   self.spike.value = spike
   self.input[:] = 0.
Ejemplo n.º 9
0
 def integral(*args, **kwargs):
     assert len(args) > 0
     dt = kwargs.pop('dt', math.get_dt())
     linear, derivative = value_and_grad(*args, **kwargs)
     phi = math.where(linear == 0., math.ones_like(linear),
                      (math.exp(dt * linear) - 1) / (dt * linear))
     return args[0] + dt * phi * derivative
Ejemplo n.º 10
0
 def update(self, _t, _dt):
     V, y, z = self.integral(self.V, self.y, self.z, _t, self.input, dt=_dt)
     self.spike.value = bm.logical_and(V >= self.V_th, self.V < self.V_th)
     self.t_last_spike.value = bm.where(self.spike, _t, self.t_last_spike)
     self.V.value = V
     self.y.value = y
     self.z.value = z
     self.input[:] = 0.
Ejemplo n.º 11
0
 def update(self, _t, _dt):
     self.pre_spike.push(self.pre.spike)
     self.spike_arrival_time.value = bm.where(self.pre_spike.pull(), _t,
                                              self.spike_arrival_time)
     syn_sp_times = bm.pre2syn(self.spike_arrival_time, self.pre_ids)
     TT = ((_t - syn_sp_times) < self.T_duration) * self.T
     self.g.value = self.integral(self.g, _t, TT, dt=_dt)
     g_post = bm.syn2post(self.g, self.post_ids, self.post.num)
     self.post.input -= self.g_max * g_post * (self.post.V - self.E)
Ejemplo n.º 12
0
 def update(self, _t, _dt):
     I1, I2, V_th, V = self.integral(self.I1,
                                     self.I2,
                                     self.V_th,
                                     self.V,
                                     _t,
                                     self.input,
                                     dt=_dt)
     spike = self.V_th <= V
     V = bm.where(spike, self.V_reset, V)
     I1 = bm.where(spike, self.R1 * I1 + self.A1, I1)
     I2 = bm.where(spike, self.R2 * I2 + self.A2, I2)
     reset_th = bm.logical_and(V_th < self.V_th_reset, spike)
     V_th = bm.where(reset_th, self.V_th_reset, V_th)
     self.spike.value = spike
     self.I1.value = I1
     self.I2.value = I2
     self.V_th.value = V_th
     self.V.value = V
     self.input[:] = 0.
Ejemplo n.º 13
0
 def update(self, _t, _dt):
     V, m, h, n = self.integral(self.V,
                                self.m,
                                self.h,
                                self.n,
                                _t,
                                self.input,
                                dt=_dt)
     self.spike.value = bm.logical_and(self.V < self.V_th, V >= self.V_th)
     self.t_last_spike.value = bm.where(self.spike, _t, self.t_last_spike)
     self.V.value = V
     self.m.value = m
     self.h.value = h
     self.n.value = n
     self.input[:] = 0.
Ejemplo n.º 14
0
    def derivative(self, p, q, t, V):
        phi_p = self.T_base_p**((self.T - 24) / 10)
        p_inf = 1. / (1. + bm.exp(-(V + 59. - self.V_sh) / 6.2))
        p_tau = 1. / (bm.exp(-(V + 132. - self.V_sh) / 16.7) + bm.exp(
            (V + 16.8 - self.V_sh) / 18.2)) + 0.612
        dpdt = phi_p * (p_inf - p) / p_tau

        phi_q = self.T_base_q**((self.T - 24) / 10)
        q_inf = 1. / (1. + bm.exp((V + 83. - self.V_sh) / 4.))
        q_tau = bm.where(V >= (-80. + self.V_sh),
                         bm.exp(-(V + 22. - self.V_sh) / 10.5) + 28.,
                         bm.exp((V + 467. - self.V_sh) / 66.6))
        dqdt = phi_q * (q_inf - q) / q_tau

        return dpdt, dqdt
    def __init__(self, pre, post, conn_prob=0.1):
        super(ThalamusInput, self).__init__(pre=pre,
                                            post=post,
                                            conn=bp.conn.FixedProb(conn_prob))
        self.check_pre_attrs('spike')
        self.check_post_attrs('I')

        # connection and weights
        self.pre2post = self.conn.require('pre2post')
        self.syn_num = self.pre2post[0].size
        self.weights = bm.random.normal(*ExpSyn.exc_weight, size=self.syn_num)
        self.weights = bm.where(self.weights < 0., 0., self.weights)

        # variables
        self.turn_on = bm.Variable(bm.asarray([False]))
Ejemplo n.º 16
0
    def filter_loss(self, tolerance=1e-5):
        """Filter fixed points whose speed larger than a given tolerance.

    Parameters
    ----------
    tolerance: float
      Discard fixed points with squared speed larger than this value.
    """
        if self.verbose:
            print(f"Excluding fixed points with squared speed above "
                  f"tolerance {tolerance}:")
        num_fps = self.fixed_points.shape[0]
        ids = self._losses < tolerance
        keep_ids = bm.where(ids)[0]
        self._fixed_points = self._fixed_points[ids]
        self._losses = self._losses[keep_ids]
        self._selected_ids = self._selected_ids[keep_ids]
        if self.verbose:
            print(f"    "
                  f"Kept {self._fixed_points.shape[0]}/{num_fps} "
                  f"fixed points with tolerance under {tolerance}.")
Ejemplo n.º 17
0
 def dist(self, d):
     v_size = bm.asarray([self.z_range, self.z_range])
     return bm.where(d > v_size / 2, v_size - d, d)
    def __init__(self, pre, post, prob, syn_type='e', conn_type=0):
        super(ExpSyn, self).__init__(pre=pre, post=post, conn=None)
        self.check_pre_attrs('spike')
        self.check_post_attrs('I')
        assert syn_type in ['e', 'i']
        # assert conn_type in [0, 1, 2, 3]
        assert 0. < prob < 1.

        # parameters
        self.syn_type = syn_type
        self.conn_type = conn_type

        # connection
        if conn_type == 0:
            # number of synapses calculated with equation 3 from the article
            num = int(
                np.log(1.0 - prob) / np.log(1.0 -
                                            (1.0 / float(pre.num * post.num))))
            self.pre2post = bp.conn.ij2csr(
                pre_ids=np.random.randint(0, pre.num, num),
                post_ids=np.random.randint(0, post.num, num),
                num_pre=pre.num)
            self.num = self.pre2post[0].size
        elif conn_type == 1:
            # number of synapses calculated with equation 5 from the article
            self.pre2post = bp.conn.FixedProb(prob)(
                pre.size, post.size).require('pre2post')
            self.num = self.pre2post[0].size
        elif conn_type == 2:
            self.num = int(prob * pre.num * post.num)
            self.pre_ids = bm.random.randint(0,
                                             pre.num,
                                             size=self.num,
                                             dtype=bm.uint32)
            self.post_ids = bm.random.randint(0,
                                              post.num,
                                              size=self.num,
                                              dtype=bm.uint32)
        elif conn_type in [3, 4]:
            self.pre2post = bp.conn.FixedProb(prob)(
                pre.size, post.size).require('pre2post')
            self.num = self.pre2post[0].size
            self.max_post_conn = bm.diff(self.pre2post[1]).max()
        else:
            raise ValueError

        # delay
        if syn_type == 'e':
            self.delay = bm.random.normal(*self.exc_delay, size=pre.num)
        elif syn_type == 'i':
            self.delay = bm.random.normal(*self.inh_delay, size=pre.num)
        else:
            raise ValueError
        self.delay = bm.where(self.delay < bm.get_dt(), bm.get_dt(),
                              self.delay)

        # weights
        self.weights = bm.random.normal(*self.exc_weight, size=self.num)
        self.weights = bm.where(self.weights < 0, 0., self.weights)
        if syn_type == 'i':
            self.weights *= self.inh_weight_scale

        # variables
        self.pre_sps = bp.ConstantDelay(pre.num, self.delay, bool)
Ejemplo n.º 19
0
 def update(self, _t, _dt):
     V, u = self.integral(self.V, self.u, _t, self.input, dt=_dt)
     spike = V >= 0.
     self.V.value = bm.where(spike, -65., V)
     self.u.value = bm.where(spike, u + 8., u)
     self.input[:] = 0.
Ejemplo n.º 20
0
 def update(self, x, **kwargs):
     if kwargs.get('train', True):
         keep_mask = self.rng.bernoulli(self.prob, x.shape)
         return bm.where(keep_mask, x / self.prob, 0.)
     else:
         return x
 def dist(self, d):
   d = bm.remainder(d, self.z_range)
   d = bm.where(d > 0.5 * self.z_range, d - self.z_range, d)
   return d
Ejemplo n.º 22
0
 def update(self, _t, _i):
   self.spike.update(self.rng.random(self.num) <= self.freqs * self.dt)
   self.t_last_spike.update(bm.where(self.spike, _t, self.t_last_spike))