def plot_traces(self, pulse_responses):
        for i, holding in enumerate(pulse_responses.keys()):
            for qc, prs in pulse_responses[holding].items():
                if len(prs) == 0:
                    continue
                prl = PulseResponseList(prs)
                post_ts = prl.post_tseries(align='spike', bsub=True)

                for trace in post_ts:
                    item = self.trace_plots[i].plot(trace.time_values,
                                                    trace.data,
                                                    pen=self.qc_color[qc])
                    if qc == 'qc_fail':
                        item.setZValue(-10)
                    self.items.append(item)
                if qc == 'qc_pass':
                    grand_trace = post_ts.mean()
                    item = self.trace_plots[i].plot(grand_trace.time_values,
                                                    grand_trace.data,
                                                    pen={
                                                        'color': 'b',
                                                        'width': 2
                                                    })
                    self.items.append(item)
            self.trace_plots[i].autoRange()
            self.trace_plots[i].setXRange(-5e-3, 10e-3)
示例#2
0
    def plot_spikes(self, pulse_responses):
        for i, holding in enumerate(pulse_responses.keys()):
            for prs in pulse_responses[holding].values():
                if len(prs) == 0:
                    continue
                prl = PulseResponseList(prs)
                for mode in ['spike', 'peak', 'stim']:
                    try:
                        pre_ts = prl.pre_tseries(
                            align=mode,
                            bsub=True,
                            alignment_failure_mode='average')
                        break
                    except Exception:
                        pre_ts = None
                        continue
                if pre_ts is None:
                    continue
                for pr, spike in zip(prl, pre_ts):
                    # pr.stim_pulse.n_spikes can == 1 but the spike time (ie max slope) is None, failing
                    # the postsynaptic responses. Consider using pr.stim_pulse.first_spike_time != None
                    # and qc failing these spikes as the traces are as well.
                    qc = 'qc_pass' if pr.stim_pulse.first_spike_time is not None else 'qc_fail'

                    item = self.spike_plots[i].plot(spike.time_values,
                                                    spike.data,
                                                    pen=self.qc_color[qc])
                    if qc == 'qc_fail':
                        item.setZValue(-10)
                    self.items.append(item)
示例#3
0
def fit_avg_pulse_response(pulse_response_list, latency_window, sign, init_params=None, ui=None):
    """Generate PSP fit parameters for a list of pulse responses, possibly correcting
    for crosstalk artifacts and gap junctional current during the presynaptic stimulus.
    
    Parameters
    ----------
    pulse_response_list : list
        A list of PulseResponse instances to be time-aligned, averaged, and fit.
    latency_window : (float, float)
        Beginning and end times of a window over which to search for a synaptic response,
        relative to the spike time.
    sign : int
        +1, -1, or 0 indicating the expected sign of the response (see neuroanalysis.fitting.fit_psp)
    
    Returns
    -------
    fit : lmfit ModelResult
        The resulting PSP fit
    average : TSeries
        The averaged pulse response data
    
    """
    # prof = Profiler(disabled=True, delayed=False)
    pair = pulse_response_list[0].pair
    clamp_mode = pulse_response_list[0].recording.patch_clamp_recording.clamp_mode

    # make a list of spike-aligned postsynaptic tseries
    tsl = PulseResponseList(pulse_response_list).post_tseries(align='spike', bsub=True)
    # prof('make tseries list')
    
    if len(tsl) == 0:
        return None, None
    
    # average all together
    average = tsl.mean()
    # prof('average')
        
    # start with even weighting
    weight = np.ones(len(average))
    
    # boost weight around PSP onset
    onset_start_idx = average.index_at(latency_window[0])
    onset_stop_idx = average.index_at(latency_window[1] + 4e-3) 
    weight[onset_start_idx:onset_stop_idx] = 3.0
    
    # decide whether to mask out crosstalk artifact
    pre_id = int(pair.pre_cell.electrode.ext_id)
    post_id = int(pair.post_cell.electrode.ext_id)
    if abs(pre_id - post_id) < 3:
        # nearby electrodes; mask out crosstalk
        pass
    # prof('weights')

    fit = fit_psp(average, search_window=latency_window, clamp_mode=clamp_mode, sign=sign, baseline_like_psp=True, init_params=init_params, fit_kws={'weights': weight})
    # prof('fit')
    
    return fit, average
示例#4
0
 def plot_spikes(self, pulse_responses):
     for i, holding in enumerate(pulse_responses.keys()):
         for prs in pulse_responses[holding].values():
             if len(prs) == 0:
                 continue
             prl = PulseResponseList(prs)
             pre_ts = prl.pre_tseries(align='spike', bsub=True)
             for pr, spike in zip(prl, pre_ts):
                 qc = 'qc_pass' if pr.stim_pulse.n_spikes == 1 else 'qc_fail'
                 item = self.spike_plots[i].plot(spike.time_values, spike.data, pen=self.qc_color[qc])
                 if qc == 'qc_fail':
                     item.setZValue(-10)
                 self.items.append(item)
    def show_pulse_responses(self, passed_prs, failed_prs):
        for prs, pen in [(failed_prs, (255, 0, 0, 40)), (passed_prs, (255, 255, 255, 40))]:
            if len(prs) == 0:
                continue

            prl = PulseResponseList(prs)

            post_ts = prl.post_tseries(align='spike', bsub=True)
            for ts in post_ts:
                item = self.response_plot.plot(ts.time_values, ts.data, pen=pen)
                self.items.append(item)

            pre_ts = prl.pre_tseries(align='spike', bsub=True)
            for ts in pre_ts:
                item = self.spike_plot.plot(ts.time_values, ts.data, pen=pen)
                self.items.append(item)

        self.response_plot.autoRange()
示例#6
0
def plot_metric_pairs(pair_list, metric, db, ax, align='pulse', norm_amp=None, perc=False, labels=None, max_ind_freq=50):
    pairs = [get_pair(eid, pre, post, db) for eid, pre, post in pair_list]
    _, metric_name, units, scale, _, cmap, cmap_log, clim, _ = get_metric_data(metric, db)
    cmap = matplotlib.cm.get_cmap(cmap)
    if cmap_log:
        norm = matplotlib.colors.LogNorm(vmin=clim[0], vmax=clim[1], clip=False)
    else:
        norm = matplotlib.colors.Normalize(vmin=clim[0], vmax=clim[1], clip=False)
    colors = [map_color_by_metric(pair, metric, cmap, norm, scale) for pair in pairs]
    for i, pair in enumerate(pairs):
        s = db.session()
        q= response_query(s, pair, max_ind_freq=max_ind_freq)
        prs = [q.PulseResponse for q in q.all()]
        sort_prs = sort_responses(prs)
        prs = sort_prs[('ic', -55)]['qc_pass']
        if pair.synapse.synapse_type=='ex':
            prs = prs + sort_prs[('ic', -70)]['qc_pass']
        if perc:
            prs_amp = [abs(pr.pulse_response_fit.fit_amp) for pr in prs if pr.pulse_response_fit is not None]
            amp_85, amp_95 = np.percentile(prs_amp, [85, 95])
            mask = (prs_amp >= amp_85) & (prs_amp <= amp_95)
            prs = np.asarray(prs)[mask]
        prl = PulseResponseList(prs)
        post_ts = prl.post_tseries(align='spike', bsub=True, bsub_win=0.1e-3)
        trace = post_ts.mean()*scale
        if norm_amp=='exc':
            
            trace = post_ts.mean()/pair.synapse.psp_amplitude
        if norm_amp=='inh':
            trace = post_ts.mean()/pair.synapse.psp_amplitude*-1
        latency = pair.synapse.latency
        if align=='pulse':
            trace.t0 = trace.t0 - latency
        label = labels[i] if labels is not None else None
        ax.plot(trace.time_values*scale, trace.data, color=colors[i], linewidth=2, label=label)
    ax.set_xlim(-2, 10)
    ax.spines['right'].set_visible(False)
    ax.spines['top'].set_visible(False)
    if labels is not None:
        ax.legend(loc='upper left', bbox_to_anchor=(1, 1))
示例#7
0
    def plot_traces(self, pulse_responses):
        for i, holding in enumerate(pulse_responses.keys()):
            for qc, prs in pulse_responses[holding].items():
                if len(prs) == 0:
                    continue
                prl = PulseResponseList(prs)

                for mode in ['spike', 'peak', 'stim']:
                    try:
                        post_ts = prl.post_tseries(
                            align=mode,
                            bsub=True,
                            alignment_failure_mode='average')
                        break
                    except Exception:
                        post_ts = None
                        continue
                if post_ts is None:
                    continue
                for trace in post_ts:
                    item = self.trace_plots[i].plot(trace.time_values,
                                                    trace.data,
                                                    pen=self.qc_color[qc])
                    if qc == 'qc_fail':
                        item.setZValue(-10)
                    self.items.append(item)
                if qc == 'qc_pass':
                    grand_trace = post_ts.mean()
                    item = self.trace_plots[i].plot(grand_trace.time_values,
                                                    grand_trace.data,
                                                    pen={
                                                        'color': 'b',
                                                        'width': 2
                                                    })
                    self.items.append(item)
            self.trace_plots[i].autoRange()
            self.trace_plots[i].setXRange(-5e-3, 10e-3)
示例#8
0
    def load_saved_fit(self, record):
        data = record.notes        
        pair_params = {'Synapse call': data['synapse_type'], 'Gap junction call': data['gap_junction']}
        self.ctrl_panel.update_user_params(**pair_params)
        self.warnings = data.get('fit_warnings', [])
        self.ctrl_panel.output_params.child('Warnings').setValue('\n'.join(self.warnings))
        self.ctrl_panel.output_params.child('Comments', '').setValue(data.get('comments', ''))

        # some records may be stored with no fit if a synapse is not present.
        if 'fit_parameters' not in data:
            return
            
        self.fit_params = data['fit_parameters']
        self.ctrl_panel.update_fit_params(data['fit_parameters']['fit'])
        self.output_fit_parameters = data['fit_parameters']['fit']
        self.initial_fit_parameters = data['fit_parameters']['initial']

        initial_vc_latency = (
            data['fit_parameters']['initial']['vc']['-55'].get('xoffset') or
            data['fit_parameters']['initial']['vc']['-70'].get('xoffset') or
            1e-3
        )
        initial_ic_latency = (
            data['fit_parameters']['initial']['ic']['-55'].get('xoffset') or
            data['fit_parameters']['initial']['ic']['-70'].get('xoffset') or
            1e-3
        )

        latency_diff = np.diff([initial_vc_latency, initial_ic_latency])[0]
        if abs(latency_diff) < 100e-6:
            self.latency_superline.set_value(initial_vc_latency, block_fit=True)
        else:
            fit_pass_vc = [data['fit_pass']['vc'][str(h)] for h in holdings]
            fit_pass_ic = [data['fit_pass']['ic'][str(h)] for h in holdings]
            if any(fit_pass_vc):
                self.latency_superline.set_value(initial_vc_latency, block_fit=True)
            elif any(fit_pass_ic):
                self.latency_superline.set_value(initial_ic_latency, block_fit=True)
            else:
                self.latency_superline.set_value(initial_vc_latency, block_fit=True)

        for mode in modes:
            for holding in holdings:
                fit_pass = data['fit_pass'][mode][str(holding)]
                self.ctrl_panel.output_params.child('Fit parameters', str(holding) + ' ' + mode.upper(), 'Fit Pass').setValue(fit_pass)
                fit_params = copy.deepcopy(data['fit_parameters']['fit'][mode][str(holding)])
                
                if fit_params:
                    fit_params.pop('nrmse', None)
                    if mode == 'ic':
                        p = StackedPsp() 
                    if mode == 'vc':
                        p = Psp()
                        
                    # make a list of spike-aligned postsynaptic tseries
                    tsl = PulseResponseList(self.sorted_responses[mode, holding]['qc_pass']).post_tseries(align='spike', bsub=True)
                    if len(tsl) == 0:
                        continue
                    
                    # average all together
                    avg = tsl.mean()
                    
                    fit_params.setdefault('exp_tau', fit_params['decay_tau'])
                    fit_psp = p.eval(x=avg.time_values, **fit_params)
                    fit_tseries = avg.copy(data=fit_psp)
                    
                    if mode == 'vc':
                        self.vc_plot.plot_fit(fit_tseries, holding, fit_pass=fit_pass)
                    if mode == 'ic':
                        self.ic_plot.plot_fit(fit_tseries, holding, fit_pass=fit_pass)