def generate(Ntrial, duration):
    '''
    Inputs:
        Ntrial: Number of trials
        duration: Trial duration in (ms)
    Outputs:
        parameters: Simulation parameters values
        train_ref0: Reference spike trains with synapse off
        train_targ0: Target spike trains with synapse off
        train_ref: Reference spike trains with synapse on
        train_targ: Target spike trains with synapse on 
    '''

    #--------------------------------------------------------------------------
    # Define model parameters
    #--------------------------------------------------------------------------

    # Simulation parameters
    time_step = 0.1  #-in (ms)
    defaultclock.dt = time_step * ms
    Fs = 1 / (time_step * .001)  #-in (Hz)

    # Neuron parameters
    cm = 250 * pF  # membrane capacitance
    gm = 25 * nS  # membrane conductance
    tau = cm / gm  # membrane time constant
    El = -70 * mV  # resting potential
    Vt = El + 20 * mV  # spike threshold
    Vr = El + 10 * mV  # reset value
    refractory_period = 0 * ms  # refractory period

    # Background input parameters
    tauI = 10 * ms  # Auto-correlation time constant
    sigmaI = 1. * mvolt  # Noise standard-deviation
    muI = Vt - .5 * mV
    xmin = muI - .5 * mV  # Minimal amplitude of the nonstationary input
    xmax = muI + .5 * mV  # Maximal amplitude
    period = 50.  # Duration of the piecewise constant intervals in (ms)
    print("background input time constant: ", tauI / ms, "(ms)",
          "Input average amplitude: ", muI / mV, "(mV)",
          "Input amplitude range:", .1 * floor(
              (xmax - xmin) / mV / .1), "(mV)", "Input standard-deviation",
          sigmaI / mV, "(mV)", "Interval duration: ", period, "(ms)")

    # Monosynapse parameters
    tauS = 3 * ms  # Synaptic time constant
    Esyn = 0 * mV  # Synaptic reversal potential
    PSC = 25 * pA  # Postsynaptic current ammplitude
    g0 = PSC / (Esyn - muI)
    latency = 1.5 * ms  # Spike transmission delay
    Nphase = 10
    phase = duration / Nphase  # Duration of session with fixed synaptic weight
    wmin = .5  # Minimal synaptic weight
    wmax = 4.  # Maximal synaptic weight
    print("Monosynaptic peak conductance: ", g0 / nsiemens, "(siemens)")

    #--------------------------------------------------------------------------
    # Define model
    #--------------------------------------------------------------------------

    # Define neurons equations
    # -- Reference neuron (synapse turned on)
    eqs_ref = Equations('''
    dV/dt = (-V+mu+sigmaI*I)/tau : volt 
    I : 1 (linked)
    mu : volt
    ''')
    # -- Reference neuron (synapse turned off)
    eqs_ref0 = Equations('''
    dV/dt = (-V+mu+sigmaI*I)/tau : volt 
    I : 1 (linked)
    mu : volt (linked)
    ''')
    # -- Input noise to reference neuron (same for synapse on/off)
    eqs_refnoise = Equations('''
    dx/dt = -x/tauI+(2/tauI)**.5*xi : 1
    ''')
    # -- Target neuron (synapse turned on)
    eqs_targ = Equations('''
    dV/dt = (-V+mu+sigmaI*I-g0/gm*gsyn*(V-Esyn))/tau : volt 
    I : 1 (linked)
    mu : volt (linked)
    #-Monosynaptic input
    dgsyn/dt = -gsyn/tauS : 1
    ''')
    # -- Input noise to target neuron (same for synapse on/off)
    eqs_targnoise = Equations('''
    dx/dt = -x/tauI+(2/tauI)**.5*xi : 1
    ''')

    # 'Synapse on' model
    # -- Constrain the model
    reference = NeuronGroup(Ntrial,
                            model=eqs_ref,
                            threshold='V>Vt',
                            reset='V=Vr',
                            refractory=refractory_period,
                            method='euler')
    target = NeuronGroup(Ntrial,
                         model=eqs_targ,
                         threshold='V>Vt',
                         reset='V=Vr',
                         refractory=refractory_period,
                         method='euler')
    reference.run_regularly('''mu = xmin+(xmax-xmin)*rand()''', dt=period * ms)
    target.mu = linked_var(reference, 'mu')
    ref_noise = NeuronGroup(Ntrial,
                            model=eqs_refnoise,
                            threshold='x>10**6',
                            reset='x=0',
                            method='euler')
    targ_noise = NeuronGroup(Ntrial,
                             model=eqs_targnoise,
                             threshold='x>10**6',
                             reset='x=0',
                             method='euler')
    reference.I = linked_var(ref_noise, 'x')
    target.I = linked_var(targ_noise, 'x')
    # -- Parameter initialization
    reference.V = (Vt - .1 * mV - Vr) * rand(Ntrial) + Vr
    target.V = (Vt - .1 * mV - Vr) * rand(Ntrial) + Vr
    target.gsyn = 0
    ref_noise.x = 2 * rand(Ntrial) - 1
    targ_noise.x = 2 * rand(Ntrial) - 1
    # -- Synaptic connection
    weight_value = np.random.permutation(linspace(wmin, wmax, Nphase))
    weight = TimedArray(weight_value, dt=phase * ms)
    synaptic = Synapses(reference,
                        target,
                        '''w = weight(t) : 1''',
                        on_pre='''
                 gsyn += w
                 ''')
    synaptic.connect(i=arange(Ntrial), j=arange(Ntrial))
    synaptic.delay = latency
    #--Record variables
    Sref = SpikeMonitor(reference)
    Starg = SpikeMonitor(target)
    Msyn = StateMonitor(synaptic, 'w', record=0)

    # 'Synapse off' model
    # -- Constrain the model
    reference0 = NeuronGroup(Ntrial,
                             model=eqs_ref0,
                             threshold='V>Vt',
                             reset='V=Vr',
                             refractory=refractory_period,
                             method='euler')
    target0 = NeuronGroup(Ntrial,
                          model=eqs_targ,
                          threshold='V>Vt',
                          reset='V=Vr',
                          refractory=refractory_period,
                          method='euler')
    reference0.mu = linked_var(reference, 'mu')
    target0.mu = linked_var(reference, 'mu')
    reference0.I = linked_var(ref_noise, 'x')
    target0.I = linked_var(targ_noise, 'x')
    # -- Parameter initialization
    reference0.V = reference.V
    target0.V = target.V
    target0.gsyn = 0
    #--Record variables
    Sref0 = SpikeMonitor(reference0)
    Starg0 = SpikeMonitor(target0)

    run(duration * ms)

    #--------------------------------------------------------------------------
    # Check the resulting spike trains
    #--------------------------------------------------------------------------

    # Represent some of the recorded variables
    FigW = figure()
    xlabel('Time (ms)')
    ylabel('Synaptic weight')
    title('Target cell')
    plot(Msyn.t / ms, Msyn.w[0], 'k')

    # Organize the spike times into two long spike trains
    # -- Synapse on
    train_ref = sort(Sref.t / ms + floor(Sref.t / (ms * phase)) *
                     (-1 + Ntrial) * phase + Sref.i * phase)
    train_targ = sort(Starg.t / ms + floor(Starg.t / (ms * phase)) *
                      (-1 + Ntrial) * phase + Starg.i * phase)
    train = append(train_ref, train_targ)
    cell = int64(append(zeros(len(train_ref)), ones(len(train_targ))))
    # -- Synapse off
    train_ref0 = sort(Sref0.t / ms + floor(Sref0.t / (ms * phase)) *
                      (-1 + Ntrial) * phase + Sref0.i * phase)
    train_targ0 = sort(Starg0.t / ms + floor(Starg0.t / (ms * phase)) *
                       (-1 + Ntrial) * phase + Starg0.i * phase)
    train0 = append(train_ref0, train_targ0)
    cell0 = int64(append(zeros(len(train_ref0)), ones(len(train_targ0))))

    # Basic firing parameters
    print("SYNAPSE ON")
    print("-- Reference train:")
    print("# spikes: ", len(train_ref), "Average firing rate",
          len(train_ref) / (Ntrial * duration * 1.), "CV",
          std(diff(train_ref)) / mean(diff(train_ref)))
    print("-- Target train:")
    print("# spikes: ", len(train_targ), "Average firing rate",
          len(train_targ) / (Ntrial * duration * 1.), "CV",
          std(diff(train_targ)) / mean(diff(train_targ)))
    print("SYNAPSE OFF")
    print("-- Target train:")
    print("# spikes: ", len(train_targ0), "Average firing rate",
          len(train_targ0) / (Ntrial * duration * 1.), "CV",
          std(diff(train_targ0)) / mean(diff(train_targ0)))

    # Compute the correlogram matrix between the two long trains
    lagmax = 100.  # Correlogram window in (ms)
    bine = 1.  # Correlogram time bin in (ms)
    #--WITH SYNAPSE--
    ind_sort = np.argsort(train)
    st = train[ind_sort] * .001
    sc = cell[ind_sort]
    Craw = correlograms(st,
                        sc,
                        sample_rate=Fs,
                        bin_size=bine / 1000.,
                        window_size=lagmax / 1000.)
    lag = (np.arange(len(Craw[0, 1])) - len(Craw[0, 1]) / 2) * bine
    #--WITHOUT SYNAPSE--
    ind_sort = np.argsort(train0)
    st = train0[ind_sort] * .001
    sc = cell0[ind_sort]
    Craw0 = correlograms(st,
                         sc,
                         sample_rate=Fs,
                         bin_size=bine / 1000.,
                         window_size=lagmax / 1000.)

    # Represent the auto- and the cross-correlograms
    FigACG = figure()
    title('Auto-correlograms', fontsize=18)
    xlim(-lagmax / 2., lagmax / 2.)
    xlabel('Time lag  (ms)', fontsize=18)
    ylabel('Firing rate (Hz)', fontsize=18)
    xticks(fontsize=18)
    yticks(fontsize=18)
    plot(lag, Craw[0, 0] / (len(train_ref) * bine * .001), '.-k')
    plot(lag, Craw[1, 1] / (len(train_targ) * bine * .001), '.-b')
    plot(lag, Craw0[1, 1] / (len(train_targ0) * bine * .001), '.-c')
    FigCCG = figure()
    xlim(-lagmax / 2., lagmax / 2.)
    title('Cross-correlograms', fontsize=18)
    xlabel('Time lag  (ms)', fontsize=18)
    ylabel('Firing rate (Hz)', fontsize=18)
    xticks(fontsize=18)
    yticks(fontsize=18)
    plot(lag, Craw[0, 1] / (len(train_ref) * bine * .001), '.-k')
    plot(lag, Craw0[0, 1] / (len(train_ref0) * bine * .001), '.-c')
    #show()

    # Save the relevant model parameters and the resulting spike trains
    parameters = np.array([Ntrial, duration, period, Fs, Nphase])

    return parameters, weight_value, train_ref, train_targ
Ejemplo n.º 2
0
# Define the target train
train = sort(S.i*duration+S.t/ms)

# Basic firing parameters
print("# spikes: ",len(train),
      "Average firing rate",len(train)/(Ntrial*duration*.001),"(Hz)",
      "CV",std(diff(train))/mean(diff(train)))

# Compute the ACG of the spike train
lagmax = 50.
bine = .1
ind_sort = np.argsort(train)
st = train[ind_sort]*.001
cell = int64(zeros(len(train)))
sc = cell[ind_sort]
Araw = correlograms(st,sc,sample_rate=Fs,bin_size=bine/1000.,window_size=lagmax/1000.)
lag = (np.arange(len(Araw[0,0]))-len(Araw[0,0])/2)*bine
FigACG = figure()
xlim(-lagmax/2.,lagmax/2.)
title('Auto-correlogram of Pyramidal Cell',fontsize=18)
xlabel('Time lag  (ms)',fontsize=18)
ylabel('Firing rate (Hz)',fontsize=18)
xticks(fontsize=18)
yticks(fontsize=18)
bar(lag,Araw[0,0]/(len(train)*bine*.001),bine,align='center',color='k',edgecolor='k')

# Show Vm STA
def STA(T,Vm,lag_max,duration,time_step):
    #--cut the part of the spike train that cannot be used for the STA    
    i = 0
    while T[i] < lag_max:
def generate(Ntrial,duration,period):
    '''
    Inputs:
        Ntrial: Number of pairs
        duration: Trial duration in (ms)
        period: Duration of the piecewise constant intervals in (ms)
    Outputs:
        train_ref0: Collection of reference spike trains
        train_targ0: Collection of target spike trains
        params: Simulation parameters values
    '''

    #--------------------------------------------------------------------------
    # Define model parameters 
    #--------------------------------------------------------------------------

    # Simulation parameters
    time_step = 0.1                  
    defaultclock.dt = time_step*ms  # Time step of equations integration 
    Fs = 1/(time_step*.001)          

    # Neuron parameters
    cm = 250*pF               # Membrane capacitance
    gm = 25*nS                # Membrane conductance
    tau = cm/gm               # Membrane time constant
    El = -70*mV               # Resting potential
    Vt = El+20*mV             # Spike threshold
    Vr = El+10*mV             # Reset voltage
    refractory_period = 0*ms  # Refractory period
    print("Spike threshold: ",Vt/mV,"(mV)",
          "Refractory period: ",refractory_period/ms,"(ms)")

    # Background input parameters
    tauI = 10*ms       # Auto-correlation time constant
    sigmaI = 1.*mvolt  # Noise standard-deviation 
    muI = Vt-.5*mV
    xmin = muI-.5*mV   # Minimal amplitude of the nonstationary input 
    xmax = muI+.5*mV   # Maximal amplitude
    print("background input time constant: ",tauI/ms,"(ms)",
          "Input average amplitude: ",muI/mV,"(mV)",
          "Input amplitude range:",.1*floor((xmax-xmin)/mV/.1),"(mV)",
          "Input standard-deviation",sigmaI/mV,"(mV)")

    #--------------------------------------------------------------------------
    # Define model  
    #--------------------------------------------------------------------------

    # Define neurons equations
    # -- Reference neuron
    eqs_ref = Equations('''                
    dV/dt = (-V+mu+sigmaI*I)/tau : volt 
    dI/dt = -I/tauI+(2/tauI)**.5*xi : 1
    mu : volt
    ''')
    # -- Target neuron
    eqs_targ = Equations('''
    dV/dt = (-V+mu+sigmaI*I)/tau : volt 
    dI/dt = -I/tauI+(2/tauI)**.5*xi : 1
    mu : volt (linked)
    ''')

    # Constrain the model
    reference = NeuronGroup(Ntrial,model=eqs_ref,threshold='V>Vt',reset='V=Vr',
                            refractory=refractory_period,method='euler')
    target = NeuronGroup(Ntrial,model=eqs_targ,threshold='V>Vt',reset='V=Vr',
                         refractory=refractory_period,method='euler')
    reference.run_regularly('''mu = xmin+(xmax-xmin)*rand()''',dt=period*ms)
    target.mu = linked_var(reference,'mu')

    # Initialize variables
    reference.V = (Vt-.1*mV-Vr)*rand(Ntrial)+Vr
    reference.I = 2*rand(Ntrial)-1
    reference.mu = xmin+(xmax-xmin)*rand(Ntrial)
    target.V = (Vt-.1*mV-Vr)*rand(Ntrial)+Vr
    target.I = 2*rand(Ntrial)-1

    # Record variables
    Sref = SpikeMonitor(reference) 
    Starg = SpikeMonitor(target)

    # Integrate equations
    run(duration*ms)

    #--------------------------------------------------------------------------
    # Check the resulting spike trains 
    #--------------------------------------------------------------------------

    # Organize the collection of spike train pairs into two long spike trains
    train_ref0 = unique(Sref.i*duration+Sref.t/ms)
    train_targ0 = unique(Starg.i*duration+Starg.t/ms)
    train0 = append(train_ref0,train_targ0)
    cell0 = int64(append(zeros(len(train_ref0)),ones(len(train_targ0))))

    # Basic statistical measure of firing 
    print("Reference train: # spikes/trial",len(train_ref0)/Ntrial*1.,
          "firing rate",len(train_ref0)/(Ntrial*duration*.001),"(Hz)",
          "CV",std(diff(train_ref0))/mean(diff(train_ref0)))
    print("Target train: # spikes/trial",len(train_targ0)/Ntrial*1.,
          "firing rate",len(train_targ0)/(Ntrial*duration*.001),"(Hz)",
          "CV",std(diff(train_targ0))/mean(diff(train_targ0)))

    # Compute the correlogram matrix between the two long trains
    lagmax = 100.  # Correlogram window in (ms)
    bine = 1.      # Correlogram time bin in (ms)
    ind_sort = np.argsort(train0)
    st = train0[ind_sort]*.001
    sc = cell0[ind_sort]
    Craw = correlograms(st,sc,sample_rate=Fs,bin_size=bine/1000.,
                        window_size=lagmax/1000.)
    lag = (np.arange(len(Craw[0,1]))-len(Craw[0,1])/2)*bine

    # Represent the auto- and the cross-correlograms
    FigACG = figure()
    title('Auto-correlograms',fontsize=18)
    xlim(-lagmax/2.,lagmax/2.)
    xlabel('Time lag  (ms)',fontsize=18)
    ylabel('Firing rate (Hz)',fontsize=18)
    xticks(fontsize=18)
    yticks(fontsize=18)
    plot(lag,Craw[0,0]/(len(train_ref0)*bine*.001),'.-k')
    plot(lag,Craw[1,1]/(len(train_targ0)*bine*.001),'.-b')
    FigCCG = figure()
    xlim(-lagmax/2.,lagmax/2.)
    title('Cross-correlogram',fontsize=18)
    xlabel('Time lag  (ms)',fontsize=18)
    ylabel('Firing rate (Hz)',fontsize=18)
    xticks(fontsize=18)
    yticks(fontsize=18)
    plot(lag,Craw[0,1]/(len(train_ref0)*bine*.001),'.-k')
    #show()

    # Save the relevant model parameters and the resulting spike trains
    parameters = np.array([Ntrial,duration,period,Fs])
    
    return parameters,train_ref0,train_targ0
def analyze(parameters, train_ref0, train_targ0):
    '''
    Inputs:
        parameters: Generative simulation parameters
        train_ref0: Reference train set 
        train_targ0: Target train set 
    Outputs:
        Figure 1: ROC curve for monosynapse detection
        Figure 2: Area under the curve for various jitter interval lenghts
    '''

    #--------------------------------------------------------------------------
    # Extract the spike data
    #--------------------------------------------------------------------------

    Ntrial = int(parameters[0])  # Number of pairs
    duration = parameters[1]  # Trial duration in (ms)
    interval_true = parameters[2]  # Nonstationarity timescale in (ms)
    Fs = parameters[3]  # Sampling frequency

    train0 = np.append(train_ref0, train_targ0)
    cell0 = np.int64(
        np.append(np.zeros(len(train_ref0)), np.ones(len(train_targ0))))

    #--------------------------------------------------------------------------
    # Define analysis parameters
    #--------------------------------------------------------------------------

    inject_count = 10  # Number of injected synchronies
    Ninjectedtrial = int(Ntrial / 2.)  # Number of pairs to be injected
    synch_width = 1.  # Width of synchrony window
    latency = 0.  # Synaptic delay (for visual purpose)
    Njitter = 110  # Number of jitter surrogates
    Ndelta = 20  # Number of tested jitter timescales
    Ntest = 100  # Number of detection threshold

    #--------------------------------------------------------------------------
    # Inject synchronous spikes
    #--------------------------------------------------------------------------

    # Inject spikes at random times avoiding present spikes
    Nwidth = int(duration / synch_width)
    allwidths = np.arange(int(Ntrial * duration / synch_width))
    include_index = np.int64(np.floor(train0 / synch_width))
    include_idx = list(set(include_index))
    mask = np.zeros(allwidths.shape, dtype=bool)
    mask[include_idx] = True
    wheretoinject = synch_width * allwidths[~mask]
    alreadythere = synch_width * allwidths[mask]
    widths = np.append(wheretoinject, alreadythere)
    tags = np.append(np.zeros(len(wheretoinject)), np.ones(len(alreadythere)))
    ind_sort = np.argsort(widths)
    widths = widths[ind_sort]
    tags = tags[ind_sort]
    widths = widths[:Ninjectedtrial * Nwidth]
    tags = tags[:Ninjectedtrial * Nwidth]
    widths = np.reshape(widths, (Ninjectedtrial, Nwidth))
    tags = np.reshape(tags, (Ninjectedtrial, Nwidth))
    ind_perm = np.transpose(
        np.random.permutation(np.mgrid[:Nwidth, :Ninjectedtrial][0]))
    widths = widths[np.arange(np.shape(widths)[0])[:, np.newaxis], ind_perm]
    tags = tags[np.arange(np.shape(tags)[0])[:, np.newaxis], ind_perm]
    ind_sort = np.argsort(tags, axis=1)
    widths = widths[np.arange(np.shape(widths)[0])[:, np.newaxis], ind_sort]
    tags = tags[np.arange(np.shape(tags)[0])[:, np.newaxis], ind_sort]
    train_inject = np.ravel(widths[:, :inject_count])  # Injected spike trains
    train_ref = np.sort(np.append(train_ref0, train_inject))
    train_targ = np.sort(np.append(train_targ0, train_inject + latency))
    train = np.append(train_ref, train_targ)
    cell = np.int64(
        np.append(np.zeros(len(train_ref)), np.ones(len(train_targ))))

    #--------------------------------------------------------------------------
    # Check the impact of injection
    #--------------------------------------------------------------------------

    lagmax = 100.  # Correlogram window in (ms)
    bine = 1.  # Correlogram time bin in (ms)

    # Select one pair of reference-target trains (without and with injection)
    ind_sort = np.argsort(train0)
    T0 = train0[ind_sort]
    G0 = cell0[ind_sort]
    ind_sort = np.argsort(train)
    T = train[ind_sort]
    G = cell[ind_sort]
    i = 0
    j = 0
    while T0[i] < duration or T[j] < duration:
        if T0[i] < duration:
            i += 1
        if T[j] < duration:
            j += 1
    T0 = T0[:i]
    G0 = G0[:i]
    T = T[:j]
    G = G[:j]

    # Compute the correlogram matrix for the chosen pairs
    ind_sort = np.argsort(T0)
    st = T0[ind_sort] * .001
    sc = G0[ind_sort]
    C0 = correlograms(st,
                      sc,
                      sample_rate=Fs,
                      bin_size=bine / 1000.,
                      window_size=lagmax / 1000.)
    lag = (np.arange(len(C0[0, 1])) - len(C0[0, 1]) / 2.) * bine
    ind_sort = np.argsort(T)
    st = T[ind_sort] * .001
    sc = G[ind_sort]
    C = correlograms(st,
                     sc,
                     sample_rate=Fs,
                     bin_size=bine / 1000.,
                     window_size=lagmax / 1000.)

    # Represent the cross-correlograms
    FigCCG = plt.figure()
    plt.xlim(-lagmax / 2., lagmax / 2.)
    plt.title('Cross-Correlogram', fontsize=18)
    plt.xlabel('Time lag  (ms)', fontsize=18)
    plt.ylabel('Firing rate (Hz)', fontsize=18)
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)
    plt.plot(lag, C[0, 1] / (len(train_ref) * bine * .001), '.-k')
    plt.plot(lag, C0[0, 1] / (len(train_ref) * bine * .001), '--c')

    #--------------------------------------------------------------------------
    # Compute Receiver Operating Characteristic curve
    #--------------------------------------------------------------------------

    # Remove the synaptic delay for synchrony computation
    train_targ = np.sort(np.append(train_targ0, train_inject))
    train = np.append(train_ref, train_targ)
    cell = np.int64(
        np.append(np.zeros(len(train_ref)), np.ones(len(train_targ))))

    # Count the number of total observed synchronies per pair
    Tref = synch_width * np.floor(train_ref / synch_width)
    Ttarg = synch_width * np.floor(train_targ / synch_width)
    Tsynch = np.array(list(set(Tref) & set(Ttarg)))
    synch_count = np.bincount(np.int64(np.floor(Tsynch / duration)),
                              minlength=Ntrial)

    # Compute the true positive and false positive rates for a range of detection thresholds
    delta_range = np.concatenate((np.linspace(5, 100, int(Ndelta / 2.)),
                                  np.linspace(100, 500, int(Ndelta / 2.))))
    pvalue_inj = np.zeros((Ndelta, Ninjectedtrial))
    pvalue_noinj = np.zeros((Ndelta, Ntrial - Ninjectedtrial))
    threshold_range = np.linspace(0, 1., Ntest)
    proba_truepositive = np.zeros((Ndelta, Ntest))
    proba_falsepositive = np.zeros((Ndelta, Ntest))
    for k in range(Ndelta):
        print('Tested timescale no: ', k + 1)
        interval = delta_range[k]
        # Jitter the target trains
        Tjitter = (np.tile(train_targ, Njitter) +
                   np.sort(np.tile(np.arange(Njitter), len(train_targ))) *
                   Ntrial * duration)
        Tjitter = (interval * np.floor(Tjitter / interval) +
                   np.random.uniform(0, interval, len(Tjitter)))
        Tjitter = synch_width * np.floor(Tjitter / synch_width)
        # Compute the p-values under the jitter null
        Tref_jitter = (np.tile(Tref, Njitter) +
                       np.sort(np.tile(np.arange(Njitter), len(Tref))) *
                       Ntrial * duration)
        Tsynch_jitter = np.array(list(set(Tref_jitter) & set(Tjitter)))
        jitter_synchrony = np.bincount(np.int64(
            np.floor(Tsynch_jitter / duration)),
                                       minlength=Ntrial * Njitter)
        observed_synchrony = np.tile(synch_count, Njitter)
        comparison = np.reshape(
            np.sign(np.sign(jitter_synchrony - observed_synchrony) + 1),
            (Njitter, Ntrial))
        pvalue = (1 + np.sum(comparison, axis=0)) / (Njitter + 1.)
        pvalue_inj[
            k, :] = pvalue[:Ninjectedtrial]  # Correspond to injected trains
        pvalue_noinj[k, :] = pvalue[
            Ninjectedtrial:]  # Correspond to non-injected trains
        # Compute the detection probabilities for each threshold value
        th = np.reshape(np.tile(threshold_range, Ninjectedtrial),
                        (Ninjectedtrial, Ntest))
        pval = np.tile(np.reshape(pvalue_inj[k, :], (Ninjectedtrial, 1)),
                       Ntest)
        proba_truepositive[k, :] = np.sum(np.sign(np.sign(th - pval) + 1),
                                          axis=0)
        pval = np.tile(np.reshape(pvalue_noinj[k, :], (Ninjectedtrial, 1)),
                       Ntest)
        proba_falsepositive[k, :] = np.sum(np.sign(np.sign(th - pval) + 1),
                                           axis=0)
    proba_truepositive = proba_truepositive / Ninjectedtrial
    proba_falsepositive = proba_falsepositive / (Ntrial - Ninjectedtrial)

    #--------------------------------------------------------------------------
    # Represent the classification result
    #--------------------------------------------------------------------------

    cm = plt.get_cmap('gist_rainbow')
    FigROC = plt.figure()
    ax = FigROC.add_subplot(111)
    ax.set_color_cycle([cm(1. * i / Ndelta) for i in range(Ndelta)])
    plt.title('ROC curve', fontsize=18)
    plt.xlabel('False positive rate', fontsize=18)
    plt.ylabel('True positive rate', fontsize=18)
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)
    plt.xlim(-.1, 1.1)
    plt.ylim(-.1, 1.1)
    clr = np.linspace(0, .5, Ndelta)
    x = np.arange(0, 1, .01)
    y = np.zeros((Ndelta, Ntest))
    area = np.zeros(Ndelta)
    for i in range(Ndelta):
        ax.plot(proba_falsepositive[i, :], proba_truepositive[i, :], '.-')
        # Compute the area under the ROC curve
        Fx = interpolate.interp1d(proba_falsepositive[i, :],
                                  proba_truepositive[i, :],
                                  bounds_error=False,
                                  fill_value=0.,
                                  kind='nearest')
        y_interp = Fx(x)
        area[i] = np.sum(y_interp) * (x[1] - x[0])
    plt.plot([0, 1], [0, 1], '--b')
    plt.plot(.5 * np.ones(2), [0, 1], '--b')
    plt.plot([0, 1], .5 * np.ones(2), '--b')

    FigDroc = plt.figure()
    plt.title('Injection Classifier Quantification', fontsize=18)
    plt.xlabel('Jitter interval length (ms)', fontsize=18)
    plt.ylabel('Area under the ROC curve', fontsize=18)
    plt.xticks(fontsize=18)
    plt.yticks(fontsize=18)
    plt.plot(interval_true * np.ones(2), [np.amin(area), np.amax(area)], '--r')
    plt.plot(delta_range, area, 'o-k')
    plt.show()
Ejemplo n.º 5
0
    def compute(
        self,
        ref,
        event,
        quantparam=None,
        binsize=0.01,
        window=1,
        fs=1250,
        nQuantiles=10,
        period=None,
    ):
        """psth of 'event' with respect to 'ref'

        Args:
            ref (array): 1-D array of timings of reference event in seconds
            event (1D array): timings of events whose psth will be calculated
            quantparam (1D array): values used to divide 'ref' into quantiles
            binsize (float, optional): [description]. Defaults to 0.01.
            window (int, optional): [description]. Defaults to 1.
            nQuantiles (int, optional): [description]. Defaults to 10.

        Returns:
            [type]: [description]
        """

        # --- parameters----------
        if period is not None:
            event = event[(event > period[0]) & (event < period[1])]
            if quantparam is not None:
                quantparam = quantparam[(ref > period[0]) & (ref < period[1])]
            ref = ref[(ref > period[0]) & (ref < period[1])]

        if quantparam is not None:
            assert len(event) == len(quantparam), print("length must be same")
            quantiles = pd.qcut(quantparam, nQuantiles, labels=False)

            quants, eventid = [], []
            for category in range(nQuantiles):
                indx = np.where(quantiles == category)[0]
                quants.append(ref[indx])
                eventid.append(category * np.ones(len(indx)).astype(int))

            quants.append(event)
            eventid.append(
                ((nQuantiles + 1) * np.ones(len(event))).astype(int))

            quants = np.concatenate(quants)
            eventid = np.concatenate(eventid)
        else:
            quants = np.concatenate((ref, event))
            eventid = np.concatenate(
                [np.ones(len(ref)), 2 * np.ones(len(event))]).astype(int)

        sort_ind = np.argsort(quants)

        ccg = correlograms(
            quants[sort_ind],
            eventid[sort_ind],
            sample_rate=fs,
            bin_size=binsize,
            window_size=window,
        )

        self.psth = ccg[:-1, -1, :]

        return self.psth