Пример #1
0
    def run(self, inp_ext, inp_g):

        n_t = len(inp_ext)

        v = np.zeros((n_t, self.n), dtype=float)
        spks = np.zeros((n_t, self.n), dtype=bool)

        rp = np.zeros(self.n, dtype=int)

        spks_prev = np.zeros(self.n, dtype=bool)

        for t_ctr, inp_ext_ in enumerate(inp_ext):

            # compute total inputs
            # (recurrent + g + external)
            v_ = self.w.dot(spks_prev) + inp_g * (1 + self.ltp_ie) + inp_ext_

            # set nrns in refrac period to 0
            v_[rp > 0] = 0

            # decrement refractory period
            rp = np.clip(rp - 1, 0, np.inf)

            # get tmp variable containing only max_active voltages
            v_max_only = np.zeros(self.n)
            idxs_most_active = np.argsort(v_)[-self.max_active:]
            v_max_only[idxs_most_active] = v_[idxs_most_active]

            # convert to spks
            spks_ = v_max_only >= self.v_th
            #spks_ = np.zeros(self.n, dtype=bool)
            #if np.any(v_max_only >= self.v_th):
            #    spks_[idxs_most_active] = True

            # reset refractory period for spking nrns
            rp[spks_] = self.rp

            # store everything
            v[t_ctr] = v_.copy()
            spks[t_ctr] = spks_.copy()
            spks_prev = spks_.copy()

        return Generic(t=np.arange(len(inp_ext)), v=v, spks=spks)
Пример #2
0
    def run(self, dt, clamp, i_ext, output_dir_name, dropouts, m, repairs=[], spks_u=None):
        """
        Run simulation.
        
        :param dt: integration timestep (s)
        :param clamp: dict of times to clamp certain variables (e.g. to initialize)
        :param i_ext: external current inputs (either 1D or 2D array, length = num timesteps for smln)
        :param spks_up: upstream inputs
        """
        n = self.n
        n_t = len(i_ext)
        syns = self.syns
        c_m = self.c_m
        g_l = self.g_l
        e_l = self.e_l
        v_th = self.v_th
        v_r = self.v_r
        t_r = self.t_r
        t_r_int = np.round(t_r/dt).astype(int)
        e_s = self.e_s
        t_s = self.t_s
        w_r = self.w_r
        w_u = self.w_u

        spk_time_hist = []

        if self.output:
            output_dir = f'./data/{output_dir_name}'
            os.makedirs(output_dir)

        
        # make data storage arrays
        gs = {syn: np.nan * np.zeros((n_t, n)) for syn in syns}
        vs = np.nan * np.zeros((n_t, n))
        spks = np.zeros((n_t, n), dtype=bool)
        
        rp_ctr = np.zeros(n, dtype=int)
        
        # convert float times in clamp dict to time idxs
        ## convert to list of tuples sorted by time
        tmp_v = sorted(list(clamp.v.items()), key=lambda x: x[0])
        tmp_spk = sorted(list(clamp.spk.items()), key=lambda x: x[0])
        clamp = Generic(
            v={int(round(t_/dt)): f_v for t_, f_v in tmp_v},
            spk={int(round(t_/dt)): f_spk for t_, f_spk in tmp_spk})

        avg_initial_input_per_cell = np.mean(self.w_r['E'][:m.N_EXC, :m.N_EXC].sum(axis=1))
        
        # loop over timesteps
        for t_ctr in range(len(i_ext) - 4 * t_r_int[0]):

            if self.output and (t_ctr == 0):
                sio.savemat(output_dir + '/' + f'{zero_pad(0, 6)}', {
                    'w_r_e': self.w_r['E'],
                    'w_r_i': self.w_r['I'],
                    'avg_input_per_cell': avg_initial_input_per_cell,
                })

            if t_ctr % 5000 == 0:
                print(f'{t_ctr / len(i_ext) * 100}% finished' )
                print(f'completed {dt * t_ctr * 1000} ms of {len(i_ext) * dt} s')

            for t, dropout in dropouts:
                if int(t / dt) == t_ctr:
                    self.w_r['E'][:, m.PROJECTION_NUM:m.N_EXC] = dropout_on_mat(self.w_r['E'][:, m.PROJECTION_NUM:m.N_EXC], dropout['E'])
                    self.w_r['I'][:, m.N_EXC:] = dropout_on_mat(self.w_r['I'][:, m.N_EXC:], dropout['I'])

                    avg_input_per_cell = np.mean(self.w_r['E'][:m.N_EXC, :m.N_EXC].sum(axis=1))
                    if self.output:
                        sio.savemat(output_dir + '/' + f'{zero_pad(1, 6)}', {
                            'avg_input_per_cell': avg_input_per_cell,
                        })

            for i_r, (t, repair_setpoint) in enumerate(repairs):
                if int(t / dt) == t_ctr:
                    target = repair_setpoint * avg_initial_input_per_cell
                    for i in range(1000):
                        avg_input_per_cell = np.mean(self.w_r['E'][:m.N_EXC, :m.N_EXC].sum(axis=1))
                        self.w_r['E'][:m.N_EXC, :m.N_EXC] += 100. * self.w_r['E'][:m.N_EXC, :m.N_EXC] * (target - avg_input_per_cell)
                        over_w_max = self.w_r['E'][:m.N_EXC, :m.N_EXC] > self.w_max
                        self.w_r['E'][:m.N_EXC, :m.N_EXC][over_w_max] = self.w_max

                    if self.output:
                        sio.savemat(output_dir + '/' + f'{zero_pad(i_r + 2, 6)}', {
                            'avg_input_per_cell': avg_input_per_cell,
                        })
            
            # update conductances
            for syn in syns:
                if t_ctr == 0:
                    gs[syn][t_ctr, :] = 0
                else:
                    g = gs[syn][t_ctr-1, :]
                    # get weighted spike inputs
                    ## recurrent
                    inp = w_r[syn].dot(spks[t_ctr-1, :])
                    ## upstream
                    if spks_u is not None:
                        if syn in w_u:
                            inp += w_u[syn].dot(spks_u[t_ctr-1, :])
                    
                    # update conductances from weighted spks
                    gs[syn][t_ctr, :] = g + (dt/t_s[syn])*(-gs[syn][t_ctr-1, :]) + inp
            
            # update voltages
            if t_ctr in clamp.v:  # check for clamped voltages
                vs[t_ctr, :] = clamp.v[t_ctr]
            else:  # update as per diff eq
                v = vs[t_ctr-1, :]
                # get total current input
                i_total = -g_l*(v - e_l)  # leak
                i_total += np.sum([-gs[syn][t_ctr, :]*(v - e_s[syn]) for syn in syns], axis=0)  # synaptic
                i_total += i_ext[t_ctr]  # external
                
                # update v
                vs[t_ctr, :] = v + (dt/c_m)*i_total
                
                # clamp v for cells still in refrac period
                vs[t_ctr, rp_ctr > 0] = v_r[rp_ctr > 0]
            
            # update spks
            if t_ctr in clamp.spk:  # check for clamped spikes
                spks[t_ctr, :] = clamp.spk[t_ctr]
            else:  # check for threshold crossings
                spks_for_t_ctr = vs[t_ctr, :] >= v_th
                spks[t_ctr, spks_for_t_ctr] = 1

            if self.weight_update:
                stdp_start = t_ctr - self.cut_idx_tau_pair
                stdp_start = 0 if stdp_start < 0 else stdp_start
                stdp_spk_hist = spks[stdp_start:(t_ctr + 1), self.plasticity_indices]
                curr_spks = stdp_spk_hist[-1, :]
                self.update_w(t_ctr, stdp_spk_hist, dt, spk_time_hist)
                self.update_spk_hist(spk_time_hist, curr_spks, t_ctr)

            # reset v and update refrac periods for nrns that spiked
            vs[t_ctr, spks[t_ctr, :]] = v_r[spks[t_ctr, :]]
            rp_ctr[spks[t_ctr, :]] = t_r_int[spks[t_ctr, :]] + 1
            
            # decrement refrac periods
            rp_ctr[rp_ctr > 0] -= 1
            
        t = dt*np.arange(n_t, dtype=float)
        
        # convert spks to spk times and cell idxs (for easy access l8r)
        tmp = spks.nonzero()
        spks_t = dt * tmp[0]
        spks_c = tmp[1]
        
        return Generic(dt=dt, t=t, gs=gs, vs=vs, spks=spks, spks_t=spks_t, spks_c=spks_c, i_ext=i_ext, ntwk=self)
Пример #3
0
    def run(self, dt, clamp, i_ext, output_dir_name, spks_u=None):
        """
        Run simulation.
        
        :param dt: integration timestep (s)
        :param clamp: dict of times to clamp certain variables (e.g. to initialize)
        :param i_ext: external current inputs (either 1D or 2D array, length = num timesteps for smln)
        :param spks_up: upstream inputs
        """
        n = self.n
        n_t = len(i_ext)
        syns = self.syns
        c_m = self.c_m
        g_l = self.g_l
        e_l = self.e_l
        v_th = self.v_th
        v_r = self.v_r
        t_r = self.t_r
        t_r_int = np.round(t_r / dt).astype(int)
        e_s = self.e_s
        t_s = self.t_s
        w_r = self.w_r
        w_u = self.w_u

        if self.output:
            output_dir = f'./data/{output_dir_name}'
            os.makedirs(output_dir)

        # make data storage arrays
        gs = {syn: np.nan * np.zeros((n_t, n)) for syn in syns}
        vs = np.nan * np.zeros((n_t, n))
        spks = np.zeros((n_t, n), dtype=bool)

        rp_ctr = np.zeros(n, dtype=int)

        # convert float times in clamp dict to time idxs
        ## convert to list of tuples sorted by time
        tmp_v = sorted(list(clamp.v.items()), key=lambda x: x[0])
        tmp_spk = sorted(list(clamp.spk.items()), key=lambda x: x[0])
        clamp = Generic(
            v={int(round(t_ / dt)): f_v
               for t_, f_v in tmp_v},
            spk={int(round(t_ / dt)): f_spk
                 for t_, f_spk in tmp_spk})

        burst_t = np.arange(0, t_r_int[0], int(t_r_int[0] / 3), dtype=int)

        # loop over timesteps
        for t_ctr in range(len(i_ext) - t_r_int[0]):

            # update conductances
            for syn in syns:
                if t_ctr == 0:
                    gs[syn][t_ctr, :] = 0
                else:
                    g = gs[syn][t_ctr - 1, :]
                    # get weighted spike inputs
                    ## recurrent
                    inp = w_r[syn].dot(spks[t_ctr - 1, :])
                    ## upstream
                    if spks_u is not None:
                        if syn in w_u:
                            inp += w_u[syn].dot(spks_u[t_ctr - 1, :])

                    # update conductances from weighted spks
                    gs[syn][t_ctr, :] = g + (dt / t_s[syn]) * (
                        -gs[syn][t_ctr - 1, :]) + inp

            # update voltages
            if t_ctr in clamp.v:  # check for clamped voltages
                vs[t_ctr, :] = clamp.v[t_ctr]
            else:  # update as per diff eq
                v = vs[t_ctr - 1, :]
                # get total current input
                i_total = -g_l * (v - e_l)  # leak
                i_total += np.sum(
                    [-gs[syn][t_ctr, :] * (v - e_s[syn]) for syn in syns],
                    axis=0)  # synaptic
                i_total += i_ext[t_ctr]  # external

                # update v
                vs[t_ctr, :] = v + (dt / c_m) * i_total

                # clamp v for cells still in refrac period
                vs[t_ctr, rp_ctr > 0] = v_r[rp_ctr > 0]

            # update spks
            if t_ctr in clamp.spk:  # check for clamped spikes
                spks[t_ctr, :] = clamp.spk[t_ctr]
            else:  # check for threshold crossings
                spks_for_t_ctr = vs[t_ctr, :] >= v_th

                for b_t in burst_t:
                    spks[b_t + t_ctr, spks_for_t_ctr] = 1

            stdp_start = t_ctr - self.tau_cut
            stdp_start = 0 if stdp_start < 0 else stdp_start
            self.update_w(
                t_ctr, spks[stdp_start:(t_ctr + 1), self.plasticity_indices],
                dt)

            if self.output and (t_ctr % self.output_freq == 0):
                sio.savemat(
                    output_dir + '/' +
                    f'{zero_pad(int(t_ctr / self.output_freq), 6)}',
                    {'w_r_e': self.w_r['E']})

            # reset v and update refrac periods for nrns that spiked
            vs[t_ctr, spks[t_ctr, :]] = v_r[spks[t_ctr, :]]
            rp_ctr[spks[t_ctr, :]] = t_r_int[spks[t_ctr, :]] + 1

            # decrement refrac periods
            rp_ctr[rp_ctr > 0] -= 1

        t = dt * np.arange(n_t, dtype=float)

        # convert spks to spk times and cell idxs (for easy access l8r)
        tmp = spks.nonzero()
        spks_t = dt * tmp[0]
        spks_c = tmp[1]

        return Generic(dt=dt,
                       t=t,
                       gs=gs,
                       vs=vs,
                       spks=spks,
                       spks_t=spks_t,
                       spks_c=spks_c,
                       i_ext=i_ext,
                       ntwk=self)
Пример #4
0
    def run(self, dt, clamp, i_ext, spks_u=None):
        """
        Run simulation.
        
        :param dt: integration timestep (s)
        :param clamp: dict of times to clamp certain variables (e.g. to initialize)
        :param i_ext: external current inputs (either 1D or 2D array, length = num timesteps for smln)
        :param spks_up: upstream inputs
        """
        n = self.n
        n_t = len(i_ext)
        c_m = self.c_m
        g_l = self.g_l
        e_l = self.e_l
        v_th = self.v_th
        v_r = self.v_r
        t_r = self.t_r
        t_r_int = np.round(t_r / dt).astype(int)
        w_r = self.w_r
        w_u = self.w_u

        if spks_u is not None:
            assert len(i_ext) == len(spks_u)

        # make data storage arrays
        vs = np.nan * np.zeros((n_t, n))
        spks = np.zeros((n_t, n), dtype=bool)

        rp_ctr = np.zeros(n, dtype=int)

        # convert float times in clamp dict to time idxs
        ## convert to list of tuples sorted by time
        tmp_v = sorted(list(clamp.v.items()), key=lambda x: x[0])
        tmp_spk = sorted(list(clamp.spk.items()), key=lambda x: x[0])
        clamp = Generic(
            v={int(round(t_ / dt)): f_v
               for t_, f_v in tmp_v},
            spk={int(round(t_ / dt)): f_spk
                 for t_, f_spk in tmp_spk})

        # loop over timesteps
        for t_ctr in range(len(i_ext)):

            # update voltages
            if t_ctr in clamp.v:  # check for clamped voltages
                vs[t_ctr, :] = clamp.v[t_ctr]
            else:  # update as per diff eq
                v = vs[t_ctr - 1, :]

                # get total current input
                i_total = -g_l * (v - e_l)  # leak

                if t_ctr >= 1:  # synaptic

                    if spks_u is not None:  # upstream
                        i_total += w_u.dot(spks_u[t_ctr - 1, :])
                    i_total += w_r.dot(spks[t_ctr - 1, :])  # recurrent

                i_total += i_ext[t_ctr]  # external

                # update v
                vs[t_ctr, :] = v + (dt / c_m) * i_total

                # clamp v for cells still in refrac period
                vs[t_ctr, rp_ctr > 0] = self.v_r[rp_ctr > 0]

            # update spks
            if t_ctr in clamp.spk:  # check for clamped spikes
                spks[t_ctr, :] = clamp.spk[t_ctr]
            else:  # check for threshold crossings
                spks[t_ctr, :] = vs[t_ctr, :] >= self.v_th

            # reset v and update refrac periods for nrns that spiked
            vs[t_ctr, spks[t_ctr]] = self.v_r[spks[t_ctr]]
            rp_ctr[spks[t_ctr]] = t_r_int[spks[t_ctr]] + 1

            # decrement refrac periods
            rp_ctr[rp_ctr > 0] -= 1

            # update aux variables and weights
            # NOT IMPLEMENTED YET

        t = dt * np.arange(n_t, dtype=float)

        # convert spks to spk times and cell idxs (for easy access l8r)
        tmp = spks.nonzero()
        spks_t = dt * tmp[0]
        spks_c = tmp[1]

        return Generic(dt=dt,
                       t=t,
                       vs=vs,
                       spks=spks,
                       spks_t=spks_t,
                       spks_c=spks_c,
                       i_ext=i_ext,
                       ntwk=self)
Пример #5
0
    def run(self, dt, clamp, i_ext, spks_u=None, sigma_b=0):
        """
        Run simulation.
        
        :param dt: integration timestep (s)
        :param clamp: dict of times to clamp certain variables (e.g. to initialize)
        :param i_ext: external current inputs (either 1D or 2D array, length = num timesteps for smln)
        :param spks_up: upstream inputs
        """
        n = self.n
        n_t = len(i_ext)
        c_m = self.c_m
        g_l = self.g_l
        e_l = self.e_l
        v_th = self.v_th
        v_r = self.v_r
        t_r = self.t_r
        t_r_int = np.round(t_r/dt).astype(int)
        w_r = self.w_r
        w_u = self.w_u
        t_a_int = np.round(self.t_a/dt).astype(int)

        i_kernel = alpha_func(np.flip(np.arange(0, 10 * t_a_int)), dt, self.t_a)

        burst_base = np.zeros((int(self.t_b/dt)))
        l_b = len(burst_base)

        def burst():
            burst = copy(burst_base)
            for i in np.arange(0, self.t_b, 1./self.f_b):
                noise = np.random.normal(0, sigma_b)
                burst[int(i/dt + noise)] = 1
            return burst
        
        if spks_u is not None:
            assert len(i_ext) == len(spks_u)
        
        # make data storage arrays
        vs = np.nan * np.zeros((n_t, n))
        spks = np.zeros((n_t, n), dtype=bool)
        
        rp_ctr = np.zeros(n, dtype=int)
        
        # convert float times in clamp dict to time idxs
        ## convert to list of tuples sorted by time
        tmp_v = sorted(list(clamp.v.items()), key=lambda x: x[0])
        tmp_spk = sorted(list(clamp.spk.items()), key=lambda x: x[0])
        clamp = Generic(
            v={int(round(t_/dt)): f_v for t_, f_v in tmp_v},
            spk={int(round(t_/dt)): f_spk for t_, f_spk in tmp_spk})
        
        # loop over timesteps
        for t_ctr in range(len(i_ext) - l_b):
            
            # update voltages
            if t_ctr in clamp.v:  # check for clamped voltages
                vs[t_ctr, :] = clamp.v[t_ctr]
            else:  # update as per diff eq
                v = vs[t_ctr-1, :]
                
                # get total current input
                i_total = -g_l*(v - e_l)  # leak

                t_start = t_ctr - 10 * t_a_int
                if t_start < 0:
                    t_start = 0

                len_spk_hist = t_ctr - t_start

                trimmed_i_kernel = i_kernel[(-len_spk_hist):]
                
                if trimmed_i_kernel.shape[0] >= 1:  # synaptic
                    
                    if spks_u is not None:  # upstream
                        input_current = trimmed_i_kernel.dot(spks_u[t_start:t_ctr, :])
                        i_total += w_u.dot(input_current)

                    rec_current = trimmed_i_kernel.dot(spks[t_start:t_ctr, :])
                    i_total += w_r.dot(rec_current)  # recurrent
                
                i_total += i_ext[t_ctr]  # external
                
                # update v
                vs[t_ctr, :] = v + (dt/c_m)*i_total
                
                # clamp v for cells still in refrac period
                vs[t_ctr, rp_ctr > 0] = self.v_r[rp_ctr > 0]
            
            # update spks
            if t_ctr in clamp.spk:  # check for clamped spikes
                clamped = clamp.spk[t_ctr]
                spks[t_ctr, clamped] = 1
                unclamped = np.ones(n, dtype=bool)
                unclamped[clamped] = 0
            else:
                unclamped = np.ones(n, dtype=bool)


            crossed_thresh = vs[t_ctr, unclamped] >= self.v_th
            crossed_thresh_bursting = crossed_thresh & self.i_b[unclamped]
            crossed_thresh_spiking = crossed_thresh & (~(self.i_b[unclamped]))

            spks[t_ctr, crossed_thresh_spiking.nonzero()[0]] = 1
            b = burst()
            for j in np.arange(l_b):
                spks[t_ctr + j, crossed_thresh_bursting.nonzero()[0]] = b[j]
                
            # reset v and update refrac periods for nrns that spiked
            vs[t_ctr, spks[t_ctr]] = self.v_r[spks[t_ctr]]
            rp_ctr[spks[t_ctr]] = t_r_int[spks[t_ctr]] + 1
            
            # decrement refrac periods
            rp_ctr[rp_ctr > 0] -= 1
            
            # update aux variables and weights
            # NOT IMPLEMENTED YET
        
        t = dt*np.arange(n_t, dtype=float)
        
        # convert spks to spk times and cell idxs (for easy access l8r)
        tmp = spks.nonzero()
        spks_t = dt * tmp[0]
        spks_c = tmp[1]
        
        return Generic(dt=dt, t=t, vs=vs, spks=spks, spks_t=spks_t, spks_c=spks_c, i_ext=i_ext, ntwk=self)