Exemplo n.º 1
0
    def optimizetimescales(self, myExp):
        myExp.plotTrainingSet()
        myExp.plotTestSet()

        myGIF_rect = GIF(0.1)

        myGIF_rect.Tref = self.T_ref
        myGIF_rect.eta = Filter_Rect_LogSpaced()
        myGIF_rect.eta.setMetaParameters(length=500.0,
                                         binsize_lb=2.0,
                                         binsize_ub=100.0,
                                         slope=4.5)
        myGIF_rect.fitVoltageReset(myExp, myGIF_rect.Tref, do_plot=False)
        myGIF_rect.fitSubthresholdDynamics(myExp,
                                           DT_beforeSpike=self.DT_beforespike)

        myGIF_rect.eta.fitSumOfExponentials(3, [1.0, 0.5, 0.1],
                                            self.tau_gamma,
                                            ROI=None,
                                            dt=0.1)
        print "Optimal timescales: ", myGIF_rect.eta.tau0

        self.tau_opt = [t for t in myGIF_rect.eta.tau0 if t < self.eta_tau_max]

        self.fitmodel(self, myExp)
Exemplo n.º 2
0
    def __init__(self, dt=0.1):
                   
        self.dt = dt                    # dt used in simulations (eta and gamma are interpolated according to this value)
  
        # Define model parameters
        
        self.gl      = 1.0/100.0        # nS, leak conductance
        self.C       = 20.0*self.gl     # nF, capacitance
        self.El      = -65.0            # mV, reversal potential
        
        self.Vr      = -50.0            # mV, voltage reset
        self.Tref    = 4.0              # ms, absolute refractory period
        
        self.Vt_star = -48.0            # mV, steady state voltage threshold VT*
        self.DV      = 0.5              # mV, threshold sharpness
        self.lambda0 = 1.0              # by default this parameter is always set to 1.0 Hz
        
        
        self.eta     = Filter_Rect_LogSpaced()    # nA, spike-triggered current (must be instance of class Filter)
        self.gamma   = Filter_Rect_LogSpaced()    # mV, spike-triggered movement of the firing threshold (must be instance of class Filter)
        
        
        # Initialize the spike-triggered current eta with an exponential function        
        
        def expfunction_eta(x):
            return 0.2*np.exp(-x/100.0)
        
        self.eta.setFilter_Function(expfunction_eta)


        # Initialize the spike-triggered current gamma with an exponential function        
        
        def expfunction_gamma(x):
            return 10.0*np.exp(-x/100.0)
        
        self.gamma.setFilter_Function(expfunction_gamma)        
        
              
        # Variables related to fitting procedure
        
        self.avg_spike_shape = 0
        self.avg_spike_shape_support = 0
Exemplo n.º 3
0
    def __init__(self, dt=0.1):

        self.dt = dt  # dt used in simulations (eta and gamma are interpolated according to this value)

        # Define model parameters

        self.gl = 1.0 / 100.0  # nS, leak conductance
        self.C = 20.0 * self.gl  # nF, capacitance
        self.El = -65.0  # mV, reversal potential

        self.Vr = -50.0  # mV, voltage reset
        self.Tref = 4.0  # ms, absolute refractory period

        self.Vt_star = -48.0  # mV, steady state voltage threshold VT*
        self.DV = 0.5  # mV, threshold sharpness
        self.lambda0 = 1.0  # by default this parameter is always set to 1.0 Hz

        self.eta = Filter_Rect_LogSpaced(
        )  # nA, spike-triggered current (must be instance of class Filter)
        self.gamma = Filter_Rect_LogSpaced(
        )  # mV, spike-triggered movement of the firing threshold (must be instance of class Filter)

        # Initialize the spike-triggered current eta with an exponential function

        def expfunction_eta(x):
            return 0.2 * np.exp(-x / 100.0)

        self.eta.setFilter_Function(expfunction_eta)

        # Initialize the spike-triggered current gamma with an exponential function

        def expfunction_gamma(x):
            return 10.0 * np.exp(-x / 100.0)

        self.gamma.setFilter_Function(expfunction_gamma)

        # Variables related to fitting procedure

        self.avg_spike_shape = 0
        self.avg_spike_shape_support = 0
Exemplo n.º 4
0
        experiment.setAEC(myAEC)
        experiment.performAEC()

        # Determine the refractory period

        #################################################################################################
        # FIT STANDARD GIF
        #################################################################################################

        # Create a new object GIF
        GIF_fit = GIF(sampling_time)

        # Define parameters
        GIF_fit.Tref = 6.0

        GIF_fit.eta = Filter_Rect_LogSpaced()
        GIF_fit.eta.setMetaParameters(length=2000.0,
                                      binsize_lb=0.5,
                                      binsize_ub=500.0,
                                      slope=10.0)

        GIF_fit.gamma = Filter_Rect_LogSpaced()
        GIF_fit.gamma.setMetaParameters(length=2000.0,
                                        binsize_lb=2.0,
                                        binsize_ub=500.0,
                                        slope=5.0)

        for tr in experiment.trainingset_traces:
            tr.setROI(
                [[2000., sampling_time * (len(voltage_trace) - 1) - 2000.]])
Exemplo n.º 5
0
class GIF(ThresholdModel):
    """
    Generalized Integrate and Fire model defined in Pozzorini et al. PLOS Comp. Biol. 2015
    
    Spike are produced stochastically with firing intensity:
    
    lambda(t) = lambda0 * exp( (V(t)-V_T(t))/DV ),
    
    where the membrane potential dynamics is given by:
    
    C dV/dt = -gl(V-El) + I - sum_j eta(t-\hat t_j)
    
    and the firing threshold V_T is given by:
    
    V_T = Vt_star + sum_j gamma(t-\hat t_j)
    
    and \hat t_j denote the spike times.
    """
    def __init__(self, dt=0.1):

        self.dt = dt  # dt used in simulations (eta and gamma are interpolated according to this value)

        # Define model parameters

        self.gl = 1.0 / 100.0  # nS, leak conductance
        self.C = 20.0 * self.gl  # nF, capacitance
        self.El = -65.0  # mV, reversal potential

        self.Vr = -50.0  # mV, voltage reset
        self.Tref = 4.0  # ms, absolute refractory period

        self.Vt_star = -48.0  # mV, steady state voltage threshold VT*
        self.DV = 0.5  # mV, threshold sharpness
        self.lambda0 = 1.0  # by default this parameter is always set to 1.0 Hz

        self.eta = Filter_Rect_LogSpaced(
        )  # nA, spike-triggered current (must be instance of class Filter)
        self.gamma = Filter_Rect_LogSpaced(
        )  # mV, spike-triggered movement of the firing threshold (must be instance of class Filter)

        # Initialize the spike-triggered current eta with an exponential function

        def expfunction_eta(x):
            return 0.2 * np.exp(-x / 100.0)

        self.eta.setFilter_Function(expfunction_eta)

        # Initialize the spike-triggered current gamma with an exponential function

        def expfunction_gamma(x):
            return 10.0 * np.exp(-x / 100.0)

        self.gamma.setFilter_Function(expfunction_gamma)

        # Variables related to fitting procedure

        self.avg_spike_shape = 0
        self.avg_spike_shape_support = 0

    def setDt(self, dt):
        """
        Define the time step used for numerical simulations. The filters eta and gamma are interpolated accordingly.
        """

        self.dt = dt

    ########################################################################################################
    # IMPLEMENT ABSTRACT METHODS OF Spiking model
    ########################################################################################################

    def simulateSpikingResponse(self, I, dt):
        """
        Simulate the spiking response of the GIF model to an input current I (nA) with time step dt.
        Return a list of spike times (in ms).
        The initial conditions for the simulation is V(0)=El.
        """
        self.setDt(dt)

        (time, V, eta_sum, V_T, spks_times) = self.simulate(I, self.El)

        return spks_times

    ########################################################################################################
    # IMPLEMENT ABSTRACT METHODS OF Threshold Model
    ########################################################################################################

    def simulateVoltageResponse(self, I, dt):

        self.setDt(dt)

        (time, V, eta_sum, V_T, spks_times) = self.simulate(I, self.El)

        return (spks_times, V, V_T)

    ########################################################################################################
    # METHODS FOR NUMERICAL SIMULATIONS
    ########################################################################################################

    def simulate(self, I, V0):
        """
        Simulate the spiking response of the GIF model to an input current I (nA) with time step dt.
        V0 indicate the initial condition V(0)=V0.
        The function returns:
        - time     : ms, support for V, eta_sum, V_T, spks
        - V        : mV, membrane potential
        - eta_sum  : nA, adaptation current
        - V_T      : mV, firing threshold
        - spks     : ms, list of spike times 
        """

        # Input parameters
        p_T = len(I)
        p_dt = self.dt

        # Model parameters
        p_gl = self.gl
        p_C = self.C
        p_El = self.El
        p_Vr = self.Vr
        p_Tref = self.Tref
        p_Vt_star = self.Vt_star
        p_DV = self.DV
        p_lambda0 = self.lambda0

        # Model kernels
        (p_eta_support, p_eta) = self.eta.getInterpolatedFilter(self.dt)
        p_eta = p_eta.astype('double')
        p_eta_l = len(p_eta)

        (p_gamma_support, p_gamma) = self.gamma.getInterpolatedFilter(self.dt)
        p_gamma = p_gamma.astype('double')
        p_gamma_l = len(p_gamma)

        # Define arrays
        V = np.array(np.zeros(p_T), dtype="double")
        I = np.array(I, dtype="double")
        spks = np.array(np.zeros(p_T), dtype="double")
        eta_sum = np.array(np.zeros(p_T + 2 * p_eta_l), dtype="double")
        gamma_sum = np.array(np.zeros(p_T + 2 * p_gamma_l), dtype="double")

        # Set initial condition
        V[0] = V0

        code = """
                #include <math.h>
                
                int   T_ind      = int(p_T);                
                float dt         = float(p_dt); 
                
                float gl         = float(p_gl);
                float C          = float(p_C);
                float El         = float(p_El);
                float Vr         = float(p_Vr);
                int   Tref_ind   = int(float(p_Tref)/dt);
                float Vt_star    = float(p_Vt_star);
                float DeltaV     = float(p_DV);
                float lambda0    = float(p_lambda0);
           
                int eta_l        = int(p_eta_l);
                int gamma_l      = int(p_gamma_l);
                
                                                  
                float rand_max  = float(RAND_MAX); 
                float p_dontspike = 0.0 ;
                float lambda = 0.0 ;            
                float r = 0.0;
                
                                                
                for (int t=0; t<T_ind-1; t++) {
    
    
                    // INTEGRATE VOLTAGE
                    V[t+1] = V[t] + dt/C*( -gl*(V[t] - El) + I[t] - eta_sum[t] );
               
               
                    // COMPUTE PROBABILITY OF EMITTING ACTION POTENTIAL
                    lambda = lambda0*exp( (V[t+1]-Vt_star-gamma_sum[t])/DeltaV );
                    p_dontspike = exp(-lambda*(dt/1000.0));                                  // since lambda0 is in Hz, dt must also be in Hz (this is why dt/1000.0)
                          
                          
                    // PRODUCE SPIKE STOCHASTICALLY
                    r = rand()/rand_max;
                    if (r > p_dontspike) {
                                        
                        if (t+1 < T_ind-1)                
                            spks[t+1] = 1.0; 
                        
                        t = t + Tref_ind;    
                        
                        if (t+1 < T_ind-1) 
                            V[t+1] = Vr;
                        
                        
                        // UPDATE ADAPTATION PROCESSES     
                        for(int j=0; j<eta_l; j++) 
                            eta_sum[t+1+j] += p_eta[j]; 
                        
                        for(int j=0; j<gamma_l; j++) 
                            gamma_sum[t+1+j] += p_gamma[j] ;  
                        
                    }
               
                }
                
                """

        vars = [
            'p_T', 'p_dt', 'p_gl', 'p_C', 'p_El', 'p_Vr', 'p_Tref',
            'p_Vt_star', 'p_DV', 'p_lambda0', 'V', 'I', 'p_eta', 'p_eta_l',
            'eta_sum', 'p_gamma', 'gamma_sum', 'p_gamma_l', 'spks'
        ]

        v = weave.inline(code, vars)

        time = np.arange(p_T) * self.dt

        eta_sum = eta_sum[:p_T]
        V_T = gamma_sum[:p_T] + p_Vt_star

        spks = (np.where(spks == 1)[0]) * self.dt

        return (time, V, eta_sum, V_T, spks)

    def simulateDeterministic_forceSpikes(self, I, V0, spks):
        """
        Simulate the subthresohld response of the GIF model to an input current I (nA) with time step dt.
        Adaptation currents are forces to accur at times specified in the list spks (in ms) given as an argument
        to the function.
        V0 indicate the initial condition V(t=0)=V0.
        
        The function returns:
        
        - time     : ms, support for V, eta_sum, V_T, spks
        - V        : mV, membrane potential
        - eta_sum  : nA, adaptation current
        """

        # Input parameters
        p_T = len(I)
        p_dt = self.dt

        # Model parameters
        p_gl = self.gl
        p_C = self.C
        p_El = self.El
        p_Vr = self.Vr
        p_Tref = self.Tref
        p_Tref_i = int(self.Tref / self.dt)

        # Model kernel
        (p_eta_support, p_eta) = self.eta.getInterpolatedFilter(self.dt)
        p_eta = p_eta.astype('double')
        p_eta_l = len(p_eta)

        # Define arrays
        V = np.array(np.zeros(p_T), dtype="double")
        I = np.array(I, dtype="double")
        spks = np.array(spks, dtype="double")
        spks_i = Tools.timeToIndex(spks, self.dt)

        # Compute adaptation current (sum of eta triggered at spike times in spks)
        eta_sum = np.array(np.zeros(p_T + 1.1 * p_eta_l + p_Tref_i),
                           dtype="double")

        for s in spks_i:
            eta_sum[s + 1 + p_Tref_i:s + 1 + p_Tref_i + p_eta_l] += p_eta

        eta_sum = eta_sum[:p_T]

        # Set initial condition
        V[0] = V0

        code = """ 
                #include <math.h>
                
                int   T_ind      = int(p_T);                
                float dt         = float(p_dt); 
                
                float gl         = float(p_gl);
                float C          = float(p_C);
                float El         = float(p_El);
                float Vr         = float(p_Vr);
                int   Tref_ind   = int(float(p_Tref)/dt);


                int next_spike = spks_i[0] + Tref_ind;
                int spks_cnt = 0;
 
                                                                       
                for (int t=0; t<T_ind-1; t++) {
    
    
                    // INTEGRATE VOLTAGE
                    V[t+1] = V[t] + dt/C*( -gl*(V[t] - El) + I[t] - eta_sum[t] );
               
               
                    if ( t == next_spike ) {
                        spks_cnt = spks_cnt + 1;
                        next_spike = spks_i[spks_cnt] + Tref_ind;
                        V[t-1] = 0 ;                  
                        V[t] = Vr ;
                        t=t-1;           
                    }
               
                }
        
                """

        vars = [
            'p_T', 'p_dt', 'p_gl', 'p_C', 'p_El', 'p_Vr', 'p_Tref', 'V', 'I',
            'eta_sum', 'spks_i'
        ]

        v = weave.inline(code, vars)

        time = np.arange(p_T) * self.dt
        eta_sum = eta_sum[:p_T]

        return (time, V, eta_sum)

    ########################################################################################################
    # METHODS FOR MODEL FITTING
    ########################################################################################################

    def fit(self, experiment, DT_beforeSpike=5.0):
        """
        Fit the GIF model on experimental data.
        The experimental data are stored in the object experiment provided as an input.
        The parameter DT_beforeSpike (in ms) defines the region that is cut before each spike when fitting the subthreshold dynamics of the membrane potential.
        Only training set traces in experiment are used to perform the fit.
        """

        # Three step procedure used for parameters extraction

        print "\n################################"
        print "# Fit GIF"
        print "################################\n"

        self.fitVoltageReset(experiment, self.Tref, do_plot=False)

        self.fitSubthresholdDynamics(experiment, DT_beforeSpike=DT_beforeSpike)

        self.fitStaticThreshold(experiment)

        self.fitThresholdDynamics(experiment)

    ########################################################################################################
    # FIT VOLTAGE RESET GIVEN ABSOLUTE REFRACOTORY PERIOD (step 1)
    ########################################################################################################

    def fitVoltageReset(self, experiment, Tref, do_plot=False):
        """
        Implement Step 1 of the fitting procedure introduced in Pozzorini et al. PLOS Comb. Biol. 2015
        experiment: Experiment object on which the model is fitted.
        Tref: ms, absolute refractory period. 
        The voltage reset is estimated by computing the spike-triggered average of the voltage.
        """

        print "Estimate voltage reset (Tref = %0.1f ms)..." % (Tref)

        # Fix absolute refractory period
        self.dt = experiment.dt
        self.Tref = Tref

        all_spike_average = []
        all_spike_nb = 0
        for tr in experiment.trainingset_traces:

            if tr.useTrace:
                if len(tr.spks) > 0:
                    (support, spike_average,
                     spike_nb) = tr.computeAverageSpikeShape()
                    all_spike_average.append(spike_average)
                    all_spike_nb += spike_nb

        spike_average = np.mean(all_spike_average, axis=0)

        # Estimate voltage reset
        Tref_ind = np.where(support >= self.Tref)[0][0]
        self.Vr = spike_average[Tref_ind]

        # Save average spike shape
        self.avg_spike_shape = spike_average
        self.avg_spike_shape_support = support

        if do_plot:
            plt.figure()
            plt.plot(support, spike_average, 'black')
            plt.plot([support[Tref_ind]], [self.Vr], '.', color='red')
            plt.show()

        print "Done! Vr = %0.2f mV (computed on %d spikes)" % (self.Vr,
                                                               all_spike_nb)

    ########################################################################################################
    # FUNCTIONS RELATED TO FIT OF SUBTHRESHOLD DYNAMICS (step 2)
    ########################################################################################################

    def fitSubthresholdDynamics(self, experiment, DT_beforeSpike=5.0):
        """
        Implement Step 2 of the fitting procedure introduced in Pozzorini et al. PLOS Comb. Biol. 2015
        The voltage reset is estimated by computing the spike-triggered average of the voltage.
        experiment: Experiment object on which the model is fitted.
        DT_beforeSpike: in ms, data right before spikes are excluded from the fit. This parameter can be used to define that time interval.
        """

        print "\nGIF MODEL - Fit subthreshold dynamics..."

        # Expand eta in basis functions
        self.dt = experiment.dt

        # Build X matrix and Y vector to perform linear regression (use all traces in training set)
        # For each training set an X matrix and a Y vector is built.
        ####################################################################################################
        X = []
        Y = []

        cnt = 0

        for tr in experiment.trainingset_traces:

            if tr.useTrace:

                cnt += 1
                reprint("Compute X matrix for repetition %d" % (cnt))

                # Compute the the X matrix and Y=\dot_V_data vector used to perform the multilinear linear regression (see Eq. 17.18 in Pozzorini et al. PLOS Comp. Biol. 2015)
                (X_tmp,
                 Y_tmp) = self.fitSubthresholdDynamics_Build_Xmatrix_Yvector(
                     tr, DT_beforeSpike=DT_beforeSpike)

                X.append(X_tmp)
                Y.append(Y_tmp)

        # Concatenate matrixes associated with different traces to perform a single multilinear regression
        ####################################################################################################
        if cnt == 1:
            X = X[0]
            Y = Y[0]

        elif cnt > 1:
            X = np.concatenate(X, axis=0)
            Y = np.concatenate(Y, axis=0)

        else:
            print "\nError, at least one training set trace should be selected to perform fit."

        # Perform linear Regression defined in Eq. 17 of Pozzorini et al. PLOS Comp. Biol. 2015
        ####################################################################################################

        print "\nPerform linear regression..."
        XTX = np.dot(np.transpose(X), X)
        XTX_inv = inv(XTX)
        XTY = np.dot(np.transpose(X), Y)
        b = np.dot(XTX_inv, XTY)
        b = b.flatten()

        # Extract explicit model parameters from regression result b
        ####################################################################################################

        self.C = 1. / b[1]
        self.gl = -b[0] * self.C
        self.El = b[2] * self.C / self.gl
        self.eta.setFilter_Coefficients(-b[3:] * self.C)

        self.printParameters()

        # Compute percentage of variance explained on dV/dt
        ####################################################################################################

        var_explained_dV = 1.0 - np.mean((Y - np.dot(X, b))**2) / np.var(Y)
        print "Percentage of variance explained (on dV/dt): %0.2f" % (
            var_explained_dV * 100.0)

        # Compute percentage of variance explained on V (see Eq. 26 in Pozzorini et al. PLOS Comp. Biol. 2105)
        ####################################################################################################

        SSE = 0  # sum of squared errors
        VAR = 0  # variance of data

        for tr in experiment.trainingset_traces:

            if tr.useTrace:

                # Simulate subthreshold dynamics
                (time, V_est,
                 eta_sum_est) = self.simulateDeterministic_forceSpikes(
                     tr.I, tr.V[0], tr.getSpikeTimes())

                indices_tmp = tr.getROI_FarFromSpikes(0.0, self.Tref)

                SSE += sum((V_est[indices_tmp] - tr.V[indices_tmp])**2)
                VAR += len(indices_tmp) * np.var(tr.V[indices_tmp])

        var_explained_V = 1.0 - SSE / VAR

        print "Percentage of variance explained (on V): %0.2f" % (
            var_explained_V * 100.0)

    def fitSubthresholdDynamics_Build_Xmatrix_Yvector(self,
                                                      trace,
                                                      DT_beforeSpike=5.0):
        """
        Compute the X matrix and the Y vector (i.e. \dot_V_data) used to perfomr the linear regression 
        defined in Eq. 17-18 of Pozzorini et al. 2015 for an individual experimental trace provided as parameter.
        The input parameter trace is an ojbect of class Trace.
        """

        # Length of the voltage trace
        Tref_ind = int(self.Tref / trace.dt)

        # Select region where to perform linear regression (specified in the ROI of individual taces)
        ####################################################################################################
        selection = trace.getROI_FarFromSpikes(DT_beforeSpike, self.Tref)
        selection_l = len(selection)

        # Build X matrix for linear regression (see Eq. 18 in Pozzorini et al. PLOS Comp. Biol. 2015)
        ####################################################################################################
        X = np.zeros((selection_l, 3))

        # Fill first two columns of X matrix
        X[:, 0] = trace.V[selection]
        X[:, 1] = trace.I[selection]
        X[:, 2] = np.ones(selection_l)

        # Compute and fill the remaining columns associated with the spike-triggered current eta
        X_eta = self.eta.convolution_Spiketrain_basisfunctions(
            trace.getSpikeTimes() + self.Tref, trace.T, trace.dt)
        X = np.concatenate((X, X_eta[selection, :]), axis=1)

        # Build Y vector (voltage derivative \dot_V_data)
        ####################################################################################################
        Y = np.array(np.concatenate(
            (np.diff(trace.V) / trace.dt, [0])))[selection]

        return (X, Y)

    ########################################################################################################
    # FUNCTIONS RELATED TO FIT FIRING THRESHOLD PARAMETERS (step 3)
    ########################################################################################################

    def fitStaticThreshold(self, experiment):
        """
        Implement Step 3 of the fitting procedure introduced in Pozzorini et al. PLOS Comb. Biol. 2015
        Instead of directly fitting a dynamic threshold, this function just fit a constant threshold.
        The output of this fit can be used as a smart initial condition to fit the full GIF model (i.e.,
        a model featuting a spike-triggered current gamma). See Pozzorini et al. PLOS Comp. Biol. 2015
        experiment: Experiment object on which the model is fitted.
        """

        print "\nGIF MODEL - Fit static threshold...\n"

        self.setDt(experiment.dt)

        # Define initial conditions (based on the average firing rate in the training set)
        ###############################################################################################

        nbSpikes = 0
        duration = 0

        for tr in experiment.trainingset_traces:

            if tr.useTrace:

                nbSpikes += tr.getSpikeNb_inROI()
                duration += tr.getTraceLength_inROI()

        mean_firingrate = 1000.0 * nbSpikes / duration

        self.lambda0 = 1.0
        self.DV = 50.0
        self.Vt_star = -np.log(mean_firingrate) * self.DV

        # Perform maximum likelihood fit (Newton method)
        ###############################################################################################

        beta0_staticThreshold = [1 / self.DV, -self.Vt_star / self.DV]
        beta_opt = self.maximizeLikelihood(experiment, beta0_staticThreshold,
                                           self.buildXmatrix_staticThreshold)

        # Store result of constnat threshold fitting
        ###############################################################################################

        self.DV = 1.0 / beta_opt[0]
        self.Vt_star = -beta_opt[1] * self.DV
        self.gamma.setFilter_toZero()

        self.printParameters()

    def fitThresholdDynamics(self, experiment):
        """
        Implement Step 3 of the fitting procedure introduced in Pozzorini et al. PLOS Comb. Biol. 2015
        Fit firing threshold dynamics by solving Eq. 20 using Newton method.
        
        experiment: Experiment object on which the model is fitted.
        """

        print "\nGIF MODEL - Fit dynamic threshold...\n"

        self.setDt(experiment.dt)

        # Perform maximum likelihood fit (Newton method)
        ###############################################################################################

        # Define initial conditions
        beta0_dynamicThreshold = np.concatenate(
            ([1 / self.DV], [-self.Vt_star / self.DV],
             self.gamma.getCoefficients() / self.DV))
        beta_opt = self.maximizeLikelihood(experiment, beta0_dynamicThreshold,
                                           self.buildXmatrix_dynamicThreshold)

        # Store result
        ###############################################################################################

        self.DV = 1.0 / beta_opt[0]
        self.Vt_star = -beta_opt[1] * self.DV
        self.gamma.setFilter_Coefficients(-beta_opt[2:] * self.DV)

        self.printParameters()

    def maximizeLikelihood(self,
                           experiment,
                           beta0,
                           buildXmatrix,
                           maxIter=10**3,
                           stopCond=10**-6):

        ###
        ### THIS IMPLEMENTATION IS NOT SO COOL :(
        ### IN NEW VERSION OF THE CODE I SHOULD IMPLEMENT A NEW CLASS THAT TAKES CARE OF MAXLIKELIHOOD ON lambda=exp(Xbeta) model
        ###
        """
        Maximize likelihood. This function can be used to fit any model of the form lambda=exp(Xbeta).
        This function is used to fit both:
        - static threshold
        - dynamic threshold
        The difference between the two functions is in the size of beta0 and the returned beta, as well
        as the function buildXmatrix.
        """

        # Precompute all the matrices used in the gradient ascent (see Eq. 20 in Pozzorini et al. 2015)
        ################################################################################################

        # here X refer to the matrix made of y vectors defined in Eq. 21 (Pozzorini et al. 2015)
        # since the fit can be perfomed on multiple traces, we need lists
        all_X = []

        # similar to X but only contains temporal samples where experimental spikes have been observed
        # storing this matrix is useful to improve speed when computing the likelihood as well as its derivative
        all_X_spikes = []

        # sum X_spikes over spikes. Precomputing this quantity improve speed when the gradient is evaluated
        all_sum_X_spikes = []

        # variables used to compute the loglikelihood of a Poisson process spiking at the experimental firing rate
        T_tot = 0.0
        N_spikes_tot = 0.0

        traces_nb = 0

        for tr in experiment.trainingset_traces:

            if tr.useTrace:

                traces_nb += 1

                # Simulate subthreshold dynamics
                (time, V_est,
                 eta_sum_est) = self.simulateDeterministic_forceSpikes(
                     tr.I, tr.V[0], tr.getSpikeTimes())

                # Precomputes matrices to compute gradient ascent on log-likelihood
                # depeinding on the model being fitted (static vs dynamic threshodl) different buildXmatrix functions can be used
                (X_tmp, X_spikes_tmp, sum_X_spikes_tmp, N_spikes,
                 T) = buildXmatrix(tr, V_est)

                T_tot += T
                N_spikes_tot += N_spikes

                all_X.append(X_tmp)
                all_X_spikes.append(X_spikes_tmp)
                all_sum_X_spikes.append(sum_X_spikes_tmp)

        # Compute log-likelihood of a poisson process (this quantity is used to normalize the model log-likelihood)
        ################################################################################################

        logL_poisson = N_spikes_tot * (np.log(N_spikes_tot / T_tot) - 1)

        # Perform gradient ascent
        ################################################################################################

        print "Maximize log-likelihood (bit/spks)..."

        beta = beta0
        old_L = 1

        for i in range(maxIter):

            learning_rate = 1.0

            # In the first iterations using a small learning rate makes things somehow more stable
            if i <= 10:
                learning_rate = 0.1

            L = 0
            G = 0
            H = 0

            for trace_i in np.arange(traces_nb):

                # compute log-likelihood, gradient and hessian on a specific trace (note that the fit is performed on multiple traces)
                (L_tmp, G_tmp, H_tmp) = self.computeLikelihoodGradientHessian(
                    beta, all_X[trace_i], all_X_spikes[trace_i],
                    all_sum_X_spikes[trace_i])

                # note that since differentiation is linear: gradient of sum = sum of gradient ; hessian of sum = sum of hessian
                L += L_tmp
                G += G_tmp
                H += H_tmp

            # Update optimal parametes (ie, implement Newton step) by tacking into account multiple traces

            beta = beta - learning_rate * np.dot(inv(H), G)

            if (i > 0 and abs((L - old_L) / old_L) < stopCond):  # If converged
                print "\nConverged after %d iterations!\n" % (i + 1)
                break

            old_L = L

            # Compute normalized likelihood (for print)
            # The likelihood is normalized with respect to a poisson process and units are in bit/spks
            L_norm = (L - logL_poisson) / np.log(2) / N_spikes_tot
            reprint(L_norm)

            if math.isnan(L_norm):
                print "Problem during gradient ascent. Optimizatino stopped."
                break

        if (i == maxIter - 1):  # If too many iterations

            print "\nNot converged after %d iterations.\n" % (maxIter)

        return beta

    def computeLikelihoodGradientHessian(self, beta, X, X_spikes,
                                         sum_X_spikes):
        """
        Compute the log-likelihood, its gradient and hessian for a model whose 
        log-likelihood has the same form as the one defined in Eq. 20 (Pozzorini et al. PLOS Comp. Biol. 2015)
        """

        # IMPORTANT: in general we assume that the lambda_0 = 1 Hz
        # The parameter lambda0 is redundant with Vt_star, so only one of those has to be fitted.
        # We genearlly fix lambda_0 adn fit Vt_star

        dt = self.dt / 1000.0  # put dt in units of seconds (to be consistent with lambda_0)

        X_spikesbeta = np.dot(X_spikes, beta)
        Xbeta = np.dot(X, beta)
        expXbeta = np.exp(Xbeta)

        # Compute loglikelihood defined in Eq. 20 Pozzorini et al. 2015
        L = sum(X_spikesbeta) - self.lambda0 * dt * sum(expXbeta)

        # Compute its gradient
        G = sum_X_spikes - self.lambda0 * dt * np.dot(np.transpose(X),
                                                      expXbeta)

        # Compute its Hessian
        H = -self.lambda0 * dt * np.dot(np.transpose(X) * expXbeta, X)

        return (L, G, H)

    def buildXmatrix_staticThreshold(self, tr, V_est):
        """
        Use this function to fit a model in which the firing threshold dynamics is defined as:
        V_T(t) = Vt_star (i.e., no spike-triggered movement of the firing threshold).
        This function computes the matrix X made of vectors y simlar to the ones defined in Eq. 21 (Pozzorini et al. 2015).
        In contrast ot Eq. 21, the X matrix computed here does not include the columns related to the spike-triggered threshold movement.
        """

        # Get indices be removing absolute refractory periods (-self.dt is to not include the time of spike)
        selection = tr.getROI_FarFromSpikes(-self.dt, self.Tref)
        T_l_selection = len(selection)

        # Get spike indices in coordinates of selection
        spk_train = tr.getSpikeTrain()
        spks_i_afterselection = np.where(spk_train[selection] == 1)[0]

        # Compute average firing rate used in the fit
        T_l = T_l_selection * tr.dt / 1000.0  # Total duration of trace used for fit (in s)
        N_spikes = len(
            spks_i_afterselection)  # Nb of spikes in the trace used for fit

        # Define X matrix
        X = np.zeros((T_l_selection, 2))
        X[:, 0] = V_est[selection]
        X[:, 1] = np.ones(T_l_selection)

        # Select time steps in which the neuron has emitted a spike
        X_spikes = X[spks_i_afterselection, :]

        # Sum X_spike over spikes
        sum_X_spikes = np.sum(X_spikes, axis=0)

        return (X, X_spikes, sum_X_spikes, N_spikes, T_l)

    def buildXmatrix_dynamicThreshold(self, tr, V_est):
        """
        Use this function to fit a model in which the firing threshold dynamics is defined as:
        V_T(t) = Vt_star + sum_i gamma(t-\hat t_i) (i.e., model with spike-triggered movement of the threshold)
        This function computes the matrix X made of vectors y defined as in Eq. 21 (Pozzorini et al. 2015).
        """

        # Get indices be removing absolute refractory periods (-self.dt is to not include the time of spike)
        selection = tr.getROI_FarFromSpikes(-tr.dt, self.Tref)
        T_l_selection = len(selection)

        # Get spike indices in coordinates of selection
        spk_train = tr.getSpikeTrain()
        spks_i_afterselection = np.where(spk_train[selection] == 1)[0]

        # Compute average firing rate used in the fit
        T_l = T_l_selection * tr.dt / 1000.0  # Total duration of trace used for fit (in s)
        N_spikes = len(
            spks_i_afterselection)  # Nb of spikes in the trace used for fit

        # Define X matrix
        X = np.zeros((T_l_selection, 2))
        X[:, 0] = V_est[selection]
        X[:, 1] = np.ones(T_l_selection)

        # Compute and fill the remaining columns associated with the spike-triggered current gamma
        X_gamma = self.gamma.convolution_Spiketrain_basisfunctions(
            tr.getSpikeTimes() + self.Tref, tr.T, tr.dt)
        X = np.concatenate((X, X_gamma[selection, :]), axis=1)

        # Precompute other quantities to speedup fitting
        X_spikes = X[spks_i_afterselection, :]
        sum_X_spikes = np.sum(X_spikes, axis=0)

        return (X, X_spikes, sum_X_spikes, N_spikes, T_l)

    ########################################################################################################
    # PLOT AND PRINT FUNCTIONS
    ########################################################################################################

    def plotParameters(self):
        """
        Generate figure with model filters.
        """

        plt.figure(facecolor='white', figsize=(14, 4))

        # Plot kappa
        plt.subplot(1, 3, 1)

        K_support = np.linspace(0, 150.0, 300)
        K = 1. / self.C * np.exp(-K_support / (self.C / self.gl))

        plt.plot(K_support, K, color='red', lw=2)
        plt.plot([K_support[0], K_support[-1]], [0, 0],
                 ls=':',
                 color='black',
                 lw=2)

        plt.xlim([K_support[0], K_support[-1]])
        plt.xlabel("Time (ms)")
        plt.ylabel("Membrane filter (MOhm/ms)")

        # Plot eta
        plt.subplot(1, 3, 2)

        (eta_support, eta) = self.eta.getInterpolatedFilter(self.dt)

        plt.plot(eta_support, eta, color='red', lw=2)
        plt.plot([eta_support[0], eta_support[-1]], [0, 0],
                 ls=':',
                 color='black',
                 lw=2)

        plt.xlim([eta_support[0], eta_support[-1]])
        plt.xlabel("Time (ms)")
        plt.ylabel("Eta (nA)")

        # Plot gamma
        plt.subplot(1, 3, 3)

        (gamma_support, gamma) = self.gamma.getInterpolatedFilter(self.dt)

        plt.plot(gamma_support, gamma, color='red', lw=2)
        plt.plot([gamma_support[0], gamma_support[-1]], [0, 0],
                 ls=':',
                 color='black',
                 lw=2)

        plt.xlim([gamma_support[0], gamma_support[-1]])
        plt.xlabel("Time (ms)")
        plt.ylabel("Gamma (mV)")
        plt.subplots_adjust(left=0.05,
                            bottom=0.15,
                            right=0.95,
                            top=0.92,
                            wspace=0.35,
                            hspace=0.25)

        plt.show()

    def printParameters(self):
        """
        Print model parameters on terminal.
        """

        print "\n-------------------------"
        print "GIF model parameters:"
        print "-------------------------"
        print "tau_m (ms):\t%0.3f" % (self.C / self.gl)
        print "R (MOhm):\t%0.3f" % (1.0 / self.gl)
        print "C (nF):\t\t%0.3f" % (self.C)
        print "gl (nS):\t%0.6f" % (self.gl)
        print "El (mV):\t%0.3f" % (self.El)
        print "Tref (ms):\t%0.3f" % (self.Tref)
        print "Vr (mV):\t%0.3f" % (self.Vr)
        print "Vt* (mV):\t%0.3f" % (self.Vt_star)
        print "DV (mV):\t%0.3f" % (self.DV)
        print "-------------------------\n"

    @classmethod
    def compareModels(cls, GIFs, labels=None):
        """
        Given a list of GIF models, GIFs, the function produce a plot in which the model parameters are compared.
        """

        # PRINT PARAMETERS

        print "\n#####################################"
        print "GIF model comparison"
        print "#####################################\n"

        cnt = 0
        for GIF in GIFs:

            #print "Model: " + labels[cnt]
            GIF.printParameters()
            cnt += 1

        print "#####################################\n"

        # PLOT PARAMETERS
        plt.figure(facecolor='white', figsize=(9, 8))

        colors = plt.cm.jet(np.linspace(0.7, 1.0, len(GIFs)))

        # Membrane filter
        plt.subplot(2, 2, 1)

        cnt = 0
        for GIF in GIFs:

            K_support = np.linspace(0, 150.0, 1500)
            K = 1. / GIF.C * np.exp(-K_support / (GIF.C / GIF.gl))
            plt.plot(K_support, K, color=colors[cnt], lw=2)
            cnt += 1

        plt.plot([K_support[0], K_support[-1]], [0, 0],
                 ls=':',
                 color='black',
                 lw=2,
                 zorder=-1)

        plt.xlim([K_support[0], K_support[-1]])
        plt.xlabel('Time (ms)')
        plt.ylabel('Membrane filter (MOhm/ms)')

        # Spike triggered current
        plt.subplot(2, 2, 2)

        cnt = 0
        for GIF in GIFs:

            if labels == None:
                label_tmp = ""
            else:
                label_tmp = labels[cnt]

            (eta_support, eta) = GIF.eta.getInterpolatedFilter(0.1)
            plt.plot(eta_support,
                     eta,
                     color=colors[cnt],
                     lw=2,
                     label=label_tmp)
            cnt += 1

        plt.plot([eta_support[0], eta_support[-1]], [0, 0],
                 ls=':',
                 color='black',
                 lw=2,
                 zorder=-1)

        if labels != None:
            plt.legend()

        plt.xlim([eta_support[0], eta_support[-1]])
        plt.xlabel('Time (ms)')
        plt.ylabel('Eta (nA)')

        # Escape rate
        plt.subplot(2, 2, 3)

        cnt = 0
        for GIF in GIFs:

            V_support = np.linspace(GIF.Vt_star - 5 * GIF.DV,
                                    GIF.Vt_star + 10 * GIF.DV, 1000)
            escape_rate = GIF.lambda0 * np.exp(
                (V_support - GIF.Vt_star) / GIF.DV)
            plt.plot(V_support, escape_rate, color=colors[cnt], lw=2)
            cnt += 1

        plt.ylim([0, 100])
        plt.plot([V_support[0], V_support[-1]], [0, 0],
                 ls=':',
                 color='black',
                 lw=2,
                 zorder=-1)

        plt.xlim([V_support[0], V_support[-1]])
        plt.xlabel('Membrane potential (mV)')
        plt.ylabel('Escape rate (Hz)')

        # Spike triggered threshold movememnt
        plt.subplot(2, 2, 4)

        cnt = 0
        for GIF in GIFs:

            (gamma_support, gamma) = GIF.gamma.getInterpolatedFilter(0.1)
            plt.plot(gamma_support, gamma, color=colors[cnt], lw=2)
            cnt += 1

        plt.plot([gamma_support[0], gamma_support[-1]], [0, 0],
                 ls=':',
                 color='black',
                 lw=2,
                 zorder=-1)

        plt.xlim([gamma_support[0] + 0.1, gamma_support[-1]])
        plt.ylim([-100, 100])
        plt.xlabel('Time (ms)')
        plt.ylabel('Gamma (mV)')

        plt.subplots_adjust(left=0.08,
                            bottom=0.10,
                            right=0.95,
                            top=0.93,
                            wspace=0.25,
                            hspace=0.25)

        plt.show()

    @classmethod
    def plotAverageModel(cls, GIFs):
        """
        Average model parameters and plot summary data.
        """

        #######################################################################################################
        # PLOT PARAMETERS
        #######################################################################################################

        fig = plt.figure(facecolor='white', figsize=(16, 7))
        fig.subplots_adjust(left=0.07,
                            bottom=0.08,
                            right=0.95,
                            top=0.90,
                            wspace=0.35,
                            hspace=0.5)
        rcParams['xtick.direction'] = 'out'
        rcParams['ytick.direction'] = 'out'

        # MEMBRANE FILTER
        #######################################################################################################

        plt.subplot(2, 4, 1)

        K_all = []

        for GIF in GIFs:

            K_support = np.linspace(0, 150.0, 300)
            K = 1. / GIF.C * np.exp(-K_support / (GIF.C / GIF.gl))
            plt.plot(K_support, K, color='0.3', lw=1, zorder=5)

            K_all.append(K)

        K_mean = np.mean(K_all, axis=0)
        K_std = np.std(K_all, axis=0)

        plt.fill_between(K_support,
                         K_mean + K_std,
                         y2=K_mean - K_std,
                         color='gray',
                         zorder=0)
        plt.plot(K_support,
                 np.mean(K_all, axis=0),
                 color='red',
                 lw=2,
                 zorder=10)
        plt.plot([K_support[0], K_support[-1]], [0, 0],
                 ls=':',
                 color='black',
                 lw=2,
                 zorder=-1)

        Tools.removeAxis(plt.gca(), ['top', 'right'])
        plt.xlim([K_support[0], K_support[-1]])
        plt.xlabel('Time (ms)')
        plt.ylabel('Membrane filter (MOhm/ms)')

        # SPIKE-TRIGGERED CURRENT
        #######################################################################################################

        plt.subplot(2, 4, 2)

        K_all = []

        for GIF in GIFs:

            (K_support, K) = GIF.eta.getInterpolatedFilter(0.1)

            plt.plot(K_support, K, color='0.3', lw=1, zorder=5)

            K_all.append(K)

        K_mean = np.mean(K_all, axis=0)
        K_std = np.std(K_all, axis=0)

        plt.fill_between(K_support,
                         K_mean + K_std,
                         y2=K_mean - K_std,
                         color='gray',
                         zorder=0)
        plt.plot(K_support,
                 np.mean(K_all, axis=0),
                 color='red',
                 lw=2,
                 zorder=10)
        plt.plot([K_support[0], K_support[-1]], [0, 0],
                 ls=':',
                 color='black',
                 lw=2,
                 zorder=-1)

        Tools.removeAxis(plt.gca(), ['top', 'right'])
        plt.xlim([K_support[0], K_support[-1] / 10.0])
        plt.xlabel('Time (ms)')
        plt.ylabel('Spike-triggered current (nA)')

        # SPIKE-TRIGGERED MOVEMENT OF THE FIRING THRESHOLD
        #######################################################################################################

        plt.subplot(2, 4, 3)

        K_all = []

        for GIF in GIFs:

            (K_support, K) = GIF.gamma.getInterpolatedFilter(0.1)

            plt.plot(K_support, K, color='0.3', lw=1, zorder=5)

            K_all.append(K)

        K_mean = np.mean(K_all, axis=0)
        K_std = np.std(K_all, axis=0)

        plt.fill_between(K_support,
                         K_mean + K_std,
                         y2=K_mean - K_std,
                         color='gray',
                         zorder=0)
        plt.plot(K_support,
                 np.mean(K_all, axis=0),
                 color='red',
                 lw=2,
                 zorder=10)
        plt.plot([K_support[0], K_support[-1]], [0, 0],
                 ls=':',
                 color='black',
                 lw=2,
                 zorder=-1)

        plt.xlim([K_support[0], K_support[-1]])
        Tools.removeAxis(plt.gca(), ['top', 'right'])
        plt.xlabel('Time (ms)')
        plt.ylabel('Spike-triggered threshold (mV)')

        # R
        #######################################################################################################

        plt.subplot(4, 6, 12 + 1)

        p_all = []
        for GIF in GIFs:

            p = 1. / GIF.gl
            p_all.append(p)

        plt.hist(p_all, histtype='bar', color='red', ec='white', lw=2)
        plt.xlabel('R (MOhm)')
        Tools.removeAxis(plt.gca(), ['top', 'left', 'right'])
        plt.yticks([])

        # tau_m
        #######################################################################################################

        plt.subplot(4, 6, 18 + 1)

        p_all = []
        for GIF in GIFs:

            p = GIF.C / GIF.gl
            p_all.append(p)

        plt.hist(p_all, histtype='bar', color='red', ec='white', lw=2)
        plt.xlabel('tau_m (ms)')
        Tools.removeAxis(plt.gca(), ['top', 'left', 'right'])
        plt.yticks([])

        # El
        #######################################################################################################

        plt.subplot(4, 6, 12 + 2)

        p_all = []
        for GIF in GIFs:

            p = GIF.El
            p_all.append(p)

        plt.hist(p_all, histtype='bar', color='red', ec='white', lw=2)
        plt.xlabel('El (mV)')
        Tools.removeAxis(plt.gca(), ['top', 'left', 'right'])
        plt.yticks([])

        # V reset
        #######################################################################################################

        plt.subplot(4, 6, 18 + 2)

        p_all = []
        for GIF in GIFs:

            p = GIF.Vr
            p_all.append(p)

        print "Mean Vr (mV): %0.1f" % (np.mean(p_all))

        plt.hist(p_all, histtype='bar', color='red', ec='white', lw=2)
        plt.xlabel('Vr (mV)')
        Tools.removeAxis(plt.gca(), ['top', 'left', 'right'])
        plt.yticks([])

        # Vt*
        #######################################################################################################

        plt.subplot(4, 6, 12 + 3)

        p_all = []
        for GIF in GIFs:

            p = GIF.Vt_star
            p_all.append(p)

        plt.hist(p_all, histtype='bar', color='red', ec='white', lw=2)
        plt.xlabel('Vt_star (mV)')
        Tools.removeAxis(plt.gca(), ['top', 'left', 'right'])
        plt.yticks([])

        # Vt*
        #######################################################################################################

        plt.subplot(4, 6, 18 + 3)

        p_all = []
        for GIF in GIFs:

            p = GIF.DV
            p_all.append(p)

        plt.hist(p_all, histtype='bar', color='red', ec='white', lw=2)
        plt.xlabel('DV (mV)')
        Tools.removeAxis(plt.gca(), ['top', 'left', 'right'])
        plt.yticks([])
Exemplo n.º 6
0
def process_all_files_for_iGIF_Ca_NP(is_E_Ca_fixed=True):
    if is_E_Ca_fixed:
        spec_GIF_Ca = 'ECa_fixed_'
    else:
        spec_GIF_Ca = 'ECa_free_'
    Md_star = {}
    epsilon_V_test = {}
    PVar = {}

    # List separate experiments in separate folder
    data_folders_for_separate_experiments = [
        'seventh_set', 'eighth_set', 'ninth_set'
    ]

    # For all experiments, extract the cell names
    CellNames = {}
    for experiment_folder in data_folders_for_separate_experiments:
        folder_path = './' + experiment_folder + '/'
        CellNames[experiment_folder] = [
            name for name in os.listdir(folder_path)
            if os.path.isdir(folder_path + name) and '_5HT' in name
        ]
    CellNames['eighth_set'].remove('DRN157_5HT')  # problematic cell
    CellNames['eighth_set'].remove('DRN164_5HT')  # problematic cell

    for experiment_folder in data_folders_for_separate_experiments:
        for cell_name in CellNames[experiment_folder]:
            print '\n\n#############################################'
            print '##########     process cell %s    ###' % cell_name
            print '#############################################'

            #################################################################################################
            # Load data
            #################################################################################################

            path_data = './' + experiment_folder + '/' + cell_name + '/'
            path_results = './Results/' + cell_name + '/'

            # Find extension of data files
            file_names = os.listdir(path_data)
            for file_name in file_names:
                if '.abf' in file_name:
                    ext = '.abf'
                    break
                elif '.mat' in file_name:
                    ext = '.mat'
                    break

            # Load AEC data
            filename_AEC = path_data + cell_name + '_aec' + ext
            (sampling_timeAEC, voltage_traceAEC,
             current_traceAEC) = load_AEC_data(filename_AEC)

            # Create experiment
            experiment = Experiment('Experiment 1', sampling_timeAEC)
            experiment.setAECTrace(voltage_traceAEC,
                                   10.**-3,
                                   current_traceAEC,
                                   10.**-12,
                                   len(voltage_traceAEC) * sampling_timeAEC,
                                   FILETYPE='Array')

            # Load training set data and add to experiment object
            filename_training = path_data + cell_name + '_training' + ext
            (sampling_time, voltage_trace, current_trace,
             time) = load_training_data(filename_training)
            experiment.addTrainingSetTrace(voltage_trace,
                                           10**-3,
                                           current_trace,
                                           10**-12,
                                           len(voltage_trace) * sampling_time,
                                           FILETYPE='Array')
            #Note: once added to experiment, current is converted to nA.

            # Load test set data
            filename_test = path_data + cell_name + '_test' + ext
            if filename_test.find('.mat') > 0:
                mat_contents = sio.loadmat(filename_test)
                analogSignals = mat_contents['analogSignals']
                times_test = mat_contents['times']
                times_test = times_test.reshape(times_test.size)
                times_test = times_test * 10.**3
                sampling_time_test = times_test[1] - times_test[0]
                for testnum in range(analogSignals.shape[1]):
                    voltage_test = analogSignals[0, testnum, :]
                    current_test = analogSignals[1, testnum, :] - 5.
                    experiment.addTestSetTrace(voltage_test,
                                               10.**-3,
                                               current_test,
                                               10.**-12,
                                               len(voltage_test) *
                                               sampling_time_test,
                                               FILETYPE='Array')
            elif filename_test.find('.abf') > 0:
                r = neo.io.AxonIO(filename=filename_test)
                bl = r.read_block()
                times_test = bl.segments[0].analogsignals[0].times.rescale(
                    'ms').magnitude
                sampling_time_test = times_test[1] - times_test[0]
                for i in xrange(len(bl.segments)):
                    voltage_test = bl.segments[i].analogsignals[0].magnitude
                    current_test = bl.segments[i].analogsignals[
                        1].magnitude - 5.
                    experiment.addTestSetTrace(voltage_test,
                                               10.**-3,
                                               current_test,
                                               10.**-12,
                                               len(voltage_test) *
                                               sampling_time_test,
                                               FILETYPE='Array')

            #################################################################################################
            # PERFORM ACTIVE ELECTRODE COMPENSATION
            #################################################################################################

            # Create new object to perform AEC
            myAEC = AEC_Badel(experiment.dt)

            # Define metaparametres
            myAEC.K_opt.setMetaParameters(length=150.0,
                                          binsize_lb=experiment.dt,
                                          binsize_ub=2.0,
                                          slope=30.0,
                                          clamp_period=1.0)
            myAEC.p_expFitRange = [3.0, 150.0]
            myAEC.p_nbRep = 15

            # Assign myAEC to experiment and compensate the voltage recordings
            experiment.setAEC(myAEC)
            experiment.performAEC()

            #################################################################################################
            # FIT GIF-Ca
            #################################################################################################

            # Create a new object GIF
            iGIF_Ca_NP_fit = iGIF_Ca_NP(experiment.dt)

            # Define parameters
            iGIF_Ca_NP_fit.Tref = 6.0
            iGIF_Ca_NP_fit.eta = Filter_Rect_LogSpaced()
            iGIF_Ca_NP_fit.eta.setMetaParameters(length=2000.0,
                                                 binsize_lb=0.5,
                                                 binsize_ub=500.0,
                                                 slope=10.0)
            iGIF_Ca_NP_fit.gamma = Filter_Rect_LogSpaced()
            iGIF_Ca_NP_fit.gamma.setMetaParameters(length=2000.0,
                                                   binsize_lb=2.0,
                                                   binsize_ub=500.0,
                                                   slope=5.0)

            for tr in experiment.trainingset_traces:
                tr.setROI(
                    [[2000.,
                      sampling_time * (len(voltage_trace) - 1) - 2000.]])

            # Define metaparameters used during the fit
            theta_inf_nbbins = 10  # Number of rect functions used to define the nonlinear coupling between
            theta_range_min = 10.
            theta_range_max = 20.
            theta_tau_all = np.linspace(
                theta_range_min, theta_range_max, theta_inf_nbbins
            )  # tau_theta is the timescale of the threshold-voltage coupling
            likelihoods = iGIF_Ca_NP_fit.fit(experiment,
                                             theta_inf_nbbins=theta_inf_nbbins,
                                             theta_tau_all=theta_tau_all,
                                             DT_beforeSpike=5.0,
                                             is_E_Ca_fixed=is_E_Ca_fixed)

            if iGIF_Ca_NP_fit.theta_tau < theta_tau_all[0] + 0.1 * (20. -
                                                                    10.) / 9.:
                theta_tau_all = np.linspace(theta_range_min - 9.,
                                            theta_range_max - 10.,
                                            theta_inf_nbbins)
                likelihoods = iGIF_Ca_NP_fit.fit(
                    experiment,
                    theta_inf_nbbins=theta_inf_nbbins,
                    theta_tau_all=theta_tau_all,
                    DT_beforeSpike=5.0,
                    is_E_Ca_fixed=is_E_Ca_fixed)

            while iGIF_Ca_NP_fit.theta_tau > theta_tau_all[-1] - 0.1 * (
                    20. - 10.) / 9.:
                theta_range_min = theta_range_min + 10.
                theta_range_max = theta_range_max + 10.
                theta_tau_all = np.linspace(theta_range_min, theta_range_max,
                                            theta_inf_nbbins)
                print 'testing range for theta_tau = [%f, %f]...' % (
                    theta_range_min, theta_range_max)
                likelihoods = iGIF_Ca_NP_fit.fit(
                    experiment,
                    theta_inf_nbbins=theta_inf_nbbins,
                    theta_tau_all=theta_tau_all,
                    DT_beforeSpike=5.0,
                    is_E_Ca_fixed=is_E_Ca_fixed)
                print 'max likelihood = %f' % np.max(np.array(likelihoods))
            iGIF_Ca_NP_fit.save(path_results + cell_name + '_iGIF_Ca_NP_' +
                                spec_GIF_Ca + 'ModelParams.pck')

            ###################################################################################################
            # EVALUATE MODEL PERFORMANCES ON THE TEST SET DATA
            ###################################################################################################

            # predict spike times in test set
            prediction = experiment.predictSpikes(iGIF_Ca_NP_fit, nb_rep=500)

            # Compute epsilon_V
            epsilon_V = 0.
            local_counter = 0.
            for tr in experiment.testset_traces:
                SSE = 0.
                VAR = 0.
                # tr.detectSpikesWithDerivative(threshold=10)
                (time, V_est, eta_sum_est
                 ) = iGIF_Ca_NP_fit.simulateDeterministic_forceSpikes(
                     tr.I, tr.V[0], tr.getSpikeTimes())
                indices_tmp = tr.getROI_FarFromSpikes(5., iGIF_Ca_NP_fit.Tref)

                SSE += sum((V_est[indices_tmp] - tr.V[indices_tmp])**2)
                VAR += len(indices_tmp) * np.var(tr.V[indices_tmp])
                epsilon_V += 1.0 - SSE / VAR
                local_counter += 1
            epsilon_V = epsilon_V / local_counter
            epsilon_V_test[cell_name] = epsilon_V

            # Compute Md*
            Md_star[cell_name] = prediction.computeMD_Kistler(
                8.0, iGIF_Ca_NP_fit.dt * 2.)
            fname = path_results + cell_name + '_iGIF_Ca_NP_' + spec_GIF_Ca + 'Raster.png'
            kernelForPSTH = 50.0
            PVar[cell_name] = prediction.plotRaster(fname, delta=kernelForPSTH)

            #################################################################################################
            #  PLOT TRAINING AND TEST TRACES, MODEL VS EXPERIMENT
            #################################################################################################

            #Comparison for training and test sets w/o inactivation
            V_training = experiment.trainingset_traces[0].V
            I_training = experiment.trainingset_traces[0].I
            (time, V, eta_sum, V_t,
             S) = iGIF_Ca_NP_fit.simulate(I_training, V_training[0])
            fig = plt.figure(figsize=(10, 6), facecolor='white')
            plt.subplot(2, 1, 1)
            plt.plot(time / 1000, V, '--r', lw=0.5, label='iGIF-Ca-NP')
            plt.plot(time / 1000, V_training, 'black', lw=0.5, label='Data')
            plt.xlim(18, 20)
            plt.ylim(-80, 20)
            plt.ylabel('Voltage [mV]')
            plt.title('Training')

            V_test = experiment.testset_traces[0].V
            I_test = experiment.testset_traces[0].I
            (time, V, eta_sum, V_t,
             S) = iGIF_Ca_NP_fit.simulate(I_test, V_test[0])
            plt.subplot(2, 1, 2)
            plt.plot(time / 1000, V, '--r', lw=0.5, label='iGIF-Ca-NP')
            plt.plot(time / 1000, V_test, 'black', lw=0.5, label='Data')
            plt.xlim(5, 7)
            plt.ylim(-80, 20)
            plt.xlabel('Times [s]')
            plt.ylabel('Voltage [mV]')
            plt.title('Test')
            plt.legend()
            plt.tight_layout()
            plt.savefig(path_results + cell_name + '_iGIF_Ca_NP_' +
                        spec_GIF_Ca + 'simulate.png',
                        format='png')
            plt.close()

    output_file = open(
        './Results/' + 'iGIF_Ca_NP_' + spec_GIF_Ca + 'FitPerformance.dat', 'w')
    output_file.write('#Cell name\tMd*\tEpsilonV\tPVar\n')

    for experiment_folder in data_folders_for_separate_experiments:
        for cell_name in CellNames[experiment_folder]:
            output_file.write(cell_name + '\t' + str(Md_star[cell_name]) +
                              '\t' + str(epsilon_V_test[cell_name]) + '\t' +
                              str(PVar[cell_name]) + '\n')
    output_file.close()
Exemplo n.º 7
0
        myAEC.p_expFitRange = [3.0, 150.0]
        myAEC.p_nbRep = 15

        # Assign myAEC to experiment and compensate the voltage recordings
        experiment.setAEC(myAEC)
        experiment.performAEC()

        #################################################################################################
        # FIT iGIF-NP model
        #################################################################################################
        # Create a new object GIF
        iGIF_NP_fit = iGIF_NP(experiment.dt)

        # Define parameters
        iGIF_NP_fit.Tref = 6.0
        iGIF_NP_fit.eta = Filter_Rect_LogSpaced()
        iGIF_NP_fit.eta.setMetaParameters(length=2000.0,
                                          binsize_lb=0.5,
                                          binsize_ub=500.0,
                                          slope=10.0)
        iGIF_NP_fit.gamma = Filter_Rect_LogSpaced()
        iGIF_NP_fit.gamma.setMetaParameters(length=2000.0,
                                            binsize_lb=2.0,
                                            binsize_ub=500.0,
                                            slope=5.0)

        for tr in experiment.trainingset_traces:
            tr.setROI(
                [[2000., sampling_time * (len(voltage_trace) - 1) - 2000.]])

        # Define metaparameters used during the fit
Exemplo n.º 8
0
class GIF(ThresholdModel):

    """
    Generalized Integrate and Fire model
    Spike are produced stochastically with firing intensity:
    
    lambda(t) = lambda0 * exp( (V(t)-V_T(t))/DV ),
    
    
    where the membrane potential dynamics is given by:
    
    C dV/dt = -gl(V-El) + I - sum_j eta(t-\hat t_j)
    
    
    and the firing threshold V_T is given by:
    
    V_T = Vt_star + sum_j gamma(t-\hat t_j)
    
    
    and \hat t_j denote the spike times.
    """

    def __init__(self, dt=0.1):

        self.dt = dt  # dt used in simulations (eta and gamma are interpolated according to this value)

        # Define model parameters

        self.gl = 1.0 / 100.0  # nS, leak conductance
        self.C = 20.0 * self.gl  # nF, capacitance
        self.El = -65.0  # mV, reversal potential

        self.Vr = -50.0  # mV, voltage reset
        self.Tref = 4.0  # ms, absolute refractory period

        self.Vt_star = -48.0  # mV, steady state voltage threshold VT*
        self.DV = 0.5  # mV, threshold sharpness
        self.lambda0 = 1.0  # by default this parameter is always set to 1.0 Hz

        self.eta = Filter_Rect_LogSpaced()  # nA, spike-triggered current (must be instance of class Filter)
        self.gamma = (
            Filter_Rect_LogSpaced()
        )  # mV, spike-triggered movement of the firing threshold (must be instance of class Filter)

        # Varialbes relatd to fit

        self.avg_spike_shape = 0
        self.avg_spike_shape_support = 0

        # Initialize the spike-triggered current eta with an exponential function

        def expfunction_eta(x):
            return 0.2 * np.exp(-x / 100.0)

        self.eta.setFilter_Function(expfunction_eta)

        # Initialize the spike-triggered current gamma with an exponential function

        def expfunction_gamma(x):
            return 10.0 * np.exp(-x / 100.0)

        self.gamma.setFilter_Function(expfunction_gamma)

    ########################################################################################################
    # SET DT FOR NUMERICAL SIMULATIONS (KERNELS ARE REINTERPOLATED EACH TIME DT IS CHANGED)
    ########################################################################################################
    def setDt(self, dt):

        """
        Define the time step used for numerical simulations. The filters eta and gamma are interpolated accordingly.
        """

        self.dt = dt

    ########################################################################################################
    # FUNCTIONS FOR SIMULATIONS
    ########################################################################################################
    def simulateSpikingResponse(self, I, dt):

        """
        Simulate the spiking response of the GIF model to an input current I (nA) with time step dt.
        Return a list of spike times (in ms).
        The initial conditions for the simulation is V(0)=El.
        """
        self.setDt(dt)

        (time, V, eta_sum, V_T, spks_times) = self.simulate(I, self.El)

        return spks_times

    def simulateVoltageResponse(self, I, dt):

        self.setDt(dt)

        (time, V, eta_sum, V_T, spks_times) = self.simulate(I, self.El)

        return (spks_times, V, V_T)

    def simulate(self, I, V0):

        """
        Simulate the spiking response of the GIF model to an input current I (nA) with time step dt.
        V0 indicate the initial condition V(0)=V0.
        The function returns:
        - time     : ms, support for V, eta_sum, V_T, spks
        - V        : mV, membrane potential
        - eta_sum  : nA, adaptation current
        - V_T      : mV, firing threshold
        - spks     : ms, list of spike times 
        """

        # Input parameters
        p_T = len(I)
        p_dt = self.dt

        # Model parameters
        p_gl = self.gl
        p_C = self.C
        p_El = self.El
        p_Vr = self.Vr
        p_Tref = self.Tref
        p_Vt_star = self.Vt_star
        p_DV = self.DV
        p_lambda0 = self.lambda0

        # Model kernels
        (p_eta_support, p_eta) = self.eta.getInterpolatedFilter(self.dt)
        p_eta = p_eta.astype("double")
        p_eta_l = len(p_eta)

        (p_gamma_support, p_gamma) = self.gamma.getInterpolatedFilter(self.dt)
        p_gamma = p_gamma.astype("double")
        p_gamma_l = len(p_gamma)

        # Define arrays
        V = np.array(np.zeros(p_T), dtype="double")
        I = np.array(I, dtype="double")
        spks = np.array(np.zeros(p_T), dtype="double")
        eta_sum = np.array(np.zeros(p_T + 2 * p_eta_l), dtype="double")
        gamma_sum = np.array(np.zeros(p_T + 2 * p_gamma_l), dtype="double")

        # Set initial condition
        V[0] = V0

        code = """
                #include <math.h>
                
                int   T_ind      = int(p_T);                
                float dt         = float(p_dt); 
                
                float gl         = float(p_gl);
                float C          = float(p_C);
                float El         = float(p_El);
                float Vr         = float(p_Vr);
                int   Tref_ind   = int(float(p_Tref)/dt);
                float Vt_star    = float(p_Vt_star);
                float DeltaV     = float(p_DV);
                float lambda0    = float(p_lambda0);
           
                int eta_l        = int(p_eta_l);
                int gamma_l      = int(p_gamma_l);
                
                                                  
                float rand_max  = float(RAND_MAX); 
                float p_dontspike = 0.0 ;
                float lambda = 0.0 ;            
                float r = 0.0;
                
                                                
                for (int t=0; t<T_ind-1; t++) {
    
    
                    // INTEGRATE VOLTAGE
                    V[t+1] = V[t] + dt/C*( -gl*(V[t] - El) + I[t] - eta_sum[t] );
               
               
                    // COMPUTE PROBABILITY OF EMITTING ACTION POTENTIAL
                    lambda = lambda0*exp( (V[t+1]-Vt_star-gamma_sum[t])/DeltaV );
                    p_dontspike = exp(-lambda*(dt/1000.0));                                  // since lambda0 is in Hz, dt must also be in Hz (this is why dt/1000.0)
                          
                          
                    // PRODUCE SPIKE STOCHASTICALLY
                    r = rand()/rand_max;
                    if (r > p_dontspike) {
                                        
                        if (t+1 < T_ind-1)                
                            spks[t+1] = 1.0; 
                        
                        t = t + Tref_ind;    
                        
                        if (t+1 < T_ind-1) 
                            V[t+1] = Vr;
                        
                        
                        // UPDATE ADAPTATION PROCESSES     
                        for(int j=0; j<eta_l; j++) 
                            eta_sum[t+1+j] += p_eta[j]; 
                        
                        for(int j=0; j<gamma_l; j++) 
                            gamma_sum[t+1+j] += p_gamma[j] ;  
                        
                    }
               
                }
                
                """

        vars = [
            "p_T",
            "p_dt",
            "p_gl",
            "p_C",
            "p_El",
            "p_Vr",
            "p_Tref",
            "p_Vt_star",
            "p_DV",
            "p_lambda0",
            "V",
            "I",
            "p_eta",
            "p_eta_l",
            "eta_sum",
            "p_gamma",
            "gamma_sum",
            "p_gamma_l",
            "spks",
        ]

        v = weave.inline(code, vars)

        time = np.arange(p_T) * self.dt

        eta_sum = eta_sum[:p_T]
        V_T = gamma_sum[:p_T] + p_Vt_star

        spks = (np.where(spks == 1)[0]) * self.dt

        return (time, V, eta_sum, V_T, spks)

    def simulateDeterministic_forceSpikes(self, I, V0, spks):

        """
        Simulate the subthresohld response of the GIF model to an input current I (nA) with time step dt.
        Adaptation currents are enforced at times specified in the list spks (in ms) given as an argument to the function.
        V0 indicate the initial condition V(0)=V0.
        The function returns:
        - time     : ms, support for V, eta_sum, V_T, spks
        - V        : mV, membrane potential
        - eta_sum  : nA, adaptation current
        """

        # Input parameters
        p_T = len(I)
        p_dt = self.dt

        # Model parameters
        p_gl = self.gl
        p_C = self.C
        p_El = self.El
        p_Vr = self.Vr
        p_Tref = self.Tref
        p_Tref_i = int(self.Tref / self.dt)

        # Model kernel
        (p_eta_support, p_eta) = self.eta.getInterpolatedFilter(self.dt)
        p_eta = p_eta.astype("double")
        p_eta_l = len(p_eta)

        # Define arrays
        V = np.array(np.zeros(p_T), dtype="double")
        I = np.array(I, dtype="double")
        spks = np.array(spks, dtype="double")
        spks_i = Tools.timeToIndex(spks, self.dt)

        # Compute adaptation current (sum of eta triggered at spike times in spks)
        eta_sum = np.array(np.zeros(p_T + 1.1 * p_eta_l + p_Tref_i), dtype="double")

        for s in spks_i:
            eta_sum[s + 1 + p_Tref_i : s + 1 + p_Tref_i + p_eta_l] += p_eta

        eta_sum = eta_sum[:p_T]

        # Set initial condition
        V[0] = V0

        code = """ 
                #include <math.h>
                
                int   T_ind      = int(p_T);                
                float dt         = float(p_dt); 
                
                float gl         = float(p_gl);
                float C          = float(p_C);
                float El         = float(p_El);
                float Vr         = float(p_Vr);
                int   Tref_ind   = int(float(p_Tref)/dt);


                int next_spike = spks_i[0] + Tref_ind;
                int spks_cnt = 0;
 
                                                                       
                for (int t=0; t<T_ind-1; t++) {
    
    
                    // INTEGRATE VOLTAGE
                    V[t+1] = V[t] + dt/C*( -gl*(V[t] - El) + I[t] - eta_sum[t] );
               
               
                    if ( t == next_spike ) {
                        spks_cnt = spks_cnt + 1;
                        next_spike = spks_i[spks_cnt] + Tref_ind;
                        V[t-1] = 0 ;                  
                        V[t] = Vr ;
                        t=t-1;           
                    }
               
                }
        
                """

        vars = ["p_T", "p_dt", "p_gl", "p_C", "p_El", "p_Vr", "p_Tref", "V", "I", "eta_sum", "spks_i"]

        v = weave.inline(code, vars)

        time = np.arange(p_T) * self.dt
        eta_sum = eta_sum[:p_T]

        return (time, V, eta_sum)

    def fit(self, experiment, DT_beforeSpike=5.0):

        """
        Fit the GIF model on experimental data.
        The experimental data are stored in the object experiment.
        The parameter DT_beforeSpike (in ms) defines the region that is cut before each spike when fitting the subthreshold dynamics of the membrane potential.
        Only training set traces in experiment are used to perform the fit.
        """

        # Three step procedure used for parameters extraction

        print "\n################################"
        print "# Fit GIF"
        print "################################\n"

        self.fitVoltageReset(experiment, self.Tref, do_plot=False)

        self.fitSubthresholdDynamics(experiment, DT_beforeSpike=DT_beforeSpike)

        self.fitStaticThreshold(experiment)

        self.fitThresholdDynamics(experiment)

    ########################################################################################################
    # FIT VOLTAGE RESET GIVEN ABSOLUTE REFRACOTORY PERIOD (step 1)
    ########################################################################################################
    def fitVoltageReset(self, experiment, Tref, do_plot=False):

        """
        Tref: ms, absolute refractory period. 
        The voltage reset is estimated by computing the spike-triggered average of the voltage.
        """

        print "Estimate voltage reset (Tref = %0.1f ms)..." % (Tref)

        # Fix absolute refractory period
        self.dt = experiment.dt
        self.Tref = Tref

        all_spike_average = []
        all_spike_nb = 0
        for tr in experiment.trainingset_traces:

            if tr.useTrace:
                if len(tr.spks) > 0:
                    (support, spike_average, spike_nb) = tr.computeAverageSpikeShape()
                    all_spike_average.append(spike_average)
                    all_spike_nb += spike_nb

        spike_average = np.mean(all_spike_average, axis=0)

        # Estimate voltage reset
        Tref_ind = np.where(support >= self.Tref)[0][0]
        self.Vr = spike_average[Tref_ind]

        # Save average spike shape
        self.avg_spike_shape = spike_average
        self.avg_spike_shape_support = support

        if do_plot:
            plt.figure()
            plt.plot(support, spike_average, "black")
            plt.plot([support[Tref_ind]], [self.Vr], ".", color="red")
            plt.show()

        print "Done! Vr = %0.2f mV (computed on %d spikes)" % (self.Vr, all_spike_nb)

    ########################################################################################################
    # FUNCTIONS RELATED TO FIT OF SUBTHRESHOLD DYNAMICS (step 2)
    ########################################################################################################

    def fitSubthresholdDynamics(self, experiment, DT_beforeSpike=5.0):

        print "\nGIF MODEL - Fit subthreshold dynamics..."

        # Expand eta in basis functions
        self.dt = experiment.dt
        self.eta.computeBins()

        # Build X matrix and Y vector to perform linear regression (use all traces in training set)
        X = []
        Y = []

        cnt = 0

        for tr in experiment.trainingset_traces:

            if tr.useTrace:

                cnt += 1
                reprint("Compute X matrix for repetition %d" % (cnt))

                (X_tmp, Y_tmp) = self.fitSubthresholdDynamics_Build_Xmatrix_Yvector(tr, DT_beforeSpike=DT_beforeSpike)

                X.append(X_tmp)
                Y.append(Y_tmp)

        # Concatenate matrixes associated with different traces to perform a single multilinear regression
        if cnt == 1:
            X = X[0]
            Y = Y[0]

        elif cnt > 1:
            X = np.concatenate(X, axis=0)
            Y = np.concatenate(Y, axis=0)

        else:
            print "\nError, at least one training set trace should be selected to perform fit."

        # Linear Regression
        print "\nPerform linear regression..."
        XTX = np.dot(np.transpose(X), X)
        XTX_inv = inv(XTX)
        XTY = np.dot(np.transpose(X), Y)
        b = np.dot(XTX_inv, XTY)
        b = b.flatten()

        # Update and print model parameters
        self.C = 1.0 / b[1]
        self.gl = -b[0] * self.C
        self.El = b[2] * self.C / self.gl
        self.eta.setFilter_Coefficients(-b[3:] * self.C)

        self.printParameters()

        # Compute percentage of variance explained on dV/dt

        var_explained_dV = 1.0 - np.mean((Y - np.dot(X, b)) ** 2) / np.var(Y)
        print "Percentage of variance explained (on dV/dt): %0.2f" % (var_explained_dV * 100.0)

        # Compute percentage of variance explained on V

        SSE = 0  # sum of squared errors
        VAR = 0  # variance of data

        for tr in experiment.trainingset_traces:

            if tr.useTrace:

                # Simulate subthreshold dynamics
                (time, V_est, eta_sum_est) = self.simulateDeterministic_forceSpikes(tr.I, tr.V[0], tr.getSpikeTimes())

                indices_tmp = tr.getROI_FarFromSpikes(0.0, self.Tref)

                SSE += sum((V_est[indices_tmp] - tr.V[indices_tmp]) ** 2)
                VAR += len(indices_tmp) * np.var(tr.V[indices_tmp])

        var_explained_V = 1.0 - SSE / VAR

        print "Percentage of variance explained (on V): %0.2f" % (var_explained_V * 100.0)

    def fitSubthresholdDynamics_Build_Xmatrix_Yvector(self, trace, DT_beforeSpike=5.0):

        # Length of the voltage trace
        Tref_ind = int(self.Tref / trace.dt)

        # Select region where to perform linear regression
        selection = trace.getROI_FarFromSpikes(DT_beforeSpike, self.Tref)
        selection_l = len(selection)

        # Build X matrix for linear regression
        X = np.zeros((selection_l, 3))

        # Fill first two columns of X matrix
        X[:, 0] = trace.V[selection]
        X[:, 1] = trace.I[selection]
        X[:, 2] = np.ones(selection_l)

        # Compute and fill the remaining columns associated with the spike-triggered current eta
        X_eta = self.eta.convolution_Spiketrain_basisfunctions(trace.getSpikeTimes() + self.Tref, trace.T, trace.dt)
        X = np.concatenate((X, X_eta[selection, :]), axis=1)

        # Build Y vector (voltage derivative)

        # COULD BE A BETTER SOLUTION IN CASE OF EXPERIMENTAL DATA (NOT CLEAR WHY)
        # Y = np.array( np.concatenate( ([0], np.diff(trace.V)/trace.dt) ) )[selection]

        # CORRECT SOLUTION TO FIT ARTIFICIAL DATA
        Y = np.array(np.concatenate((np.diff(trace.V) / trace.dt, [0])))[selection]

        return (X, Y)

    ########################################################################################################
    # FUNCTIONS RELATED TO FIT FIRING THRESHOLD PARAMETERS (step 3)
    ########################################################################################################

    def fitStaticThreshold(self, experiment):

        self.setDt(experiment.dt)

        # Start by fitting a constant firing threshold, the result is used as initial condition to fit dynamic threshold

        print "\nGIF MODEL - Fit static threshold...\n"

        # Define initial conditions (based on the average firing rate in the training set)

        nbSpikes = 0
        duration = 0

        for tr in experiment.trainingset_traces:

            if tr.useTrace:

                nbSpikes += tr.getSpikeNb_inROI()
                duration += tr.getTraceLength_inROI()

        mean_firingrate = 1000.0 * nbSpikes / duration

        self.lambda0 = 1.0
        self.DV = 50.0
        self.Vt_star = -np.log(mean_firingrate) * self.DV

        # Perform fit
        beta0_staticThreshold = [1 / self.DV, -self.Vt_star / self.DV]
        beta_opt = self.maximizeLikelihood(experiment, beta0_staticThreshold, self.buildXmatrix_staticThreshold)

        # Store result
        self.DV = 1.0 / beta_opt[0]
        self.Vt_star = -beta_opt[1] * self.DV
        self.gamma.setFilter_toZero()

        self.printParameters()

    def fitThresholdDynamics(self, experiment):

        self.setDt(experiment.dt)

        # Fit a dynamic threshold using a initial condition the result obtained by fitting a static threshold

        print "\nGIF MODEL - Fit dynamic threshold...\n"

        # Perform fit
        beta0_dynamicThreshold = np.concatenate(
            ([1 / self.DV], [-self.Vt_star / self.DV], self.gamma.getCoefficients() / self.DV)
        )
        beta_opt = self.maximizeLikelihood(experiment, beta0_dynamicThreshold, self.buildXmatrix_dynamicThreshold)

        # Store result
        self.DV = 1.0 / beta_opt[0]
        self.Vt_star = -beta_opt[1] * self.DV
        self.gamma.setFilter_Coefficients(-beta_opt[2:] * self.DV)

        self.printParameters()

    def maximizeLikelihood(self, experiment, beta0, buildXmatrix, maxIter=10 ** 3, stopCond=10 ** -6):

        """
        Maximize likelihood. This function can be used to fit any model of the form lambda=exp(Xbeta).
        Here this function is used to fit both:
        - static threshold
        - dynamic threshold
        The difference between the two functions is in the size of beta0 and the returned beta, as well
        as the function buildXmatrix.
        """

        # Precompute all the matrices used in the gradient ascent

        all_X = []
        all_X_spikes = []
        all_sum_X_spikes = []

        T_tot = 0.0
        N_spikes_tot = 0.0

        traces_nb = 0

        for tr in experiment.trainingset_traces:

            if tr.useTrace:

                traces_nb += 1

                # Simulate subthreshold dynamics
                (time, V_est, eta_sum_est) = self.simulateDeterministic_forceSpikes(tr.I, tr.V[0], tr.getSpikeTimes())

                # Precomputes matrices to perform gradient ascent on log-likelihood
                (X_tmp, X_spikes_tmp, sum_X_spikes_tmp, N_spikes, T) = buildXmatrix(tr, V_est)

                T_tot += T
                N_spikes_tot += N_spikes

                all_X.append(X_tmp)
                all_X_spikes.append(X_spikes_tmp)
                all_sum_X_spikes.append(sum_X_spikes_tmp)

        logL_poisson = N_spikes_tot * (np.log(N_spikes_tot / T_tot) - 1)

        # Perform gradient ascent

        print "Maximize log-likelihood (bit/spks)..."

        beta = beta0
        old_L = 1

        for i in range(maxIter):

            learning_rate = 1.0

            if (
                i <= 10
            ):  # be careful in the first iterations (using a small learning rate in the first step makes the fit more stable)
                learning_rate = 0.1

            L = 0
            G = 0
            H = 0

            for trace_i in np.arange(traces_nb):
                (L_tmp, G_tmp, H_tmp) = self.computeLikelihoodGradientHessian(
                    beta, all_X[trace_i], all_X_spikes[trace_i], all_sum_X_spikes[trace_i]
                )
                L += L_tmp
                G += G_tmp
                H += H_tmp

            beta = beta - learning_rate * np.dot(inv(H), G)

            if i > 0 and abs((L - old_L) / old_L) < stopCond:  # If converged
                print "\nConverged after %d iterations!\n" % (i + 1)
                break

            old_L = L

            # Compute normalized likelihood (for print)
            # The likelihood is normalized with respect to a poisson process and units are in bit/spks
            L_norm = (L - logL_poisson) / np.log(2) / N_spikes_tot
            reprint(L_norm)

        if i == maxIter - 1:  # If too many iterations
            print "\nNot converged after %d iterations.\n" % (maxIter)

        return beta

    def computeLikelihoodGradientHessian(self, beta, X, X_spikes, sum_X_spikes):

        # IMPORTANT: in general we assume that the lambda_0 = 1 Hz
        # The parameter lambda0 is redundant with Vt_star, so only one of those has to be fitted.
        # We genearlly fix lambda_0 adn fit Vt_star

        dt = self.dt / 1000.0  # put dt in units of seconds (to be consistent with lambda_0)

        X_spikesbeta = np.dot(X_spikes, beta)
        Xbeta = np.dot(X, beta)
        expXbeta = np.exp(Xbeta)

        # Compute likelihood (would be nice to improve this to get a normalized likelihood)
        L = sum(X_spikesbeta) - self.lambda0 * dt * sum(expXbeta)

        # Compute gradient
        G = sum_X_spikes - self.lambda0 * dt * np.dot(np.transpose(X), expXbeta)

        # Compute Hessian
        H = -self.lambda0 * dt * np.dot(np.transpose(X) * expXbeta, X)

        return (L, G, H)

    def buildXmatrix_staticThreshold(self, tr, V_est):

        """
        Use this function to fit a model in which the firing threshold dynamics is defined as:
        V_T(t) = Vt_star (i.e., no spike-triggered movement of the firing threshold)
        """
        # Get indices be removing absolute refractory periods (-self.dt is to not include the time of spike)
        selection = tr.getROI_FarFromSpikes(-self.dt, self.Tref)
        T_l_selection = len(selection)

        # Get spike indices in coordinates of selection
        spk_train = tr.getSpikeTrain()
        spks_i_afterselection = np.where(spk_train[selection] == 1)[0]

        # Compute average firing rate used in the fit
        T_l = T_l_selection * tr.dt / 1000.0  # Total duration of trace used for fit (in s)
        N_spikes = len(spks_i_afterselection)  # Nb of spikes in the trace used for fit

        # Define X matrix
        X = np.zeros((T_l_selection, 2))
        X[:, 0] = V_est[selection]
        X[:, 1] = np.ones(T_l_selection)

        X_spikes = X[spks_i_afterselection, :]

        sum_X_spikes = np.sum(X_spikes, axis=0)

        return (X, X_spikes, sum_X_spikes, N_spikes, T_l)

    def buildXmatrix_dynamicThreshold(self, tr, V_est):

        """
        Use this function to fit a model in which the firing threshold dynamics is defined as:
        V_T(t) = Vt_star + sum_i gamma(t-\hat t_i) (i.e., model with spike-triggered movement of the threshold)
        """

        # Get indices be removing absolute refractory periods (-self.dt is to not include the time of spike)
        selection = tr.getROI_FarFromSpikes(-tr.dt, self.Tref)
        T_l_selection = len(selection)

        # Get spike indices in coordinates of selection
        spk_train = tr.getSpikeTrain()
        spks_i_afterselection = np.where(spk_train[selection] == 1)[0]

        # Compute average firing rate used in the fit
        T_l = T_l_selection * tr.dt / 1000.0  # Total duration of trace used for fit (in s)
        N_spikes = len(spks_i_afterselection)  # Nb of spikes in the trace used for fit

        # Define X matrix
        X = np.zeros((T_l_selection, 2))
        X[:, 0] = V_est[selection]
        X[:, 1] = np.ones(T_l_selection)

        # Compute and fill the remaining columns associated with the spike-triggered current gamma
        X_gamma = self.gamma.convolution_Spiketrain_basisfunctions(tr.getSpikeTimes() + self.Tref, tr.T, tr.dt)
        X = np.concatenate((X, X_gamma[selection, :]), axis=1)

        # Precompute other quantities
        X_spikes = X[spks_i_afterselection, :]
        sum_X_spikes = np.sum(X_spikes, axis=0)

        return (X, X_spikes, sum_X_spikes, N_spikes, T_l)

    ########################################################################################################
    # PLOT AND PRINT FUNCTIONS
    ########################################################################################################

    def plotParameters(self):

        plt.figure(facecolor="white", figsize=(14, 4))

        # Plot kappa
        plt.subplot(1, 3, 1)

        K_support = np.linspace(0, 150.0, 300)
        K = 1.0 / self.C * np.exp(-K_support / (self.C / self.gl))

        plt.plot(K_support, K, color="red", lw=2)
        plt.plot([K_support[0], K_support[-1]], [0, 0], ls=":", color="black", lw=2)

        plt.xlim([K_support[0], K_support[-1]])
        plt.xlabel("Time (ms)")
        plt.ylabel("Membrane filter (MOhm/ms)")

        # Plot eta
        plt.subplot(1, 3, 2)

        (eta_support, eta) = self.eta.getInterpolatedFilter(self.dt)

        plt.plot(eta_support, eta, color="red", lw=2)
        plt.plot([eta_support[0], eta_support[-1]], [0, 0], ls=":", color="black", lw=2)

        plt.xlim([eta_support[0], eta_support[-1]])
        plt.xlabel("Time (ms)")
        plt.ylabel("Eta (nA)")

        # Plot gamma
        plt.subplot(1, 3, 3)

        (gamma_support, gamma) = self.gamma.getInterpolatedFilter(self.dt)

        plt.plot(gamma_support, gamma, color="red", lw=2)
        plt.plot([gamma_support[0], gamma_support[-1]], [0, 0], ls=":", color="black", lw=2)

        plt.xlim([gamma_support[0], gamma_support[-1]])
        plt.xlabel("Time (ms)")
        plt.ylabel("Gamma (mV)")
        plt.subplots_adjust(left=0.05, bottom=0.15, right=0.95, top=0.92, wspace=0.35, hspace=0.25)

        plt.show()

    def printParameters(self):

        print "\n-------------------------"
        print "GIF model parameters:"
        print "-------------------------"
        print "tau_m (ms):\t%0.3f" % (self.C / self.gl)
        print "R (MOhm):\t%0.3f" % (1.0 / self.gl)
        print "C (nF):\t\t%0.3f" % (self.C)
        print "gl (nS):\t%0.3f" % (self.gl)
        print "El (mV):\t%0.3f" % (self.El)
        print "Tref (ms):\t%0.3f" % (self.Tref)
        print "Vr (mV):\t%0.3f" % (self.Vr)
        print "Vt* (mV):\t%0.3f" % (self.Vt_star)
        print "DV (mV):\t%0.3f" % (self.DV)
        print "-------------------------\n"

    @classmethod
    def compareModels(cls, GIFs, labels=None):

        """
        Given a list of GIF models, GIFs, the function produce a plot in which the model parameters are compared.
        """

        # PRINT PARAMETERS

        print "\n#####################################"
        print "GIF model comparison"
        print "#####################################\n"

        cnt = 0
        for GIF in GIFs:

            print "Model: " + labels[cnt]
            GIF.printParameters()
            cnt += 1

        print "#####################################\n"

        # PLOT PARAMETERS
        plt.figure(facecolor="white", figsize=(9, 8))

        colors = plt.cm.jet(np.linspace(0.7, 1.0, len(GIFs)))

        # Membrane filter
        plt.subplot(2, 2, 1)

        cnt = 0
        for GIF in GIFs:

            K_support = np.linspace(0, 150.0, 1500)
            K = 1.0 / GIF.C * np.exp(-K_support / (GIF.C / GIF.gl))
            plt.plot(K_support, K, color=colors[cnt], lw=2)
            cnt += 1

        plt.plot([K_support[0], K_support[-1]], [0, 0], ls=":", color="black", lw=2, zorder=-1)

        plt.xlim([K_support[0], K_support[-1]])
        plt.xlabel("Time (ms)")
        plt.ylabel("Membrane filter (MOhm/ms)")

        # Spike triggered current
        plt.subplot(2, 2, 2)

        cnt = 0
        for GIF in GIFs:

            if labels == None:
                label_tmp = ""
            else:
                label_tmp = labels[cnt]

            (eta_support, eta) = GIF.eta.getInterpolatedFilter(0.1)
            plt.plot(eta_support, eta, color=colors[cnt], lw=2, label=label_tmp)
            cnt += 1

        plt.plot([eta_support[0], eta_support[-1]], [0, 0], ls=":", color="black", lw=2, zorder=-1)

        if labels != None:
            plt.legend()

        plt.xlim([eta_support[0], eta_support[-1]])
        plt.xlabel("Time (ms)")
        plt.ylabel("Eta (nA)")

        # Escape rate
        plt.subplot(2, 2, 3)

        cnt = 0
        for GIF in GIFs:

            V_support = np.linspace(GIF.Vt_star - 5 * GIF.DV, GIF.Vt_star + 10 * GIF.DV, 1000)
            escape_rate = GIF.lambda0 * np.exp((V_support - GIF.Vt_star) / GIF.DV)
            plt.plot(V_support, escape_rate, color=colors[cnt], lw=2)
            cnt += 1

        plt.ylim([0, 100])
        plt.plot([V_support[0], V_support[-1]], [0, 0], ls=":", color="black", lw=2, zorder=-1)

        plt.xlim([V_support[0], V_support[-1]])
        plt.xlabel("Membrane potential (mV)")
        plt.ylabel("Escape rate (Hz)")

        # Spike triggered threshold movememnt
        plt.subplot(2, 2, 4)

        cnt = 0
        for GIF in GIFs:

            (gamma_support, gamma) = GIF.gamma.getInterpolatedFilter(0.1)
            plt.plot(gamma_support, gamma, color=colors[cnt], lw=2)
            cnt += 1

        plt.plot([gamma_support[0], gamma_support[-1]], [0, 0], ls=":", color="black", lw=2, zorder=-1)

        plt.xlim([gamma_support[0] + 0.1, gamma_support[-1]])
        plt.xlabel("Time (ms)")
        plt.ylabel("Gamma (mV)")

        plt.subplots_adjust(left=0.08, bottom=0.10, right=0.95, top=0.93, wspace=0.25, hspace=0.25)

        plt.show()

    @classmethod
    def plotAverageModel(cls, GIFs):

        """
        Average model parameters and plot summary data.
        """

        #######################################################################################################
        # PLOT PARAMETERS
        #######################################################################################################

        fig = plt.figure(facecolor="white", figsize=(16, 7))
        fig.subplots_adjust(left=0.07, bottom=0.08, right=0.95, top=0.90, wspace=0.35, hspace=0.5)
        rcParams["xtick.direction"] = "out"
        rcParams["ytick.direction"] = "out"

        # MEMBRANE FILTER
        #######################################################################################################

        plt.subplot(2, 4, 1)

        K_all = []

        for GIF in GIFs:

            K_support = np.linspace(0, 150.0, 300)
            K = 1.0 / GIF.C * np.exp(-K_support / (GIF.C / GIF.gl))
            plt.plot(K_support, K, color="0.3", lw=1, zorder=5)

            K_all.append(K)

        K_mean = np.mean(K_all, axis=0)
        K_std = np.std(K_all, axis=0)

        plt.fill_between(K_support, K_mean + K_std, y2=K_mean - K_std, color="gray", zorder=0)
        plt.plot(K_support, np.mean(K_all, axis=0), color="red", lw=2, zorder=10)
        plt.plot([K_support[0], K_support[-1]], [0, 0], ls=":", color="black", lw=2, zorder=-1)

        Tools.removeAxis(plt.gca(), ["top", "right"])
        plt.xlim([K_support[0], K_support[-1]])
        plt.xlabel("Time (ms)")
        plt.ylabel("Membrane filter (MOhm/ms)")

        # SPIKE-TRIGGERED CURRENT
        #######################################################################################################

        plt.subplot(2, 4, 2)

        K_all = []

        for GIF in GIFs:

            (K_support, K) = GIF.eta.getInterpolatedFilter(0.1)

            plt.plot(K_support, K, color="0.3", lw=1, zorder=5)

            K_all.append(K)

        K_mean = np.mean(K_all, axis=0)
        K_std = np.std(K_all, axis=0)

        plt.fill_between(K_support, K_mean + K_std, y2=K_mean - K_std, color="gray", zorder=0)
        plt.plot(K_support, np.mean(K_all, axis=0), color="red", lw=2, zorder=10)
        plt.plot([K_support[0], K_support[-1]], [0, 0], ls=":", color="black", lw=2, zorder=-1)

        Tools.removeAxis(plt.gca(), ["top", "right"])
        plt.xlim([K_support[0], K_support[-1] / 10.0])
        plt.xlabel("Time (ms)")
        plt.ylabel("Spike-triggered current (nA)")

        # SPIKE-TRIGGERED MOVEMENT OF THE FIRING THRESHOLD
        #######################################################################################################

        plt.subplot(2, 4, 3)

        K_all = []

        for GIF in GIFs:

            (K_support, K) = GIF.gamma.getInterpolatedFilter(0.1)

            plt.plot(K_support, K, color="0.3", lw=1, zorder=5)

            K_all.append(K)

        K_mean = np.mean(K_all, axis=0)
        K_std = np.std(K_all, axis=0)

        plt.fill_between(K_support, K_mean + K_std, y2=K_mean - K_std, color="gray", zorder=0)
        plt.plot(K_support, np.mean(K_all, axis=0), color="red", lw=2, zorder=10)
        plt.plot([K_support[0], K_support[-1]], [0, 0], ls=":", color="black", lw=2, zorder=-1)

        plt.xlim([K_support[0], K_support[-1]])
        Tools.removeAxis(plt.gca(), ["top", "right"])
        plt.xlabel("Time (ms)")
        plt.ylabel("Spike-triggered threshold (mV)")

        # R
        #######################################################################################################

        plt.subplot(4, 6, 12 + 1)

        p_all = []
        for GIF in GIFs:

            p = 1.0 / GIF.gl
            p_all.append(p)

        plt.hist(p_all, histtype="bar", color="red", ec="white", lw=2)
        plt.xlabel("R (MOhm)")
        Tools.removeAxis(plt.gca(), ["top", "left", "right"])
        plt.yticks([])

        # tau_m
        #######################################################################################################

        plt.subplot(4, 6, 18 + 1)

        p_all = []
        for GIF in GIFs:

            p = GIF.C / GIF.gl
            p_all.append(p)

        plt.hist(p_all, histtype="bar", color="red", ec="white", lw=2)
        plt.xlabel("tau_m (ms)")
        Tools.removeAxis(plt.gca(), ["top", "left", "right"])
        plt.yticks([])

        # El
        #######################################################################################################

        plt.subplot(4, 6, 12 + 2)

        p_all = []
        for GIF in GIFs:

            p = GIF.El
            p_all.append(p)

        plt.hist(p_all, histtype="bar", color="red", ec="white", lw=2)
        plt.xlabel("El (mV)")
        Tools.removeAxis(plt.gca(), ["top", "left", "right"])
        plt.yticks([])

        # V reset
        #######################################################################################################

        plt.subplot(4, 6, 18 + 2)

        p_all = []
        for GIF in GIFs:

            p = GIF.Vr
            p_all.append(p)

        print "Mean Vr (mV): %0.1f" % (np.mean(p_all))

        plt.hist(p_all, histtype="bar", color="red", ec="white", lw=2)
        plt.xlabel("Vr (mV)")
        Tools.removeAxis(plt.gca(), ["top", "left", "right"])
        plt.yticks([])

        # Vt*
        #######################################################################################################

        plt.subplot(4, 6, 12 + 3)

        p_all = []
        for GIF in GIFs:

            p = GIF.Vt_star
            p_all.append(p)

        plt.hist(p_all, histtype="bar", color="red", ec="white", lw=2)
        plt.xlabel("Vt_star (mV)")
        Tools.removeAxis(plt.gca(), ["top", "left", "right"])
        plt.yticks([])

        # Vt*
        #######################################################################################################

        plt.subplot(4, 6, 18 + 3)

        p_all = []
        for GIF in GIFs:

            p = GIF.DV
            p_all.append(p)

        plt.hist(p_all, histtype="bar", color="red", ec="white", lw=2)
        plt.xlabel("DV (mV)")
        Tools.removeAxis(plt.gca(), ["top", "left", "right"])
        plt.yticks([])
myExp.setAEC(myAEC_Dummy)  
myExp.performAEC()  
"""

############################################################################################################
# STEP 3A: FIT GIF WITH RECT BASIS FUNCTIONS TO DATA
############################################################################################################

# Create a new object GIF 
myGIF_rect = GIF(0.1)

# Define parameters
myGIF_rect.Tref = 4.0  

# Define eta and gamma as a sum of rectangular functions (log-spaced)
myGIF_rect.eta = Filter_Rect_LogSpaced()
myGIF_rect.eta.setMetaParameters(length=5000.0, binsize_lb=2.0, binsize_ub=1000.0, slope=4.5)

myGIF_rect.gamma = Filter_Rect_LogSpaced()
myGIF_rect.gamma.setMetaParameters(length=5000.0, binsize_lb=5.0, binsize_ub=1000.0, slope=5.0)

# Perform the fit
myGIF_rect.fit(myExp, DT_beforeSpike=5.0)


############################################################################################################
# STEP 3B: FIT GIF WITH EXP BASIS FUNCTIONS TO DATA
############################################################################################################

# Create a new object GIF 
myGIF_exp = GIF(0.1)
Exemplo n.º 10
0
class GIF(ThresholdModel) :

    """
    Generalized Integrate and Fire model defined in Pozzorini et al. PLOS Comp. Biol. 2015
    
    Spike are produced stochastically with firing intensity:
    
    lambda(t) = lambda0 * exp( (V(t)-V_T(t))/DV ),
    
    where the membrane potential dynamics is given by:
    
    C dV/dt = -gl(V-El) + I - sum_j eta(t-\hat t_j)
    
    and the firing threshold V_T is given by:
    
    V_T = Vt_star + sum_j gamma(t-\hat t_j)
    
    and \hat t_j denote the spike times.
    """

    def __init__(self, dt=0.1):
                   
        self.dt = dt                    # dt used in simulations (eta and gamma are interpolated according to this value)
  
        # Define model parameters
        
        self.gl      = 1.0/100.0        # nS, leak conductance
        self.C       = 20.0*self.gl     # nF, capacitance
        self.El      = -65.0            # mV, reversal potential
        
        self.Vr      = -50.0            # mV, voltage reset
        self.Tref    = 4.0              # ms, absolute refractory period
        
        self.Vt_star = -48.0            # mV, steady state voltage threshold VT*
        self.DV      = 0.5              # mV, threshold sharpness
        self.lambda0 = 1.0              # by default this parameter is always set to 1.0 Hz
        
        
        self.eta     = Filter_Rect_LogSpaced()    # nA, spike-triggered current (must be instance of class Filter)
        self.gamma   = Filter_Rect_LogSpaced()    # mV, spike-triggered movement of the firing threshold (must be instance of class Filter)
        
        
        # Initialize the spike-triggered current eta with an exponential function        
        
        def expfunction_eta(x):
            return 0.2*np.exp(-x/100.0)
        
        self.eta.setFilter_Function(expfunction_eta)


        # Initialize the spike-triggered current gamma with an exponential function        
        
        def expfunction_gamma(x):
            return 10.0*np.exp(-x/100.0)
        
        self.gamma.setFilter_Function(expfunction_gamma)        
        
              
        # Variables related to fitting procedure
        
        self.avg_spike_shape = 0
        self.avg_spike_shape_support = 0
        
    
    
    def setDt(self, dt):

        """
        Define the time step used for numerical simulations. The filters eta and gamma are interpolated accordingly.
        """
        
        self.dt = dt

    
    ########################################################################################################
    # IMPLEMENT ABSTRACT METHODS OF Spiking model
    ########################################################################################################
    
    
    def simulateSpikingResponse(self, I, dt):
        
        """
        Simulate the spiking response of the GIF model to an input current I (nA) with time step dt.
        Return a list of spike times (in ms).
        The initial conditions for the simulation is V(0)=El.
        """
        self.setDt(dt)
    
        (time, V, eta_sum, V_T, spks_times) = self.simulate(I, self.El)
        
        return spks_times


    ########################################################################################################
    # IMPLEMENT ABSTRACT METHODS OF Threshold Model
    ########################################################################################################
    
    
    def simulateVoltageResponse(self, I, dt) :

        self.setDt(dt)
    
        (time, V, eta_sum, V_T, spks_times) = self.simulate(I, self.El)
        
        return (spks_times, V, V_T)


    ########################################################################################################
    # METHODS FOR NUMERICAL SIMULATIONS
    ########################################################################################################  
      
    def simulate(self, I, V0):
 
        """
        Simulate the spiking response of the GIF model to an input current I (nA) with time step dt.
        V0 indicate the initial condition V(0)=V0.
        The function returns:
        - time     : ms, support for V, eta_sum, V_T, spks
        - V        : mV, membrane potential
        - eta_sum  : nA, adaptation current
        - V_T      : mV, firing threshold
        - spks     : ms, list of spike times 
        """
 
        # Input parameters
        p_T         = len(I)
        p_dt        = self.dt
        
        # Model parameters
        p_gl        = self.gl
        p_C         = self.C 
        p_El        = self.El
        p_Vr        = self.Vr
        p_Tref      = self.Tref
        p_Vt_star   = self.Vt_star
        p_DV        = self.DV
        p_lambda0   = self.lambda0
        
        # Model kernels   
        (p_eta_support, p_eta) = self.eta.getInterpolatedFilter(self.dt)   
        p_eta       = p_eta.astype('double')
        p_eta_l     = len(p_eta)

        (p_gamma_support, p_gamma) = self.gamma.getInterpolatedFilter(self.dt)   
        p_gamma     = p_gamma.astype('double')
        p_gamma_l   = len(p_gamma)
      
        # Define arrays
        V = np.array(np.zeros(p_T), dtype="double")
        I = np.array(I, dtype="double")
        spks = np.array(np.zeros(p_T), dtype="double")                      
        eta_sum = np.array(np.zeros(p_T + 2*p_eta_l), dtype="double")
        gamma_sum = np.array(np.zeros(p_T + 2*p_gamma_l), dtype="double")            
 
        # Set initial condition
        V[0] = V0
         
        code =  """
                #include <math.h>
                
                int   T_ind      = int(p_T);                
                float dt         = float(p_dt); 
                
                float gl         = float(p_gl);
                float C          = float(p_C);
                float El         = float(p_El);
                float Vr         = float(p_Vr);
                int   Tref_ind   = int(float(p_Tref)/dt);
                float Vt_star    = float(p_Vt_star);
                float DeltaV     = float(p_DV);
                float lambda0    = float(p_lambda0);
           
                int eta_l        = int(p_eta_l);
                int gamma_l      = int(p_gamma_l);
                
                                                  
                float rand_max  = float(RAND_MAX); 
                float p_dontspike = 0.0 ;
                float lambda = 0.0 ;            
                float r = 0.0;
                
                                                
                for (int t=0; t<T_ind-1; t++) {
    
    
                    // INTEGRATE VOLTAGE
                    V[t+1] = V[t] + dt/C*( -gl*(V[t] - El) + I[t] - eta_sum[t] );
               
               
                    // COMPUTE PROBABILITY OF EMITTING ACTION POTENTIAL
                    lambda = lambda0*exp( (V[t+1]-Vt_star-gamma_sum[t])/DeltaV );
                    p_dontspike = exp(-lambda*(dt/1000.0));                                  // since lambda0 is in Hz, dt must also be in Hz (this is why dt/1000.0)
                          
                          
                    // PRODUCE SPIKE STOCHASTICALLY
                    r = rand()/rand_max;
                    if (r > p_dontspike) {
                                        
                        if (t+1 < T_ind-1)                
                            spks[t+1] = 1.0; 
                        
                        t = t + Tref_ind;    
                        
                        if (t+1 < T_ind-1) 
                            V[t+1] = Vr;
                        
                        
                        // UPDATE ADAPTATION PROCESSES     
                        for(int j=0; j<eta_l; j++) 
                            eta_sum[t+1+j] += p_eta[j]; 
                        
                        for(int j=0; j<gamma_l; j++) 
                            gamma_sum[t+1+j] += p_gamma[j] ;  
                        
                    }
               
                }
                
                """
 
        vars = [ 'p_T','p_dt','p_gl','p_C','p_El','p_Vr','p_Tref','p_Vt_star','p_DV','p_lambda0','V','I','p_eta','p_eta_l','eta_sum','p_gamma','gamma_sum','p_gamma_l','spks' ]
        
        v = weave.inline(code, vars)

        time = np.arange(p_T)*self.dt
        
        eta_sum   = eta_sum[:p_T]     
        V_T = gamma_sum[:p_T] + p_Vt_star
     
        spks = (np.where(spks==1)[0])*self.dt
    
        return (time, V, eta_sum, V_T, spks)

        
    def simulateDeterministic_forceSpikes(self, I, V0, spks):
        
        """
        Simulate the subthresohld response of the GIF model to an input current I (nA) with time step dt.
        Adaptation currents are forces to accur at times specified in the list spks (in ms) given as an argument
        to the function.
        V0 indicate the initial condition V(t=0)=V0.
        
        The function returns:
        
        - time     : ms, support for V, eta_sum, V_T, spks
        - V        : mV, membrane potential
        - eta_sum  : nA, adaptation current
        """
 
        # Input parameters
        p_T          = len(I)
        p_dt         = self.dt
          
          
        # Model parameters
        p_gl        = self.gl
        p_C         = self.C 
        p_El        = self.El
        p_Vr        = self.Vr
        p_Tref      = self.Tref
        p_Tref_i    = int(self.Tref/self.dt)
    
    
        # Model kernel      
        (p_eta_support, p_eta) = self.eta.getInterpolatedFilter(self.dt)   
        p_eta       = p_eta.astype('double')
        p_eta_l     = len(p_eta)


        # Define arrays
        V        = np.array(np.zeros(p_T), dtype="double")
        I        = np.array(I, dtype="double")
        spks     = np.array(spks, dtype="double")                      
        spks_i   = Tools.timeToIndex(spks, self.dt)


        # Compute adaptation current (sum of eta triggered at spike times in spks) 
        eta_sum  = np.array(np.zeros(p_T + 1.1*p_eta_l + p_Tref_i), dtype="double")   
        
        for s in spks_i :
            eta_sum[s + 1 + p_Tref_i  : s + 1 + p_Tref_i + p_eta_l] += p_eta
        
        eta_sum  = eta_sum[:p_T]  
   
   
        # Set initial condition
        V[0] = V0
        
    
        code = """ 
                #include <math.h>
                
                int   T_ind      = int(p_T);                
                float dt         = float(p_dt); 
                
                float gl         = float(p_gl);
                float C          = float(p_C);
                float El         = float(p_El);
                float Vr         = float(p_Vr);
                int   Tref_ind   = int(float(p_Tref)/dt);


                int next_spike = spks_i[0] + Tref_ind;
                int spks_cnt = 0;
 
                                                                       
                for (int t=0; t<T_ind-1; t++) {
    
    
                    // INTEGRATE VOLTAGE
                    V[t+1] = V[t] + dt/C*( -gl*(V[t] - El) + I[t] - eta_sum[t] );
               
               
                    if ( t == next_spike ) {
                        spks_cnt = spks_cnt + 1;
                        next_spike = spks_i[spks_cnt] + Tref_ind;
                        V[t-1] = 0 ;                  
                        V[t] = Vr ;
                        t=t-1;           
                    }
               
                }
        
                """
 
        vars = [ 'p_T','p_dt','p_gl','p_C','p_El','p_Vr','p_Tref','V','I','eta_sum','spks_i' ]
        
        v = weave.inline(code, vars)

        time = np.arange(p_T)*self.dt
        eta_sum = eta_sum[:p_T]     

        return (time, V, eta_sum)

           
    ########################################################################################################
    # METHODS FOR MODEL FITTING
    ########################################################################################################  
      
         
    def fit(self, experiment, DT_beforeSpike = 5.0):
        
        """
        Fit the GIF model on experimental data.
        The experimental data are stored in the object experiment provided as an input.
        The parameter DT_beforeSpike (in ms) defines the region that is cut before each spike when fitting the subthreshold dynamics of the membrane potential.
        Only training set traces in experiment are used to perform the fit.
        """
        
        # Three step procedure used for parameters extraction 
        
        print "\n################################"
        print "# Fit GIF"
        print "################################\n"
        
        self.fitVoltageReset(experiment, self.Tref, do_plot=False)
        
        self.fitSubthresholdDynamics(experiment, DT_beforeSpike=DT_beforeSpike)
        
        self.fitStaticThreshold(experiment)

        self.fitThresholdDynamics(experiment)



    ########################################################################################################
    # FIT VOLTAGE RESET GIVEN ABSOLUTE REFRACOTORY PERIOD (step 1)
    ########################################################################################################


    def fitVoltageReset(self, experiment, Tref, do_plot=False):
        
        """
        Implement Step 1 of the fitting procedure introduced in Pozzorini et al. PLOS Comb. Biol. 2015
        experiment: Experiment object on which the model is fitted.
        Tref: ms, absolute refractory period. 
        The voltage reset is estimated by computing the spike-triggered average of the voltage.
        """
        
        print "Estimate voltage reset (Tref = %0.1f ms)..." % (Tref)
        
        # Fix absolute refractory period
        self.dt = experiment.dt
        self.Tref = Tref
        
        all_spike_average = []
        all_spike_nb = 0
        for tr in experiment.trainingset_traces :
            
            if tr.useTrace :
                if len(tr.spks) > 0 :
                    (support, spike_average, spike_nb) = tr.computeAverageSpikeShape()
                    all_spike_average.append(spike_average)
                    all_spike_nb += spike_nb

        spike_average = np.mean(all_spike_average, axis=0)
        
        # Estimate voltage reset
        Tref_ind = np.where(support >= self.Tref)[0][0]
        self.Vr = spike_average[Tref_ind]

        # Save average spike shape
        self.avg_spike_shape = spike_average
        self.avg_spike_shape_support = support
        
        if do_plot :
            plt.figure()
            plt.plot(support, spike_average, 'black')
            plt.plot([support[Tref_ind]], [self.Vr], '.', color='red')            
            plt.show()
        
        print "Done! Vr = %0.2f mV (computed on %d spikes)" % (self.Vr, all_spike_nb)
        


    ########################################################################################################
    # FUNCTIONS RELATED TO FIT OF SUBTHRESHOLD DYNAMICS (step 2)
    ########################################################################################################


    def fitSubthresholdDynamics(self, experiment, DT_beforeSpike=5.0):
          
        """
        Implement Step 2 of the fitting procedure introduced in Pozzorini et al. PLOS Comb. Biol. 2015
        The voltage reset is estimated by computing the spike-triggered average of the voltage.
        experiment: Experiment object on which the model is fitted.
        DT_beforeSpike: in ms, data right before spikes are excluded from the fit. This parameter can be used to define that time interval.
        """  
                  
        print "\nGIF MODEL - Fit subthreshold dynamics..." 
            
        # Expand eta in basis functions
        self.dt = experiment.dt
        
        
        # Build X matrix and Y vector to perform linear regression (use all traces in training set)    
        # For each training set an X matrix and a Y vector is built.   
        ####################################################################################################
        X = []
        Y = []
    
        cnt = 0
        
        for tr in experiment.trainingset_traces :
        
            if tr.useTrace :
        
                cnt += 1
                reprint( "Compute X matrix for repetition %d" % (cnt) )          
                
                # Compute the the X matrix and Y=\dot_V_data vector used to perform the multilinear linear regression (see Eq. 17.18 in Pozzorini et al. PLOS Comp. Biol. 2015)
                (X_tmp, Y_tmp) = self.fitSubthresholdDynamics_Build_Xmatrix_Yvector(tr, DT_beforeSpike=DT_beforeSpike)
     
                X.append(X_tmp)
                Y.append(Y_tmp)
    
    
        # Concatenate matrixes associated with different traces to perform a single multilinear regression
        ####################################################################################################
        if cnt == 1:
            X = X[0]
            Y = Y[0]
            
        elif cnt > 1:
            X = np.concatenate(X, axis=0)
            Y = np.concatenate(Y, axis=0)
        
        else :
            print "\nError, at least one training set trace should be selected to perform fit."
        
        
        # Perform linear Regression defined in Eq. 17 of Pozzorini et al. PLOS Comp. Biol. 2015
        ####################################################################################################
        
        print "\nPerform linear regression..."
        XTX     = np.dot(np.transpose(X), X)
        XTX_inv = inv(XTX)
        XTY     = np.dot(np.transpose(X), Y)
        b       = np.dot(XTX_inv, XTY)
        b       = b.flatten()
   
   
        # Extract explicit model parameters from regression result b
        ####################################################################################################

        self.C  = 1./b[1]
        self.gl = -b[0]*self.C
        self.El = b[2]*self.C/self.gl
        self.eta.setFilter_Coefficients(-b[3:]*self.C)
    
    
        self.printParameters()   
        
        
        # Compute percentage of variance explained on dV/dt
        ####################################################################################################

        var_explained_dV = 1.0 - np.mean((Y - np.dot(X,b))**2)/np.var(Y)
        print "Percentage of variance explained (on dV/dt): %0.2f" % (var_explained_dV*100.0)

        
        # Compute percentage of variance explained on V (see Eq. 26 in Pozzorini et al. PLOS Comp. Biol. 2105)
        ####################################################################################################

        SSE = 0     # sum of squared errors
        VAR = 0     # variance of data
        
        for tr in experiment.trainingset_traces :
        
            if tr.useTrace :

                # Simulate subthreshold dynamics 
                (time, V_est, eta_sum_est) = self.simulateDeterministic_forceSpikes(tr.I, tr.V[0], tr.getSpikeTimes())
                
                indices_tmp = tr.getROI_FarFromSpikes(0.0, self.Tref)
                
                SSE += sum((V_est[indices_tmp] - tr.V[indices_tmp])**2)
                VAR += len(indices_tmp)*np.var(tr.V[indices_tmp])
                
        var_explained_V = 1.0 - SSE / VAR
        
        print "Percentage of variance explained (on V): %0.2f" % (var_explained_V*100.0)
                
                    
    def fitSubthresholdDynamics_Build_Xmatrix_Yvector(self, trace, DT_beforeSpike=5.0):
           
        """
        Compute the X matrix and the Y vector (i.e. \dot_V_data) used to perfomr the linear regression 
        defined in Eq. 17-18 of Pozzorini et al. 2015 for an individual experimental trace provided as parameter.
        The input parameter trace is an ojbect of class Trace.
        """
                
        # Length of the voltage trace       
        Tref_ind = int(self.Tref/trace.dt)
        
        
        # Select region where to perform linear regression (specified in the ROI of individual taces)
        ####################################################################################################
        selection = trace.getROI_FarFromSpikes(DT_beforeSpike, self.Tref)
        selection_l = len(selection)
        
        
        # Build X matrix for linear regression (see Eq. 18 in Pozzorini et al. PLOS Comp. Biol. 2015)
        ####################################################################################################
        X = np.zeros( (selection_l, 3) )
        
        # Fill first two columns of X matrix        
        X[:,0] = trace.V[selection]
        X[:,1] = trace.I[selection]
        X[:,2] = np.ones(selection_l) 
        
       
        # Compute and fill the remaining columns associated with the spike-triggered current eta               
        X_eta = self.eta.convolution_Spiketrain_basisfunctions(trace.getSpikeTimes() + self.Tref, trace.T, trace.dt) 
        X = np.concatenate( (X, X_eta[selection,:]), axis=1 )


        # Build Y vector (voltage derivative \dot_V_data)    
        ####################################################################################################
        Y = np.array( np.concatenate( (np.diff(trace.V)/trace.dt, [0]) ) )[selection]      

        return (X, Y)
        
        
        
    ########################################################################################################
    # FUNCTIONS RELATED TO FIT FIRING THRESHOLD PARAMETERS (step 3)
    ########################################################################################################        
 
         
    def fitStaticThreshold(self, experiment):
        
        """
        Implement Step 3 of the fitting procedure introduced in Pozzorini et al. PLOS Comb. Biol. 2015
        Instead of directly fitting a dynamic threshold, this function just fit a constant threshold.
        The output of this fit can be used as a smart initial condition to fit the full GIF model (i.e.,
        a model featuting a spike-triggered current gamma). See Pozzorini et al. PLOS Comp. Biol. 2015
        experiment: Experiment object on which the model is fitted.
        """

        print "\nGIF MODEL - Fit static threshold...\n"

        
        self.setDt(experiment.dt)
    
            
        # Define initial conditions (based on the average firing rate in the training set)
        ###############################################################################################
       
        nbSpikes = 0
        duration = 0
        
        for tr in experiment.trainingset_traces :
            
            if tr.useTrace :
                
                nbSpikes += tr.getSpikeNb_inROI()
                duration += tr.getTraceLength_inROI()
                
        mean_firingrate = 1000.0*nbSpikes/duration      
        
        self.lambda0 = 1.0
        self.DV = 50.0
        self.Vt_star = -np.log(mean_firingrate)*self.DV


        # Perform maximum likelihood fit (Newton method)    
        ###############################################################################################

        beta0_staticThreshold = [1/self.DV, -self.Vt_star/self.DV] 
        beta_opt = self.maximizeLikelihood(experiment, beta0_staticThreshold, self.buildXmatrix_staticThreshold) 
            
            
        # Store result of constnat threshold fitting  
        ###############################################################################################
        
        self.DV      = 1.0/beta_opt[0]
        self.Vt_star = -beta_opt[1]*self.DV 
        self.gamma.setFilter_toZero()
        
        self.printParameters()

   
    def fitThresholdDynamics(self, experiment):
                  
        """
        Implement Step 3 of the fitting procedure introduced in Pozzorini et al. PLOS Comb. Biol. 2015
        Fit firing threshold dynamics by solving Eq. 20 using Newton method.
        
        experiment: Experiment object on which the model is fitted.
        """        
        
        print "\nGIF MODEL - Fit dynamic threshold...\n"
        
        
        self.setDt(experiment.dt)
  
        
        # Perform maximum likelihood fit (Newton method) 
        ###############################################################################################
   
        # Define initial conditions
        beta0_dynamicThreshold = np.concatenate( ( [1/self.DV], [-self.Vt_star/self.DV], self.gamma.getCoefficients()/self.DV))        
        beta_opt = self.maximizeLikelihood(experiment, beta0_dynamicThreshold, self.buildXmatrix_dynamicThreshold)

        
        # Store result
        ###############################################################################################
        
        self.DV      = 1.0/beta_opt[0]
        self.Vt_star = -beta_opt[1]*self.DV 
        self.gamma.setFilter_Coefficients(-beta_opt[2:]*self.DV)

        self.printParameters()
          
      
    def maximizeLikelihood(self, experiment, beta0, buildXmatrix, maxIter=10**3, stopCond=10**-6) :
    
        ###
        ### THIS IMPLEMENTATION IS NOT SO COOL :(
        ### IN NEW VERSION OF THE CODE I SHOULD IMPLEMENT A NEW CLASS THAT TAKES CARE OF MAXLIKELIHOOD ON lambda=exp(Xbeta) model
        ###
        
        """
        Maximize likelihood. This function can be used to fit any model of the form lambda=exp(Xbeta).
        This function is used to fit both:
        - static threshold
        - dynamic threshold
        The difference between the two functions is in the size of beta0 and the returned beta, as well
        as the function buildXmatrix.
        """
        
        # Precompute all the matrices used in the gradient ascent (see Eq. 20 in Pozzorini et al. 2015)
        ################################################################################################
        
        # here X refer to the matrix made of y vectors defined in Eq. 21 (Pozzorini et al. 2015)
        # since the fit can be perfomed on multiple traces, we need lists
        all_X        = []           
        
        # similar to X but only contains temporal samples where experimental spikes have been observed 
        # storing this matrix is useful to improve speed when computing the likelihood as well as its derivative
        all_X_spikes = []
        
        # sum X_spikes over spikes. Precomputing this quantity improve speed when the gradient is evaluated
        all_sum_X_spikes = []
        
        
        # variables used to compute the loglikelihood of a Poisson process spiking at the experimental firing rate
        T_tot = 0.0
        N_spikes_tot = 0.0
        
        traces_nb = 0
        
        for tr in experiment.trainingset_traces:
            
            if tr.useTrace :              
                
                traces_nb += 1
                
                # Simulate subthreshold dynamics 
                (time, V_est, eta_sum_est) = self.simulateDeterministic_forceSpikes(tr.I, tr.V[0], tr.getSpikeTimes())
                             
                # Precomputes matrices to compute gradient ascent on log-likelihood
                # depeinding on the model being fitted (static vs dynamic threshodl) different buildXmatrix functions can be used
                (X_tmp, X_spikes_tmp, sum_X_spikes_tmp, N_spikes, T) = buildXmatrix(tr, V_est) 
                    
                T_tot        += T
                N_spikes_tot += N_spikes
                    
                all_X.append(X_tmp)
                all_X_spikes.append(X_spikes_tmp)
                all_sum_X_spikes.append(sum_X_spikes_tmp)
        
        # Compute log-likelihood of a poisson process (this quantity is used to normalize the model log-likelihood)
        ################################################################################################
        
        logL_poisson = N_spikes_tot*(np.log(N_spikes_tot/T_tot)-1)


        # Perform gradient ascent
        ################################################################################################
    
        print "Maximize log-likelihood (bit/spks)..."
                        
        beta = beta0
        old_L = 1

        for i in range(maxIter) :
            
            learning_rate = 1.0
            
            # In the first iterations using a small learning rate makes things somehow more stable
            if i<=10 :                      
                learning_rate = 0.1
            
            
            L=0; G=0; H=0;  
               
            for trace_i in np.arange(traces_nb):
                
                # compute log-likelihood, gradient and hessian on a specific trace (note that the fit is performed on multiple traces)
                (L_tmp,G_tmp,H_tmp) = self.computeLikelihoodGradientHessian(beta, all_X[trace_i], all_X_spikes[trace_i], all_sum_X_spikes[trace_i])
                
                # note that since differentiation is linear: gradient of sum = sum of gradient ; hessian of sum = sum of hessian
                L+=L_tmp; 
                G+=G_tmp; 
                H+=H_tmp;
            
            
            # Update optimal parametes (ie, implement Newton step) by tacking into account multiple traces
            
            beta = beta - learning_rate*np.dot(inv(H),G)
                
            if (i>0 and abs((L-old_L)/old_L) < stopCond) :              # If converged
                print "\nConverged after %d iterations!\n" % (i+1)
                break
            
            old_L = L
            
            # Compute normalized likelihood (for print)
            # The likelihood is normalized with respect to a poisson process and units are in bit/spks
            L_norm = (L-logL_poisson)/np.log(2)/N_spikes_tot
            reprint(L_norm)
            
            if math.isnan(L_norm):
                print "Problem during gradient ascent. Optimizatino stopped."
                break
    
        if (i==maxIter - 1) :                                           # If too many iterations
            
            print "\nNot converged after %d iterations.\n" % (maxIter)


        return beta
     
        
    def computeLikelihoodGradientHessian(self, beta, X, X_spikes, sum_X_spikes) : 
        
        """
        Compute the log-likelihood, its gradient and hessian for a model whose 
        log-likelihood has the same form as the one defined in Eq. 20 (Pozzorini et al. PLOS Comp. Biol. 2015)
        """
        
        # IMPORTANT: in general we assume that the lambda_0 = 1 Hz
        # The parameter lambda0 is redundant with Vt_star, so only one of those has to be fitted.
        # We genearlly fix lambda_0 adn fit Vt_star
              
        dt = self.dt/1000.0     # put dt in units of seconds (to be consistent with lambda_0)
        
        X_spikesbeta    = np.dot(X_spikes,beta)
        Xbeta           = np.dot(X,beta)
        expXbeta        = np.exp(Xbeta)

        # Compute loglikelihood defined in Eq. 20 Pozzorini et al. 2015
        L = sum(X_spikesbeta) - self.lambda0*dt*sum(expXbeta)
                                       
        # Compute its gradient
        G = sum_X_spikes - self.lambda0*dt*np.dot(np.transpose(X), expXbeta)
        
        # Compute its Hessian
        H = -self.lambda0*dt*np.dot(np.transpose(X)*expXbeta, X)
        
        return (L,G,H)


    def buildXmatrix_staticThreshold(self, tr, V_est) :

        """
        Use this function to fit a model in which the firing threshold dynamics is defined as:
        V_T(t) = Vt_star (i.e., no spike-triggered movement of the firing threshold).
        This function computes the matrix X made of vectors y simlar to the ones defined in Eq. 21 (Pozzorini et al. 2015).
        In contrast ot Eq. 21, the X matrix computed here does not include the columns related to the spike-triggered threshold movement.
        """        
        
        # Get indices be removing absolute refractory periods (-self.dt is to not include the time of spike)       
        selection = tr.getROI_FarFromSpikes(-self.dt, self.Tref )
        T_l_selection  = len(selection)

         
        # Get spike indices in coordinates of selection   
        spk_train = tr.getSpikeTrain()
        spks_i_afterselection = np.where(spk_train[selection]==1)[0]


        # Compute average firing rate used in the fit   
        T_l = T_l_selection*tr.dt/1000.0                # Total duration of trace used for fit (in s)
        N_spikes = len(spks_i_afterselection)           # Nb of spikes in the trace used for fit

        
        # Define X matrix
        X       = np.zeros((T_l_selection, 2))
        X[:,0]  = V_est[selection]
        X[:,1]  = np.ones(T_l_selection)
        
        # Select time steps in which the neuron has emitted a spike
        X_spikes = X[spks_i_afterselection,:]
            
        # Sum X_spike over spikes    
        sum_X_spikes = np.sum( X_spikes, axis=0)
        
        return (X, X_spikes, sum_X_spikes, N_spikes, T_l)
        
            
    def buildXmatrix_dynamicThreshold(self, tr, V_est) :

        """
        Use this function to fit a model in which the firing threshold dynamics is defined as:
        V_T(t) = Vt_star + sum_i gamma(t-\hat t_i) (i.e., model with spike-triggered movement of the threshold)
        This function computes the matrix X made of vectors y defined as in Eq. 21 (Pozzorini et al. 2015).
        """
           
        # Get indices be removing absolute refractory periods (-self.dt is to not include the time of spike)       
        selection = tr.getROI_FarFromSpikes(-tr.dt, self.Tref)
        T_l_selection  = len(selection)

            
        # Get spike indices in coordinates of selection   
        spk_train = tr.getSpikeTrain()
        spks_i_afterselection = np.where(spk_train[selection]==1)[0]


        # Compute average firing rate used in the fit   
        T_l = T_l_selection*tr.dt/1000.0                # Total duration of trace used for fit (in s)
        N_spikes = len(spks_i_afterselection)           # Nb of spikes in the trace used for fit
        
        
        # Define X matrix
        X       = np.zeros((T_l_selection, 2))
        X[:,0]  = V_est[selection]
        X[:,1]  = np.ones(T_l_selection)
           
        # Compute and fill the remaining columns associated with the spike-triggered current gamma              
        X_gamma = self.gamma.convolution_Spiketrain_basisfunctions(tr.getSpikeTimes() + self.Tref, tr.T, tr.dt)
        X = np.concatenate( (X, X_gamma[selection,:]), axis=1 )
  
        # Precompute other quantities to speedup fitting
        X_spikes = X[spks_i_afterselection,:]
        sum_X_spikes = np.sum( X_spikes, axis=0)
                     
        return (X, X_spikes, sum_X_spikes,  N_spikes, T_l)
 
 
        
    ########################################################################################################
    # PLOT AND PRINT FUNCTIONS
    ########################################################################################################     
        
        
    def plotParameters(self) :
        
        """
        Generate figure with model filters.
        """
        
        plt.figure(facecolor='white', figsize=(14,4))
            
        # Plot kappa
        plt.subplot(1,3,1)
        
        K_support = np.linspace(0,150.0, 300)             
        K = 1./self.C*np.exp(-K_support/(self.C/self.gl)) 
            
        plt.plot(K_support, K, color='red', lw=2)
        plt.plot([K_support[0], K_support[-1]], [0,0], ls=':', color='black', lw=2)
            
        plt.xlim([K_support[0], K_support[-1]])    
        plt.xlabel("Time (ms)")
        plt.ylabel("Membrane filter (MOhm/ms)")        
        
        # Plot eta
        plt.subplot(1,3,2)
        
        (eta_support, eta) = self.eta.getInterpolatedFilter(self.dt) 
        
        plt.plot(eta_support, eta, color='red', lw=2)
        plt.plot([eta_support[0], eta_support[-1]], [0,0], ls=':', color='black', lw=2)
            
        plt.xlim([eta_support[0], eta_support[-1]])    
        plt.xlabel("Time (ms)")
        plt.ylabel("Eta (nA)")
        

        # Plot gamma
        plt.subplot(1,3,3)
        
        (gamma_support, gamma) = self.gamma.getInterpolatedFilter(self.dt) 
        
        plt.plot(gamma_support, gamma, color='red', lw=2)
        plt.plot([gamma_support[0], gamma_support[-1]], [0,0], ls=':', color='black', lw=2)
            
        plt.xlim([gamma_support[0], gamma_support[-1]])    
        plt.xlabel("Time (ms)")
        plt.ylabel("Gamma (mV)")
        plt.subplots_adjust(left=0.05, bottom=0.15, right=0.95, top=0.92, wspace=0.35, hspace=0.25)

        plt.show()
      
      
    def printParameters(self):

        """
        Print model parameters on terminal.
        """

        print "\n-------------------------"        
        print "GIF model parameters:"
        print "-------------------------"
        print "tau_m (ms):\t%0.3f"  % (self.C/self.gl)
        print "R (MOhm):\t%0.3f"    % (1.0/self.gl)
        print "C (nF):\t\t%0.3f"    % (self.C)
        print "gl (nS):\t%0.6f"     % (self.gl)
        print "El (mV):\t%0.3f"     % (self.El)
        print "Tref (ms):\t%0.3f"   % (self.Tref)
        print "Vr (mV):\t%0.3f"     % (self.Vr)     
        print "Vt* (mV):\t%0.3f"    % (self.Vt_star)    
        print "DV (mV):\t%0.3f"     % (self.DV)          
        print "-------------------------\n"
                  

    @classmethod
    def compareModels(cls, GIFs, labels=None):

        """
        Given a list of GIF models, GIFs, the function produce a plot in which the model parameters are compared.
        """

        # PRINT PARAMETERS        

        print "\n#####################################"
        print "GIF model comparison"
        print "#####################################\n"
        
        cnt = 0
        for GIF in GIFs :
            
            #print "Model: " + labels[cnt]          
            GIF.printParameters()
            cnt+=1

        print "#####################################\n"                
                
        # PLOT PARAMETERS
        plt.figure(facecolor='white', figsize=(9,8)) 
               
        colors = plt.cm.jet( np.linspace(0.7, 1.0, len(GIFs) ) )   
        
        # Membrane filter
        plt.subplot(2,2,1)
            
        cnt = 0
        for GIF in GIFs :
            
            K_support = np.linspace(0,150.0, 1500)             
            K = 1./GIF.C*np.exp(-K_support/(GIF.C/GIF.gl))     
            plt.plot(K_support, K, color=colors[cnt], lw=2)
            cnt += 1
            
        plt.plot([K_support[0], K_support[-1]], [0,0], ls=':', color='black', lw=2, zorder=-1)   
                
        plt.xlim([K_support[0], K_support[-1]])
        plt.xlabel('Time (ms)')
        plt.ylabel('Membrane filter (MOhm/ms)')  


        # Spike triggered current
        plt.subplot(2,2,2)
            
        cnt = 0
        for GIF in GIFs :
            
            if labels == None :
                label_tmp =""
            else :
                label_tmp = labels[cnt]
            
            (eta_support, eta) = GIF.eta.getInterpolatedFilter(0.1)         
            plt.plot(eta_support, eta, color=colors[cnt], lw=2, label=label_tmp)
            cnt += 1
            
        plt.plot([eta_support[0], eta_support[-1]], [0,0], ls=':', color='black', lw=2, zorder=-1)   
        
        if labels != None :
            plt.legend()       
            
        
        plt.xlim([eta_support[0], eta_support[-1]])
        plt.xlabel('Time (ms)')
        plt.ylabel('Eta (nA)')        
        

        # Escape rate
        plt.subplot(2,2,3)
            
        cnt = 0
        for GIF in GIFs :
            
            V_support = np.linspace(GIF.Vt_star-5*GIF.DV,GIF.Vt_star+10*GIF.DV, 1000) 
            escape_rate = GIF.lambda0*np.exp((V_support-GIF.Vt_star)/GIF.DV)                
            plt.plot(V_support, escape_rate, color=colors[cnt], lw=2)
            cnt += 1
          
        plt.ylim([0, 100])    
        plt.plot([V_support[0], V_support[-1]], [0,0], ls=':', color='black', lw=2, zorder=-1)   
    
        plt.xlim([V_support[0], V_support[-1]])
        plt.xlabel('Membrane potential (mV)')
        plt.ylabel('Escape rate (Hz)')  


        # Spike triggered threshold movememnt
        plt.subplot(2,2,4)
            
        cnt = 0
        for GIF in GIFs :
            
            (gamma_support, gamma) = GIF.gamma.getInterpolatedFilter(0.1)         
            plt.plot(gamma_support, gamma, color=colors[cnt], lw=2)
            cnt += 1
            
        plt.plot([gamma_support[0], gamma_support[-1]], [0,0], ls=':', color='black', lw=2, zorder=-1)   
      
        plt.xlim([gamma_support[0]+0.1, gamma_support[-1]])
        plt.ylim([-100,100])
        plt.xlabel('Time (ms)')
        plt.ylabel('Gamma (mV)')   

        plt.subplots_adjust(left=0.08, bottom=0.10, right=0.95, top=0.93, wspace=0.25, hspace=0.25)
        
        plt.show()
    
 
    @classmethod
    def plotAverageModel(cls, GIFs):

        """
        Average model parameters and plot summary data.
        """
                   
        #######################################################################################################
        # PLOT PARAMETERS
        #######################################################################################################        
        
        fig = plt.figure(facecolor='white', figsize=(16,7))  
        fig.subplots_adjust(left=0.07, bottom=0.08, right=0.95, top=0.90, wspace=0.35, hspace=0.5)   
        rcParams['xtick.direction'] = 'out'
        rcParams['ytick.direction'] = 'out'
       
       
        # MEMBRANE FILTER
        #######################################################################################################
        
        plt.subplot(2,4,1)
                    
        K_all = []
        
        for GIF in GIFs :
                      
            K_support = np.linspace(0,150.0, 300)             
            K = 1./GIF.C*np.exp(-K_support/(GIF.C/GIF.gl))     
            plt.plot(K_support, K, color='0.3', lw=1, zorder=5)
            
            K_all.append(K)

        K_mean = np.mean(K_all, axis=0)
        K_std  = np.std(K_all, axis=0)
        
        plt.fill_between(K_support, K_mean+K_std,y2=K_mean-K_std, color='gray', zorder=0)
        plt.plot(K_support, np.mean(K_all, axis=0), color='red', lw=2, zorder=10)  
        plt.plot([K_support[0], K_support[-1]], [0,0], ls=':', color='black', lw=2, zorder=-1)   
                
        Tools.removeAxis(plt.gca(), ['top', 'right'])
        plt.xlim([K_support[0], K_support[-1]])
        plt.xlabel('Time (ms)')
        plt.ylabel('Membrane filter (MOhm/ms)')  

       
        # SPIKE-TRIGGERED CURRENT
        #######################################################################################################
    
        plt.subplot(2,4,2)
                    
        K_all = []
        
        for GIF in GIFs :
                
            (K_support, K) = GIF.eta.getInterpolatedFilter(0.1)      
   
            plt.plot(K_support, K, color='0.3', lw=1, zorder=5)
            
            K_all.append(K)

        K_mean = np.mean(K_all, axis=0)
        K_std  = np.std(K_all, axis=0)
        
        plt.fill_between(K_support, K_mean+K_std,y2=K_mean-K_std, color='gray', zorder=0)
        plt.plot(K_support, np.mean(K_all, axis=0), color='red', lw=2, zorder=10)  
        plt.plot([K_support[0], K_support[-1]], [0,0], ls=':', color='black', lw=2, zorder=-1)   
                
        Tools.removeAxis(plt.gca(), ['top', 'right'])
        plt.xlim([K_support[0], K_support[-1]/10.0])
        plt.xlabel('Time (ms)')
        plt.ylabel('Spike-triggered current (nA)')  
 
 
        # SPIKE-TRIGGERED MOVEMENT OF THE FIRING THRESHOLD
        #######################################################################################################
    
        plt.subplot(2,4,3)
                    
        K_all = []
        
        for GIF in GIFs :
                
            (K_support, K) = GIF.gamma.getInterpolatedFilter(0.1)      
   
            plt.plot(K_support, K, color='0.3', lw=1, zorder=5)
            
            K_all.append(K)

        K_mean = np.mean(K_all, axis=0)
        K_std  = np.std(K_all, axis=0)
        
        plt.fill_between(K_support, K_mean+K_std,y2=K_mean-K_std, color='gray', zorder=0)
        plt.plot(K_support, np.mean(K_all, axis=0), color='red', lw=2, zorder=10)   
        plt.plot([K_support[0], K_support[-1]], [0,0], ls=':', color='black', lw=2, zorder=-1)   
                
        plt.xlim([K_support[0], K_support[-1]])
        Tools.removeAxis(plt.gca(), ['top', 'right'])
        plt.xlabel('Time (ms)')
        plt.ylabel('Spike-triggered threshold (mV)')  
 
      
        # R
        #######################################################################################################
    
        plt.subplot(4,6,12+1)
 
        p_all = []
        for GIF in GIFs :
                
            p = 1./GIF.gl
            p_all.append(p)
            
        plt.hist(p_all, histtype='bar', color='red', ec='white', lw=2)
        plt.xlabel('R (MOhm)')
        Tools.removeAxis(plt.gca(), ['top', 'left', 'right'])
        plt.yticks([])        
        
        
        # tau_m
        #######################################################################################################
    
        plt.subplot(4,6,18+1)
 
        p_all = []
        for GIF in GIFs :
                
            p = GIF.C/GIF.gl
            p_all.append(p)
            
        plt.hist(p_all, histtype='bar', color='red', ec='white', lw=2)
        plt.xlabel('tau_m (ms)')
        Tools.removeAxis(plt.gca(), ['top', 'left', 'right'])
        plt.yticks([])     
       
   
        # El
        #######################################################################################################
    
        plt.subplot(4,6,12+2)
 
        p_all = []
        for GIF in GIFs :
                
            p = GIF.El
            p_all.append(p)
            
        plt.hist(p_all, histtype='bar', color='red', ec='white', lw=2)
        plt.xlabel('El (mV)')
        Tools.removeAxis(plt.gca(), ['top', 'left', 'right'])
        plt.yticks([])     
       
          
        # V reset
        #######################################################################################################
    
        plt.subplot(4,6,18+2)
 
        p_all = []
        for GIF in GIFs :
                
            p = GIF.Vr
            p_all.append(p)
        
        print "Mean Vr (mV): %0.1f" % (np.mean(p_all))  
        
        plt.hist(p_all, histtype='bar', color='red', ec='white', lw=2)
        plt.xlabel('Vr (mV)')
        Tools.removeAxis(plt.gca(), ['top', 'left', 'right'])
        plt.yticks([])     
        
        
        # Vt*
        #######################################################################################################
    
        plt.subplot(4,6,12+3)
 
        p_all = []
        for GIF in GIFs :
                
            p = GIF.Vt_star
            p_all.append(p)
            
        plt.hist(p_all, histtype='bar', color='red', ec='white', lw=2)
        plt.xlabel('Vt_star (mV)')
        Tools.removeAxis(plt.gca(), ['top', 'left', 'right'])
        plt.yticks([])    
        
        # Vt*
        #######################################################################################################
    
        plt.subplot(4,6,18+3)
 
        p_all = []
        for GIF in GIFs :
                
            p = GIF.DV
            p_all.append(p)
            
        plt.hist(p_all, histtype='bar', color='red', ec='white', lw=2)
        plt.xlabel('DV (mV)')
        Tools.removeAxis(plt.gca(), ['top', 'left', 'right'])
        plt.yticks([])    
Exemplo n.º 11
0
# Plot training and test set
myExp.plotTrainingSet()
myExp.plotTestSet()

############################################################################################################
# STEP 3: FIT GIF MODEL TO DATA
############################################################################################################

# Create a new object GIF
myGIF = GIF(0.1)

# Define parameters
myGIF.Tref = 4.0

myGIF.eta = Filter_Rect_LogSpaced()
myGIF.eta.setMetaParameters(length=500.0,
                            binsize_lb=2.0,
                            binsize_ub=1000.0,
                            slope=4.5)

myGIF.gamma = Filter_Rect_LogSpaced()
myGIF.gamma.setMetaParameters(length=500.0,
                              binsize_lb=5.0,
                              binsize_ub=1000.0,
                              slope=5.0)

# Define the ROI of the training set to be used for the fit (in this example we will use only the first 100 s)
myExp.trainingset_traces[0].setROI([[0, 100000.0]])

# To visualize the training set and the ROI call again
Exemplo n.º 12
0
def process_all_files_for_GIF_K(is_E_K_fixed=True):
    if is_E_K_fixed:
        spec_GIF_K = 'EK_fixed_'
    else:
        spec_GIF_K = 'EK_free_'
    Md_star = {}
    epsilon_V_test = {}
    PVar = {}

    # List separate experiments in separate folder
    data_folders_for_separate_experiments = ['seventh_set', 'eighth_set', 'ninth_set', 'tenth_set']

    # For all experiments, extract the cell names
    CellNames = {}
    for experiment_folder in data_folders_for_separate_experiments:
        folder_path = './' + experiment_folder + '/'
        CellNames[experiment_folder] = [name for name in os.listdir(folder_path) if os.path.isdir(folder_path + name) and '_5HT' in name]


    for experiment_folder in data_folders_for_separate_experiments:
        for cell_name in CellNames[experiment_folder]:
            print '\n\n#############################################'
            print '##########     process cell %s    ###' %cell_name
            print '#############################################'


            #################################################################################################
            # Load data
            #################################################################################################

            path_data = './' + experiment_folder + '/' + cell_name + '/'
            path_results = './Results/' + cell_name + '/'

            if not os.path.exists(path_results):
                os.makedirs(path_results)

            # Find extension of data files
            file_names = os.listdir(path_data)
            for file_name in file_names:
                if '.abf' in file_name:
                    ext = '.abf'
                    break
                elif '.mat' in file_name:
                    ext = '.mat'
                    break

            # Load AEC data
            filename_AEC = path_data + cell_name + '_aec' + ext
            (sampling_timeAEC, voltage_traceAEC, current_traceAEC) = load_AEC_data(filename_AEC)

            # Create experiment
            experiment = Experiment('Experiment 1', sampling_timeAEC)
            experiment.setAECTrace(voltage_traceAEC, 10.**-3, current_traceAEC, 10.**-12, len(voltage_traceAEC)*sampling_timeAEC, FILETYPE='Array')

            # Load training set data and add to experiment object
            filename_training = path_data + cell_name + '_training' + ext
            (sampling_time, voltage_trace, current_trace, time) = load_training_data(filename_training)
            experiment.addTrainingSetTrace(voltage_trace, 10**-3, current_trace, 10**-12, len(voltage_trace)*sampling_time, FILETYPE='Array')
            #Note: once added to experiment, current is converted to nA.

            # Load test set data
            filename_test = path_data + cell_name + '_test' + ext
            if filename_test.find('.mat') > 0:
                mat_contents = sio.loadmat(filename_test)
                analogSignals = mat_contents['analogSignals']
                times_test = mat_contents['times'];
                times_test = times_test.reshape(times_test.size)
                times_test = times_test*10.**3
                sampling_time_test = times_test[1] - times_test[0]
                for testnum in range(analogSignals.shape[1]):
                    voltage_test = analogSignals[0, testnum, :]
                    current_test = analogSignals[1, testnum, :] - 5.
                    experiment.addTestSetTrace(voltage_test, 10. ** -3, current_test, 10. ** -12,
                                               len(voltage_test) * sampling_time_test, FILETYPE='Array')
            elif filename_test.find('.abf') > 0:
                r = neo.io.AxonIO(filename=filename_test)
                bl = r.read_block()
                times_test = bl.segments[0].analogsignals[0].times.rescale('ms').magnitude
                sampling_time_test = times_test[1] - times_test[0]
                for i in xrange(len(bl.segments)):
                    voltage_test = bl.segments[i].analogsignals[0].magnitude
                    current_test = bl.segments[i].analogsignals[1].magnitude - 5.
                    experiment.addTestSetTrace(voltage_test, 10. ** -3, current_test, 10. ** -12,
                                               len(voltage_test) * sampling_time_test, FILETYPE='Array')


            #################################################################################################
            # PERFORM ACTIVE ELECTRODE COMPENSATION
            #################################################################################################

            # Create new object to perform AEC
            myAEC = AEC_Badel(experiment.dt)

            # Define metaparametres
            myAEC.K_opt.setMetaParameters(length=150.0, binsize_lb=experiment.dt, binsize_ub=2.0, slope=30.0, clamp_period=1.0)
            myAEC.p_expFitRange = [3.0,150.0]
            myAEC.p_nbRep = 15

            # Assign myAEC to experiment and compensate the voltage recordings
            experiment.setAEC(myAEC)
            experiment.performAEC()


            #################################################################################################
            # FIT GIF-K
            #################################################################################################

            # Create a new object GIF
            GIF_K_fit = GIF_K(sampling_time)

            # Define parameters and filter characteristics
            GIF_K_fit.Tref = 6.0
            GIF_K_fit.eta = Filter_Rect_LogSpaced()
            GIF_K_fit.eta.setMetaParameters(length=2000.0, binsize_lb=0.5, binsize_ub=500.0, slope=10.0)
            GIF_K_fit.gamma = Filter_Rect_LogSpaced()
            GIF_K_fit.gamma.setMetaParameters(length=2000.0, binsize_lb=2.0, binsize_ub=500.0, slope=5.0)

            # Define the ROI of the training set to be used for the fit
            for tr in experiment.trainingset_traces:
                tr.setROI([[2000., sampling_time * (len(voltage_trace) - 1) - 2000.]])

            # Perform the fit
            (var_explained_dV, var_explained_V_GIF_K_train) = GIF_K_fit.fit(experiment, DT_beforeSpike=5.0,
                                                                              is_E_K_fixed=is_E_K_fixed)
            # Save the model
            GIF_K_fit.save(path_results + cell_name + '_GIF_K_'+spec_GIF_K+'ModelParams' + '.pck')

            ###################################################################################################
            # EVALUATE MODEL PERFORMANCES ON THE TEST SET DATA
            ###################################################################################################

            # predict spike times in test set
            prediction = experiment.predictSpikes(GIF_K_fit, nb_rep=500)

            # Compute epsilon_V
            epsilon_V = 0.
            local_counter = 0.
            for tr in experiment.testset_traces:
                SSE = 0.
                VAR = 0.
                # tr.detectSpikesWithDerivative(threshold=10)
                (time, V_est, eta_sum_est) = GIF_K_fit.simulateDeterministic_forceSpikes(tr.I, tr.V[0], tr.getSpikeTimes())
                indices_tmp = tr.getROI_FarFromSpikes(5., GIF_K_fit.Tref)

                SSE += sum((V_est[indices_tmp] - tr.V[indices_tmp]) ** 2)
                VAR += len(indices_tmp) * np.var(tr.V[indices_tmp])
                epsilon_V += 1.0 - SSE / VAR
                local_counter += 1
            epsilon_V = epsilon_V / local_counter
            epsilon_V_test[cell_name] = epsilon_V

            # Compute Md*
            Md_star[cell_name] = prediction.computeMD_Kistler(8.0, GIF_K_fit.dt*2.)
            fname = path_results  + cell_name  + '_GIF_K_' + spec_GIF_K + 'Raster.png'
            kernelForPSTH = 50.0
            PVar[cell_name] = prediction.plotRaster(fname, delta=kernelForPSTH)


            #################################################################################################
            #  PLOT TRAINING AND TEST TRACES, MODEL VS EXPERIMENT
            #################################################################################################

            #Comparison for training and test sets w/o inactivation
            V_training = experiment.trainingset_traces[0].V
            I_training = experiment.trainingset_traces[0].I
            (time, V, eta_sum, V_t, S) = GIF_K_fit.simulate(I_training, V_training[0])
            fig = plt.figure(figsize=(10,6), facecolor='white')
            plt.subplot(2,1,1)
            plt.plot(time/1000, V,'--r', lw=0.5, label='GIF-K')
            plt.plot(time/1000, V_training,'black', lw=0.5, label='Data')
            plt.xlim(18,20)
            plt.ylim(-80,20)
            plt.ylabel('Voltage [mV]')
            plt.title('Training')

            V_test = experiment.testset_traces[0].V
            I_test = experiment.testset_traces[0].I
            (time, V, eta_sum, V_t, S) = GIF_K_fit.simulate(I_test, V_test[0])
            plt.subplot(2,1,2)
            plt.plot(time/1000, V,'--r', lw=0.5, label='GIF-K')
            plt.plot(time/1000, V_test,'black', lw=0.5, label='Data')
            plt.xlim(5,7)
            plt.ylim(-80,20)
            plt.xlabel('Times [s]')
            plt.ylabel('Voltage [mV]')
            plt.title('Test')
            plt.legend()
            plt.tight_layout()
            plt.savefig(path_results  + cell_name + '_GIF_K_' + spec_GIF_K + 'simulate.png', format='png')
            plt.close()

            # Figure comparing V_model, V_data and I during training with forced spikes
            (time, V, eta_sum) = GIF_K_fit.simulateDeterministic_forceSpikes(I_training, V_training[0], experiment.trainingset_traces[0].getSpikeTimes())
            I_K = GIF_K_fit.I_K_with_Deterministic_forceSpikes(I_training, V_training[0],
                                                                  experiment.trainingset_traces[0].getSpikeTimes())
            fig = plt.figure(figsize=(10,6), facecolor='white')
            plt.subplot(2,1,1)
            plt.plot(time/1000, I_training,'-b', lw=0.5, label='$I$')
            plt.plot(time / 1000, I_K, '--r', lw=0.5, label='$I_\mathrm{K}$')
            plt.xlim(17,20)
            plt.ylabel('Current [nA]')
            plt.title('Training')
            plt.subplot(2,1,2)
            plt.plot(time/1000, V,'-b', lw=0.5, label='GIF')
            plt.plot(time/1000, V_training,'black', lw=0.5, label='Data')
            plt.xlim(17,20)
            plt.ylim(-75,0)
            plt.ylabel('Time [s]')
            plt.ylabel('Voltage [mV]')
            plt.legend(loc='best')
            plt.savefig(path_results  + cell_name + '_GIF_K_' + spec_GIF_K + 'simulateForcedSpikes_Training.png', format='png')
            plt.close(fig)

    output_file = open('./Results/' + 'GIF_K_'+spec_GIF_K+'FitPerformance.dat','w')
    output_file.write('#Cell name\tMd*\tEpsilonV\tPVar\n')

    for experiment_folder in data_folders_for_separate_experiments:
        for cell_name in CellNames[experiment_folder]:
            output_file.write(cell_name + '\t' + str(Md_star[cell_name]) + '\t' + str(epsilon_V_test[cell_name]) + '\t' + str(PVar[cell_name]) + '\n')
    output_file.close()