def bsub_mean(self):
        """Return a baseline-subtracted, average evoked response trace between two cells.

        All traces are downsampled to the minimum sample rate in the set.
        """
        if len(self) == 0:
            return None

        if self._bsub_mean is None:
            responses = self.responses
            baselines = self.baselines
            
            # downsample all traces to the same rate
            # yarg: how does this change SNR?

            avg = TraceList([r.copy(t0=0) for r in responses]).mean()
            avg_baseline = TraceList([b.copy(t0=0) for b in baselines]).mean().data

            # subtract baseline
            baseline = np.median(avg_baseline)
            bsub = avg.data - baseline
            result = avg.copy(data=bsub)
            assert len(result.time_values) == len(result)

            # Attach some extra metadata to the result:
            result.meta['baseline'] = avg_baseline
            result.meta['baseline_med'] = baseline
            if len(avg_baseline) == 0:
                result.meta['baseline_std'] = None
            else:
                result.meta['baseline_std'] = scipy.signal.detrend(avg_baseline).std()

            self._bsub_mean = result

        return self._bsub_mean
Exemple #2
0
    def bsub_mean(self):
        """Return a baseline-subtracted, average evoked response trace between two cells.

        All traces are downsampled to the minimum sample rate in the set.
        """
        if len(self) == 0:
            return None

        if self._bsub_mean is None:
            responses = self.responses
            baselines = self.baselines
            
            # downsample all traces to the same rate
            # yarg: how does this change SNR?
            avg = TraceList([r.copy(t0=0) for r in responses]).mean()
            avg_baseline = TraceList([b.copy(t0=0) for b in baselines]).mean().data

            # subtract baseline
            baseline = np.median(avg_baseline)
            bsub = avg.data - baseline

            result = avg.copy(data=bsub)
            assert len(result.time_values) == len(result)

            # Attach some extra metadata to the result:
            result.meta['baseline'] = avg_baseline
            result.meta['baseline_med'] = baseline
            if len(avg_baseline) == 0:
                result.meta['baseline_std'] = None
            else:
                result.meta['baseline_std'] = scipy.signal.detrend(avg_baseline).std()

            self._bsub_mean = result

        return self._bsub_mean
def train_response_plot(expt_list, name=None, summary_plots=[None, None], color=None):
    ind_base_subtract = []
    rec_base_subtract = []
    train_plots = pg.plot()
    train_plots.setLabels(left=('Vm', 'V'))
    tau =15e-3
    lp = 1000
    for expt in expt_list:
        for pre, post in expt.connections:
            if expt.cells[pre].cre_type == cre_type[0] and expt.cells[post].cre_type == cre_type[1]:
                print ('Processing experiment: %s' % (expt.nwb_file))
                ind = []
                rec = []
                analyzer = DynamicsAnalyzer(expt, pre, post)
                train_responses = analyzer.train_responses
                artifact = analyzer.cross_talk()
                if artifact > 0.03e-3:
                    continue
                for i, stim_params in enumerate(train_responses.keys()):
                     rec_t = int(np.round(stim_params[1] * 1e3, -1))
                     if stim_params[0] == 50 and rec_t == 250:
                        pulse_offsets = analyzer.pulse_offsets
                        if len(train_responses[stim_params][0]) != 0:
                            ind_group = train_responses[stim_params][0]
                            rec_group = train_responses[stim_params][1]
                            for j in range(len(ind_group)):
                                ind.append(ind_group.responses[j])
                                rec.append(rec_group.responses[j])
                if len(ind) > 5:
                    ind_avg = TraceList(ind).mean()
                    rec_avg = TraceList(rec).mean()
                    rec_avg.t0 = 0.3
                    base = float_mode(ind_avg.data[:int(10e-3 / ind_avg.dt)])
                    ind_base_subtract.append(ind_avg.copy(data=ind_avg.data - base))
                    rec_base_subtract.append(rec_avg.copy(data=rec_avg.data - base))
                    train_plots.plot(ind_avg.time_values, ind_avg.data - base)
                    train_plots.plot(rec_avg.time_values, rec_avg.data - base)
                    app.processEvents()
    if len(ind_base_subtract) != 0:
        print (name + ' n = %d' % len(ind_base_subtract))
        ind_grand_mean = TraceList(ind_base_subtract).mean()
        rec_grand_mean = TraceList(rec_base_subtract).mean()
        ind_grand_mean_dec = bessel_filter(exp_deconvolve(ind_grand_mean, tau), lp)
        train_plots.addLegend()
        train_plots.plot(ind_grand_mean.time_values, ind_grand_mean.data, pen={'color': 'g', 'width': 3}, name=name)
        train_plots.plot(rec_grand_mean.time_values, rec_grand_mean.data, pen={'color': 'g', 'width': 3}, name=name)
        #train_plots.plot(ind_grand_mean_dec.time_values, ind_grand_mean_dec.data, pen={'color': 'g', 'dash': [1,5,3,2]})
        train_amps = train_amp([ind_base_subtract, rec_base_subtract], pulse_offsets, '+')
        if ind_grand_mean is not None:
            train_plots = summary_plot_train(ind_grand_mean, plot=summary_plots[0], color=color,
                                             name=(legend + ' 50 Hz induction'))
            train_plots = summary_plot_train(rec_grand_mean, plot=summary_plots[0], color=color)
            train_plots2 = summary_plot_train(ind_grand_mean_dec, plot=summary_plots[1], color=color,
                                              name=(legend + ' 50 Hz induction'))
            return train_plots, train_plots2, train_amps
    else:
        print ("No Traces")
        return None
def get_amplitude(response_list):
    avg_trace = TraceList(response_list).mean()
    data = avg_trace.data
    dt = avg_trace.dt
    base = float_mode(data[:int(10e-3 / dt)])
    bsub_mean = avg_trace.copy(data=data - base)
    neg = bsub_mean.data[int(13e-3 / dt):].min()
    pos = bsub_mean.data[int(13e-3 / dt):].max()
    avg_amp = neg if abs(neg) > abs(pos) else pos
    amp_sign = '-' if avg_amp < 0 else '+'
    peak_ind = list(bsub_mean.data).index(avg_amp)
    peak_t = bsub_mean.time_values[peak_ind]
    return bsub_mean, avg_amp, amp_sign, peak_t
def train_response_plot(expt_list, name=None, summary_plots=[None, None], color=None):
    grand_train = [[], []]
    train_plots = pg.plot()
    train_plots.setLabels(left=('Vm', 'V'))
    tau =15e-3
    lp = 1000
    for expt in expt_list:
        for pre, post in expt.connections:
            if expt.cells[pre].cre_type == cre_type[0] and expt.cells[post].cre_type == cre_type[1]:
                print ('Processing experiment: %s' % (expt.nwb_file))

                train_responses, artifact = get_response(expt, pre, post, analysis_type='train')
                if artifact > 0.03e-3:
                    continue

                train_filter = response_filter(train_responses['responses'], freq_range=[50, 50], train=0, delta_t=250)
                pulse_offsets = response_filter(train_responses['pulse_offsets'], freq_range=[50, 50], train=0, delta_t=250)

                if len(train_filter[0]) > 5:
                    ind_avg = TraceList(train_filter[0]).mean()
                    rec_avg = TraceList(train_filter[1]).mean()
                    rec_avg.t0 = 0.3
                    grand_train[0].append(ind_avg)
                    grand_train[1].append(rec_avg)
                    train_plots.plot(ind_avg.time_values, ind_avg.data)
                    train_plots.plot(rec_avg.time_values, rec_avg.data)
                    app.processEvents()
    if len(grand_train[0]) != 0:
        print (name + ' n = %d' % len(grand_train[0]))
        ind_grand_mean = TraceList(grand_train[0]).mean()
        rec_grand_mean = TraceList(grand_train[1]).mean()
        ind_grand_mean_dec = bessel_filter(exp_deconvolve(ind_grand_mean, tau), lp)
        train_plots.addLegend()
        train_plots.plot(ind_grand_mean.time_values, ind_grand_mean.data, pen={'color': 'g', 'width': 3}, name=name)
        train_plots.plot(rec_grand_mean.time_values, rec_grand_mean.data, pen={'color': 'g', 'width': 3}, name=name)
        train_amps = train_amp([grand_train[0], grand_train[1]], pulse_offsets, '+')
        if ind_grand_mean is not None:
            train_plots = summary_plot_train(ind_grand_mean, plot=summary_plots[0], color=color,
                                             name=(legend + ' 50 Hz induction'))
            train_plots = summary_plot_train(rec_grand_mean, plot=summary_plots[0], color=color)
            train_plots2 = summary_plot_train(ind_grand_mean_dec, plot=summary_plots[1], color=color,
                                              name=(legend + ' 50 Hz induction'))
            return train_plots, train_plots2, train_amps
    else:
        print ("No Traces")
        return None
Exemple #6
0
def first_pulse_features(pair, pulse_responses, pulse_response_amps):

    avg_psp = TraceList(pulse_responses).mean()
    dt = avg_psp.dt
    avg_psp_baseline = float_mode(avg_psp.data[:int(10e-3 / dt)])
    avg_psp_bsub = avg_psp.copy(data=avg_psp.data - avg_psp_baseline)
    lower_bound = -float('inf')
    upper_bound = float('inf')
    xoffset = pair.connection_strength.ic_fit_xoffset
    if xoffset is None:
        xoffset = 14 * 10e-3
    synapse_type = pair.connection_strength.synapse_type
    if synapse_type == 'ex':
        amp_sign = '+'
    elif synapse_type == 'in':
        amp_sign = '-'
    else:
        raise Exception(
            'Synapse type is not defined, reconsider fitting this pair %s %d->%d'
            % (pair.expt_id, pair.pre_cell_id, pair.post_cell_id))

    weight = np.ones(len(
        avg_psp.data)) * 10.  # set everything to ten initially
    weight[int(10e-3 / dt):int(12e-3 / dt)] = 0.  # area around stim artifact
    weight[int(12e-3 / dt):int(19e-3 / dt)] = 30.  # area around steep PSP rise

    psp_fits = fit_psp(avg_psp,
                       xoffset=(xoffset, lower_bound, upper_bound),
                       yoffset=(avg_psp_baseline, lower_bound, upper_bound),
                       sign=amp_sign,
                       weight=weight)

    amp_cv = np.std(pulse_response_amps) / np.mean(pulse_response_amps)

    features = {
        'ic_fit_amp': psp_fits.best_values['amp'],
        'ic_fit_latency': psp_fits.best_values['xoffset'] - 10e-3,
        'ic_fit_rise_time': psp_fits.best_values['rise_time'],
        'ic_fit_decay_tau': psp_fits.best_values['decay_tau'],
        'ic_amp_cv': amp_cv,
        'avg_psp': avg_psp_bsub.data
    }
    #'ic_fit_NRMSE': psp_fits.nrmse()} TODO: nrmse not returned from psp_fits?

    return features
 def clicked(sp, pts):
     traces = pts[0].data()['traces']
     print([t.amp for t in traces])
     plt = pg.plot()
     bsub = [t.copy(data=t.data - np.median(t.time_slice(0, 1e-3).data)) for t in traces]
     for t in bsub:
         plt.plot(t.time_values, t.data, pen=(0, 0, 0, 50))
     mean = TraceList(bsub).mean()
     plt.plot(mean.time_values, mean.data, pen='g')
Exemple #8
0
 def plot_element_data(self,
                       pre_class,
                       post_class,
                       element,
                       field_name,
                       color='g',
                       trace_plt=None):
     trace_plt = None
     val = element[field_name].mean()
     line = pg.InfiniteLine(val,
                            pen={
                                'color': color,
                                'width': 2
                            },
                            movable=False)
     scatter = None
     baseline_window = int(db.default_sample_rate * 5e-3)
     values = []
     traces = []
     point_data = []
     for pair, value in element[field_name].iteritems():
         if np.isnan(value):
             continue
         traces = []
         if trace_plt is not None:
             trace = cs.ic_average_response if field_name.startswith(
                 'ic') else cs.vc_average_response
             x_offset = cs.ic_fit_latency if field_name.startswith(
                 'ic') else cs.vc_fit_latency
             trace = format_trace(trace,
                                  baseline_window,
                                  x_offset,
                                  align='psp')
             trace_plt.plot(trace.time_values, trace.data)
             traces.append(trace)
         values.append(value)
         y_values = pg.pseudoScatter(np.asarray(values, dtype=float),
                                     spacing=1)
         scatter = pg.ScatterPlotItem(symbol='o',
                                      brush=(color + (150, )),
                                      pen='w',
                                      size=12)
         scatter.setData(values, y_values + 10.)
         if trace_plt is not None:
             grand_trace = TraceList(traces).mean()
             trace_plt.plot(grand_trace.time_values,
                            grand_trace.data,
                            pen={
                                'color': color,
                                'width': 3
                            })
             units = 'V' if field_name.startswith('ic') else 'A'
             trace_plt.setXRange(0, 20e-3)
             trace_plt.setLabels(left=('', units),
                                 bottom=('Time from stimulus', 's'))
     return line, scatter
def first_pulse_features(pair, pulse_responses, pulse_response_amps):

    avg_psp = TraceList(pulse_responses).mean()
    dt = avg_psp.dt
    avg_psp_baseline = float_mode(avg_psp.data[:int(10e-3/dt)])
    avg_psp_bsub = avg_psp.copy(data=avg_psp.data - avg_psp_baseline)
    lower_bound = -float('inf')
    upper_bound = float('inf')
    xoffset = pair.connection_strength.ic_fit_xoffset
    if xoffset is None:
        xoffset = 14*10e-3
    synapse_type = pair.connection_strength.synapse_type
    if synapse_type == 'ex':
        amp_sign = '+'
    elif synapse_type == 'in':
        amp_sign = '-'
    else:
        raise Exception('Synapse type is not defined, reconsider fitting this pair %s %d->%d' %
                        (pair.expt_id, pair.pre_cell_id, pair.post_cell_id))

    weight = np.ones(len(avg_psp.data)) * 10.  # set everything to ten initially
    weight[int(10e-3 / dt):int(12e-3 / dt)] = 0.  # area around stim artifact
    weight[int(12e-3 / dt):int(19e-3 / dt)] = 30.  # area around steep PSP rise

    psp_fits = fit_psp(avg_psp,
                       xoffset=(xoffset, lower_bound, upper_bound),
                       yoffset=(avg_psp_baseline, lower_bound, upper_bound),
                       sign=amp_sign,
                       weight=weight)

    amp_cv = np.std(pulse_response_amps)/np.mean(pulse_response_amps)

    features = {'ic_fit_amp': psp_fits.best_values['amp'],
                'ic_fit_latency': psp_fits.best_values['xoffset'] - 10e-3,
                'ic_fit_rise_time': psp_fits.best_values['rise_time'],
                'ic_fit_decay_tau': psp_fits.best_values['decay_tau'],
                'ic_amp_cv': amp_cv,
                'avg_psp': avg_psp_bsub.data}
                #'ic_fit_NRMSE': psp_fits.nrmse()} TODO: nrmse not returned from psp_fits?

    return features
 def clicked(sp, pts):
     data = pts[0].data()
     print("-----------------------\nclicked:", data['rise_time'], data['amp'], data['prediction'], data['confidence'])
     for r in data['results']:
         print({k:r[k] for k in classifier.features})
     traces = data['traces']
     plt = pg.plot()
     bsub = [t.copy(data=t.data - np.median(t.time_slice(0, 1e-3).data)) for t in traces]
     for t in bsub:
         plt.plot(t.time_values, t.data, pen=(0, 0, 0, 50))
     mean = TraceList(bsub).mean()
     plt.plot(mean.time_values, mean.data, pen='g')
def get_average_pulse_response(pair, desired_clamp='ic'):
    """
    Inputs
    ------
    pair: multipatch_analysis.database.database.Pair object

    desired_clamp: string
        Specifies whether current or voltage clamp sweeps are desired.
        Options are:
            'ic': current clamp
            'vc': voltage clamp

    Returns
    -------
    Note that all returned variables are set to None if there are no acceptable (qc pasing) sweeps
    pulse_responses: TraceList 
        traces where the start of each trace is 10 ms before the spike 
    pulse_ids: list of ints
        pulse ids of *pulse_responses*
    psp_amps_measured: list of floats
        amplitude of *pulse_responses* from the *pulse_response* table
    freq: list of floats
        the stimulation frequency corresponding to the *pulse_responses* 
    avg_psp: Trace
        average of the pulse_responses
    measured_relative_amp: float
        measured amplitude relative to baseline
    measured_baseline: float
        value of baseline
    """
    # get pulses that pass qc
    pulse_responses, pulse_ids, psp_amps_measured, freq = extract_first_pulse_info_from_Pair_object(
        pair, desired_clamp=desired_clamp)

    # if pulses are returned take the average
    if len(pulse_responses) > 0:
        avg_psp = TraceList(pulse_responses).mean()
    else:
        return None, None, None, None, None, None, None

    # get the measured baseline and amplitude of psp
    measured_relative_amp, measured_baseline = measure_amp(
        avg_psp.data, [0, int((time_before_spike - 1.e-3) / avg_psp.dt)],
        [int((time_before_spike + .5e-3) / avg_psp.dt), -1])

    return pulse_responses, pulse_ids, psp_amps_measured, freq, avg_psp, measured_relative_amp, measured_baseline
def first_pulse_plot(expt_list, name=None, summary_plot=None, color=None, scatter=0):
    amp_plots = pg.plot()
    amp_plots.setLabels(left=('Vm', 'V'))
    amp_base_subtract = []
    avg_ests = []
    for expt in expt_list:
        for pre, post in expt.connections:
            if expt.cells[pre].cre_type == cre_type[0] and expt.cells[post].cre_type == cre_type[1]:
                avg_est, avg_amp, n_sweeps = responses(expt, pre, post)
                if expt.cells[pre].cre_type in EXCITATORY_CRE_TYPES and avg_est < 0:
                    continue
                elif expt.cells[pre].cre_type in INHIBITORY_CRE_TYPES and avg_est > 0:
                    continue
                if n_sweeps >= 10:
                    avg_amp.t0 = 0
                    avg_ests.append(avg_est)
                    base = float_mode(avg_amp.data[:int(10e-3 / avg_amp.dt)])
                    amp_base_subtract.append(avg_amp.copy(data=avg_amp.data - base))

                    current_connection_HS = post, pre
                    if len(expt.connections) > 1 and args.recip is True:
                        for i,x in enumerate(expt.connections):
                            if x == current_connection_HS:  # determine if a reciprocal connection
                                amp_plots.plot(avg_amp.time_values, avg_amp.data - base, pen={'color': 'r', 'width': 1})
                                break
                            elif x != current_connection_HS and i == len(expt.connections) - 1:  # reciprocal connection was not found
                                amp_plots.plot(avg_amp.time_values, avg_amp.data - base)
                    else:
                        amp_plots.plot(avg_amp.time_values, avg_amp.data - base)

                    app.processEvents()

    if len(amp_base_subtract) != 0:
        print(name + ' n = %d' % len(amp_base_subtract))
        grand_mean = TraceList(amp_base_subtract).mean()
        grand_est = np.mean(np.array(avg_ests))
        amp_plots.addLegend()
        amp_plots.plot(grand_mean.time_values, grand_mean.data, pen={'color': 'g', 'width': 3}, name=name)
        amp_plots.addLine(y=grand_est, pen={'color': 'g'})
        if grand_mean is not None:
            print(legend + ' Grand mean amplitude = %f' % grand_est)
            summary_plots = summary_plot_pulse(grand_mean, avg_ests, grand_est, labels=['Vm', 'V'], titles='Amplitude', i=scatter, plot=summary_plot, color=color, name=legend)
            return avg_ests, summary_plots
    else:
        print ("No Traces")
        return None, avg_ests, None, None
Exemple #13
0
 def plot_element_data(self,
                       pre_class,
                       post_class,
                       element,
                       field_name,
                       color='g',
                       trace_plt=None):
     summary = element.agg(self.summary_stat)
     val = summary[field_name]['metric_summary']
     line = pg.InfiniteLine(val,
                            pen={
                                'color': color,
                                'width': 2
                            },
                            movable=False)
     scatter = None
     baseline_window = int(db.default_sample_rate * 5e-3)
     traces = []
     point_data = []
     connections = element[element['connected'] == True].index.tolist()
     for pair in connections:
         cs = pair.connection_strength
         trace = cs.ic_average_response
         if trace is not None:
             x_offset = cs.ic_fit_xoffset
             trace = format_trace(trace,
                                  baseline_window,
                                  x_offset,
                                  align='psp')
             trace_plt.plot(trace.time_values, trace.data)
             traces.append(trace)
     grand_trace = TraceList(traces).mean()
     name = ('%s->%s, n=%d' % (pre_class, post_class, len(traces)))
     trace_plt.plot(grand_trace.time_values,
                    grand_trace.data,
                    pen={
                        'color': color,
                        'width': 3
                    },
                    name=name)
     trace_plt.setXRange(0, 20e-3)
     trace_plt.setLabels(left=('', 'V'), bottom=('Time from stimulus', 's'))
     return line, scatter
def trace_avg(response_list):
# doc string commented out to discourage code reuse given the change of values of t0
#    """
#    Parameters
#    ----------
#    response_list : list of neuroanalysis.data.TraceView objects
#        neuroanalysis.data.TraceView object contains waveform data. 
#        
#    Returns
#    -------
#    bsub_mean : neuroanalysis.data.Trace object
#        averages and baseline subtracts the ephys waveform data in the 
#        input response_list TraceView objects and replaces the .t0 value with 0. 
#    
#    """
    for trace in response_list: 
        trace.t0 = 0  #align traces for the use of TraceList().mean() funtion
    avg_trace = TraceList(response_list).mean() #returns the average of the wave form in a of a neuroanalysis.data.Trace object 
    bsub_mean = bsub(avg_trace) #returns a copy of avg_trace but replaces the ephys waveform in .data with the base_line subtracted wave_form
    
    return bsub_mean
Exemple #15
0
 def get_tseries(self,
                 series,
                 bsub=True,
                 align='stim',
                 bsub_window=(-3e-3, 0)):
     """Return a TraceList of timeseries, optionally baseline-subtracted and time-aligned.
     
     Parameters
     ----------
     series : str
         "stim", "pre", or "post"
     """
     assert series in (
         'stim', 'pre',
         'post'), "series must be one of 'stim', 'pre', or 'post'"
     tseries = []
     for i, sr in enumerate(self.srs):
         ts = getattr(sr, series + '_tseries')
         if bsub:
             bstart = sr.stim_pulse.onset_time + bsub_window[0]
             bstop = sr.stim_pulse.onset_time + bsub_window[1]
             baseline = np.median(ts.time_slice(bstart, bstop).data)
             ts = ts - baseline
         if align is not None:
             if align == 'stim':
                 t_align = sr.stim_pulse.onset_time
             elif align == 'pre':
                 t_align = sr.stim_pulse.spikes[0].max_dvdt_time
             elif align == 'post':
                 raise NotImplementedError()
             else:
                 raise ValueError("invalid time alignment mode %r" % align)
             t_align = t_align or 0
             ts = ts.copy(t0=ts.t0 - t_align)
         tseries.append(ts)
     return TraceList(tseries)
def analyze_pair_connectivity(amps, sign=None):
    """Given response strength records for a single pair, generate summary
    statistics characterizing strength, latency, and connectivity.
    
    Parameters
    ----------
    amps : dict
        Contains foreground and background strength analysis records
        (see input format below)
    sign : None, -1, or +1
        If None, then automatically determine whether to treat this connection as
        inhibitory or excitatory.

    Input must have the following structure::
    
        amps = {
            ('ic', 'fg'): recs, 
            ('ic', 'bg'): recs,
            ('vc', 'fg'): recs, 
            ('vc', 'bg'): recs,
        }
        
    Where each *recs* must be a structured array containing fields as returned
    by get_amps() and get_baseline_amps().
    
    The overall strategy here is:
    
    1. Make an initial decision on whether to treat this pair as excitatory or
       inhibitory, based on differences between foreground and background amplitude
       measurements
    2. Generate mean and stdev for amplitudes, deconvolved amplitudes, and deconvolved
       latencies
    3. Generate KS test p values describing the differences between foreground
       and background distributions for amplitude, deconvolved amplitude, and
       deconvolved latency    
    """
    requested_sign = sign
    fields = {}  # used to fill the new DB record
    
    # Use KS p value to check for differences between foreground and background
    qc_amps = {}
    ks_pvals = {}
    amp_means = {}
    amp_diffs = {}
    for clamp_mode in ('ic', 'vc'):
        clamp_mode_fg = amps[clamp_mode, 'fg']
        clamp_mode_bg = amps[clamp_mode, 'bg']
        if (len(clamp_mode_fg) == 0 or len(clamp_mode_bg) == 0):
            continue
        for sign in ('pos', 'neg'):
            # Separate into positive/negative tests and filter out responses that failed qc
            qc_field = {'vc': {'pos': 'in_qc_pass', 'neg': 'ex_qc_pass'}, 'ic': {'pos': 'ex_qc_pass', 'neg': 'in_qc_pass'}}[clamp_mode][sign]
            fg = clamp_mode_fg[clamp_mode_fg[qc_field]]
            bg = clamp_mode_bg[clamp_mode_bg[qc_field]]
            qc_amps[sign, clamp_mode, 'fg'] = fg
            qc_amps[sign, clamp_mode, 'bg'] = bg
            if (len(fg) == 0 or len(bg) == 0):
                continue
            
            # Measure some statistics from these records
            fg = fg[sign + '_dec_amp']
            bg = bg[sign + '_dec_amp']
            pval = scipy.stats.ks_2samp(fg, bg).pvalue
            ks_pvals[(sign, clamp_mode)] = pval
            # we could ensure that the average amplitude is in the right direction:
            fg_mean = np.mean(fg)
            bg_mean = np.mean(bg)
            amp_means[sign, clamp_mode] = {'fg': fg_mean, 'bg': bg_mean}
            amp_diffs[sign, clamp_mode] = fg_mean - bg_mean

    if requested_sign is None:
        # Decide whether to treat this connection as excitatory or inhibitory.
        #   strategy: accumulate evidence for either possibility by checking
        #   the ks p-values for each sign/clamp mode and the direction of the deflection
        is_exc = 0
        # print(expt.acq_timestamp, pair.pre_cell.ext_id, pair.post_cell.ext_id)
        for sign in ('pos', 'neg'):
            for mode in ('ic', 'vc'):
                ks = ks_pvals.get((sign, mode), None)
                if ks is None:
                    continue
                # turn p value into a reasonable scale factor
                ks = norm_pvalue(ks)
                dif_sign = 1 if amp_diffs[sign, mode] > 0 else -1
                if mode == 'vc':
                    dif_sign *= -1
                is_exc += dif_sign * ks
                # print("    ", sign, mode, is_exc, dif_sign * ks)
    else:
        is_exc = requested_sign

    if is_exc > 0:
        fields['synapse_type'] = 'ex'
        signs = {'ic':'pos', 'vc':'neg'}
    else:
        fields['synapse_type'] = 'in'
        signs = {'ic':'neg', 'vc':'pos'}

    # compute the rest of statistics for only positive or negative deflections
    for clamp_mode in ('ic', 'vc'):
        sign = signs[clamp_mode]
        fg = qc_amps.get((sign, clamp_mode, 'fg'))
        bg = qc_amps.get((sign, clamp_mode, 'bg'))
        if fg is None or bg is None or len(fg) == 0 or len(bg) == 0:
            fields[clamp_mode + '_n_samples'] = 0
            continue
        
        fields[clamp_mode + '_n_samples'] = len(fg)
        fields[clamp_mode + '_crosstalk_mean'] = np.mean(fg['crosstalk'])
        fields[clamp_mode + '_base_crosstalk_mean'] = np.mean(bg['crosstalk'])
        
        # measure mean, stdev, and statistical differences between
        # fg and bg for each measurement
        for val, field in [('amp', 'amp'), ('deconv_amp', 'dec_amp'), ('latency', 'dec_latency')]:
            f = fg[sign + '_' + field]
            b = bg[sign + '_' + field]
            fields[clamp_mode + '_' + val + '_mean'] = np.mean(f)
            fields[clamp_mode + '_' + val + '_stdev'] = np.std(f)
            fields[clamp_mode + '_base_' + val + '_mean'] = np.mean(b)
            fields[clamp_mode + '_base_' + val + '_stdev'] = np.std(b)
            # statistical tests comparing fg vs bg
            # Note: we use log(1-log(pval)) because it's nicer to plot and easier to
            # use as a classifier input
            tt_pval = scipy.stats.ttest_ind(f, b, equal_var=False).pvalue
            ks_pval = scipy.stats.ks_2samp(f, b).pvalue
            fields[clamp_mode + '_' + val + '_ttest'] = norm_pvalue(tt_pval)
            fields[clamp_mode + '_' + val + '_ks2samp'] = norm_pvalue(ks_pval)


        ### generate the average response and psp fit
        
        # collect all bg and fg traces
        # bg_traces = TraceList([Trace(data, sample_rate=db.default_sample_rate) for data in amps[clamp_mode, 'bg']['data']])
        fg_traces = TraceList()
        for rec in fg:
            t0 = rec['response_start_time'] - rec['max_dvdt_time']   # time-align to presynaptic spike
            trace = Trace(rec['data'], sample_rate=db.default_sample_rate, t0=t0)
            fg_traces.append(trace)
        
        # get averages
        # bg_avg = bg_traces.mean()        
        fg_avg = fg_traces.mean()
        base_rgn = fg_avg.time_slice(-6e-3, 0)
        base = float_mode(base_rgn.data)
        fields[clamp_mode + '_average_response'] = fg_avg.data
        fields[clamp_mode + '_average_response_t0'] = fg_avg.t0
        fields[clamp_mode + '_average_base_stdev'] = base_rgn.std()

        sign = {'pos':'+', 'neg':'-'}[signs[clamp_mode]]
        fg_bsub = fg_avg.copy(data=fg_avg.data - base)  # remove base to help fitting
        try:
            fit = fit_psp(fg_bsub, mode=clamp_mode, sign=sign, xoffset=(1e-3, 0, 6e-3), yoffset=(0, None, None), rise_time_mult_factor=4)              
            for param, val in fit.best_values.items():
                fields['%s_fit_%s' % (clamp_mode, param)] = val
            fields[clamp_mode + '_fit_yoffset'] = fit.best_values['yoffset'] + base
            fields[clamp_mode + '_fit_nrmse'] = fit.nrmse()
        except:
            print("Error in PSP fit:")
            sys.excepthook(*sys.exc_info())
            continue
        
        #global fit_plot
        #if fit_plot is None:
            #fit_plot = FitExplorer(fit)
            #fit_plot.show()
        #else:
            #fit_plot.set_fit(fit)
        #raw_input("Waiting to continue..")

    return fields
Exemple #17
0
                                train_response,
                                freqs,
                                holding,
                                thresh=sweep_threshold,
                                ind_dict=grand_induction,
                                offset_dict=offset_ind)
                            grand_recovery, offset_rec = recovery_summary(
                                train_response,
                                t_rec,
                                holding,
                                thresh=sweep_threshold,
                                rec_dict=grand_recovery,
                                offset_dict=offset_rec)

        if len(grand_pulse_response) > 0:
            grand_pulse_trace = TraceList(grand_pulse_response).mean()
            p2 = trace_plot(grand_pulse_trace,
                            color=avg_color,
                            plot=p2,
                            x_range=[0, 27e-3],
                            name=('n = %d' % len(grand_pulse_response)))
            if len(grand_induction) > 0:
                for f, freq in enumerate(freqs):
                    if freq in grand_induction:
                        offset = offset_ind[freq]
                        ind_pass_qc = train_qc(grand_induction[freq],
                                               offset,
                                               amp=amp_thresh,
                                               sign=sign)
                        n = len(ind_pass_qc[0])
                        if n > 0:
Exemple #18
0
def trace_avg(response_list):
    for trace in response_list:
        trace.t0 = 0
    avg_trace = TraceList(response_list).mean()
    bsub_mean = bsub(avg_trace)
    return bsub_mean
                    offset_dict=pulse_offset_rec,
                    uid=(expt.uid, pre, post))
    for f, freq in enumerate(freqs):
        if freq not in induction_grand.keys():
            print("%d Hz not represented in data set for %s" % (freq, c_type))
            continue
        ind_offsets = pulse_offset_ind[freq]
        qc_plot.clear()
        ind_pass_qc = train_qc(induction_grand[freq],
                               ind_offsets,
                               amp=qc_params[1][c],
                               sign=qc_params[0],
                               plot=qc_plot)
        n_synapses = len(ind_pass_qc[0])
        if n_synapses > 0:
            induction_grand_trace = TraceList(ind_pass_qc[0]).mean()
            ind_rec_grand_trace = TraceList(ind_pass_qc[1]).mean()
            ind_amp = train_amp(ind_pass_qc, ind_offsets, '+')
            ind_amp_grand = np.nanmean(ind_amp, 0)

            if f == 0:
                ind_plot[f, c].setTitle(connection_types[c])
                type = pg.LabelItem('%s -> %s' % connection_types[c])
                type.setParentItem(summary_plot[c, 0])
                type.setPos(50, 0)
            if c == 0:
                label = pg.LabelItem('%d Hz Induction' % freq)
                label.setParentItem(ind_plot[f, c].vb)
                label.setPos(50, 0)
                summary_plot[c, 0].setTitle('Induction')
            ind_plot[f, c].addLegend()
Exemple #20
0
 def plot_element_data(self,
                       pre_class,
                       post_class,
                       element,
                       field_name,
                       color='g',
                       trace_plt=None):
     fn = field_name.split('_all')[0] if field_name.endswith(
         'all') else field_name.split('_first_pulse')[0]
     val = element[field_name].mean()
     line = pg.InfiniteLine(val,
                            pen={
                                'color': color,
                                'width': 2
                            },
                            movable=False)
     scatter = None
     baseline_window = int(db.default_sample_rate * 5e-3)
     values = []
     traces = []
     point_data = []
     for pair, value in element[field_name].iteritems():
         if pair.synapse is not True:
             continue
         if np.isnan(value):
             continue
         if field_name.endswith('all'):
             cs = pair.connection_strength
             trace = cs.ic_average_response if field_name.startswith(
                 'ic') else cs.vc_average_response
             x_offset = cs.ic_fit_xoffset if field_name.startswith(
                 'ic') else cs.vc_fit_xoffset
         elif field_name.endswith('first_pulse'):
             fpf = pair.avg_first_pulse_fit
             if fpf is None:
                 continue
             trace = fpf.ic_avg_psp_data if field_name.startswith(
                 'ic') else fpf.vc_avg_psp_data
             x_offset = fpf.ic_latency if field_name.startswith(
                 'ic') else fpf.vc_latency
         if trace is None:
             continue
         values.append(value)
         trace = format_trace(trace, baseline_window, x_offset, align='psp')
         trace_item = trace_plt.plot(trace.time_values, trace.data)
         point_data.append(pair)
         trace_item.pair = pair
         trace_item.curve.setClickable(True)
         trace_item.sigClicked.connect(self.trace_plot_clicked)
         traces.append(trace)
         self.pair_items[pair.id] = [trace_item]
     y_values = pg.pseudoScatter(np.asarray(values, dtype=float), spacing=1)
     scatter = pg.ScatterPlotItem(symbol='o',
                                  brush=(color + (150, )),
                                  pen='w',
                                  size=12)
     scatter.setData(values, y_values + 10., data=point_data)
     for point in scatter.points():
         pair_id = point.data().id
         self.pair_items[pair_id].append(point)
     scatter.sigClicked.connect(self.scatter_plot_clicked)
     grand_trace = TraceList(traces).mean()
     name = ('%s->%s, n=%d' % (pre_class, post_class, len(traces)))
     trace_plt.plot(grand_trace.time_values,
                    grand_trace.data,
                    pen={
                        'color': color,
                        'width': 3
                    },
                    name=name)
     units = 'V' if field_name.startswith('ic') else 'A'
     trace_plt.setXRange(0, 20e-3)
     trace_plt.setLabels(left=('', units),
                         bottom=('Time from stimulus', 's'))
     return line, scatter
Exemple #21
0
                        else:
                            trace_color = (0, 0, 0, 30)
                        trace_plot(avg_trace, trace_color, plot=synapse_plot[c, 0], x_range=[0, 27e-3])
                        app.processEvents()
#                    decay_response = response_filter(pulse_response, freq_range=[0, 20], holding_range=holding)
#                    qc_list = pulse_qc(response_subset, baseline=2, pulse=None, plot=qc_plot)
#                    if len(qc_list) >= sweep_threshold:
#                        avg_trace, avg_amp, amp_sign, peak_t = get_amplitude(qc_list)
#                        if amp_sign is '-':
#                            continue
#                        psp_fits = fit_psp(avg_trace, sign=amp_sign, yoffset=0, amp=avg_amp, method='leastsq', stacked = False,  fit_kws={})
#                        grand_response[type[0]]['decay'].append(psp_fits.best_values['decay_tau'])
    if len(grand_response[type[0]]['trace']) == 0:
        continue
    if len(grand_response[type[0]]['trace']) > 1:
        grand_trace = TraceList(grand_response[type[0]]['trace']).mean()
        grand_trace.t0 = 0
    else:
        grand_trace = grand_response[type[0]]['trace'][0]
    n_synapses = len(grand_response[type[0]]['trace'])
    trace_plot(grand_trace, color={'color': color, 'width': 2}, plot=synapse_plot[c, 0], x_range=[0, 27e-3],
               name=('%s, n = %d' % (connection_types[c], n_synapses)))
    synapse_plot[c, 0].hideAxis('bottom')
    # all_amps = np.hstack(np.asarray(grand_response[cre_type[0]]['fail_rate']))
    # y, x = np.histogram(all_amps, bins=np.linspace(0, 2e-3, 40))
    # synapse_plot[c, 1].plot(x, y, stepMode=True, fillLevel=0, brush='k')
    # synapse_plot[c, 1].setLabels(bottom=('Vm', 'V'))
    # synapse_plot[c, 1].setXRange(0, 2e-3)
    print ('%s kinetics n = %d' % (type[0], len(grand_response[type[0]]['latency'])))
    feature_list = (grand_response[type[0]]['amp'], grand_response[type[0]]['CV'], grand_response[type[0]]['latency'],
                    grand_response[type[0]]['rise'])
     label = pg.LabelItem('%d ms Recovery' % delta)
     label.setParentItem(rec_plot[t, c].vb)
     label.setPos(50, 0)
     summary_plot[c, 1].setTitle('Recovery')
 rec_plot[t, c].addLegend()
 [rec_plot[t, c].plot(ind.time_values, ind.data, pen=trace_color) for ind in rec_pass_qc[0]]
 [rec_plot[t, c].plot(rec.time_values, rec.data, pen=trace_color) for rec in rec_pass_qc[1]]
 rec_plot[t, c].plot(rec_ind_grand_trace.time_values, rec_ind_grand_trace.data, pen={'color': color, 'width': 2},
                     name=("n = %d" % n_synapses))
 rec_plot[t, c].plot(recovery_grand_trace.time_values, recovery_grand_trace.data, pen={'color': color, 'width': 2})
 rec_plot[t, c].setLabels(left=('Vm', 'V'))
 rec_plot[t, c].setLabels(bottom=('t', 's'))
 if deconv is True:
     rec_deconv = deconv_train(rec_pass_qc[:2])
     rec_deconv_grand = TraceList(rec_deconv[0]).mean()
     rec_ind_deconv_grand = TraceList(rec_deconv[1]).mean()
     #log_rec_plt.plot(rec_deconv_grand.time_values, rec_deconv_grand.data,
     #                    pen={'color': color, 'width': 2})
     rec_deconv_ind_grand2 = rec_ind_deconv_grand.copy(t0=delta + 0.2)
     #log_rec_plt.plot(rec_deconv_ind_grand2.time_values, rec_deconv_ind_grand2.data,
     #                    pen={'color': color, 'width': 2})
     [deconv_rec_plot[t, c].plot(ind.time_values, ind.data, pen=trace_color) for ind in rec_deconv[0]]
     [deconv_rec_plot[t, c].plot(rec.time_values, rec.data, pen=trace_color) for rec in rec_deconv[1]]
     deconv_rec_plot[t, c].plot(rec_ind_deconv_grand.time_values, rec_ind_deconv_grand.data,
                         pen={'color': color, 'width': 2}, name=("n = %d" % n_synapses))
     deconv_rec_plot[t, c].plot(rec_deconv_grand.time_values, rec_deconv_grand.data,
                         pen={'color': color, 'width': 2})
 summary_plot[c, 1].setLabels(left=('Norm Amp', ''))
 summary_plot[c, 1].setLabels(bottom=('Pulse Number', ''))
 f_color = pg.hsvColor(hue=hue, sat=float(t+0.5) / len(rec_t), val=1)
 summary_plot[c, 1].plot(rec_amp_grand/rec_amp_grand[0], name=('  %d ms' % delta), pen=f_color, symbol=symbols[t],
def train_response_plot(expt_list,
                        name=None,
                        summary_plots=[None, None],
                        color=None):
    grand_train = [[], []]
    train_plots = pg.plot()
    train_plots.setLabels(left=('Vm', 'V'))
    tau = 15e-3
    lp = 1000
    for expt in expt_list:
        for pre, post in expt.connections:
            if expt.cells[pre].cre_type == cre_type[0] and expt.cells[
                    post].cre_type == cre_type[1]:
                print('Processing experiment: %s' % (expt.nwb_file))

                train_responses, artifact = get_response(expt,
                                                         pre,
                                                         post,
                                                         analysis_type='train')
                if artifact > 0.03e-3:
                    continue

                train_filter = response_filter(train_responses['responses'],
                                               freq_range=[50, 50],
                                               train=0,
                                               delta_t=250)
                pulse_offsets = response_filter(
                    train_responses['pulse_offsets'],
                    freq_range=[50, 50],
                    train=0,
                    delta_t=250)

                if len(train_filter[0]) > 5:
                    ind_avg = TraceList(train_filter[0]).mean()
                    rec_avg = TraceList(train_filter[1]).mean()
                    rec_avg.t0 = 0.3
                    grand_train[0].append(ind_avg)
                    grand_train[1].append(rec_avg)
                    train_plots.plot(ind_avg.time_values, ind_avg.data)
                    train_plots.plot(rec_avg.time_values, rec_avg.data)
                    app.processEvents()
    if len(grand_train[0]) != 0:
        print(name + ' n = %d' % len(grand_train[0]))
        ind_grand_mean = TraceList(grand_train[0]).mean()
        rec_grand_mean = TraceList(grand_train[1]).mean()
        ind_grand_mean_dec = bessel_filter(exp_deconvolve(ind_grand_mean, tau),
                                           lp)
        train_plots.addLegend()
        train_plots.plot(ind_grand_mean.time_values,
                         ind_grand_mean.data,
                         pen={
                             'color': 'g',
                             'width': 3
                         },
                         name=name)
        train_plots.plot(rec_grand_mean.time_values,
                         rec_grand_mean.data,
                         pen={
                             'color': 'g',
                             'width': 3
                         },
                         name=name)
        train_amps = train_amp([grand_train[0], grand_train[1]], pulse_offsets,
                               '+')
        if ind_grand_mean is not None:
            train_plots = summary_plot_train(ind_grand_mean,
                                             plot=summary_plots[0],
                                             color=color,
                                             name=(legend +
                                                   ' 50 Hz induction'))
            train_plots = summary_plot_train(rec_grand_mean,
                                             plot=summary_plots[0],
                                             color=color)
            train_plots2 = summary_plot_train(ind_grand_mean_dec,
                                              plot=summary_plots[1],
                                              color=color,
                                              name=(legend +
                                                    ' 50 Hz induction'))
            return train_plots, train_plots2, train_amps
    else:
        print("No Traces")
        return None
def analyze_pair_connectivity(amps, sign=None):
    """Given response strength records for a single pair, generate summary
    statistics characterizing strength, latency, and connectivity.
    
    Parameters
    ----------
    amps : dict
        Contains foreground and background strength analysis records
        (see input format below)
    sign : None, -1, or +1
        If None, then automatically determine whether to treat this connection as
        inhibitory or excitatory.

    Input must have the following structure::
    
        amps = {
            ('ic', 'fg'): recs, 
            ('ic', 'bg'): recs,
            ('vc', 'fg'): recs, 
            ('vc', 'bg'): recs,
        }
        
    Where each *recs* must be a structured array containing fields as returned
    by get_amps() and get_baseline_amps().
    
    The overall strategy here is:
    
    1. Make an initial decision on whether to treat this pair as excitatory or
       inhibitory, based on differences between foreground and background amplitude
       measurements
    2. Generate mean and stdev for amplitudes, deconvolved amplitudes, and deconvolved
       latencies
    3. Generate KS test p values describing the differences between foreground
       and background distributions for amplitude, deconvolved amplitude, and
       deconvolved latency    
    """
    requested_sign = sign
    fields = {}  # used to fill the new DB record
    
    # Use KS p value to check for differences between foreground and background
    qc_amps = {}
    ks_pvals = {}
    amp_means = {}
    amp_diffs = {}
    for clamp_mode in ('ic', 'vc'):
        clamp_mode_fg = amps[clamp_mode, 'fg']
        clamp_mode_bg = amps[clamp_mode, 'bg']
        if (len(clamp_mode_fg) == 0 or len(clamp_mode_bg) == 0):
            continue
        for sign in ('pos', 'neg'):
            # Separate into positive/negative tests and filter out responses that failed qc
            qc_field = {'vc': {'pos': 'in_qc_pass', 'neg': 'ex_qc_pass'}, 'ic': {'pos': 'ex_qc_pass', 'neg': 'in_qc_pass'}}[clamp_mode][sign]
            fg = clamp_mode_fg[clamp_mode_fg[qc_field]]
            bg = clamp_mode_bg[clamp_mode_bg[qc_field]]
            qc_amps[sign, clamp_mode, 'fg'] = fg
            qc_amps[sign, clamp_mode, 'bg'] = bg
            if (len(fg) == 0 or len(bg) == 0):
                continue
            
            # Measure some statistics from these records
            fg = fg[sign + '_dec_amp']
            bg = bg[sign + '_dec_amp']
            pval = scipy.stats.ks_2samp(fg, bg).pvalue
            ks_pvals[(sign, clamp_mode)] = pval
            # we could ensure that the average amplitude is in the right direction:
            fg_mean = np.mean(fg)
            bg_mean = np.mean(bg)
            amp_means[sign, clamp_mode] = {'fg': fg_mean, 'bg': bg_mean}
            amp_diffs[sign, clamp_mode] = fg_mean - bg_mean

    if requested_sign is None:
        # Decide whether to treat this connection as excitatory or inhibitory.
        #   strategy: accumulate evidence for either possibility by checking
        #   the ks p-values for each sign/clamp mode and the direction of the deflection
        is_exc = 0
        # print(expt.acq_timestamp, pair.pre_cell.ext_id, pair.post_cell.ext_id)
        for sign in ('pos', 'neg'):
            for mode in ('ic', 'vc'):
                ks = ks_pvals.get((sign, mode), None)
                if ks is None:
                    continue
                # turn p value into a reasonable scale factor
                ks = norm_pvalue(ks)
                dif_sign = 1 if amp_diffs[sign, mode] > 0 else -1
                if mode == 'vc':
                    dif_sign *= -1
                is_exc += dif_sign * ks
                # print("    ", sign, mode, is_exc, dif_sign * ks)
    else:
        is_exc = requested_sign

    if is_exc > 0:
        fields['synapse_type'] = 'ex'
        signs = {'ic':'pos', 'vc':'neg'}
    else:
        fields['synapse_type'] = 'in'
        signs = {'ic':'neg', 'vc':'pos'}

    # compute the rest of statistics for only positive or negative deflections
    for clamp_mode in ('ic', 'vc'):
        sign = signs[clamp_mode]
        fg = qc_amps.get((sign, clamp_mode, 'fg'))
        bg = qc_amps.get((sign, clamp_mode, 'bg'))
        if fg is None or bg is None or len(fg) == 0 or len(bg) == 0:
            fields[clamp_mode + '_n_samples'] = 0
            continue
        
        fields[clamp_mode + '_n_samples'] = len(fg)
        fields[clamp_mode + '_crosstalk_mean'] = np.mean(fg['crosstalk'])
        fields[clamp_mode + '_base_crosstalk_mean'] = np.mean(bg['crosstalk'])
        
        # measure mean, stdev, and statistical differences between
        # fg and bg for each measurement
        for val, field in [('amp', 'amp'), ('deconv_amp', 'dec_amp'), ('latency', 'dec_latency')]:
            f = fg[sign + '_' + field]
            b = bg[sign + '_' + field]
            fields[clamp_mode + '_' + val + '_mean'] = np.mean(f)
            fields[clamp_mode + '_' + val + '_stdev'] = np.std(f)
            fields[clamp_mode + '_base_' + val + '_mean'] = np.mean(b)
            fields[clamp_mode + '_base_' + val + '_stdev'] = np.std(b)
            # statistical tests comparing fg vs bg
            # Note: we use log(1-log(pval)) because it's nicer to plot and easier to
            # use as a classifier input
            tt_pval = scipy.stats.ttest_ind(f, b, equal_var=False).pvalue
            ks_pval = scipy.stats.ks_2samp(f, b).pvalue
            fields[clamp_mode + '_' + val + '_ttest'] = norm_pvalue(tt_pval)
            fields[clamp_mode + '_' + val + '_ks2samp'] = norm_pvalue(ks_pval)


        ### generate the average response and psp fit
        
        # collect all bg and fg traces
        # bg_traces = TraceList([Trace(data, sample_rate=db.default_sample_rate) for data in amps[clamp_mode, 'bg']['data']])
        fg_traces = TraceList()
        for rec in fg:
            t0 = rec['response_start_time'] - rec['max_dvdt_time']   # time-align to presynaptic spike
            trace = Trace(rec['data'], sample_rate=db.default_sample_rate, t0=t0)
            fg_traces.append(trace)
        
        # get averages
        # bg_avg = bg_traces.mean()        
        fg_avg = fg_traces.mean()
        base_rgn = fg_avg.time_slice(-6e-3, 0)
        base = float_mode(base_rgn.data)
        fields[clamp_mode + '_average_response'] = fg_avg.data
        fields[clamp_mode + '_average_response_t0'] = fg_avg.t0
        fields[clamp_mode + '_average_base_stdev'] = base_rgn.std()

        sign = {'pos':'+', 'neg':'-'}[signs[clamp_mode]]
        fg_bsub = fg_avg.copy(data=fg_avg.data - base)  # remove base to help fitting
        try:
            fit = fit_psp(fg_bsub, mode=clamp_mode, sign=sign, xoffset=(1e-3, 0, 6e-3), yoffset=(0, None, None), rise_time_mult_factor=4)              
            for param, val in fit.best_values.items():
                fields['%s_fit_%s' % (clamp_mode, param)] = val
            fields[clamp_mode + '_fit_yoffset'] = fit.best_values['yoffset'] + base
            fields[clamp_mode + '_fit_nrmse'] = fit.nrmse()
        except:
            print("Error in PSP fit:")
            sys.excepthook(*sys.exc_info())
            continue
        
        #global fit_plot
        #if fit_plot is None:
            #fit_plot = FitExplorer(fit)
            #fit_plot.show()
        #else:
            #fit_plot.set_fit(fit)
        #raw_input("Waiting to continue..")

    return fields
        join patch_clamp_recording post_pcrec on post_pcrec.recording_id=post_rec.id
        join multi_patch_probe on multi_patch_probe.patch_clamp_recording_id=post_pcrec.id
        join recording pre_rec on stim_pulse.recording_id=pre_rec.id
        join sync_rec on post_rec.sync_rec_id=sync_rec.id
        join experiment on sync_rec.experiment_id=experiment.id
    where
        {conditions}
""".format(conditions='\n        and '.join(conditions))

print(query)

rp = session.execute(query)

recs = rp.fetchall()
data = [np.load(io.BytesIO(rec[0])) for rec in recs]
print("\n\nloaded %d records" % len(data))



plt = pg.plot(labels={'left': ('Vm', 'V')})
traces = TraceList()
for i,x in enumerate(data):
    trace = Trace(x - np.median(x[:100]), sample_rate=20000)
    traces.append(trace)
    if i<100:
        plt.plot(trace.time_values, trace.data, pen=(255, 255, 255, 100))

avg = traces.mean()
plt.plot(avg.time_values, avg.data, pen='g')

Exemple #26
0
             pre_base = float_mode(pre_spike.data[:int(10e-3 /
                                                       pre_spike.dt)])
             sweep_list['response'].append(
                 sweep_trace.copy(data=sweep_trace.data - post_base))
             sweep_list['spike'].append(
                 pre_spike.copy(data=pre_spike.data - pre_base))
 if len(sweep_list['response']) > 5:
     n = len(sweep_list['response'])
     if plot_sweeps is True:
         for sweep in range(n):
             current_sweep = sweep_list['response'][sweep]
             current_sweep.t0 = 0
             grid[row[1], 0].plot(current_sweep.time_values,
                                  current_sweep.data,
                                  pen=sweep_color)
     avg_first_pulse = TraceList(sweep_list['response']).mean()
     avg_first_pulse.t0 = 0
     avg_spike = TraceList(sweep_list['spike']).mean()
     avg_spike.t0 = 0
     grid[row[1], 0].setLabels(left=('Vm', 'V'))
     grid[row[1], 0].setLabels(bottom=('t', 's'))
     grid[row[1], 0].setXRange(-2e-3, 27e-3)
     grid[row[1], 0].plot(avg_first_pulse.time_values,
                          avg_first_pulse.data,
                          pen={
                              'color': (255, 0, 255),
                              'width': 2
                          })
     grid[row[0], 0].setLabels(left=('Vm', 'V'))
     sweep_list['spike'][0].t0 = 0
     grid[row[0], 0].plot(avg_spike.time_values, avg_spike.data, pen='k')
                                   trace_color,
                                   plot=synapse_plot[c, 0],
                                   x_range=[0, 27e-3])
                        app.processEvents()
#                    decay_response = response_filter(pulse_response, freq_range=[0, 20], holding_range=holding)
#                    qc_list = pulse_qc(response_subset, baseline=2, pulse=None, plot=qc_plot)
#                    if len(qc_list) >= sweep_threshold:
#                        avg_trace, avg_amp, amp_sign, peak_t = get_amplitude(qc_list)
#                        if amp_sign is '-':
#                            continue
#                        psp_fits = fit_psp(avg_trace, sign=amp_sign, yoffset=0, amp=avg_amp, method='leastsq', stacked = False,  fit_kws={})
#                        grand_response[conn_type[0]]['decay'].append(psp_fits.best_values['decay_tau'])
    if len(grand_response[conn_type[0]]['trace']) == 0:
        continue
    if len(grand_response[conn_type[0]]['trace']) > 1:
        grand_trace = TraceList(grand_response[conn_type[0]]['trace']).mean()
    else:
        grand_trace = grand_response[conn_type[0]]['trace'][0]
    n_synapses = len(grand_response[conn_type[0]]['trace'])
    trace_plot(grand_trace,
               color={
                   'color': color,
                   'width': 2
               },
               plot=synapse_plot[c, 0],
               x_range=[0, 27e-3],
               name=('%s, n = %d' % (connection_types[c], n_synapses)))
    synapse_plot[c, 0].hideAxis('bottom')
    # all_amps = np.hstack(np.asarray(grand_response[cre_type[0]]['fail_rate']))
    # y, x = np.histogram(all_amps, bins=np.linspace(0, 2e-3, 40))
    # synapse_plot[c, 1].plot(x, y, stepMode=True, fillLevel=0, brush='k')
    def add_connection_plots(i, name, timestamp, pre_id, post_id):
        global session, win, filtered
        p = pg.debug.Profiler(disabled=True, delayed=False)
        trace_plot = win.addPlot(i, 1)
        trace_plots.append(trace_plot)
        trace_plot.setYRange(-1.4e-3, 2.1e-3)
        # deconv_plot = win.addPlot(i, 2)
        # deconv_plots.append(deconv_plot)
        # deconv_plot.hide()

        hist_plot = win.addPlot(i, 2)
        hist_plots.append(hist_plot)
        limit_plot = win.addPlot(i, 3)
        limit_plot.addLegend()
        limit_plot.setLogMode(True, False)
        limit_plot.addLine(y=classifier.prob_threshold)

        # Find this connection in the pair list
        idx = np.argwhere((abs(filtered['acq_timestamp'] - timestamp) < 1)
                          & (filtered['pre_cell_id'] == pre_id)
                          & (filtered['post_cell_id'] == post_id))
        if idx.size == 0:
            print("not in filtered connections")
            return
        idx = idx[0, 0]
        p()

        # Mark the point in scatter plot
        scatter_plot.plot([background[idx]], [signal[idx]],
                          pen='k',
                          symbol='o',
                          size=10,
                          symbolBrush='r',
                          symbolPen=None)

        # Plot example traces and histograms
        for plts in [trace_plots]:  #, deconv_plots]:
            plt = plts[-1]
            plt.setXLink(plts[0])
            plt.setYLink(plts[0])
            plt.setXRange(-10e-3, 17e-3, padding=0)
            plt.hideAxis('left')
            plt.hideAxis('bottom')
            plt.addLine(x=0)
            plt.setDownsampling(auto=True, mode='peak')
            plt.setClipToView(True)
            hbar = pg.QtGui.QGraphicsLineItem(0, 0, 2e-3, 0)
            hbar.setPen(pg.mkPen(color='k', width=5))
            plt.addItem(hbar)
            vbar = pg.QtGui.QGraphicsLineItem(0, 0, 0, 100e-6)
            vbar.setPen(pg.mkPen(color='k', width=5))
            plt.addItem(vbar)

        hist_plot.setXLink(hist_plots[0])

        pair = session.query(
            db.Pair).filter(db.Pair.id == filtered[idx]['pair_id']).all()[0]
        p()
        amps = strength_analysis.get_amps(session, pair)
        p()
        base_amps = strength_analysis.get_baseline_amps(session,
                                                        pair,
                                                        amps=amps,
                                                        clamp_mode='ic')
        p()

        q = strength_analysis.response_query(session)
        p()
        q = q.join(strength_analysis.PulseResponseStrength)
        q = q.filter(strength_analysis.PulseResponseStrength.id.in_(
            amps['id']))
        q = q.join(db.MultiPatchProbe)
        q = q.filter(db.MultiPatchProbe.induction_frequency < 100)
        # pre_cell = db.aliased(db.Cell)
        # post_cell = db.aliased(db.Cell)
        # q = q.join(db.Pair).join(db.Experiment).join(pre_cell, db.Pair.pre_cell_id==pre_cell.id).join(post_cell, db.Pair.post_cell_id==post_cell.id)
        # q = q.filter(db.Experiment.id==filtered[idx]['experiment_id'])
        # q = q.filter(pre_cell.ext_id==pre_id)
        # q = q.filter(post_cell.ext_id==post_id)

        fg_recs = q.all()
        p()

        traces = []
        deconvs = []
        for rec in fg_recs[:100]:
            result = strength_analysis.analyze_response_strength(
                rec,
                source='pulse_response',
                lpf=True,
                lowpass=2000,
                remove_artifacts=False,
                bsub=True)
            trace = result['raw_trace']
            trace.t0 = -result['spike_time']
            trace = trace - np.median(trace.time_slice(-0.5e-3, 0.5e-3).data)
            traces.append(trace)
            trace_plot.plot(trace.time_values, trace.data, pen=(0, 0, 0, 20))

            trace = result['dec_trace']
            trace.t0 = -result['spike_time']
            trace = trace - np.median(trace.time_slice(-0.5e-3, 0.5e-3).data)
            deconvs.append(trace)
            # deconv_plot.plot(trace.time_values, trace.data, pen=(0, 0, 0, 20))

        # plot average trace
        mean = TraceList(traces).mean()
        trace_plot.plot(mean.time_values,
                        mean.data,
                        pen={
                            'color': 'g',
                            'width': 2
                        },
                        shadowPen={
                            'color': 'k',
                            'width': 3
                        },
                        antialias=True)
        mean = TraceList(deconvs).mean()
        # deconv_plot.plot(mean.time_values, mean.data, pen={'color':'g', 'width': 2}, shadowPen={'color':'k', 'width': 3}, antialias=True)

        # add label
        label = pg.LabelItem(name)
        label.setParentItem(trace_plot)

        p("analyze_response_strength")

        # bins = np.arange(-0.0005, 0.002, 0.0001)
        # field = 'pos_amp'
        bins = np.arange(-0.001, 0.015, 0.0005)
        field = 'pos_dec_amp'
        n = min(len(amps), len(base_amps))
        hist_y, hist_bins = np.histogram(base_amps[:n][field], bins=bins)
        hist_plot.plot(hist_bins,
                       hist_y,
                       stepMode=True,
                       pen=None,
                       brush=(200, 0, 0, 150),
                       fillLevel=0)
        hist_y, hist_bins = np.histogram(amps[:n][field], bins=bins)
        hist_plot.plot(hist_bins,
                       hist_y,
                       stepMode=True,
                       pen='k',
                       brush=(0, 150, 150, 100),
                       fillLevel=0)
        p()

        pg.QtGui.QApplication.processEvents()

        # Plot detectability analysis
        q = strength_analysis.baseline_query(session)
        q = q.join(strength_analysis.BaselineResponseStrength)
        q = q.filter(
            strength_analysis.BaselineResponseStrength.id.in_(base_amps['id']))
        # q = q.limit(100)
        bg_recs = q.all()

        def clicked(sp, pts):
            data = pts[0].data()
            print("-----------------------\nclicked:", data['rise_time'],
                  data['amp'], data['prediction'], data['confidence'])
            for r in data['results']:
                print({k: r[k] for k in classifier.features})
            traces = data['traces']
            plt = pg.plot()
            bsub = [
                t.copy(data=t.data - np.median(t.time_slice(0, 1e-3).data))
                for t in traces
            ]
            for t in bsub:
                plt.plot(t.time_values, t.data, pen=(0, 0, 0, 50))
            mean = TraceList(bsub).mean()
            plt.plot(mean.time_values, mean.data, pen='g')

        # def analyze_response_strength(recs, source, dtype):
        #     results = []
        #     for i,rec in enumerate(recs):
        #         result = strength_analysis.analyze_response_strength(rec, source)
        #         results.append(result)
        #     return str_analysis_result_table(results)

        # measure background connection strength
        bg_results = [
            strength_analysis.analyze_response_strength(rec, 'baseline')
            for rec in bg_recs
        ]
        bg_results = strength_analysis.str_analysis_result_table(
            bg_results, bg_recs)

        # for this example, we use background data to simulate foreground
        # (but this will be biased due to lack of crosstalk in background data)
        fg_recs = bg_recs

        # now measure foreground simulated under different conditions
        amps = 2e-6 * 2**np.arange(9)
        amps[0] = 0
        rtimes = [1e-3, 2e-3, 4e-3, 6e-3]
        dt = 1 / db.default_sample_rate
        results = np.empty((len(amps), len(rtimes)),
                           dtype=[('results', object), ('predictions', object),
                                  ('confidence', object), ('traces', object),
                                  ('rise_time', float), ('amp', float)])
        print("  Simulating synaptic events..")

        cachefile = 'fig_3_cache.pkl'
        if os.path.exists(cachefile):
            cache = pickle.load(open(cachefile, 'rb'))
        else:
            cache = {}
        pair_key = (timestamp, pre_id, post_id)
        pair_cache = cache.setdefault(pair_key, {})

        for j, rtime in enumerate(rtimes):
            new_results = False
            for i, amp in enumerate(amps):
                print(
                    "---------------------------------------    %d/%d  %d/%d      \r"
                    % (i, len(amps), j, len(rtimes)), )
                result = pair_cache.get((rtime, amp))
                if result is None:
                    result = strength_analysis.simulate_connection(
                        fg_recs, bg_results, classifier, amp, rtime)
                    pair_cache[rtime, amp] = result
                    new_results = True

                for k, v in result.items():
                    results[i, j][k] = v

            c = limit_plot.plot(
                amps, [np.mean(x) for x in results[:, j]['confidence']],
                pen=pg.intColor(j, len(rtimes) * 1.3, maxValue=150),
                symbol='o',
                antialias=True,
                name="%dus" % (rtime * 1e6),
                data=results[:, j],
                symbolSize=4)
            c.scatter.sigClicked.connect(clicked)
            pg.QtGui.QApplication.processEvents()

            if new_results:
                pickle.dump(cache, open(cachefile, 'wb'))

        pg.QtGui.QApplication.processEvents()
 
 #analyzer = MultiPatchExperimentAnalyzer(expt.data)
 #pulses = analyzer.get_evoked_responses(pre_id, post_id, clamp_mode='ic', pulse_ids=[0])
 
 analyzer = DynamicsAnalyzer(expt, pre_id, post_id, align_to='spike')
 
 # collect all first pulse responses
 responses = analyzer.amp_group
 
 # collect all events
 #responses = analyzer.all_events
 
 n_responses = len(responses)
 
 # do exponential deconvolution on all responses
 deconv = TraceList()
 grid1 = PlotGrid()
 grid1.set_shape(2, 1)
 for i in range(n_responses):
     r = responses.responses[i]
     grid1[0, 0].plot(r.time_values, r.data)
     
     filt = bessel_filter(r - np.median(r.time_slice(0, 10e-3).data), 300.)
     responses.responses[i] = filt
     
     dec = exp_deconvolve(r, 15e-3)
     baseline = np.median(dec.data[:100])
     r2 = bessel_filter(dec-baseline, 300.)
     grid1[1, 0].plot(r2.time_values, r2.data)
     
     deconv.append(r2)
def first_pulse_plot(expt_list,
                     name=None,
                     summary_plot=None,
                     color=None,
                     scatter=0,
                     features=False):
    amp_plots = pg.plot()
    amp_plots.setLabels(left=('Vm', 'V'))
    grand_response = []
    avg_amps = {'amp': [], 'latency': [], 'rise': []}
    for expt in expt_list:
        if expt.connections is not None:
            for pre, post in expt.connections:
                if expt.cells[pre].cre_type == cre_type[0] and expt.cells[
                        post].cre_type == cre_type[1]:
                    all_responses, artifact = get_response(
                        expt, pre, post, analysis_type='pulse')
                    if artifact > 0.03e-3:
                        continue
                    filtered_responses = response_filter(
                        all_responses,
                        freq_range=[0, 50],
                        holding_range=[-68, -72],
                        pulse=True)
                    n_sweeps = len(filtered_responses)
                    if n_sweeps >= 10:
                        avg_trace, avg_amp, amp_sign, _ = get_amplitude(
                            filtered_responses)
                        if expt.cells[
                                pre].cre_type in EXCITATORY_CRE_TYPES and avg_amp < 0:
                            continue
                        elif expt.cells[
                                pre].cre_type in INHIBITORY_CRE_TYPES and avg_amp > 0:
                            continue
                        avg_trace.t0 = 0
                        avg_amps['amp'].append(avg_amp)
                        grand_response.append(avg_trace)
                        if features is True:
                            psp_fits = fit_psp(avg_trace,
                                               sign=amp_sign,
                                               yoffset=0,
                                               amp=avg_amp,
                                               method='leastsq',
                                               fit_kws={})
                            avg_amps['latency'].append(
                                psp_fits.best_values['xoffset'] - 10e-3)
                            avg_amps['rise'].append(
                                psp_fits.best_values['rise_time'])

                        current_connection_HS = post, pre
                        if len(expt.connections) > 1 and args.recip is True:
                            for i, x in enumerate(expt.connections):
                                if x == current_connection_HS:  # determine if a reciprocal connection
                                    amp_plots.plot(avg_trace.time_values,
                                                   avg_trace.data,
                                                   pen={
                                                       'color': 'r',
                                                       'width': 1
                                                   })
                                    break
                                elif x != current_connection_HS and i == len(
                                        expt.connections
                                ) - 1:  # reciprocal connection was not found
                                    amp_plots.plot(avg_trace.time_values,
                                                   avg_trace.data)
                        else:
                            amp_plots.plot(avg_trace.time_values,
                                           avg_trace.data)

                        app.processEvents()

    if len(grand_response) != 0:
        print(name + ' n = %d' % len(grand_response))
        grand_mean = TraceList(grand_response).mean()
        grand_amp = np.mean(np.array(avg_amps['amp']))
        grand_amp_sem = stats.sem(np.array(avg_amps['amp']))
        amp_plots.addLegend()
        amp_plots.plot(grand_mean.time_values,
                       grand_mean.data,
                       pen={
                           'color': 'g',
                           'width': 3
                       },
                       name=name)
        amp_plots.addLine(y=grand_amp, pen={'color': 'g'})
        if grand_mean is not None:
            print(legend + ' Grand mean amplitude = %f +- %f' %
                  (grand_amp, grand_amp_sem))
            if features is True:
                feature_list = (avg_amps['amp'], avg_amps['latency'],
                                avg_amps['rise'])
                labels = (['Vm', 'V'], ['t', 's'], ['t', 's'])
                titles = ('Amplitude', 'Latency', 'Rise time')
            else:
                feature_list = [avg_amps['amp']]
                labels = (['Vm', 'V'])
                titles = 'Amplitude'
            summary_plots = summary_plot_pulse(feature_list[0],
                                               labels=labels,
                                               titles=titles,
                                               i=scatter,
                                               grand_trace=grand_mean,
                                               plot=summary_plot,
                                               color=color,
                                               name=legend)
            return avg_amps, summary_plots
    else:
        print("No Traces")
        return avg_amps, None
Exemple #31
0
        pulse_response
        join stim_pulse on pulse_response.pulse_id=stim_pulse.id
        join recording post_rec on pulse_response.recording_id=post_rec.id
        join patch_clamp_recording post_pcrec on post_pcrec.recording_id=post_rec.id
        join multi_patch_probe on multi_patch_probe.patch_clamp_recording_id=post_pcrec.id
        join recording pre_rec on stim_pulse.recording_id=pre_rec.id
        join sync_rec on post_rec.sync_rec_id=sync_rec.id
        join experiment on sync_rec.experiment_id=experiment.id
    where
        {conditions}
""".format(conditions='\n        and '.join(conditions))

print(query)

rp = session.execute(query)

recs = rp.fetchall()
data = [np.load(io.BytesIO(rec[0])) for rec in recs]
print("\n\nloaded %d records" % len(data))

plt = pg.plot(labels={'left': ('Vm', 'V')})
traces = TraceList()
for i, x in enumerate(data):
    trace = Trace(x - np.median(x[:100]), sample_rate=20000)
    traces.append(trace)
    if i < 100:
        plt.plot(trace.time_values, trace.data, pen=(255, 255, 255, 100))

avg = traces.mean()
plt.plot(avg.time_values, avg.data, pen='g')
Exemple #32
0
 def mean(self):
     if len(self) == 0:
         return None
     return TraceList(self.responses).mean()
    def plot_prd_ids(self, ids, source, pen=None, trace_list=None, avg=False):
        """Plot raw or decolvolved PulseResponse data, given IDs of records in
        a PulseResponseStrength table.
        """
        with pg.BusyCursor():
            if source == 'fg':
                q = response_query(self.session)
                q = q.join(PulseResponseStrength)
                q = q.filter(PulseResponseStrength.id.in_(ids))
                traces = self.selected_fg_traces
                plot = self.fg_trace_plot
            else:
                q = baseline_query(self.session)
                q = q.join(BaselineResponseStrength)
                q = q.filter(BaselineResponseStrength.id.in_(ids))
                traces = self.selected_bg_traces
                plot = self.bg_trace_plot
            recs = q.all()
            if len(recs) == 0:
                return
            
            for i in trace_list[:]:
                plot.removeItem(i)
                trace_list.remove(i)
                
            if pen is None:
                alpha = np.clip(1000 / len(recs), 30, 255)
                pen = (255, 255, 255, alpha)
                
            traces = []
            spike_times = []
            spike_values = []
            for rec in recs:
                s = {'fg': 'pulse_response', 'bg': 'baseline'}[source]
                result = analyze_response_strength(rec, source=s, lpf=self.lpf_check.isChecked(), 
                                                   remove_artifacts=self.ar_check.isChecked(), bsub=self.bsub_check.isChecked())

                if self.deconv_check.isChecked():
                    trace = result['dec_trace']
                else:
                    trace = result['raw_trace']
                    if self.bsub_check.isChecked():
                        trace = trace - np.median(trace.time_slice(0, 9e-3).data)
                    if self.lpf_check.isChecked():
                        trace = filter.bessel_filter(trace, 500)
                
                spike_values.append(trace.value_at([result['spike_time']])[0])
                if self.align_check.isChecked():
                    trace.t0 = -result['spike_time']
                    spike_times.append(0)
                else:
                    spike_times.append(result['spike_time'])

                traces.append(trace)
                trace_list.append(plot.plot(trace.time_values, trace.data, pen=pen))

            if avg:
                mean = TraceList(traces).mean()
                trace_list.append(plot.plot(mean.time_values, mean.data, pen='g'))
                trace_list[-1].setZValue(10)

            spike_scatter = pg.ScatterPlotItem(spike_times, spike_values, size=4, pen=None, brush=(200, 200, 0))
            spike_scatter.setZValue(-100)
            plot.addItem(spike_scatter)
            trace_list.append(spike_scatter)
    def add_connection_plots(i, name, timestamp, pre_id, post_id):
        global session, win, filtered
        p = pg.debug.Profiler(disabled=True, delayed=False)
        trace_plot = win.addPlot(i, 1)
        trace_plots.append(trace_plot)
        deconv_plot = win.addPlot(i, 2)
        deconv_plots.append(deconv_plot)
        hist_plot = win.addPlot(i, 3)
        hist_plots.append(hist_plot)
        limit_plot = win.addPlot(i, 4)
        limit_plot.addLegend()
        limit_plot.setLogMode(True, True)
        # Find this connection in the pair list
        idx = np.argwhere((abs(filtered['acq_timestamp'] - timestamp) < 1) & (filtered['pre_cell_id'] == pre_id) & (filtered['post_cell_id'] == post_id))
        if idx.size == 0:
            print("not in filtered connections")
            return
        idx = idx[0,0]
        p()

        # Mark the point in scatter plot
        scatter_plot.plot([background[idx]], [signal[idx]], pen='k', symbol='o', size=10, symbolBrush='r', symbolPen=None)
            
        # Plot example traces and histograms
        for plts in [trace_plots, deconv_plots]:
            plt = plts[-1]
            plt.setXLink(plts[0])
            plt.setYLink(plts[0])
            plt.setXRange(-10e-3, 17e-3, padding=0)
            plt.hideAxis('left')
            plt.hideAxis('bottom')
            plt.addLine(x=0)
            plt.setDownsampling(auto=True, mode='peak')
            plt.setClipToView(True)
            hbar = pg.QtGui.QGraphicsLineItem(0, 0, 2e-3, 0)
            hbar.setPen(pg.mkPen(color='k', width=5))
            plt.addItem(hbar)
            vbar = pg.QtGui.QGraphicsLineItem(0, 0, 0, 100e-6)
            vbar.setPen(pg.mkPen(color='k', width=5))
            plt.addItem(vbar)


        hist_plot.setXLink(hist_plots[0])
        
        pair = session.query(db.Pair).filter(db.Pair.id==filtered[idx]['pair_id']).all()[0]
        p()
        amps = strength_analysis.get_amps(session, pair)
        p()
        base_amps = strength_analysis.get_baseline_amps(session, pair)
        p()
        
        q = strength_analysis.response_query(session)
        p()
        q = q.join(strength_analysis.PulseResponseStrength)
        q = q.filter(strength_analysis.PulseResponseStrength.id.in_(amps['id']))
        q = q.join(db.Recording, db.Recording.id==db.PulseResponse.recording_id).join(db.PatchClampRecording).join(db.MultiPatchProbe)
        q = q.filter(db.MultiPatchProbe.induction_frequency < 100)
        # pre_cell = db.aliased(db.Cell)
        # post_cell = db.aliased(db.Cell)
        # q = q.join(db.Pair).join(db.Experiment).join(pre_cell, db.Pair.pre_cell_id==pre_cell.id).join(post_cell, db.Pair.post_cell_id==post_cell.id)
        # q = q.filter(db.Experiment.id==filtered[idx]['experiment_id'])
        # q = q.filter(pre_cell.ext_id==pre_id)
        # q = q.filter(post_cell.ext_id==post_id)

        fg_recs = q.all()
        p()

        traces = []
        deconvs = []
        for rec in fg_recs[:100]:
            result = strength_analysis.analyze_response_strength(rec, source='pulse_response', lpf=True, lowpass=2000,
                                                remove_artifacts=False, bsub=True)
            trace = result['raw_trace']
            trace.t0 = -result['spike_time']
            trace = trace - np.median(trace.time_slice(-0.5e-3, 0.5e-3).data)
            traces.append(trace)            
            trace_plot.plot(trace.time_values, trace.data, pen=(0, 0, 0, 20))

            trace = result['dec_trace']
            trace.t0 = -result['spike_time']
            trace = trace - np.median(trace.time_slice(-0.5e-3, 0.5e-3).data)
            deconvs.append(trace)            
            deconv_plot.plot(trace.time_values, trace.data, pen=(0, 0, 0, 20))

        # plot average trace
        mean = TraceList(traces).mean()
        trace_plot.plot(mean.time_values, mean.data, pen={'color':'g', 'width': 2}, shadowPen={'color':'k', 'width': 3}, antialias=True)
        mean = TraceList(deconvs).mean()
        deconv_plot.plot(mean.time_values, mean.data, pen={'color':'g', 'width': 2}, shadowPen={'color':'k', 'width': 3}, antialias=True)

        # add label
        label = pg.LabelItem(name)
        label.setParentItem(trace_plot)


        p("analyze_response_strength")

        # bins = np.arange(-0.0005, 0.002, 0.0001) 
        # field = 'pos_amp'
        bins = np.arange(-0.001, 0.015, 0.0005) 
        field = 'pos_dec_amp'
        n = min(len(amps), len(base_amps))
        hist_y, hist_bins = np.histogram(base_amps[:n][field], bins=bins)
        hist_plot.plot(hist_bins, hist_y, stepMode=True, pen=None, brush=(200, 0, 0, 150), fillLevel=0)
        hist_y, hist_bins = np.histogram(amps[:n][field], bins=bins)
        hist_plot.plot(hist_bins, hist_y, stepMode=True, pen='k', brush=(0, 150, 150, 100), fillLevel=0)
        p()

        pg.QtGui.QApplication.processEvents()


        # Plot detectability analysis
        q = strength_analysis.baseline_query(session)
        q = q.join(strength_analysis.BaselineResponseStrength)
        q = q.filter(strength_analysis.BaselineResponseStrength.id.in_(base_amps['id']))
        # q = q.limit(100)
        bg_recs = q.all()

        def clicked(sp, pts):
            traces = pts[0].data()['traces']
            print([t.amp for t in traces])
            plt = pg.plot()
            bsub = [t.copy(data=t.data - np.median(t.time_slice(0, 1e-3).data)) for t in traces]
            for t in bsub:
                plt.plot(t.time_values, t.data, pen=(0, 0, 0, 50))
            mean = TraceList(bsub).mean()
            plt.plot(mean.time_values, mean.data, pen='g')

        # first measure background a few times
        N = len(fg_recs)
        N = 50  # temporary for testing
        print("Testing %d trials" % N)


        bg_results = []
        M = 500
        print("  Grinding on %d background trials" % len(bg_recs))
        for i in range(M):
            amps = base_amps.copy()
            np.random.shuffle(amps)
            bg_results.append(np.median(amps[:N]['pos_dec_amp']) / np.std(amps[:N]['pos_dec_latency']))
            print("    %d/%d      \r" % (i, M),)
        print("    done.            ")
        print("    ", bg_results)


        # now measure foreground simulated under different conditions
        amps = 5e-6 * 2**np.arange(6)
        amps[0] = 0
        rtimes = 1e-3 * 1.71**np.arange(4)
        dt = 1 / db.default_sample_rate
        results = np.empty((len(amps), len(rtimes)), dtype=[('pos_dec_amp', float), ('latency_stdev', float), ('result', float), ('percentile', float), ('traces', object)])
        print("  Simulating synaptic events..")
        for j,rtime in enumerate(rtimes):
            for i,amp in enumerate(amps):
                trial_results = []
                t = np.arange(0, 15e-3, dt)
                template = Psp.psp_func(t, xoffset=0, yoffset=0, rise_time=rtime, decay_tau=15e-3, amp=1, rise_power=2)

                for l in range(20):
                    print("    %d/%d  %d/%d      \r" % (i,len(amps),j,len(rtimes)),)
                    r_amps = amp * 2**np.random.normal(size=N, scale=0.5)
                    r_latency = np.random.normal(size=N, scale=600e-6, loc=12.5e-3)
                    fg_results = []
                    traces = []
                    np.random.shuffle(bg_recs)
                    for k,rec in enumerate(bg_recs[:N]):
                        data = rec.data.copy()
                        start = int(r_latency[k] / dt)
                        length = len(rec.data) - start
                        rec.data[start:] += template[:length] * r_amps[k]

                        fg_result = strength_analysis.analyze_response_strength(rec, 'baseline')
                        fg_results.append((fg_result['pos_dec_amp'], fg_result['pos_dec_latency']))

                        traces.append(Trace(rec.data.copy(), dt=dt))
                        traces[-1].amp = r_amps[k]
                        rec.data[:] = data  # can't modify rec, so we have to muck with the array (and clean up afterward) instead
                    
                    fg_amp = np.array([r[0] for r in fg_results])
                    fg_latency = np.array([r[1] for r in fg_results])
                    trial_results.append(np.median(fg_amp) / np.std(fg_latency))
                results[i,j]['result'] = np.median(trial_results) / np.median(bg_results)
                results[i,j]['percentile'] = stats.percentileofscore(bg_results, results[i,j]['result'])
                results[i,j]['traces'] = traces

            assert all(np.isfinite(results[i]['pos_dec_amp']))
            print(i, results[i]['result'])
            print(i, results[i]['percentile'])
            

            # c = limit_plot.plot(rtimes, results[i]['result'], pen=(i, len(amps)*1.3), symbol='o', antialias=True, name="%duV"%(amp*1e6), data=results[i], symbolSize=4)
            # c.scatter.sigClicked.connect(clicked)
            # pg.QtGui.QApplication.processEvents()
            c = limit_plot.plot(amps, results[:,j]['result'], pen=(j, len(rtimes)*1.3), symbol='o', antialias=True, name="%dus"%(rtime*1e6), data=results[:,j], symbolSize=4)
            c.scatter.sigClicked.connect(clicked)
            pg.QtGui.QApplication.processEvents()

                
        pg.QtGui.QApplication.processEvents()
Exemple #35
0
 
 #analyzer = MultiPatchExperimentAnalyzer(expt.data)
 #pulses = analyzer.get_evoked_responses(pre_id, post_id, clamp_mode='ic', pulse_ids=[0])
 
 analyzer = DynamicsAnalyzer(expt, pre_id, post_id, align_to='spike')
 
 # collect all first pulse responses
 responses = analyzer.amp_group
 
 # collect all events
 #responses = analyzer.all_events
 
 n_responses = len(responses)
 
 # do exponential deconvolution on all responses
 deconv = TraceList()
 grid1 = PlotGrid()
 grid1.set_shape(2, 1)
 for i in range(n_responses):
     r = responses.responses[i]
     grid1[0, 0].plot(r.time_values, r.data)
     
     filt = bessel_filter(r - np.median(r.time_slice(0, 10e-3).data), 300.)
     responses.responses[i] = filt
     
     dec = exp_deconvolve(r, 15e-3)
     baseline = np.median(dec.data[:100])
     r2 = bessel_filter(dec-baseline, 300.)
     grid1[1, 0].plot(r2.time_values, r2.data)
     
     deconv.append(r2)