Example #1
0
    def run(self, dt, clamp, i_ext, output_dir_name, dropouts, m, repairs=[], spks_u=None):
        """
        Run simulation.
        
        :param dt: integration timestep (s)
        :param clamp: dict of times to clamp certain variables (e.g. to initialize)
        :param i_ext: external current inputs (either 1D or 2D array, length = num timesteps for smln)
        :param spks_up: upstream inputs
        """
        n = self.n
        n_t = len(i_ext)
        syns = self.syns
        c_m = self.c_m
        g_l = self.g_l
        e_l = self.e_l
        v_th = self.v_th
        v_r = self.v_r
        t_r = self.t_r
        t_r_int = np.round(t_r/dt).astype(int)
        e_s = self.e_s
        t_s = self.t_s
        w_r = self.w_r
        w_u = self.w_u

        spk_time_hist = []

        if self.output:
            output_dir = f'./data/{output_dir_name}'
            os.makedirs(output_dir)

        
        # make data storage arrays
        gs = {syn: np.nan * np.zeros((n_t, n)) for syn in syns}
        vs = np.nan * np.zeros((n_t, n))
        spks = np.zeros((n_t, n), dtype=bool)
        
        rp_ctr = np.zeros(n, dtype=int)
        
        # convert float times in clamp dict to time idxs
        ## convert to list of tuples sorted by time
        tmp_v = sorted(list(clamp.v.items()), key=lambda x: x[0])
        tmp_spk = sorted(list(clamp.spk.items()), key=lambda x: x[0])
        clamp = Generic(
            v={int(round(t_/dt)): f_v for t_, f_v in tmp_v},
            spk={int(round(t_/dt)): f_spk for t_, f_spk in tmp_spk})

        avg_initial_input_per_cell = np.mean(self.w_r['E'][:m.N_EXC, :m.N_EXC].sum(axis=1))
        
        # loop over timesteps
        for t_ctr in range(len(i_ext) - 4 * t_r_int[0]):

            if self.output and (t_ctr == 0):
                sio.savemat(output_dir + '/' + f'{zero_pad(0, 6)}', {
                    'w_r_e': self.w_r['E'],
                    'w_r_i': self.w_r['I'],
                    'avg_input_per_cell': avg_initial_input_per_cell,
                })

            if t_ctr % 5000 == 0:
                print(f'{t_ctr / len(i_ext) * 100}% finished' )
                print(f'completed {dt * t_ctr * 1000} ms of {len(i_ext) * dt} s')

            for t, dropout in dropouts:
                if int(t / dt) == t_ctr:
                    self.w_r['E'][:, m.PROJECTION_NUM:m.N_EXC] = dropout_on_mat(self.w_r['E'][:, m.PROJECTION_NUM:m.N_EXC], dropout['E'])
                    self.w_r['I'][:, m.N_EXC:] = dropout_on_mat(self.w_r['I'][:, m.N_EXC:], dropout['I'])

                    avg_input_per_cell = np.mean(self.w_r['E'][:m.N_EXC, :m.N_EXC].sum(axis=1))
                    if self.output:
                        sio.savemat(output_dir + '/' + f'{zero_pad(1, 6)}', {
                            'avg_input_per_cell': avg_input_per_cell,
                        })

            for i_r, (t, repair_setpoint) in enumerate(repairs):
                if int(t / dt) == t_ctr:
                    target = repair_setpoint * avg_initial_input_per_cell
                    for i in range(1000):
                        avg_input_per_cell = np.mean(self.w_r['E'][:m.N_EXC, :m.N_EXC].sum(axis=1))
                        self.w_r['E'][:m.N_EXC, :m.N_EXC] += 100. * self.w_r['E'][:m.N_EXC, :m.N_EXC] * (target - avg_input_per_cell)
                        over_w_max = self.w_r['E'][:m.N_EXC, :m.N_EXC] > self.w_max
                        self.w_r['E'][:m.N_EXC, :m.N_EXC][over_w_max] = self.w_max

                    if self.output:
                        sio.savemat(output_dir + '/' + f'{zero_pad(i_r + 2, 6)}', {
                            'avg_input_per_cell': avg_input_per_cell,
                        })
            
            # update conductances
            for syn in syns:
                if t_ctr == 0:
                    gs[syn][t_ctr, :] = 0
                else:
                    g = gs[syn][t_ctr-1, :]
                    # get weighted spike inputs
                    ## recurrent
                    inp = w_r[syn].dot(spks[t_ctr-1, :])
                    ## upstream
                    if spks_u is not None:
                        if syn in w_u:
                            inp += w_u[syn].dot(spks_u[t_ctr-1, :])
                    
                    # update conductances from weighted spks
                    gs[syn][t_ctr, :] = g + (dt/t_s[syn])*(-gs[syn][t_ctr-1, :]) + inp
            
            # update voltages
            if t_ctr in clamp.v:  # check for clamped voltages
                vs[t_ctr, :] = clamp.v[t_ctr]
            else:  # update as per diff eq
                v = vs[t_ctr-1, :]
                # get total current input
                i_total = -g_l*(v - e_l)  # leak
                i_total += np.sum([-gs[syn][t_ctr, :]*(v - e_s[syn]) for syn in syns], axis=0)  # synaptic
                i_total += i_ext[t_ctr]  # external
                
                # update v
                vs[t_ctr, :] = v + (dt/c_m)*i_total
                
                # clamp v for cells still in refrac period
                vs[t_ctr, rp_ctr > 0] = v_r[rp_ctr > 0]
            
            # update spks
            if t_ctr in clamp.spk:  # check for clamped spikes
                spks[t_ctr, :] = clamp.spk[t_ctr]
            else:  # check for threshold crossings
                spks_for_t_ctr = vs[t_ctr, :] >= v_th
                spks[t_ctr, spks_for_t_ctr] = 1

            if self.weight_update:
                stdp_start = t_ctr - self.cut_idx_tau_pair
                stdp_start = 0 if stdp_start < 0 else stdp_start
                stdp_spk_hist = spks[stdp_start:(t_ctr + 1), self.plasticity_indices]
                curr_spks = stdp_spk_hist[-1, :]
                self.update_w(t_ctr, stdp_spk_hist, dt, spk_time_hist)
                self.update_spk_hist(spk_time_hist, curr_spks, t_ctr)

            # reset v and update refrac periods for nrns that spiked
            vs[t_ctr, spks[t_ctr, :]] = v_r[spks[t_ctr, :]]
            rp_ctr[spks[t_ctr, :]] = t_r_int[spks[t_ctr, :]] + 1
            
            # decrement refrac periods
            rp_ctr[rp_ctr > 0] -= 1
            
        t = dt*np.arange(n_t, dtype=float)
        
        # convert spks to spk times and cell idxs (for easy access l8r)
        tmp = spks.nonzero()
        spks_t = dt * tmp[0]
        spks_c = tmp[1]
        
        return Generic(dt=dt, t=t, gs=gs, vs=vs, spks=spks, spks_t=spks_t, spks_c=spks_c, i_ext=i_ext, ntwk=self)
Example #2
0
    def run(self, dt, clamp, i_ext, output_dir_name, dropouts, spks_u=None):
        """
        Run simulation.
        
        :param dt: integration timestep (s)
        :param clamp: dict of times to clamp certain variables (e.g. to initialize)
        :param i_ext: external current inputs (either 1D or 2D array, length = num timesteps for smln)
        :param spks_up: upstream inputs
        """
        n = self.n
        n_t = len(i_ext)
        syns = self.syns
        c_m = self.c_m
        g_l = self.g_l
        e_l = self.e_l
        v_th = self.v_th
        v_r = self.v_r
        t_r = self.t_r
        t_r_int = np.round(t_r/dt).astype(int)
        e_s = self.e_s
        t_s = self.t_s
        w_r = self.w_r
        w_u = self.w_u

        if self.output:
            output_dir = f'./data/{output_dir_name}'
            os.makedirs(output_dir)

        
        # make data storage arrays
        gs = {syn: np.nan * np.zeros((n_t, n)) for syn in syns}
        vs = np.nan * np.zeros((n_t, n))
        spks = np.zeros((n_t, n), dtype=bool)
        
        rp_ctr = np.zeros(n, dtype=int)
        
        # convert float times in clamp dict to time idxs
        ## convert to list of tuples sorted by time
        tmp_v = sorted(list(clamp.v.items()), key=lambda x: x[0])
        tmp_spk = sorted(list(clamp.spk.items()), key=lambda x: x[0])
        clamp = Generic(
            v={int(round(t_/dt)): f_v for t_, f_v in tmp_v},
            spk={int(round(t_/dt)): f_spk for t_, f_spk in tmp_spk})

        burst_t = np.arange(0, 4 * t_r_int[0], t_r_int[0])
        
        # loop over timesteps
        for t_ctr in range(len(i_ext) - 4 * t_r_int[0]):

            for t, dropout in dropouts:
                if int(t / dt) == t_ctr:
                    self.w_r['E'][:, :50] = dropout_on_mat(self.w_r['E'][:, :50], dropout['E'])
                    self.w_r['I'][:, 50:] = dropout_on_mat(self.w_r['I'][:, 50:], dropout['I'])
            
            # update conductances
            for syn in syns:
                if t_ctr == 0:
                    gs[syn][t_ctr, :] = 0
                else:
                    g = gs[syn][t_ctr-1, :]
                    # get weighted spike inputs
                    ## recurrent
                    inp = w_r[syn].dot(spks[t_ctr-1, :])
                    ## upstream
                    if spks_u is not None:
                        if syn in w_u:
                            inp += w_u[syn].dot(spks_u[t_ctr-1, :])
                    
                    # update conductances from weighted spks
                    gs[syn][t_ctr, :] = g + (dt/t_s[syn])*(-gs[syn][t_ctr-1, :]) + inp
            
            # update voltages
            if t_ctr in clamp.v:  # check for clamped voltages
                vs[t_ctr, :] = clamp.v[t_ctr]
            else:  # update as per diff eq
                v = vs[t_ctr-1, :]
                # get total current input
                i_total = -g_l*(v - e_l)  # leak
                i_total += np.sum([-gs[syn][t_ctr, :]*(v - e_s[syn]) for syn in syns], axis=0)  # synaptic
                i_total += i_ext[t_ctr]  # external
                
                # update v
                vs[t_ctr, :] = v + (dt/c_m)*i_total
                
                # clamp v for cells still in refrac period
                vs[t_ctr, rp_ctr > 0] = v_r[rp_ctr > 0]
            
            # update spks
            if t_ctr in clamp.spk:  # check for clamped spikes
                spks[t_ctr, :] = clamp.spk[t_ctr]
            else:  # check for threshold crossings
                spks_for_t_ctr = vs[t_ctr, :] >= v_th
                for b_t in burst_t:
                    spks[b_t + t_ctr, spks_for_t_ctr] = 1

            stdp_start = t_ctr - self.tau_cut 
            stdp_start = 0 if stdp_start < 0 else stdp_start
            self.update_w(t_ctr, spks[stdp_start:(t_ctr + 1), self.plasticity_indices], dt)

            if self.output and (t_ctr % self.output_freq == 0):
                sio.savemat(output_dir + '/' + f'{zero_pad(int(t_ctr / self.output_freq), 6)}', {'w_r_e': self.w_r['E']})

            # reset v and update refrac periods for nrns that spiked
            vs[t_ctr, spks[t_ctr, :]] = v_r[spks[t_ctr, :]]
            rp_ctr[spks[t_ctr, :]] = t_r_int[spks[t_ctr, :]] + 1
            
            # decrement refrac periods
            rp_ctr[rp_ctr > 0] -= 1
            
        t = dt*np.arange(n_t, dtype=float)
        
        # convert spks to spk times and cell idxs (for easy access l8r)
        tmp = spks.nonzero()
        spks_t = dt * tmp[0]
        spks_c = tmp[1]
        
        return Generic(dt=dt, t=t, gs=gs, vs=vs, spks=spks, spks_t=spks_t, spks_c=spks_c, i_ext=i_ext, ntwk=self)