Ejemplo n.º 1
0
def format_trace(trace, baseline_win, x_offset, align='spike'):
    # align can be to the pre-synaptic spike (default) or the onset of the PSP ('psp')
    baseline = float_mode(trace[0:baseline_win])
    trace = Trace(data=(trace-baseline), sample_rate=db.default_sample_rate)
    if align == 'psp':
        trace.t0 = -x_offset
    return trace
Ejemplo n.º 2
0
def measure_peak(trace,
                 sign,
                 spike_time,
                 pulse_times,
                 spike_delay=1e-3,
                 response_window=4e-3):
    # Start measuring response after the pulse has finished, and no earlier than 1 ms after spike onset
    # response_start = max(spike_time + spike_delay, pulse_times[1])

    # Start measuring after spike and hope that the pulse offset doesn't get in the way
    # (if we wait for the pulse to end, then we miss too many fast rise / short latency events)
    response_start = spike_time + spike_delay
    response_stop = response_start + response_window

    # measure baseline from beginning of data until 50µs before pulse onset
    baseline_start = 0
    baseline_stop = pulse_times[0] - 50e-6

    baseline = float_mode(trace.time_slice(baseline_start, baseline_stop).data)
    response = trace.time_slice(response_start, response_stop)

    if sign == '+':
        i = np.argmax(response.data)
    else:
        i = np.argmin(response.data)
    peak = response.data[i]
    latency = response.time_values[i] - spike_time
    return peak - baseline, latency
Ejemplo n.º 3
0
def format_trace(trace, baseline_win, x_offset=1e-3, align='spike'):
    # align can be to the pre-synaptic spike (default) or the onset of the PSP ('psp')
    baseline = float_mode(trace.time_slice(baseline_win[0],baseline_win[1]).data)
    trace = TSeries(data=(trace.data-baseline), sample_rate=db.default_sample_rate)
    if align == 'psp':
        trace.t0 = x_offset
    return trace
    def estimate_amplitude(self, plot=False):
        amp_group = self.amp_group
        amp_est = None
        amp_plot = None
        amp_sign = None
        avg_amp = None
        n_sweeps = len(amp_group)
        if n_sweeps == 0:
            return amp_est, amp_sign, avg_amp, amp_plot, n_sweeps
        # Generate average first response
        avg_amp = amp_group.bsub_mean()
        if plot:
            amp_plot = pg.plot(title='First pulse amplitude')
            amp_plot.plot(avg_amp.time_values, avg_amp.data)

        # Make initial amplitude estimate
        ad = avg_amp.data
        dt = avg_amp.dt
        base = float_mode(ad[:int(10e-3/dt)])
        neg = ad[int(13e-3/dt):].min() - base
        pos = ad[int(13e-3/dt):].max() - base
        amp_est = neg if abs(neg) > abs(pos) else pos
        if plot:
            amp_plot.addLine(y=base + amp_est)
        amp_sign = '-' if amp_est < 0 else '+'
        
        self._psp_estimate['amp'] = amp_est
        self._psp_estimate['amp_sign'] = amp_sign
        
        return amp_est, amp_sign, avg_amp, amp_plot, n_sweeps
Ejemplo n.º 5
0
    def _get_tserieslist(self, ts_name, align, bsub):
        tsl = []
        for pr in self.prs:
            ts = getattr(pr, ts_name)
            stim_time = pr.stim_pulse.onset_time

            if bsub is True:
                start_time = max(ts.t0, stim_time - 5e-3)
                baseline_data = ts.time_slice(start_time, stim_time).data
                if len(baseline_data) == 0:
                    baseline = ts.data[0]
                else:
                    baseline = float_mode(baseline_data)
                ts = ts - baseline

            if align is not None:
                if align == 'spike':
                    align_t = pr.stim_pulse.first_spike_time
                    # ignore PRs with no known spike time
                    if align_t is None:
                        continue
                elif align == 'pulse':
                    align_t = stim_time
                else:
                    raise ValueError(
                        "align must be None, 'spike', or 'pulse'.")
                ts = ts.copy(t0=ts.t0 - align_t)

            tsl.append(ts)
        return TSeriesList(tsl)
    def estimate_amplitude(self, plot=False):
        amp_group = self.amp_group
        amp_est = None
        amp_plot = None
        amp_sign = None
        avg_amp = None
        n_sweeps = len(amp_group)
        if n_sweeps == 0:
            return amp_est, amp_sign, avg_amp, amp_plot, n_sweeps
        # Generate average first response
        avg_amp = amp_group.bsub_mean()
        if plot:
            amp_plot = pg.plot(title='First pulse amplitude')
            amp_plot.plot(avg_amp.time_values, avg_amp.data)

        # Make initial amplitude estimate
        ad = avg_amp.data
        dt = avg_amp.dt
        base = float_mode(ad[:int(10e-3 / dt)])
        neg = ad[int(13e-3 / dt):].min() - base
        pos = ad[int(13e-3 / dt):].max() - base
        amp_est = neg if abs(neg) > abs(pos) else pos
        if plot:
            amp_plot.addLine(y=base + amp_est)
        amp_sign = '-' if amp_est < 0 else '+'

        self._psp_estimate['amp'] = amp_est
        self._psp_estimate['amp_sign'] = amp_sign

        return amp_est, amp_sign, avg_amp, amp_plot, n_sweeps
Ejemplo n.º 7
0
def format_trace(trace, baseline_win, x_offset, align='spike'):
    # align can be to the pre-synaptic spike (default) or the onset of the PSP ('psp')
    baseline = float_mode(trace[0:baseline_win])
    trace = Trace(data=(trace - baseline), sample_rate=db.default_sample_rate)
    if align == 'psp':
        trace.t0 = -x_offset
    return trace
def train_response_plot(expt_list, name=None, summary_plots=[None, None], color=None):
    ind_base_subtract = []
    rec_base_subtract = []
    train_plots = pg.plot()
    train_plots.setLabels(left=('Vm', 'V'))
    tau =15e-3
    lp = 1000
    for expt in expt_list:
        for pre, post in expt.connections:
            if expt.cells[pre].cre_type == cre_type[0] and expt.cells[post].cre_type == cre_type[1]:
                print ('Processing experiment: %s' % (expt.nwb_file))
                ind = []
                rec = []
                analyzer = DynamicsAnalyzer(expt, pre, post)
                train_responses = analyzer.train_responses
                artifact = analyzer.cross_talk()
                if artifact > 0.03e-3:
                    continue
                for i, stim_params in enumerate(train_responses.keys()):
                     rec_t = int(np.round(stim_params[1] * 1e3, -1))
                     if stim_params[0] == 50 and rec_t == 250:
                        pulse_offsets = analyzer.pulse_offsets
                        if len(train_responses[stim_params][0]) != 0:
                            ind_group = train_responses[stim_params][0]
                            rec_group = train_responses[stim_params][1]
                            for j in range(len(ind_group)):
                                ind.append(ind_group.responses[j])
                                rec.append(rec_group.responses[j])
                if len(ind) > 5:
                    ind_avg = TraceList(ind).mean()
                    rec_avg = TraceList(rec).mean()
                    rec_avg.t0 = 0.3
                    base = float_mode(ind_avg.data[:int(10e-3 / ind_avg.dt)])
                    ind_base_subtract.append(ind_avg.copy(data=ind_avg.data - base))
                    rec_base_subtract.append(rec_avg.copy(data=rec_avg.data - base))
                    train_plots.plot(ind_avg.time_values, ind_avg.data - base)
                    train_plots.plot(rec_avg.time_values, rec_avg.data - base)
                    app.processEvents()
    if len(ind_base_subtract) != 0:
        print (name + ' n = %d' % len(ind_base_subtract))
        ind_grand_mean = TraceList(ind_base_subtract).mean()
        rec_grand_mean = TraceList(rec_base_subtract).mean()
        ind_grand_mean_dec = bessel_filter(exp_deconvolve(ind_grand_mean, tau), lp)
        train_plots.addLegend()
        train_plots.plot(ind_grand_mean.time_values, ind_grand_mean.data, pen={'color': 'g', 'width': 3}, name=name)
        train_plots.plot(rec_grand_mean.time_values, rec_grand_mean.data, pen={'color': 'g', 'width': 3}, name=name)
        #train_plots.plot(ind_grand_mean_dec.time_values, ind_grand_mean_dec.data, pen={'color': 'g', 'dash': [1,5,3,2]})
        train_amps = train_amp([ind_base_subtract, rec_base_subtract], pulse_offsets, '+')
        if ind_grand_mean is not None:
            train_plots = summary_plot_train(ind_grand_mean, plot=summary_plots[0], color=color,
                                             name=(legend + ' 50 Hz induction'))
            train_plots = summary_plot_train(rec_grand_mean, plot=summary_plots[0], color=color)
            train_plots2 = summary_plot_train(ind_grand_mean_dec, plot=summary_plots[1], color=color,
                                              name=(legend + ' 50 Hz induction'))
            return train_plots, train_plots2, train_amps
    else:
        print ("No Traces")
        return None
Ejemplo n.º 9
0
def get_amplitude(response_list):
    avg_trace = TraceList(response_list).mean()
    data = avg_trace.data
    dt = avg_trace.dt
    base = float_mode(data[:int(10e-3 / dt)])
    bsub_mean = avg_trace.copy(data=data - base)
    neg = bsub_mean.data[int(13e-3 / dt):].min()
    pos = bsub_mean.data[int(13e-3 / dt):].max()
    avg_amp = neg if abs(neg) > abs(pos) else pos
    amp_sign = '-' if avg_amp < 0 else '+'
    peak_ind = list(bsub_mean.data).index(avg_amp)
    peak_t = bsub_mean.time_values[peak_ind]
    return bsub_mean, avg_amp, amp_sign, peak_t
Ejemplo n.º 10
0
def first_pulse_plot(expt_list, name=None, summary_plot=None, color=None, scatter=0):
    amp_plots = pg.plot()
    amp_plots.setLabels(left=('Vm', 'V'))
    amp_base_subtract = []
    avg_ests = []
    for expt in expt_list:
        for pre, post in expt.connections:
            if expt.cells[pre].cre_type == cre_type[0] and expt.cells[post].cre_type == cre_type[1]:
                avg_est, avg_amp, n_sweeps = responses(expt, pre, post)
                if expt.cells[pre].cre_type in EXCITATORY_CRE_TYPES and avg_est < 0:
                    continue
                elif expt.cells[pre].cre_type in INHIBITORY_CRE_TYPES and avg_est > 0:
                    continue
                if n_sweeps >= 10:
                    avg_amp.t0 = 0
                    avg_ests.append(avg_est)
                    base = float_mode(avg_amp.data[:int(10e-3 / avg_amp.dt)])
                    amp_base_subtract.append(avg_amp.copy(data=avg_amp.data - base))

                    current_connection_HS = post, pre
                    if len(expt.connections) > 1 and args.recip is True:
                        for i,x in enumerate(expt.connections):
                            if x == current_connection_HS:  # determine if a reciprocal connection
                                amp_plots.plot(avg_amp.time_values, avg_amp.data - base, pen={'color': 'r', 'width': 1})
                                break
                            elif x != current_connection_HS and i == len(expt.connections) - 1:  # reciprocal connection was not found
                                amp_plots.plot(avg_amp.time_values, avg_amp.data - base)
                    else:
                        amp_plots.plot(avg_amp.time_values, avg_amp.data - base)

                    app.processEvents()

    if len(amp_base_subtract) != 0:
        print(name + ' n = %d' % len(amp_base_subtract))
        grand_mean = TraceList(amp_base_subtract).mean()
        grand_est = np.mean(np.array(avg_ests))
        amp_plots.addLegend()
        amp_plots.plot(grand_mean.time_values, grand_mean.data, pen={'color': 'g', 'width': 3}, name=name)
        amp_plots.addLine(y=grand_est, pen={'color': 'g'})
        if grand_mean is not None:
            print(legend + ' Grand mean amplitude = %f' % grand_est)
            summary_plots = summary_plot_pulse(grand_mean, avg_ests, grand_est, labels=['Vm', 'V'], titles='Amplitude', i=scatter, plot=summary_plot, color=color, name=legend)
            return avg_ests, summary_plots
    else:
        print ("No Traces")
        return None, avg_ests, None, None
Ejemplo n.º 11
0
def first_pulse_features(pair, pulse_responses, pulse_response_amps):

    avg_psp = TSeriesList(pulse_responses).mean()
    dt = avg_psp.dt
    avg_psp_baseline = float_mode(avg_psp.data[:int(10e-3 / dt)])
    avg_psp_bsub = avg_psp.copy(data=avg_psp.data - avg_psp_baseline)
    lower_bound = -float('inf')
    upper_bound = float('inf')
    xoffset = pair.synapse_prediction.ic_fit_xoffset
    if xoffset is None:
        xoffset = 14 * 10e-3
    synapse_type = pair.synapse_prediction.synapse_type
    if synapse_type == 'ex':
        amp_sign = '+'
    elif synapse_type == 'in':
        amp_sign = '-'
    else:
        raise Exception(
            'Synapse type is not defined, reconsider fitting this pair %s %d->%d'
            % (pair.expt_id, pair.pre_cell_id, pair.post_cell_id))

    weight = np.ones(len(
        avg_psp.data)) * 10.  # set everything to ten initially
    weight[int(10e-3 / dt):int(12e-3 / dt)] = 0.  # area around stim artifact
    weight[int(12e-3 / dt):int(19e-3 / dt)] = 30.  # area around steep PSP rise

    psp_fits = fit_psp(avg_psp,
                       xoffset=(xoffset, lower_bound, upper_bound),
                       yoffset=(avg_psp_baseline, lower_bound, upper_bound),
                       sign=amp_sign,
                       weight=weight)

    amp_cv = np.std(pulse_response_amps) / np.mean(pulse_response_amps)

    features = {
        'ic_fit_amp': psp_fits.best_values['amp'],
        'ic_fit_latency': psp_fits.best_values['xoffset'] - 10e-3,
        'ic_fit_rise_time': psp_fits.best_values['rise_time'],
        'ic_fit_decay_tau': psp_fits.best_values['decay_tau'],
        'ic_amp_cv': amp_cv,
        'avg_psp': avg_psp_bsub.data
    }
    #'ic_fit_NRMSE': psp_fits.nrmse()} TODO: nrmse not returned from psp_fits?

    return features
def bsub(trace):
    """Returns a copy of the neuroanalysis.data.Trace object 
    where the ephys data waveform is replaced with a baseline 
    subtracted ephys data waveform.  
    
    Parameters
    ----------
    trace : neuroanalysis.data.Trace object  
        
    Returns
    -------
    bsub_trace : neuroanalysis.data.Trace object
       Ephys data waveform is replaced with a baseline subtracted ephys data waveform
    """
    data = trace.data # actual numpy array of time series ephys waveform
    dt = trace.dt # time step of the data
    base = float_mode(data[:int(10e-3 / dt)]) # baseline value for trace 
    bsub_trace = trace.copy(data=data - base) # new neuroanalysis.data.Trace object for baseline subtracted data
    return bsub_trace
def first_pulse_features(pair, pulse_responses, pulse_response_amps):

    avg_psp = TraceList(pulse_responses).mean()
    dt = avg_psp.dt
    avg_psp_baseline = float_mode(avg_psp.data[:int(10e-3/dt)])
    avg_psp_bsub = avg_psp.copy(data=avg_psp.data - avg_psp_baseline)
    lower_bound = -float('inf')
    upper_bound = float('inf')
    xoffset = pair.connection_strength.ic_fit_xoffset
    if xoffset is None:
        xoffset = 14*10e-3
    synapse_type = pair.connection_strength.synapse_type
    if synapse_type == 'ex':
        amp_sign = '+'
    elif synapse_type == 'in':
        amp_sign = '-'
    else:
        raise Exception('Synapse type is not defined, reconsider fitting this pair %s %d->%d' %
                        (pair.expt_id, pair.pre_cell_id, pair.post_cell_id))

    weight = np.ones(len(avg_psp.data)) * 10.  # set everything to ten initially
    weight[int(10e-3 / dt):int(12e-3 / dt)] = 0.  # area around stim artifact
    weight[int(12e-3 / dt):int(19e-3 / dt)] = 30.  # area around steep PSP rise

    psp_fits = fit_psp(avg_psp,
                       xoffset=(xoffset, lower_bound, upper_bound),
                       yoffset=(avg_psp_baseline, lower_bound, upper_bound),
                       sign=amp_sign,
                       weight=weight)

    amp_cv = np.std(pulse_response_amps)/np.mean(pulse_response_amps)

    features = {'ic_fit_amp': psp_fits.best_values['amp'],
                'ic_fit_latency': psp_fits.best_values['xoffset'] - 10e-3,
                'ic_fit_rise_time': psp_fits.best_values['rise_time'],
                'ic_fit_decay_tau': psp_fits.best_values['decay_tau'],
                'ic_amp_cv': amp_cv,
                'avg_psp': avg_psp_bsub.data}
                #'ic_fit_NRMSE': psp_fits.nrmse()} TODO: nrmse not returned from psp_fits?

    return features
Ejemplo n.º 14
0
 def fit_psp(self, data, t, dt, clamp_mode):
     mode = float_mode(data[:int(1e-3/dt)])
     sign = -1 if data.mean() - mode < 0 else 1
     params = OrderedDict([
         ('xoffset', (2e-3, 3e-4, 5e-3)),
         ('yoffset', data[0]),
         ('amp', sign * 10e-12),
         #('k', (2e-3, 50e-6, 10e-3)),
         ('rise_time', (2e-3, 50e-6, 10e-3)),
         ('decay_tau', (4e-3, 500e-6, 50e-3)),
         ('rise_power', (2.0, 'fixed')),
     ])
     if clamp_mode == 'ic':
         params['amp'] = sign * 10e-3
         #params['k'] = (5e-3, 50e-6, 20e-3)
         params['rise_time'] = (5e-3, 50e-6, 20e-3)
         params['decay_tau'] = (15e-3, 500e-6, 150e-3)
     
     fit_kws = {'xtol': 1e-3, 'maxfev': 100}
     
     psp = fitting.Psp()
     return psp.fit(data, x=t, fit_kws=fit_kws, params=params)
Ejemplo n.º 15
0
    def fit_psp(self, data, t, dt, clamp_mode):
        mode = float_mode(data[:int(1e-3 / dt)])
        sign = -1 if data.mean() - mode < 0 else 1
        params = OrderedDict([
            ('xoffset', (2e-3, 3e-4, 5e-3)),
            ('yoffset', data[0]),
            ('amp', sign * 10e-12),
            #('k', (2e-3, 50e-6, 10e-3)),
            ('rise_time', (2e-3, 50e-6, 10e-3)),
            ('decay_tau', (4e-3, 500e-6, 50e-3)),
            ('rise_power', (2.0, 'fixed')),
        ])
        if clamp_mode == 'ic':
            params['amp'] = sign * 10e-3
            #params['k'] = (5e-3, 50e-6, 20e-3)
            params['rise_time'] = (5e-3, 50e-6, 20e-3)
            params['decay_tau'] = (15e-3, 500e-6, 150e-3)

        fit_kws = {'xtol': 1e-3, 'maxfev': 100}

        psp = fitting.Psp()
        return psp.fit(data, x=t, fit_kws=fit_kws, params=params)
def measure_peak(trace, sign, spike_time, pulse_times, spike_delay=1e-3, response_window=4e-3):
    # Start measuring response after the pulse has finished, and no earlier than 1 ms after spike onset
    # response_start = max(spike_time + spike_delay, pulse_times[1])

    # Start measuring after spike and hope that the pulse offset doesn't get in the way
    # (if we wait for the pulse to end, then we miss too many fast rise / short latency events)
    response_start = spike_time + spike_delay
    response_stop = response_start + response_window

    # measure baseline from beginning of data until 50µs before pulse onset
    baseline_start = 0
    baseline_stop = pulse_times[0] - 50e-6

    baseline = float_mode(trace.time_slice(baseline_start, baseline_stop).data)
    response = trace.time_slice(response_start, response_stop)

    if sign == '+':
        i = np.argmax(response.data)
    else:
        i = np.argmin(response.data)
    peak = response.data[i]
    latency = response.time_values[i] - spike_time
    return peak - baseline, latency
Ejemplo n.º 17
0
    def _load_nwb(self, session, expt_entry, elecs_by_ad_channel, pairs_by_device_id):
        nwb = self.expt.data
        
        for srec in nwb.contents:
            temp = srec.meta.get('temperature', None)
            srec_entry = db.SyncRec(ext_id=srec.key, experiment=expt_entry, temperature=temp)
            session.add(srec_entry)
            
            srec_has_mp_probes = False
            
            rec_entries = {}
            all_pulse_entries = {}
            for rec in srec.recordings:
                
                # import all recordings
                rec_entry = db.Recording(
                    sync_rec=srec_entry,
                    electrode=elecs_by_ad_channel[rec.device_id],  # should probably just skip if this causes KeyError?
                    start_time=rec.start_time,
                )
                session.add(rec_entry)
                rec_entries[rec.device_id] = rec_entry
                
                # import patch clamp recording information
                if not isinstance(rec, PatchClampRecording):
                    continue
                qc_pass = qc.recording_qc_pass(rec)
                pcrec_entry = db.PatchClampRecording(
                    recording=rec_entry,
                    clamp_mode=rec.clamp_mode,
                    patch_mode=rec.patch_mode,
                    stim_name=rec.meta['stim_name'],
                    baseline_potential=rec.baseline_potential,
                    baseline_current=rec.baseline_current,
                    baseline_rms_noise=rec.baseline_rms_noise,
                    qc_pass=qc_pass,
                )
                session.add(pcrec_entry)

                # import test pulse information
                tp = rec.nearest_test_pulse
                if tp is not None:
                    tp_entry = db.TestPulse(
                        start_index=tp.indices[0],
                        stop_index=tp.indices[1],
                        baseline_current=tp.baseline_current,
                        baseline_potential=tp.baseline_potential,
                        access_resistance=tp.access_resistance,
                        input_resistance=tp.input_resistance,
                        capacitance=tp.capacitance,
                        time_constant=tp.time_constant,
                    )
                    session.add(tp_entry)
                    pcrec_entry.nearest_test_pulse = tp_entry
                    
                # import information about STP protocol
                if not isinstance(rec, MultiPatchProbe):
                    continue
                srec_has_mp_probes = True
                psa = PulseStimAnalyzer.get(rec)
                ind_freq, rec_delay = psa.stim_params()
                mprec_entry = db.MultiPatchProbe(
                    patch_clamp_recording=pcrec_entry,
                    induction_frequency=ind_freq,
                    recovery_delay=rec_delay,
                )
                session.add(mprec_entry)
            
                # import presynaptic stim pulses
                pulses = psa.pulses()
                
                pulse_entries = {}
                all_pulse_entries[rec.device_id] = pulse_entries
                
                rec_tvals = rec['primary'].time_values

                for i,pulse in enumerate(pulses):
                    # Record information about all pulses, including test pulse.
                    t0 = rec_tvals[pulse[0]]
                    t1 = rec_tvals[pulse[1]]
                    data_start = max(0, t0 - 10e-3)
                    data_stop = t0 + 10e-3
                    pulse_entry = db.StimPulse(
                        recording=rec_entry,
                        pulse_number=i,
                        onset_time=t0,
                        amplitude=pulse[2],
                        duration=t1-t0,
                        data=rec['primary'].time_slice(data_start, data_stop).resample(sample_rate=20000).data,
                        data_start_time=data_start,
                    )
                    session.add(pulse_entry)
                    pulse_entries[i] = pulse_entry
                    

                # import presynaptic evoked spikes
                # For now, we only detect up to 1 spike per pulse, but eventually
                # this may be adapted for more.
                spikes = psa.evoked_spikes()
                for i,sp in enumerate(spikes):
                    pulse = pulse_entries[sp['pulse_n']]
                    if sp['spike'] is not None:
                        spinfo = sp['spike']
                        extra = {
                            'peak_time': rec_tvals[spinfo['peak_index']],
                            'max_dvdt_time': rec_tvals[spinfo['rise_index']],
                            'max_dvdt': spinfo['max_dvdt'],
                        }
                        if 'peak_diff' in spinfo:
                            extra['peak_diff'] = spinfo['peak_diff']
                        if 'peak_value' in spinfo:
                            extra['peak_value'] = spinfo['peak_value']
                        
                        pulse.n_spikes = 1
                    else:
                        extra = {}
                        pulse.n_spikes = 0
                    
                    spike_entry = db.StimSpike(
                        pulse=pulse,
                        **extra
                    )
                    session.add(spike_entry)
                    pulse.first_spike = spike_entry
            
            if not srec_has_mp_probes:
                continue
            
            # import postsynaptic responses
            mpa = MultiPatchSyncRecAnalyzer(srec)
            for pre_dev in srec.devices:
                for post_dev in srec.devices:
                    if pre_dev == post_dev:
                        continue

                    # get all responses, regardless of the presence of a spike
                    responses = mpa.get_spike_responses(srec[pre_dev], srec[post_dev], align_to='pulse', require_spike=False)
                    post_tvals = srec[post_dev]['primary'].time_values
                    for resp in responses:
                        # base_entry = db.Baseline(
                        #     recording=rec_entries[post_dev],
                        #     start_index=resp['baseline_start'],
                        #     stop_index=resp['baseline_stop'],
                        #     data=resp['baseline'].resample(sample_rate=20000).data,
                        #     mode=float_mode(resp['baseline'].data),
                        # )
                        # session.add(base_entry)
                        pair_entry = pairs_by_device_id[(pre_dev, post_dev)]
                        if resp['ex_qc_pass']:
                            pair_entry.n_ex_test_spikes += 1
                        if resp['in_qc_pass']:
                            pair_entry.n_in_test_spikes += 1
                            
                        resp_entry = db.PulseResponse(
                            recording=rec_entries[post_dev],
                            stim_pulse=all_pulse_entries[pre_dev][resp['pulse_n']],
                            pair=pair_entry,
                            # baseline=base_entry,
                            start_time=post_tvals[resp['rec_start']],
                            data=resp['response'].resample(sample_rate=20000).data,
                            ex_qc_pass=resp['ex_qc_pass'],
                            in_qc_pass=resp['in_qc_pass'],
                        )
                        session.add(resp_entry)
                        
            # generate up to 20 baseline snippets for each recording
            for dev in srec.devices:
                rec = srec[dev]
                rec_tvals = rec['primary'].time_values
                dist = BaselineDistributor.get(rec)
                for i in range(20):
                    base = dist.get_baseline_chunk(20e-3)
                    if base is None:
                        # all out!
                        break
                    start, stop = base
                    data = rec['primary'][start:stop].resample(sample_rate=20000).data

                    ex_qc_pass = qc.pulse_response_qc_pass(+1, rec, [start, stop], None)
                    in_qc_pass = qc.pulse_response_qc_pass(-1, rec, [start, stop], None)

                    base_entry = db.Baseline(
                        recording=rec_entries[dev],
                        start_time=rec_tvals[start],
                        data=data,
                        mode=float_mode(data),
                        ex_qc_pass=ex_qc_pass,
                        in_qc_pass=in_qc_pass,
                    )
                    session.add(base_entry)
Ejemplo n.º 18
0
    def create_db_entries(cls, job, session):
        db = job['database']
        job_id = job['job_id']

        # Load experiment from DB
        expt_entry = db.experiment_from_ext_id(job_id, session=session)
        elecs_by_ad_channel = {elec.device_id:elec for elec in expt_entry.electrodes}
        cell_entries = expt_entry.cells ## do this once here instead of multiple times later because it's slooooowwwww
        pairs_by_cell_id = expt_entry.pairs

        # load NWB file
        path = os.path.join(config.synphys_data, expt_entry.storage_path)
        expt = AI_Experiment(loader=OptoExperimentLoader(site_path=path))
        nwb = expt.data
        stim_log = expt.loader.load_stimulation_log()
        if stim_log['version'] < 3:
            ## gonna need to load an image in order to calculate spiral size later
            from acq4.util.DataManager import getHandle

        last_stim_pulse_time = {}
        # Load all data from NWB into DB
        for srec in nwb.contents:
            temp = srec.meta.get('temperature', None)
            srec_entry = db.SyncRec(ext_id=srec.key, experiment=expt_entry, temperature=temp)
            session.add(srec_entry)

            rec_entries = {}
            all_pulse_entries = {}
            for rec in srec.recordings:
                
                # import all recordings
                electrode_entry = elecs_by_ad_channel.get(rec.device_id, None)

                rec_entry = db.Recording(
                    sync_rec=srec_entry,
                    electrode=electrode_entry,
                    start_time=rec.start_time,
                    device_name=str(rec.device_id)
                )
                session.add(rec_entry)
                rec_entries[rec.device_id] = rec_entry

                # import patch clamp recording information
                if isinstance(rec, PatchClampRecording):
                    qc_pass, qc_failures = qc.recording_qc_pass(rec)
                    pcrec_entry = db.PatchClampRecording(
                        recording=rec_entry,
                        clamp_mode=rec.clamp_mode,
                        patch_mode=rec.patch_mode,
                        stim_name=rec.stimulus.description,
                        baseline_potential=rec.baseline_potential,
                        baseline_current=rec.baseline_current,
                        baseline_rms_noise=rec.baseline_rms_noise,
                        qc_pass=qc_pass,
                        meta=None if len(qc_failures) == 0 else {'qc_failures': qc_failures},
                    )
                    session.add(pcrec_entry)

                    # import test pulse information
                    tp = rec.nearest_test_pulse
                    if tp is not None:
                        indices = tp.indices or [None, None]
                        tp_entry = db.TestPulse(
                            electrode=electrode_entry,
                            recording=rec_entry,
                            start_index=indices[0],
                            stop_index=indices[1],
                            baseline_current=tp.baseline_current,
                            baseline_potential=tp.baseline_potential,
                            access_resistance=tp.access_resistance,
                            input_resistance=tp.input_resistance,
                            capacitance=tp.capacitance,
                            time_constant=tp.time_constant,
                        )
                        session.add(tp_entry)
                        pcrec_entry.nearest_test_pulse = tp_entry

                    psa = PatchClampStimPulseAnalyzer.get(rec)
                    pulses = psa.pulse_chunks()
                    pulse_entries = {}
                    all_pulse_entries[rec.device_id] = pulse_entries
                    cell_entry = electrode_entry.cell

                    for i,pulse in enumerate(pulses):
                        # Record information about all pulses, including test pulse.
                        t0, t1 = pulse.meta['pulse_edges']
                        resampled = pulse['primary'].resample(sample_rate=20000)
                        
                        clock_time = t0 + datetime_to_timestamp(rec_entry.start_time)
                        prev_pulse_dt = clock_time - last_stim_pulse_time.get(cell_entry.ext_id, -np.inf)
                        last_stim_pulse_time[cell_entry.ext_id] = clock_time

                        pulse_entry = db.StimPulse(
                            recording=rec_entry,
                            pulse_number=pulse.meta['pulse_n'],
                            onset_time=t0,
                            amplitude=pulse.meta['pulse_amplitude'],
                            duration=t1-t0,
                            data=resampled.data,
                            data_start_time=resampled.t0,
                            #cell=electrode_entry.cell if electrode_entry is not None else None,
                            cell=cell_entry,
                            #device_name=str(rec.device_id),
                            previous_pulse_dt=prev_pulse_dt
                        )
                        session.add(pulse_entry)
                        pulse_entries[pulse.meta['pulse_n']] = pulse_entry


                #elif isinstance(rec, OptoRecording) and (rec.device_name=='Fidelity'): 
                elif rec.device_type == 'Fidelity':
                    ## This is a 2p stimulation

                    ## get cell entry
                    stim_num = rec.meta['notebook']['USER_stim_num']
                    if stim_num is None: ### this is a trace that would have been labeled as 'unknown'
                        continue
                    stim = stim_log[str(int(stim_num))]
                    cell_entry = cell_entries[stim['stimulationPoint']['name']]

                    ## get stimulation shape parameters
                    if stim_log['version'] >=3:
                        shape={'spiral_revolutions':stim['shape']['spiral revolutions'], 'spiral_size':stim['shape']['size']}
                    else:
                        ## need to calculate spiral size from reference image, cause stimlog is from before we were saving spiral size
                        shape={'spiral_revolutions':stim.get('prairieCmds', {}).get('spiralRevolutions')}
                        prairie_size = stim['prairieCmds']['spiralSize']
                        ref_image = os.path.join(expt.path, stim['prairieImage'][-23:])
                        if os.path.exists(ref_image):
                            h = getHandle(ref_image)
                            xPixels = h.info()['PrairieMetaInfo']['Environment']['PixelsPerLine']
                            pixelLength = h.info()['PrairieMetaInfo']['Environment']['XAxis_umPerPixel']
                            size = prairie_size * pixelLength * xPixels * 1e-6
                            shape['spiral_size'] = size
                        else:
                            shape['spiral_size'] = None

                    ## calculate offset_distance
                    offset = stim.get('offset')
                    if offset is not None:
                        offset_distance = (offset[0]**2 + offset[1]**2 + offset[2]**2)**0.5
                    else:
                        offset_distance = None

                    pulse_entries = {}
                    all_pulse_entries[rec.device_id] = pulse_entries

                    ospa = GenericStimPulseAnalyzer.get(rec)

                    for i, pulse in enumerate(ospa.pulses(channel='reporter')):
                        ### pulse is (start, stop, amplitude)
                    # Record information about all pulses, including test pulse.
                        #t0, t1 = pulse.meta['pulse_edges']
                        #resampled = pulse['reporter'].resample(sample_rate=20000)

                        t0, t1 = pulse[0], pulse[1]
                        
                        clock_time = t0 + datetime_to_timestamp(rec_entry.start_time)
                        prev_pulse_dt = clock_time - last_stim_pulse_time.get(cell_entry.ext_id, -np.inf)
                        last_stim_pulse_time[cell_entry.ext_id] = clock_time
                        pulse_entry = db.StimPulse(
                            recording=rec_entry,
                            cell=cell_entry,
                            pulse_number=i, #pulse.meta['pulse_n'],
                            onset_time=pulse[0],#rec.pulse_start_times[i], #t0,
                            amplitude=power_cal.convert_voltage_to_power(pulse[2], timestamp_to_datetime(expt_entry.acq_timestamp), expt_entry.rig_name), ## need to fill in laser/objective correctly
                            duration=pulse[1]-pulse[0],#rec.pulse_duration()[i],
                            previous_pulse_dt=prev_pulse_dt,
                            #data=resampled.data,
                            #data_start_time=resampled.t0,
                            #wavelength,
                            #light_source,
                            position=stim['stimPos'],
                            #position_offset=stim['offset'],
                            #device_name=rec.device_id,
                            #qc_pass=None
                            meta = {'shape': shape,
                                    'pockel_cmd':stim.get('prairieCmds',{}).get('laserPower', [None]*100)[i],
                                    'pockel_voltage': float(pulse[2]),#rec.pulse_power()[i],
                                    'position_offset':offset,
                                    'offset_distance':offset_distance,
                                    'wavelength': 1070e-9
                                    } # TODO: put in light_source and wavelength
                            )
                        qc_pass, qc_failures = qc.opto_stim_pulse_qc_pass(pulse_entry)
                        pulse_entry.qc_pass = qc_pass
                        if not qc_pass:
                            pulse_entry.meta['qc_failures'] = qc_failures

                        session.add(pulse_entry)
                        pulse_entries[i] = pulse_entry


                elif 'LED' in rec.device_type:
                    #if rec.device_id == 'TTL1P_0': ## this is the ttl output to Prairie, not an LED stimulation
                    #    continue

                    ### This is an LED stimulation
                    #if rec.device_id in ['TTL1_1', 'TTL1P_1']:
                    #    lightsource = 'LED-470nm'
                    #elif rec.device_id in ['TTL1_2', 'TTL1P_2']:
                    #    lightsource = 'LED-590nm'
                    #else:
                    #    raise Exception("Don't know lightsource for device: %s" % rec.device_id)

                    pulse_entries = {}
                    all_pulse_entries[rec.device_id] = pulse_entries

                    spa = PWMStimPulseAnalyzer.get(rec)
                    pulses = spa.pulses(channel='reporter')
                    max_power=power_cal.get_led_power(timestamp_to_datetime(expt_entry.acq_timestamp), expt_entry.rig_name, rec.device_id)

                    for i, pulse in enumerate(pulses):
                        pulse_entry = db.StimPulse(
                            recording=rec_entry,
                            #cell=cell_entry, ## we're not stimulating just one cell here TODO: but maybe this should be a list of cells in the fov?
                            pulse_number=i,
                            onset_time=pulse.global_start_time,
                            amplitude=max_power*pulse.amplitude,
                            duration=pulse.duration,
                            #data=resampled.data, ## don't need data, it's just a square pulse
                            #data_start_time=resampled.t0,
                            #position=None, # don't have a 3D position, have a field
                            #device_name=rec.device_id,
                            meta = {'shape': 'wide-field', ## TODO: description of field of view
                                    'LED_voltage':str(pulse.amplitude),
                                    'light_source':rec.device_id,
                                    'pulse_width_modulation': spa.pwm_params(channel='reporter', pulse_n=i),
                                    #'position_offset':offset,
                                    #'offset_distance':offset_distance,
                                    } ## TODO: put in lightsource and wavelength
                            )
                        ## TODO: make qc function for LED stimuli
                        #qc_pass, qc_failures = qc.opto_stim_pulse_qc_pass(pulse_entry)
                        #pulse_entry.qc_pass = qc_pass
                        #if not qc_pass:
                        #    pulse_entry.meta['qc_failures'] = qc_failures

                        session.add(pulse_entry)
                        pulse_entries[i] = pulse_entry
                    
                elif rec.device_id == 'unknown': 
                    ## At the end of some .nwbs there are vc traces to check access resistance.
                    ## These have an AD6(fidelity) channel, but do not have an optical stimulation and
                    ## this channel is labeled unknown when it gets created in OptoRecording
                    pass

                elif rec.device_id== 'Prairie_Command':
                    ### this is just the TTL command sent to the laser, the actually data about when the Laser was active is in the Fidelity channel
                    pass
                    
                else:
                    raise Exception('Need to figure out recording type for %s (device_id:%s)' % (rec, rec.device_id))

            # collect and shuffle baseline chunks for each recording
            baseline_chunks = {}
            for post_rec in [rec for rec in srec.recordings if isinstance(rec, PatchClampRecording)]:
                post_dev = post_rec.device_id

                base_dist = BaselineDistributor.get(post_rec)
                chunks = list(base_dist.baseline_chunks())
                
                # generate a different random shuffle for each combination pre,post device
                # (we are not allowed to reuse the same baseline chunks for a particular pre-post pair,
                # but it is ok to reuse them across pairs)
                for pre_dev in srec.devices: 
                    # shuffle baseline chunks in a deterministic way:
                    # convert expt_id/srec_id/pre/post into an integer seed
                    seed_str = ("%s %s %s %s" % (job_id, srec.key, pre_dev, post_dev)).encode()
                    seed = struct.unpack('I', hashlib.sha1(seed_str).digest()[:4])[0]
                    rng = np.random.RandomState(seed)
                    rng.shuffle(chunks)
                    
                    baseline_chunks[pre_dev, post_dev] = chunks[:]

            baseline_qc_cache = {}
            baseline_entry_cache = {}

            ### import postsynaptic responses
            unmatched = 0
            osra = OptoSyncRecAnalyzer.get(srec)
            for stim_rec in srec.recordings:
                if stim_rec.device_type in ['Prairie_Command', 'unknown']: ### these don't actually contain data we want to use -- ignore them
                    continue
                if isinstance(stim_rec, PatchClampRecording):
                    ### exclude trying to analyze intrinsic pulses
                    stim_name = stim_rec.stimulus.description
                    if any(substr in stim_name for substr in ['intrins']):
                        continue

                for post_rec in [x for x in srec.recordings if isinstance(x, PatchClampRecording)]:
                    if stim_rec == post_rec:
                        continue

                    if 'Fidelity' in stim_rec.device_type:
                        stim_num = stim_rec.meta['notebook']['USER_stim_num']
                        if stim_num is None: ## happens when last sweep records a voltage offset - used to be labelled as 'unknown' device
                            continue
                        
                        stim = stim_log[str(int(stim_num))]
                        pre_cell_name = str(stim['stimulationPoint']['name'])

                        post_cell_name = str('electrode_'+ str(post_rec.device_id))

                        pair_entry = pairs_by_cell_id.get((pre_cell_name, post_cell_name))

                    elif 'led' in stim_rec.device_type.lower():
                        pair_entry = None

                    elif isinstance(stim_rec, PatchClampRecording):
                        pre_cell_name = str('electrode_' + str(stim_rec.device_id))
                        post_cell_name = str('electrode_'+ str(post_rec.device_id))
                        pair_entry = pairs_by_cell_id.get((pre_cell_name, post_cell_name))

                    # get all responses, regardless of the presence of a spike
                    responses = osra.get_responses(stim_rec, post_rec)
                    if len(responses) > 10:
                        raise Exception('Found more than 10 pulse responses for %s. Please investigate.'%srec)
                    for resp in responses:
                        if pair_entry is not None: ### when recordings are crappy cells are not always included in connections files so won't exist as pairs in the db, also led stimulations don't have pairs
                            if resp['ex_qc_pass']:
                                pair_entry.n_ex_test_spikes += 1
                            if resp['in_qc_pass']:
                                pair_entry.n_in_test_spikes += 1
                            
                        resampled = resp['response']['primary'].resample(sample_rate=20000)
                        resp_entry = db.PulseResponse(
                            recording=rec_entries[post_rec.device_id],
                            stim_pulse=all_pulse_entries[stim_rec.device_id][resp['pulse_n']],
                            pair=pair_entry,
                            data=resampled.data,
                            data_start_time=resampled.t0,
                            ex_qc_pass=resp['ex_qc_pass'],
                            in_qc_pass=resp['in_qc_pass'],
                            meta=None if resp['ex_qc_pass'] and resp['in_qc_pass'] else {'qc_failures': resp['qc_failures']},
                        )
                        session.add(resp_entry)

                        # find a baseline chunk from this recording with compatible qc metrics
                        got_baseline = False
                        for i, (start, stop) in enumerate(baseline_chunks[stim_rec.device_id, post_rec.device_id]):
                            key = (post_rec.device_id, start, stop)

                            # pull data and run qc if needed
                            if key not in baseline_qc_cache:
                                data = post_rec['primary'].time_slice(start, stop).resample(sample_rate=db.default_sample_rate).data
                                ex_qc_pass, in_qc_pass, qc_failures = qc.opto_pulse_response_qc_pass(post_rec, [start, stop])
                                baseline_qc_cache[key] = (data, ex_qc_pass, in_qc_pass)
                            else:
                                (data, ex_qc_pass, in_qc_pass) = baseline_qc_cache[key]

                            if resp_entry.ex_qc_pass is True and ex_qc_pass is not True:
                                continue
                            elif resp_entry.in_qc_pass is True and in_qc_pass is not True:
                                continue
                            else:
                                got_baseline = True
                                baseline_chunks[stim_rec.device_id, post_rec.device_id].pop(i)
                                break

                        if not got_baseline:
                            # no matching baseline available
                            unmatched += 1
                            continue

                        if key not in baseline_entry_cache:
                            # create a db record for this baseline chunk if it has not already appeared elsewhere
                            base_entry = db.Baseline(
                                recording=rec_entries[post_rec.device_id],
                                data=data,
                                data_start_time=start,
                                mode=float_mode(data),
                                ex_qc_pass=ex_qc_pass,
                                in_qc_pass=in_qc_pass,
                                meta=None if ex_qc_pass is True and in_qc_pass is True else {'qc_failures': qc_failures},
                            )
                            session.add(base_entry)
                            baseline_entry_cache[key] = base_entry
                        
                        resp_entry.baseline = baseline_entry_cache[key]

            if unmatched > 0:
                print("%s %s: %d pulse responses without matched baselines" % (job_id, srec, unmatched))
Ejemplo n.º 19
0
    def _update_plots(self):
        sweeps = self.sweeps
        self.current_event_set = None
        self.event_table.clear()
        
        # clear all plots
        self.pre_plot.clear()
        self.post_plot.clear()

        pre = self.params['pre']
        post = self.params['post']
        
        # If there are no selected sweeps or channels have not been set, return
        if len(sweeps) == 0 or pre == post or pre not in self.channels or post not in self.channels:
            return

        pre_mode = sweeps[0][pre].clamp_mode
        post_mode = sweeps[0][post].clamp_mode
        for ch, mode, plot in [(pre, pre_mode, self.pre_plot), (post, post_mode, self.post_plot)]:
            units = 'A' if mode == 'vc' else 'V'
            plot.setLabels(left=("Channel %d" % ch, units), bottom=("Time", 's'))
        
        # Iterate over selected channels of all sweeps, plotting traces one at a time
        # Collect information about pulses and spikes
        pulses = []
        spikes = []
        post_traces = []
        for i,sweep in enumerate(sweeps):
            pre_trace = sweep[pre]['primary']
            post_trace = sweep[post]['primary']
            
            # Detect pulse times
            stim = sweep[pre]['command'].data
            sdiff = np.diff(stim)
            on_times = np.argwhere(sdiff > 0)[1:, 0]  # 1: skips test pulse
            off_times = np.argwhere(sdiff < 0)[1:, 0]
            pulses.append(on_times)

            # filter data
            post_filt = self.artifact_remover.process(post_trace, list(on_times) + list(off_times))
            post_filt = self.baseline_remover.process(post_filt)
            post_filt = self.filter.process(post_filt)
            post_traces.append(post_filt)
            
            # plot raw data
            color = pg.intColor(i, hues=len(sweeps)*1.3, sat=128)
            color.setAlpha(128)
            for trace, plot in [(pre_trace, self.pre_plot), (post_filt, self.post_plot)]:
                plot.plot(trace.time_values, trace.data, pen=color, antialias=False)

            # detect spike times
            spike_inds = []
            spike_info = []
            for on, off in zip(on_times, off_times):
                spike = detect_evoked_spike(sweep[pre], [on, off])
                spike_info.append(spike)
                if spike is None:
                    spike_inds.append(None)
                else:
                    spike_inds.append(spike['rise_index'])
            spikes.append(spike_info)
                    
            dt = pre_trace.dt
            vticks = pg.VTickGroup([x * dt for x in spike_inds if x is not None], yrange=[0.0, 0.2], pen=color)
            self.pre_plot.addItem(vticks)

        # Iterate over spikes, plotting average response
        all_responses = []
        avg_responses = []
        fits = []
        fit = None
        
        npulses = max(map(len, pulses))
        self.response_plots.clear()
        self.response_plots.set_shape(1, npulses+1) # 1 extra for global average
        self.response_plots.setYLink(self.response_plots[0,0])
        for i in range(1, npulses+1):
            self.response_plots[0,i].hideAxis('left')
        units = 'A' if post_mode == 'vc' else 'V'
        self.response_plots[0, 0].setLabels(left=("Averaged events (Channel %d)" % post, units))
        
        fit_pen = {'color':(30, 30, 255), 'width':2, 'dash': [1, 1]}
        for i in range(npulses):
            # get the chunk of each sweep between spikes
            responses = []
            all_responses.append(responses)
            for j, sweep in enumerate(sweeps):
                # get the current spike
                if i >= len(spikes[j]):
                    continue
                spike = spikes[j][i]
                if spike is None:
                    continue
                
                # find next spike
                next_spike = None
                for sp in spikes[j][i+1:]:
                    if sp is not None:
                        next_spike = sp
                        break
                    
                # determine time range for response
                max_len = int(40e-3 / dt)  # don't take more than 50ms for any response
                start = spike['rise_index']
                if next_spike is not None:
                    stop = min(start + max_len, next_spike['rise_index'])
                else:
                    stop = start + max_len
                    
                # collect data from this trace
                trace = post_traces[j]
                d = trace.data[start:stop].copy()
                responses.append(d)

            if len(responses) == 0:
                continue
                
            # extend all responses to the same length and take nanmean
            avg = ragged_mean(responses, method='clip')
            avg -= float_mode(avg[:int(1e-3/dt)])
            avg_responses.append(avg)
            
            # plot average response for this pulse
            start = np.median([sp[i]['rise_index'] for sp in spikes if sp[i] is not None]) * dt
            t = np.arange(len(avg)) * dt
            self.response_plots[0,i].plot(t, avg, pen='w', antialias=True)

            # fit!
            fit = self.fit_psp(avg, t, dt, post_mode)
            fits.append(fit)
            
            # let the user mess with this fit
            curve = self.response_plots[0,i].plot(t, fit.eval(), pen=fit_pen, antialias=True).curve
            curve.setClickable(True)
            curve.fit = fit
            curve.sigClicked.connect(self.fit_curve_clicked)
            
        # display global average
        global_avg = ragged_mean(avg_responses, method='clip')
        t = np.arange(len(global_avg)) * dt
        self.response_plots[0,-1].plot(t, global_avg, pen='w', antialias=True)
        global_fit = self.fit_psp(global_avg, t, dt, post_mode)
        self.response_plots[0,-1].plot(t, global_fit.eval(), pen=fit_pen, antialias=True)
            
        # display fit parameters in table
        events = []
        for i,f in enumerate(fits + [global_fit]):
            if f is None:
                continue
            if i >= len(fits):
                vals = OrderedDict([('id', 'avg'), ('spike_time', np.nan), ('spike_stdev', np.nan)])
            else:
                spt = [s[i]['peak_index'] * dt for s in spikes if s[i] is not None]
                vals = OrderedDict([('id', i), ('spike_time', np.mean(spt)), ('spike_stdev', np.std(spt))])
            vals.update(OrderedDict([(k,f.best_values[k]) for k in f.params.keys()]))
            events.append(vals)
            
        self.current_event_set = (pre, post, events, sweeps)
        self.event_set_list.setCurrentRow(0)
        self.event_set_selected()
def analyze_pair_connectivity(amps, sign=None):
    """Given response strength records for a single pair, generate summary
    statistics characterizing strength, latency, and connectivity.
    
    Parameters
    ----------
    amps : dict
        Contains foreground and background strength analysis records
        (see input format below)
    sign : None, -1, or +1
        If None, then automatically determine whether to treat this connection as
        inhibitory or excitatory.

    Input must have the following structure::
    
        amps = {
            ('ic', 'fg'): recs, 
            ('ic', 'bg'): recs,
            ('vc', 'fg'): recs, 
            ('vc', 'bg'): recs,
        }
        
    Where each *recs* must be a structured array containing fields as returned
    by get_amps() and get_baseline_amps().
    
    The overall strategy here is:
    
    1. Make an initial decision on whether to treat this pair as excitatory or
       inhibitory, based on differences between foreground and background amplitude
       measurements
    2. Generate mean and stdev for amplitudes, deconvolved amplitudes, and deconvolved
       latencies
    3. Generate KS test p values describing the differences between foreground
       and background distributions for amplitude, deconvolved amplitude, and
       deconvolved latency    
    """
    requested_sign = sign
    fields = {}  # used to fill the new DB record
    
    # Use KS p value to check for differences between foreground and background
    qc_amps = {}
    ks_pvals = {}
    amp_means = {}
    amp_diffs = {}
    for clamp_mode in ('ic', 'vc'):
        clamp_mode_fg = amps[clamp_mode, 'fg']
        clamp_mode_bg = amps[clamp_mode, 'bg']
        if (len(clamp_mode_fg) == 0 or len(clamp_mode_bg) == 0):
            continue
        for sign in ('pos', 'neg'):
            # Separate into positive/negative tests and filter out responses that failed qc
            qc_field = {'vc': {'pos': 'in_qc_pass', 'neg': 'ex_qc_pass'}, 'ic': {'pos': 'ex_qc_pass', 'neg': 'in_qc_pass'}}[clamp_mode][sign]
            fg = clamp_mode_fg[clamp_mode_fg[qc_field]]
            bg = clamp_mode_bg[clamp_mode_bg[qc_field]]
            qc_amps[sign, clamp_mode, 'fg'] = fg
            qc_amps[sign, clamp_mode, 'bg'] = bg
            if (len(fg) == 0 or len(bg) == 0):
                continue
            
            # Measure some statistics from these records
            fg = fg[sign + '_dec_amp']
            bg = bg[sign + '_dec_amp']
            pval = scipy.stats.ks_2samp(fg, bg).pvalue
            ks_pvals[(sign, clamp_mode)] = pval
            # we could ensure that the average amplitude is in the right direction:
            fg_mean = np.mean(fg)
            bg_mean = np.mean(bg)
            amp_means[sign, clamp_mode] = {'fg': fg_mean, 'bg': bg_mean}
            amp_diffs[sign, clamp_mode] = fg_mean - bg_mean

    if requested_sign is None:
        # Decide whether to treat this connection as excitatory or inhibitory.
        #   strategy: accumulate evidence for either possibility by checking
        #   the ks p-values for each sign/clamp mode and the direction of the deflection
        is_exc = 0
        # print(expt.acq_timestamp, pair.pre_cell.ext_id, pair.post_cell.ext_id)
        for sign in ('pos', 'neg'):
            for mode in ('ic', 'vc'):
                ks = ks_pvals.get((sign, mode), None)
                if ks is None:
                    continue
                # turn p value into a reasonable scale factor
                ks = norm_pvalue(ks)
                dif_sign = 1 if amp_diffs[sign, mode] > 0 else -1
                if mode == 'vc':
                    dif_sign *= -1
                is_exc += dif_sign * ks
                # print("    ", sign, mode, is_exc, dif_sign * ks)
    else:
        is_exc = requested_sign

    if is_exc > 0:
        fields['synapse_type'] = 'ex'
        signs = {'ic':'pos', 'vc':'neg'}
    else:
        fields['synapse_type'] = 'in'
        signs = {'ic':'neg', 'vc':'pos'}

    # compute the rest of statistics for only positive or negative deflections
    for clamp_mode in ('ic', 'vc'):
        sign = signs[clamp_mode]
        fg = qc_amps.get((sign, clamp_mode, 'fg'))
        bg = qc_amps.get((sign, clamp_mode, 'bg'))
        if fg is None or bg is None or len(fg) == 0 or len(bg) == 0:
            fields[clamp_mode + '_n_samples'] = 0
            continue
        
        fields[clamp_mode + '_n_samples'] = len(fg)
        fields[clamp_mode + '_crosstalk_mean'] = np.mean(fg['crosstalk'])
        fields[clamp_mode + '_base_crosstalk_mean'] = np.mean(bg['crosstalk'])
        
        # measure mean, stdev, and statistical differences between
        # fg and bg for each measurement
        for val, field in [('amp', 'amp'), ('deconv_amp', 'dec_amp'), ('latency', 'dec_latency')]:
            f = fg[sign + '_' + field]
            b = bg[sign + '_' + field]
            fields[clamp_mode + '_' + val + '_mean'] = np.mean(f)
            fields[clamp_mode + '_' + val + '_stdev'] = np.std(f)
            fields[clamp_mode + '_base_' + val + '_mean'] = np.mean(b)
            fields[clamp_mode + '_base_' + val + '_stdev'] = np.std(b)
            # statistical tests comparing fg vs bg
            # Note: we use log(1-log(pval)) because it's nicer to plot and easier to
            # use as a classifier input
            tt_pval = scipy.stats.ttest_ind(f, b, equal_var=False).pvalue
            ks_pval = scipy.stats.ks_2samp(f, b).pvalue
            fields[clamp_mode + '_' + val + '_ttest'] = norm_pvalue(tt_pval)
            fields[clamp_mode + '_' + val + '_ks2samp'] = norm_pvalue(ks_pval)


        ### generate the average response and psp fit
        
        # collect all bg and fg traces
        # bg_traces = TraceList([Trace(data, sample_rate=db.default_sample_rate) for data in amps[clamp_mode, 'bg']['data']])
        fg_traces = TraceList()
        for rec in fg:
            t0 = rec['response_start_time'] - rec['max_dvdt_time']   # time-align to presynaptic spike
            trace = Trace(rec['data'], sample_rate=db.default_sample_rate, t0=t0)
            fg_traces.append(trace)
        
        # get averages
        # bg_avg = bg_traces.mean()        
        fg_avg = fg_traces.mean()
        base_rgn = fg_avg.time_slice(-6e-3, 0)
        base = float_mode(base_rgn.data)
        fields[clamp_mode + '_average_response'] = fg_avg.data
        fields[clamp_mode + '_average_response_t0'] = fg_avg.t0
        fields[clamp_mode + '_average_base_stdev'] = base_rgn.std()

        sign = {'pos':'+', 'neg':'-'}[signs[clamp_mode]]
        fg_bsub = fg_avg.copy(data=fg_avg.data - base)  # remove base to help fitting
        try:
            fit = fit_psp(fg_bsub, mode=clamp_mode, sign=sign, xoffset=(1e-3, 0, 6e-3), yoffset=(0, None, None), rise_time_mult_factor=4)              
            for param, val in fit.best_values.items():
                fields['%s_fit_%s' % (clamp_mode, param)] = val
            fields[clamp_mode + '_fit_yoffset'] = fit.best_values['yoffset'] + base
            fields[clamp_mode + '_fit_nrmse'] = fit.nrmse()
        except:
            print("Error in PSP fit:")
            sys.excepthook(*sys.exc_info())
            continue
        
        #global fit_plot
        #if fit_plot is None:
            #fit_plot = FitExplorer(fit)
            #fit_plot.show()
        #else:
            #fit_plot.set_fit(fit)
        #raw_input("Waiting to continue..")

    return fields
Ejemplo n.º 21
0
    def create_db_entries(cls, job_id, session):
        
        # Load experiment from DB
        expt_entry = db.experiment_from_timestamp(job_id, session=session)
        elecs_by_ad_channel = {elec.device_id:elec for elec in expt_entry.electrodes}
        pairs_by_device_id = {}
        for pair in expt_entry.pairs.values():
            pre_dev_id = pair.pre_cell.electrode.device_id
            post_dev_id = pair.post_cell.electrode.device_id
            pairs_by_device_id[(pre_dev_id, post_dev_id)] = pair
        
        # load NWB file
        path = os.path.join(config.synphys_data, expt_entry.storage_path)
        expt = Experiment(path)
        nwb = expt.data
        
        # Load all data from NWB into DB
        for srec in nwb.contents:
            temp = srec.meta.get('temperature', None)
            srec_entry = db.SyncRec(ext_id=srec.key, experiment=expt_entry, temperature=temp)
            session.add(srec_entry)
            
            srec_has_mp_probes = False
            
            rec_entries = {}
            all_pulse_entries = {}
            for rec in srec.recordings:
                
                # import all recordings
                electrode_entry = elecs_by_ad_channel[rec.device_id]  # should probably just skip if this causes KeyError?
                rec_entry = db.Recording(
                    sync_rec=srec_entry,
                    electrode=electrode_entry,
                    start_time=rec.start_time,
                )
                session.add(rec_entry)
                rec_entries[rec.device_id] = rec_entry
                
                # import patch clamp recording information
                if not isinstance(rec, PatchClampRecording):
                    continue
                qc_pass = qc.recording_qc_pass(rec)
                pcrec_entry = db.PatchClampRecording(
                    recording=rec_entry,
                    clamp_mode=rec.clamp_mode,
                    patch_mode=rec.patch_mode,
                    stim_name=rec.stimulus.description,
                    baseline_potential=rec.baseline_potential,
                    baseline_current=rec.baseline_current,
                    baseline_rms_noise=rec.baseline_rms_noise,
                    qc_pass=qc_pass,
                )
                session.add(pcrec_entry)

                # import test pulse information
                tp = rec.nearest_test_pulse
                if tp is not None:
                    indices = tp.indices or [None, None]
                    tp_entry = db.TestPulse(
                        electrode=electrode_entry,
                        recording=rec_entry,
                        start_index=indices[0],
                        stop_index=indices[1],
                        baseline_current=tp.baseline_current,
                        baseline_potential=tp.baseline_potential,
                        access_resistance=tp.access_resistance,
                        input_resistance=tp.input_resistance,
                        capacitance=tp.capacitance,
                        time_constant=tp.time_constant,
                    )
                    session.add(tp_entry)
                    pcrec_entry.nearest_test_pulse = tp_entry
                    
                # import information about STP protocol
                if not isinstance(rec, MultiPatchProbe):
                    continue
                srec_has_mp_probes = True
                psa = PulseStimAnalyzer.get(rec)
                ind_freq, rec_delay = psa.stim_params()
                mprec_entry = db.MultiPatchProbe(
                    patch_clamp_recording=pcrec_entry,
                    induction_frequency=ind_freq,
                    recovery_delay=rec_delay,
                )
                session.add(mprec_entry)
            
                # import presynaptic stim pulses
                pulses = psa.pulses()
                
                pulse_entries = {}
                all_pulse_entries[rec.device_id] = pulse_entries
                
                rec_tvals = rec['primary'].time_values

                for i,pulse in enumerate(pulses):
                    # Record information about all pulses, including test pulse.
                    t0 = rec_tvals[pulse[0]]
                    t1 = rec_tvals[pulse[1]]
                    data_start = max(0, t0 - 10e-3)
                    data_stop = t0 + 10e-3
                    pulse_entry = db.StimPulse(
                        recording=rec_entry,
                        pulse_number=i,
                        onset_time=t0,
                        amplitude=pulse[2],
                        duration=t1-t0,
                        data=rec['primary'].time_slice(data_start, data_stop).resample(sample_rate=20000).data,
                        data_start_time=data_start,
                    )
                    session.add(pulse_entry)
                    pulse_entries[i] = pulse_entry
                    

                # import presynaptic evoked spikes
                # For now, we only detect up to 1 spike per pulse, but eventually
                # this may be adapted for more.
                spikes = psa.evoked_spikes()
                for i,sp in enumerate(spikes):
                    pulse = pulse_entries[sp['pulse_n']]
                    if sp['spike'] is not None:
                        spinfo = sp['spike']
                        extra = {
                            'peak_time': rec_tvals[spinfo['peak_index']],
                            'max_dvdt_time': rec_tvals[spinfo['rise_index']],
                            'max_dvdt': spinfo['max_dvdt'],
                        }
                        if 'peak_diff' in spinfo:
                            extra['peak_diff'] = spinfo['peak_diff']
                        if 'peak_value' in spinfo:
                            extra['peak_value'] = spinfo['peak_value']
                        
                        pulse.n_spikes = 1
                    else:
                        extra = {}
                        pulse.n_spikes = 0
                    
                    spike_entry = db.StimSpike(
                        stim_pulse=pulse,
                        **extra
                    )
                    session.add(spike_entry)
                    pulse.first_spike = spike_entry
            
            if not srec_has_mp_probes:
                continue
            
            # import postsynaptic responses
            mpa = MultiPatchSyncRecAnalyzer(srec)
            for pre_dev in srec.devices:
                for post_dev in srec.devices:
                    if pre_dev == post_dev:
                        continue

                    # get all responses, regardless of the presence of a spike
                    responses = mpa.get_spike_responses(srec[pre_dev], srec[post_dev], align_to='pulse', require_spike=False)
                    post_tvals = srec[post_dev]['primary'].time_values
                    for resp in responses:
                        # base_entry = db.Baseline(
                        #     recording=rec_entries[post_dev],
                        #     start_index=resp['baseline_start'],
                        #     stop_index=resp['baseline_stop'],
                        #     data=resp['baseline'].resample(sample_rate=20000).data,
                        #     mode=float_mode(resp['baseline'].data),
                        # )
                        # session.add(base_entry)
                        pair_entry = pairs_by_device_id.get((pre_dev, post_dev), None)
                        if pair_entry is None:
                            continue  # no data for one or both channels
                        if resp['ex_qc_pass']:
                            pair_entry.n_ex_test_spikes += 1
                        if resp['in_qc_pass']:
                            pair_entry.n_in_test_spikes += 1
                        resp_entry = db.PulseResponse(
                            recording=rec_entries[post_dev],
                            stim_pulse=all_pulse_entries[pre_dev][resp['pulse_n']],
                            pair=pair_entry,
                            start_time=post_tvals[resp['rec_start']],
                            data=resp['response'].resample(sample_rate=20000).data,
                            ex_qc_pass=resp['ex_qc_pass'],
                            in_qc_pass=resp['in_qc_pass'],
                        )
                        session.add(resp_entry)
                        
            # generate up to 20 baseline snippets for each recording
            for dev in srec.devices:
                rec = srec[dev]
                rec_tvals = rec['primary'].time_values
                dist = BaselineDistributor.get(rec)
                for i in range(20):
                    base = dist.get_baseline_chunk(20e-3)
                    if base is None:
                        # all out!
                        break
                    start, stop = base
                    data = rec['primary'][start:stop].resample(sample_rate=20000).data

                    ex_qc_pass, in_qc_pass = qc.pulse_response_qc_pass(rec, [start, stop], None, [])

                    base_entry = db.Baseline(
                        recording=rec_entries[dev],
                        start_time=rec_tvals[start],
                        data=data,
                        mode=float_mode(data),
                        ex_qc_pass=ex_qc_pass,
                        in_qc_pass=in_qc_pass,
                    )
                    session.add(base_entry)
Ejemplo n.º 22
0
 n_sweeps = len(amp_group)
 if n_sweeps == 0:
     print "No Sweeps"
 for sweep in range(n_sweeps):
     stim_name = amp_group.responses[sweep].recording.meta['stim_name']
     stim_param = stim_name.split('_')
     freq = stim_param[1]
     freq = int(freq.split('H')[0])
     if freq <= stop_freq:
         sweep_trace = amp_group.responses[sweep]
         holding_potential = int(sweep_trace.recording.holding_potential *
                                 1000)
         holding = []
         if holding_potential >= -72 and holding_potential <= -68:
             holding.append(holding_potential)
             post_base = float_mode(sweep_trace.data[:int(10e-3 /
                                                          sweep_trace.dt)])
             pre_spike = amp_group.spikes[sweep]
             pre_base = float_mode(pre_spike.data[:int(10e-3 /
                                                       pre_spike.dt)])
             sweep_list['response'].append(
                 sweep_trace.copy(data=sweep_trace.data - post_base))
             sweep_list['spike'].append(
                 pre_spike.copy(data=pre_spike.data - pre_base))
 if len(sweep_list['response']) > 5:
     n = len(sweep_list['response'])
     if plot_sweeps is True:
         for sweep in range(n):
             current_sweep = sweep_list['response'][sweep]
             current_sweep.t0 = 0
             grid[row[1], 0].plot(current_sweep.time_values,
                                  current_sweep.data,
Ejemplo n.º 23
0
def bsub(trace):
    data = trace.data
    dt = trace.dt
    base = float_mode(data[:int(10e-3 / dt)])
    bsub_trace = trace.copy(data=data - base)
    return bsub_trace
Ejemplo n.º 24
0
    def create_db_entries(cls, job, session):
        db = job['database']
        job_id = job['job_id']

        # Load experiment from DB
        expt_entry = db.experiment_from_ext_id(job_id, session=session)
        elecs_by_ad_channel = {
            elec.device_id: elec
            for elec in expt_entry.electrodes
        }
        pairs_by_device_id = {}
        for pair in expt_entry.pairs.values():
            pre_dev_id = pair.pre_cell.electrode.device_id
            post_dev_id = pair.post_cell.electrode.device_id
            pairs_by_device_id[(pre_dev_id, post_dev_id)] = pair
            pair.n_ex_test_spikes = 0
            pair.n_in_test_spikes = 0

        # load NWB file
        path = os.path.join(config.synphys_data, expt_entry.storage_path)
        expt = Experiment(path)
        nwb = expt.data

        last_stim_pulse_time = {}

        # Load all data from NWB into DB
        for srec in nwb.contents:
            temp = srec.meta.get('temperature', None)
            srec_entry = db.SyncRec(ext_id=srec.key,
                                    experiment=expt_entry,
                                    temperature=temp)
            session.add(srec_entry)

            srec_has_mp_probes = False

            rec_entries = {}
            all_pulse_entries = {}
            for rec in srec.recordings:
                if rec.aborted:
                    # skip incomplete recordings
                    continue

                # import all recordings
                electrode_entry = elecs_by_ad_channel[
                    rec.
                    device_id]  # should probably just skip if this causes KeyError?
                rec_entry = db.Recording(
                    sync_rec=srec_entry,
                    electrode=electrode_entry,
                    start_time=rec.start_time,
                    stim_name=(None if rec.stimulus is None else
                               rec.stimulus.description),
                    stim_meta=(None if rec.stimulus is None else
                               rec.stimulus.save()),
                )
                session.add(rec_entry)
                rec_entries[rec.device_id] = rec_entry

                # import patch clamp recording information
                if not isinstance(rec, PatchClampRecording):
                    continue
                qc_pass, qc_failures = qc.recording_qc_pass(rec)
                pcrec_entry = db.PatchClampRecording(
                    recording=rec_entry,
                    clamp_mode=rec.clamp_mode,
                    patch_mode=rec.patch_mode,
                    baseline_potential=rec.baseline_potential,
                    baseline_current=rec.baseline_current,
                    baseline_rms_noise=rec.baseline_rms_noise,
                    qc_pass=qc_pass,
                    meta=None
                    if len(qc_failures) == 0 else {'qc_failures': qc_failures},
                )
                session.add(pcrec_entry)

                # import test pulse information
                tp = rec.nearest_test_pulse
                if tp is not None:
                    indices = tp.indices or [None, None]
                    tp_entry = db.TestPulse(
                        electrode=electrode_entry,
                        recording=rec_entry,
                        start_index=indices[0],
                        stop_index=indices[1],
                        baseline_current=tp.baseline_current,
                        baseline_potential=tp.baseline_potential,
                        access_resistance=tp.access_resistance,
                        input_resistance=tp.input_resistance,
                        capacitance=tp.capacitance,
                        time_constant=tp.time_constant,
                    )
                    session.add(tp_entry)
                    pcrec_entry.nearest_test_pulse = tp_entry

                # import information about STP protocol
                if not isinstance(rec,
                                  (MultiPatchProbe, MultiPatchMixedFreqTrain)):
                    continue

                srec_has_mp_probes = True

                if isinstance(rec, MultiPatchProbe):
                    ind_freq, rec_delay = rec.stim_params()
                    mprec_entry = db.MultiPatchProbe(
                        patch_clamp_recording=pcrec_entry,
                        induction_frequency=ind_freq,
                        recovery_delay=rec_delay,
                    )
                    session.add(mprec_entry)

                # import presynaptic stim pulses
                psa = PatchClampStimPulseAnalyzer.get(rec)
                pulses = psa.pulse_chunks()

                pulse_entries = {}
                all_pulse_entries[rec.device_id] = pulse_entries

                for i, pulse in enumerate(pulses):
                    # Record information about all pulses, including test pulse.
                    t0, t1 = pulse.meta['pulse_edges']
                    resampled = pulse['primary'].resample(
                        sample_rate=db.default_sample_rate)
                    clock_time = t0 + datetime_to_timestamp(
                        rec_entry.start_time)
                    prev_pulse_dt = clock_time - last_stim_pulse_time.get(
                        rec.device_id, -np.inf)
                    last_stim_pulse_time[rec.device_id] = clock_time
                    pulse_entry = db.StimPulse(
                        recording=rec_entry,
                        pulse_number=pulse.meta['pulse_n'],
                        onset_time=t0,
                        amplitude=pulse.meta['pulse_amplitude'],
                        duration=t1 - t0,
                        data=resampled.data,
                        data_start_time=resampled.t0,
                        previous_pulse_dt=prev_pulse_dt,
                    )
                    session.add(pulse_entry)
                    pulse_entries[pulse.meta['pulse_n']] = pulse_entry

                # import presynaptic evoked spikes
                # For now, we only detect up to 1 spike per pulse, but eventually
                # this may be adapted for more.
                spikes = psa.evoked_spikes()
                for i, sp in enumerate(spikes):
                    pulse = pulse_entries[sp['pulse_n']]
                    pulse.n_spikes = len(sp['spikes'])
                    for i, spike in enumerate(sp['spikes']):
                        spike_entry = db.StimSpike(
                            stim_pulse=pulse,
                            onset_time=spike['onset_time'],
                            peak_time=spike['peak_time'],
                            max_slope_time=spike['max_slope_time'],
                            max_slope=spike['max_slope'],
                            peak_diff=spike.get('peak_diff'),
                            peak_value=spike['peak_value'],
                        )
                        session.add(spike_entry)
                        if i == 0:
                            # pulse.first_spike = spike_entry
                            pulse.first_spike_time = spike_entry.max_slope_time

            if not srec_has_mp_probes:
                continue

            # collect and shuffle baseline chunks for each recording
            baseline_chunks = {}
            for post_dev in srec.devices:

                base_dist = BaselineDistributor.get(srec[post_dev])
                chunks = list(base_dist.baseline_chunks())

                # generate a different random shuffle for each combination pre,post device
                # (we are not allowed to reuse the same baseline chunks for a particular pre-post pair,
                # but it is ok to reuse them across pairs)
                for pre_dev in srec.devices:
                    # shuffle baseline chunks in a deterministic way:
                    # convert expt_id/srec_id/pre/post into an integer seed
                    seed_str = (
                        "%s %s %s %s" %
                        (job_id, srec.key, pre_dev, post_dev)).encode()
                    seed = struct.unpack(
                        'I',
                        hashlib.sha1(seed_str).digest()[:4])[0]
                    rng = np.random.RandomState(seed)
                    rng.shuffle(chunks)

                    baseline_chunks[pre_dev, post_dev] = chunks[:]

            baseline_qc_cache = {}
            baseline_entry_cache = {}

            # import postsynaptic responses
            unmatched = 0
            mpa = MultiPatchSyncRecAnalyzer(srec)
            for pre_dev in srec.devices:
                for post_dev in srec.devices:
                    if pre_dev == post_dev:
                        continue

                    # get all responses, regardless of the presence of a spike
                    responses = mpa.get_spike_responses(srec[pre_dev],
                                                        srec[post_dev],
                                                        align_to='pulse',
                                                        require_spike=False)

                    pair_entry = pairs_by_device_id.get((pre_dev, post_dev),
                                                        None)
                    if pair_entry is None:
                        continue  # no data for one or both channels

                    for resp in responses:
                        if resp['ex_qc_pass']:
                            pair_entry.n_ex_test_spikes += 1
                        if resp['in_qc_pass']:
                            pair_entry.n_in_test_spikes += 1

                        resampled = resp['response']['primary'].resample(
                            sample_rate=db.default_sample_rate)
                        resp_entry = db.PulseResponse(
                            recording=rec_entries[post_dev],
                            stim_pulse=all_pulse_entries[pre_dev][
                                resp['pulse_n']],
                            pair=pair_entry,
                            data=resampled.data,
                            data_start_time=resampled.t0,
                            ex_qc_pass=resp['ex_qc_pass'],
                            in_qc_pass=resp['in_qc_pass'],
                            meta=None
                            if resp['ex_qc_pass'] and resp['in_qc_pass'] else
                            {'qc_failures': resp['qc_failures']},
                        )
                        session.add(resp_entry)

                        # find a baseline chunk from this recording with compatible qc metrics
                        got_baseline = False
                        for i, (start,
                                stop) in enumerate(baseline_chunks[pre_dev,
                                                                   post_dev]):
                            key = (post_dev, start, stop)

                            # pull data and run qc if needed
                            if key not in baseline_qc_cache:
                                data = srec[post_dev]['primary'].time_slice(
                                    start,
                                    stop).resample(sample_rate=db.
                                                   default_sample_rate).data
                                ex_qc_pass, in_qc_pass, qc_failures = qc.pulse_response_qc_pass(
                                    srec[post_dev], [start, stop], None, [])
                                baseline_qc_cache[key] = (data, ex_qc_pass,
                                                          in_qc_pass)
                            else:
                                (data, ex_qc_pass,
                                 in_qc_pass) = baseline_qc_cache[key]

                            if resp_entry.ex_qc_pass is True and ex_qc_pass is not True:
                                continue
                            elif resp_entry.in_qc_pass is True and in_qc_pass is not True:
                                continue
                            else:
                                got_baseline = True
                                baseline_chunks[pre_dev, post_dev].pop(i)
                                break

                        if not got_baseline:
                            # no matching baseline available
                            unmatched += 1
                            continue

                        if key not in baseline_entry_cache:
                            # create a db record for this baseline chunk if it has not already appeared elsewhere
                            base_entry = db.Baseline(
                                recording=rec_entries[post_dev],
                                data=data,
                                data_start_time=start,
                                mode=float_mode(data),
                                ex_qc_pass=ex_qc_pass,
                                in_qc_pass=in_qc_pass,
                                meta=None
                                if ex_qc_pass is True and in_qc_pass is True
                                else {'qc_failures': qc_failures},
                            )
                            session.add(base_entry)
                            baseline_entry_cache[key] = base_entry

                        resp_entry.baseline = baseline_entry_cache[key]

            if unmatched > 0:
                print("%s %s: %d pulse responses without matched baselines" %
                      (job_id, srec, unmatched))
def first_pulse_plot(expt_list, name=None, summary_plot=None, color=None, scatter=0, features=False):
    amp_plots = pg.plot()
    amp_plots.setLabels(left=('Vm', 'V'))
    amp_base_subtract = []
    avg_amps = {'amp': [], 'latency': [], 'rise': []}
    for expt in expt_list:
        for pre, post in expt.connections:
            if expt.cells[pre].cre_type == cre_type[0] and expt.cells[post].cre_type == cre_type[1]:
                avg_amp, avg_trace, n_sweeps = responses(expt, pre, post, thresh=0.03e-3, filter=[[0, 50], [-68, -72]])
                if expt.cells[pre].cre_type in EXCITATORY_CRE_TYPES and avg_amp < 0:
                    continue
                elif expt.cells[pre].cre_type in INHIBITORY_CRE_TYPES and avg_amp > 0:
                    continue
                if n_sweeps >= 10:
                    avg_trace.t0 = 0
                    avg_amps['amp'].append(avg_amp)
                    base = float_mode(avg_trace.data[:int(10e-3 / avg_trace.dt)])
                    amp_base_subtract.append(avg_trace.copy(data=avg_trace.data - base))
                    if features is True:
                        if avg_amp > 0:
                            amp_sign = '+'
                        else:
                            amp_sign = '-'
                        psp_fits = fit_psp(avg_trace, sign=amp_sign, yoffset=0, amp=avg_amp, method='leastsq',
                                           fit_kws={})
                        avg_amps['latency'].append(psp_fits.best_values['xoffset'] - 10e-3)
                        avg_amps['rise'].append(psp_fits.best_values['rise_time'])

                    current_connection_HS = post, pre
                    if len(expt.connections) > 1 and args.recip is True:
                        for i,x in enumerate(expt.connections):
                            if x == current_connection_HS:  # determine if a reciprocal connection
                                amp_plots.plot(avg_trace.time_values, avg_trace.data - base, pen={'color': 'r', 'width': 1})
                                break
                            elif x != current_connection_HS and i == len(expt.connections) - 1:  # reciprocal connection was not found
                                amp_plots.plot(avg_trace.time_values, avg_trace.data - base)
                    else:
                        amp_plots.plot(avg_trace.time_values, avg_trace.data - base)

                    app.processEvents()

    if len(amp_base_subtract) != 0:
        print(name + ' n = %d' % len(amp_base_subtract))
        grand_mean = TraceList(amp_base_subtract).mean()
        grand_amp = np.mean(np.array(avg_amps['amp']))
        grand_amp_sem = stats.sem(np.array(avg_amps['amp']))
        amp_plots.addLegend()
        amp_plots.plot(grand_mean.time_values, grand_mean.data, pen={'color': 'g', 'width': 3}, name=name)
        amp_plots.addLine(y=grand_amp, pen={'color': 'g'})
        if grand_mean is not None:
            print(legend + ' Grand mean amplitude = %f +- %f' % (grand_amp, grand_amp_sem))
            if features is True:
                feature_list = (avg_amps['amp'], avg_amps['latency'], avg_amps['rise'])
                labels = (['Vm', 'V'], ['t', 's'], ['t', 's'])
                titles = ('Amplitude', 'Latency', 'Rise time')
            else:
                feature_list = [avg_amps['amp']]
                labels = (['Vm', 'V'])
                titles = 'Amplitude'
            summary_plots = summary_plot_pulse(feature_list[0], labels=labels, titles=titles, i=scatter,
                                               grand_trace=grand_mean, plot=summary_plot, color=color, name=legend)
            return avg_amps, summary_plots
    else:
        print ("No Traces")
        return None, avg_amps, None, None
Ejemplo n.º 26
0
    def _update_plots(self):
        sweeps = self.sweeps
        self.current_event_set = None
        self.event_table.clear()

        # clear all plots
        self.pre_plot.clear()
        self.post_plot.clear()

        pre = self.params['pre']
        post = self.params['post']

        # If there are no selected sweeps or channels have not been set, return
        if len(
                sweeps
        ) == 0 or pre == post or pre not in self.channels or post not in self.channels:
            return

        pre_mode = sweeps[0][pre].clamp_mode
        post_mode = sweeps[0][post].clamp_mode
        for ch, mode, plot in [(pre, pre_mode, self.pre_plot),
                               (post, post_mode, self.post_plot)]:
            units = 'A' if mode == 'vc' else 'V'
            plot.setLabels(left=("Channel %d" % ch, units),
                           bottom=("Time", 's'))

        # Iterate over selected channels of all sweeps, plotting traces one at a time
        # Collect information about pulses and spikes
        pulses = []
        spikes = []
        post_traces = []
        for i, sweep in enumerate(sweeps):
            pre_trace = sweep[pre]['primary']
            post_trace = sweep[post]['primary']

            # Detect pulse times
            stim = sweep[pre]['command'].data
            sdiff = np.diff(stim)
            on_times = np.argwhere(sdiff > 0)[1:, 0]  # 1: skips test pulse
            off_times = np.argwhere(sdiff < 0)[1:, 0]
            pulses.append(on_times)

            # filter data
            post_filt = self.artifact_remover.process(
                post_trace,
                list(on_times) + list(off_times))
            post_filt = self.baseline_remover.process(post_filt)
            post_filt = self.filter.process(post_filt)
            post_traces.append(post_filt)

            # plot raw data
            color = pg.intColor(i, hues=len(sweeps) * 1.3, sat=128)
            color.setAlpha(128)
            for trace, plot in [(pre_trace, self.pre_plot),
                                (post_filt, self.post_plot)]:
                plot.plot(trace.time_values,
                          trace.data,
                          pen=color,
                          antialias=False)

            # detect spike times
            spike_inds = []
            spike_info = []
            for on, off in zip(on_times, off_times):
                spike = detect_evoked_spike(sweep[pre], [on, off])
                spike_info.append(spike)
                if spike is None:
                    spike_inds.append(None)
                else:
                    spike_inds.append(spike['rise_index'])
            spikes.append(spike_info)

            dt = pre_trace.dt
            vticks = pg.VTickGroup(
                [x * dt for x in spike_inds if x is not None],
                yrange=[0.0, 0.2],
                pen=color)
            self.pre_plot.addItem(vticks)

        # Iterate over spikes, plotting average response
        all_responses = []
        avg_responses = []
        fits = []
        fit = None

        npulses = max(map(len, pulses))
        self.response_plots.clear()
        self.response_plots.set_shape(1, npulses +
                                      1)  # 1 extra for global average
        self.response_plots.setYLink(self.response_plots[0, 0])
        for i in range(1, npulses + 1):
            self.response_plots[0, i].hideAxis('left')
        units = 'A' if post_mode == 'vc' else 'V'
        self.response_plots[0,
                            0].setLabels(left=("Averaged events (Channel %d)" %
                                               post, units))

        fit_pen = {'color': (30, 30, 255), 'width': 2, 'dash': [1, 1]}
        for i in range(npulses):
            # get the chunk of each sweep between spikes
            responses = []
            all_responses.append(responses)
            for j, sweep in enumerate(sweeps):
                # get the current spike
                if i >= len(spikes[j]):
                    continue
                spike = spikes[j][i]
                if spike is None:
                    continue

                # find next spike
                next_spike = None
                for sp in spikes[j][i + 1:]:
                    if sp is not None:
                        next_spike = sp
                        break

                # determine time range for response
                max_len = int(40e-3 /
                              dt)  # don't take more than 50ms for any response
                start = spike['rise_index']
                if next_spike is not None:
                    stop = min(start + max_len, next_spike['rise_index'])
                else:
                    stop = start + max_len

                # collect data from this trace
                trace = post_traces[j]
                d = trace.data[start:stop].copy()
                responses.append(d)

            if len(responses) == 0:
                continue

            # extend all responses to the same length and take nanmean
            avg = ragged_mean(responses, method='clip')
            avg -= float_mode(avg[:int(1e-3 / dt)])
            avg_responses.append(avg)

            # plot average response for this pulse
            start = np.median(
                [sp[i]['rise_index']
                 for sp in spikes if sp[i] is not None]) * dt
            t = np.arange(len(avg)) * dt
            self.response_plots[0, i].plot(t, avg, pen='w', antialias=True)

            # fit!
            fit = self.fit_psp(avg, t, dt, post_mode)
            fits.append(fit)

            # let the user mess with this fit
            curve = self.response_plots[0, i].plot(t,
                                                   fit.eval(),
                                                   pen=fit_pen,
                                                   antialias=True).curve
            curve.setClickable(True)
            curve.fit = fit
            curve.sigClicked.connect(self.fit_curve_clicked)

        # display global average
        global_avg = ragged_mean(avg_responses, method='clip')
        t = np.arange(len(global_avg)) * dt
        self.response_plots[0, -1].plot(t, global_avg, pen='w', antialias=True)
        global_fit = self.fit_psp(global_avg, t, dt, post_mode)
        self.response_plots[0, -1].plot(t,
                                        global_fit.eval(),
                                        pen=fit_pen,
                                        antialias=True)

        # display fit parameters in table
        events = []
        for i, f in enumerate(fits + [global_fit]):
            if f is None:
                continue
            if i >= len(fits):
                vals = OrderedDict([('id', 'avg'), ('spike_time', np.nan),
                                    ('spike_stdev', np.nan)])
            else:
                spt = [
                    s[i]['peak_index'] * dt for s in spikes if s[i] is not None
                ]
                vals = OrderedDict([('id', i), ('spike_time', np.mean(spt)),
                                    ('spike_stdev', np.std(spt))])
            vals.update(
                OrderedDict([(k, f.best_values[k]) for k in f.params.keys()]))
            events.append(vals)

        self.current_event_set = (pre, post, events, sweeps)
        self.event_set_list.setCurrentRow(0)
        self.event_set_selected()
Ejemplo n.º 27
0
def analyze_pair_connectivity(amps, sign=None):
    """Given response strength records for a single pair, generate summary
    statistics characterizing strength, latency, and connectivity.
    
    Parameters
    ----------
    amps : dict
        Contains foreground and background strength analysis records
        (see input format below)
    sign : None, -1, or +1
        If None, then automatically determine whether to treat this connection as
        inhibitory or excitatory.

    Input must have the following structure::
    
        amps = {
            ('ic'): recs, 
            ('vc'): recs, 
        }
        
    Where each *recs* must be a structured array containing fields as returned
    by get_amps().
    
    The overall strategy here is:
    
    1. Make an initial decision on whether to treat this pair as excitatory or
       inhibitory, based on differences between foreground and background amplitude
       measurements
    2. Generate mean and stdev for amplitudes, deconvolved amplitudes, and deconvolved
       latencies
    3. Generate KS test p values describing the differences between foreground
       and background distributions for amplitude, deconvolved amplitude, and
       deconvolved latency    
    """
    # Filter by QC
    for k, v in amps.items():
        mask = v['qc_pass'].astype(bool)
        amps[k] = v[mask]

    # See if any data remains
    if all([len(a) == 0 for a in amps]):
        return None

    requested_sign = sign
    fields = {}  # used to fill the new DB record

    # Use KS p value to check for differences between foreground and background
    qc_amps = {}
    ks_pvals = {}
    amp_means = {}
    amp_diffs = {}
    for clamp_mode in ('ic', 'vc'):
        clamp_mode_amps = amps[clamp_mode]
        if len(clamp_mode_amps) == 0:
            continue
        for sign in ('pos', 'neg'):
            # Separate into positive/negative tests and filter out responses that failed qc
            qc_field = {
                'vc': {
                    'pos': 'in_qc_pass',
                    'neg': 'ex_qc_pass'
                },
                'ic': {
                    'pos': 'ex_qc_pass',
                    'neg': 'in_qc_pass'
                }
            }[clamp_mode][sign]
            fg = clamp_mode_amps[clamp_mode_amps[qc_field]]
            qc_amps[sign, clamp_mode] = fg
            if len(fg) == 0:
                continue

            # Measure some statistics from these records
            bg = fg['baseline_' + sign + '_dec_amp']
            fg = fg[sign + '_dec_amp']
            pval = scipy.stats.ks_2samp(fg, bg).pvalue
            ks_pvals[(sign, clamp_mode)] = pval
            # we could ensure that the average amplitude is in the right direction:
            fg_mean = np.mean(fg)
            bg_mean = np.mean(bg)
            amp_means[sign, clamp_mode] = {'fg': fg_mean, 'bg': bg_mean}
            amp_diffs[sign, clamp_mode] = fg_mean - bg_mean

    if requested_sign is None:
        # Decide whether to treat this connection as excitatory or inhibitory.
        #   strategy: accumulate evidence for either possibility by checking
        #   the ks p-values for each sign/clamp mode and the direction of the deflection
        is_exc = 0
        # print(expt.acq_timestamp, pair.pre_cell.ext_id, pair.post_cell.ext_id)
        for sign in ('pos', 'neg'):
            for mode in ('ic', 'vc'):
                ks = ks_pvals.get((sign, mode), None)
                if ks is None:
                    continue
                # turn p value into a reasonable scale factor
                ks = norm_pvalue(ks)
                dif_sign = 1 if amp_diffs[sign, mode] > 0 else -1
                if mode == 'vc':
                    dif_sign *= -1
                is_exc += dif_sign * ks
                # print("    ", sign, mode, is_exc, dif_sign * ks)
    else:
        is_exc = requested_sign

    if is_exc > 0:
        fields['synapse_type'] = 'ex'
        signs = {'ic': 'pos', 'vc': 'neg'}
    else:
        fields['synapse_type'] = 'in'
        signs = {'ic': 'neg', 'vc': 'pos'}

    # compute the rest of statistics for only positive or negative deflections
    for clamp_mode in ('ic', 'vc'):
        sign = signs[clamp_mode]
        fg = qc_amps.get((sign, clamp_mode))
        if fg is None or len(fg) == 0:
            fields[clamp_mode + '_n_samples'] = 0
            continue

        fields[clamp_mode + '_n_samples'] = len(fg)
        fields[clamp_mode + '_crosstalk_mean'] = np.mean(fg['crosstalk'])
        fields[clamp_mode + '_base_crosstalk_mean'] = np.mean(
            fg['baseline_crosstalk'])

        # measure mean, stdev, and statistical differences between
        # fg and bg for each measurement
        for val, field in [('amp', 'amp'), ('deconv_amp', 'dec_amp'),
                           ('latency', 'dec_latency')]:
            f = fg[sign + '_' + field]
            b = fg['baseline_' + sign + '_' + field]
            fields[clamp_mode + '_' + val + '_mean'] = np.mean(f)
            fields[clamp_mode + '_' + val + '_stdev'] = np.std(f)
            fields[clamp_mode + '_base_' + val + '_mean'] = np.mean(b)
            fields[clamp_mode + '_base_' + val + '_stdev'] = np.std(b)
            # statistical tests comparing fg vs bg
            # Note: we use log(1-log(pval)) because it's nicer to plot and easier to
            # use as a classifier input
            tt_pval = scipy.stats.ttest_ind(f, b, equal_var=False).pvalue
            ks_pval = scipy.stats.ks_2samp(f, b).pvalue
            fields[clamp_mode + '_' + val + '_ttest'] = norm_pvalue(tt_pval)
            fields[clamp_mode + '_' + val + '_ks2samp'] = norm_pvalue(ks_pval)

        ### generate the average response and psp fit

        # collect all bg and fg traces
        fg_traces = TSeriesList()
        for rec in fg:
            if not np.isfinite(
                    rec['max_slope_time']) or rec['max_slope_time'] is None:
                continue
            t0 = rec['response_start_time'] - rec[
                'max_slope_time']  # time-align to presynaptic spike
            trace = TSeries(rec['data'],
                            sample_rate=db.default_sample_rate,
                            t0=t0)
            fg_traces.append(trace)

        # get averages

        if len(fg_traces) == 0:
            continue

        # bg_avg = bg_traces.mean()
        fg_avg = fg_traces.mean()
        base_rgn = fg_avg.time_slice(-6e-3, 0)
        base = float_mode(base_rgn.data)
        fields[clamp_mode + '_average_response'] = fg_avg.data
        fields[clamp_mode + '_average_response_t0'] = fg_avg.t0
        fields[clamp_mode + '_average_base_stdev'] = base_rgn.std()

        sign = {'pos': 1, 'neg': -1}[signs[clamp_mode]]
        fg_bsub = fg_avg.copy(data=fg_avg.data -
                              base)  # remove base to help fitting
        try:
            fit = fit_psp(fg_bsub,
                          clamp_mode=clamp_mode,
                          sign=sign,
                          search_window=[0, 6e-3])
            for param, val in fit.best_values.items():
                fields['%s_fit_%s' % (clamp_mode, param)] = val
            fields[clamp_mode +
                   '_fit_yoffset'] = fit.best_values['yoffset'] + base
            fields[clamp_mode + '_fit_nrmse'] = fit.nrmse()
        except:
            print("Error in PSP fit:")
            sys.excepthook(*sys.exc_info())
            continue

        #global fit_plot
        #if fit_plot is None:
        #fit_plot = FitExplorer(fit)
        #fit_plot.show()
        #else:
        #fit_plot.set_fit(fit)
        #raw_input("Waiting to continue..")

    return fields
Ejemplo n.º 28
0
    def create_db_entries(cls, job, session):
        db = job['database']
        job_id = job['job_id']

        # Load experiment from DB
        expt_entry = db.experiment_from_timestamp(job_id, session=session)
        elecs_by_ad_channel = {elec.device_id:elec for elec in expt_entry.electrodes}
        pairs_by_device_id = {}
        for pair in expt_entry.pairs.values():
            pre_dev_id = pair.pre_cell.electrode.device_id
            post_dev_id = pair.post_cell.electrode.device_id
            pairs_by_device_id[(pre_dev_id, post_dev_id)] = pair
        
        # load NWB file
        path = os.path.join(config.synphys_data, expt_entry.storage_path)
        expt = Experiment(path)
        nwb = expt.data
        
        # Load all data from NWB into DB
        for srec in nwb.contents:
            temp = srec.meta.get('temperature', None)
            srec_entry = db.SyncRec(ext_id=srec.key, experiment=expt_entry, temperature=temp)
            session.add(srec_entry)
            
            srec_has_mp_probes = False
            
            rec_entries = {}
            all_pulse_entries = {}
            for rec in srec.recordings:
                
                # import all recordings
                electrode_entry = elecs_by_ad_channel[rec.device_id]  # should probably just skip if this causes KeyError?
                rec_entry = db.Recording(
                    sync_rec=srec_entry,
                    electrode=electrode_entry,
                    start_time=rec.start_time,
                )
                session.add(rec_entry)
                rec_entries[rec.device_id] = rec_entry
                
                # import patch clamp recording information
                if not isinstance(rec, PatchClampRecording):
                    continue
                qc_pass, qc_failures = qc.recording_qc_pass(rec)
                pcrec_entry = db.PatchClampRecording(
                    recording=rec_entry,
                    clamp_mode=rec.clamp_mode,
                    patch_mode=rec.patch_mode,
                    stim_name=rec.stimulus.description,
                    baseline_potential=rec.baseline_potential,
                    baseline_current=rec.baseline_current,
                    baseline_rms_noise=rec.baseline_rms_noise,
                    qc_pass=qc_pass,
                    meta=None if len(qc_failures) == 0 else {'qc_failures': qc_failures},
                )
                session.add(pcrec_entry)

                # import test pulse information
                tp = rec.nearest_test_pulse
                if tp is not None:
                    indices = tp.indices or [None, None]
                    tp_entry = db.TestPulse(
                        electrode=electrode_entry,
                        recording=rec_entry,
                        start_index=indices[0],
                        stop_index=indices[1],
                        baseline_current=tp.baseline_current,
                        baseline_potential=tp.baseline_potential,
                        access_resistance=tp.access_resistance,
                        input_resistance=tp.input_resistance,
                        capacitance=tp.capacitance,
                        time_constant=tp.time_constant,
                    )
                    session.add(tp_entry)
                    pcrec_entry.nearest_test_pulse = tp_entry
                    
                # import information about STP protocol
                if not isinstance(rec, MultiPatchProbe):
                    continue
                srec_has_mp_probes = True
                psa = PulseStimAnalyzer.get(rec)
                ind_freq, rec_delay = psa.stim_params()
                mprec_entry = db.MultiPatchProbe(
                    patch_clamp_recording=pcrec_entry,
                    induction_frequency=ind_freq,
                    recovery_delay=rec_delay,
                )
                session.add(mprec_entry)
            
                # import presynaptic stim pulses
                pulses = psa.pulse_chunks()
                
                pulse_entries = {}
                all_pulse_entries[rec.device_id] = pulse_entries

                for i,pulse in enumerate(pulses):
                    # Record information about all pulses, including test pulse.
                    t0, t1 = pulse.meta['pulse_edges']
                    resampled = pulse['primary'].resample(sample_rate=20000)
                    pulse_entry = db.StimPulse(
                        recording=rec_entry,
                        pulse_number=pulse.meta['pulse_n'],
                        onset_time=t0,
                        amplitude=pulse.meta['pulse_amplitude'],
                        duration=t1-t0,
                        data=resampled.data,
                        data_start_time=resampled.t0,
                    )
                    session.add(pulse_entry)
                    pulse_entries[pulse.meta['pulse_n']] = pulse_entry
                    

                # import presynaptic evoked spikes
                # For now, we only detect up to 1 spike per pulse, but eventually
                # this may be adapted for more.
                spikes = psa.evoked_spikes()
                for i,sp in enumerate(spikes):
                    pulse = pulse_entries[sp['pulse_n']]
                    pulse.n_spikes = len(sp['spikes'])
                    for i,spike in enumerate(sp['spikes']):
                        spike_entry = db.StimSpike(
                            stim_pulse=pulse,
                            onset_time=spike['onset_time'],
                            peak_time=spike['peak_time'],
                            max_slope_time=spike['max_slope_time'],
                            max_slope=spike['max_slope'],
                            peak_diff=spike.get('peak_diff'),
                            peak_value=spike['peak_value'],
                        )
                        session.add(spike_entry)
                        if i == 0:
                            # pulse.first_spike = spike_entry
                            pulse.first_spike_time = spike_entry.max_slope_time
            
            if not srec_has_mp_probes:
                continue
            
            # import postsynaptic responses
            mpa = MultiPatchSyncRecAnalyzer(srec)
            for pre_dev in srec.devices:
                for post_dev in srec.devices:
                    if pre_dev == post_dev:
                        continue

                    # get all responses, regardless of the presence of a spike
                    responses = mpa.get_spike_responses(srec[pre_dev], srec[post_dev], align_to='pulse', require_spike=False)
                    for resp in responses:
                        # base_entry = db.Baseline(
                        #     recording=rec_entries[post_dev],
                        #     start_index=resp['baseline_start'],
                        #     stop_index=resp['baseline_stop'],
                        #     data=resp['baseline'].resample(sample_rate=20000).data,
                        #     mode=float_mode(resp['baseline'].data),
                        # )
                        # session.add(base_entry)
                        pair_entry = pairs_by_device_id.get((pre_dev, post_dev), None)
                        if pair_entry is None:
                            continue  # no data for one or both channels
                        if resp['ex_qc_pass']:
                            pair_entry.n_ex_test_spikes += 1
                        if resp['in_qc_pass']:
                            pair_entry.n_in_test_spikes += 1
                            
                        resampled = resp['response']['primary'].resample(sample_rate=20000)
                        resp_entry = db.PulseResponse(
                            recording=rec_entries[post_dev],
                            stim_pulse=all_pulse_entries[pre_dev][resp['pulse_n']],
                            pair=pair_entry,
                            data=resampled.data,
                            data_start_time=resampled.t0,
                            ex_qc_pass=resp['ex_qc_pass'],
                            in_qc_pass=resp['in_qc_pass'],
                            meta=None if resp['ex_qc_pass'] and resp['in_qc_pass'] else {'qc_failures': resp['qc_failures']},
                        )
                        session.add(resp_entry)
                        
            # generate up to 20 baseline snippets for each recording
            for dev in srec.devices:
                rec = srec[dev]
                dist = BaselineDistributor.get(rec)
                for i in range(20):
                    base = dist.get_baseline_chunk(20e-3)
                    if base is None:
                        # all out!
                        break
                    start, stop = base
                    data = rec['primary'].time_slice(start, stop).resample(sample_rate=20000).data

                    ex_qc_pass, in_qc_pass, qc_failures = qc.pulse_response_qc_pass(rec, [start, stop], None, [])

                    base_entry = db.Baseline(
                        recording=rec_entries[dev],
                        data=data,
                        data_start_time=start,
                        mode=float_mode(data),
                        ex_qc_pass=ex_qc_pass,
                        in_qc_pass=in_qc_pass,
                        meta=None if ex_qc_pass is True and in_qc_pass is True else {'qc_failures': qc_failures},
                    )
                    session.add(base_entry)
Ejemplo n.º 29
0
    def _get_tserieslist(self,
                         ts_name,
                         align,
                         bsub,
                         bsub_win=5e-3,
                         alignment_failure_mode='ignore'):
        tsl = []
        if align is not None and alignment_failure_mode == 'average':
            if align == 'spike':
                average_align_t = np.mean([
                    p.stim_pulse.first_spike_time for p in self.prs
                    if p.stim_pulse.first_spike_time is not None
                ])
            elif align == 'peak':
                average_align_t = np.mean([
                    p.stim_pulse.spikes[0].peak_time for p in self.prs
                    if p.stim_pulse.n_spikes == 1
                    and p.stim_pulse.spikes[0].peak_time is not None
                ])
            elif align == 'stim':
                average_align_t = np.mean([
                    p.stim_pulse.onset_time for p in self.prs
                    if p.stim_pulse.onset_time is not None
                ])
            else:
                raise ValueError(
                    "align must be None, 'spike', 'peak', or 'pulse'.")

        for pr in self.prs:
            ts = getattr(pr, ts_name)
            stim_time = pr.stim_pulse.onset_time

            if bsub is True:
                start_time = max(ts.t0, stim_time - bsub_win)
                baseline_data = ts.time_slice(start_time, stim_time).data
                if len(baseline_data) == 0:
                    baseline = ts.data[0]
                else:
                    baseline = float_mode(baseline_data)
                ts = ts - baseline

            if align is not None:
                if align == 'spike':
                    # first_spike_time is the max dv/dt of the spike
                    align_t = pr.stim_pulse.first_spike_time
                elif align == 'pulse':
                    align_t = stim_time
                elif align == 'peak':
                    # peak of the first spike
                    align_t = pr.stim_pulse.spikes[
                        0].peak_time if pr.stim_pulse.n_spikes == 1 else None
                else:
                    raise ValueError(
                        "align must be None, 'spike', 'peak', or 'pulse'.")

                if align_t is None:
                    if alignment_failure_mode == 'ignore':
                        # ignore PRs with no known timing
                        continue
                    elif alignment_failure_mode == 'average':
                        align_t = average_align_t
                        if np.isnan(align_t):
                            raise Exception(
                                "average %s time is None, try another mode" %
                                align)
                    elif alignment_failure_mode == 'raise':
                        raise Exception(
                            "%s time is not available for pulse %s and can't be aligned"
                            % (align, pr))

                ts = ts.copy(t0=ts.t0 - align_t)

            tsl.append(ts)
        return TSeriesList(tsl)
Ejemplo n.º 30
0
    def create(self, session):
        err, warn = self.check()
        if len(err) > 0:
            raise Exception("Submission has errors:\n%s" % '\n'.join(err))

        # look up slice record in DB
        slice_dir = self.dh.parent()
        ts = datetime.fromtimestamp(slice_dir.info()['__timestamp__'])
        slice_entry = db.slice_from_timestamp(ts, session=session)

        # Create entry in experiment table
        data = self.fields
        expt = db.Experiment(**data)
        expt.slice = slice_entry
        self.expt_entry = expt
        session.add(expt)

        # Load NWB file and create data entries
        nwb = MultiPatchExperiment(self.nwb_file.name())

        for srec in nwb.contents:
            temp = srec.meta.get('temperature', None)
            srec_entry = db.SyncRec(sync_rec_key=srec.key,
                                    experiment=expt,
                                    temperature=temp)
            session.add(srec_entry)

            srec_has_mp_probes = False

            rec_entries = {}
            all_pulse_entries = {}
            for rec in srec.recordings:

                # import all recordings
                rec_entry = db.Recording(
                    sync_rec=srec_entry,
                    device_key=rec.device_id,
                    start_time=rec.start_time,
                )
                session.add(rec_entry)
                rec_entries[rec.device_id] = rec_entry

                # import patch clamp recording information
                if not isinstance(rec, PatchClampRecording):
                    continue
                pcrec_entry = db.PatchClampRecording(
                    recording=rec_entry,
                    clamp_mode=rec.clamp_mode,
                    patch_mode=rec.patch_mode,
                    stim_name=rec.meta['stim_name'],
                    baseline_potential=rec.baseline_potential,
                    baseline_current=rec.baseline_current,
                    baseline_rms_noise=rec.baseline_rms_noise,
                )
                session.add(pcrec_entry)

                # import test pulse information
                tp = rec.nearest_test_pulse
                if tp is not None:
                    tp_entry = db.TestPulse(
                        start_index=tp.indices[0],
                        stop_index=tp.indices[1],
                        baseline_current=tp.baseline_current,
                        baseline_potential=tp.baseline_potential,
                        access_resistance=tp.access_resistance,
                        input_resistance=tp.input_resistance,
                        capacitance=tp.capacitance,
                        time_constant=tp.time_constant,
                    )
                    session.add(tp_entry)
                    pcrec_entry.nearest_test_pulse = tp_entry

                # import information about STP protocol
                if not isinstance(rec, MultiPatchProbe):
                    continue
                srec_has_mp_probes = True
                psa = PulseStimAnalyzer.get(rec)
                ind_freq, rec_delay = psa.stim_params()
                mprec_entry = db.MultiPatchProbe(
                    patch_clamp_recording=pcrec_entry,
                    induction_frequency=ind_freq,
                    recovery_delay=rec_delay,
                )
                session.add(mprec_entry)

                # import presynaptic stim pulses
                pulses = psa.pulses()

                pulse_entries = {}
                all_pulse_entries[rec.device_id] = pulse_entries

                for i, pulse in enumerate(pulses):
                    # Record information about all pulses, including test pulse.
                    pulse_entry = db.StimPulse(
                        recording=rec_entry,
                        pulse_number=i,
                        onset_index=pulse[0],
                        amplitude=pulse[2],
                        length=pulse[1] - pulse[0],
                    )
                    session.add(pulse_entry)
                    pulse_entries[i] = pulse_entry

                # import presynaptic evoked spikes
                spikes = psa.evoked_spikes()
                for i, sp in enumerate(spikes):
                    pulse = pulse_entries[sp['pulse_n']]
                    if sp['spike'] is not None:
                        extra = sp['spike']
                        pulse.n_spikes = 1
                    else:
                        extra = {}
                        pulse.n_spikes = 0

                    spike_entry = db.StimSpike(recording=rec_entry,
                                               pulse=pulse,
                                               **extra)
                    session.add(spike_entry)

            if not srec_has_mp_probes:
                continue

            # import postsynaptic responses
            mpa = MultiPatchSyncRecAnalyzer(srec)
            for pre_dev in srec.devices:
                for post_dev in srec.devices:
                    # get all responses, regardless of the presence of a spike
                    responses = mpa.get_spike_responses(srec[pre_dev],
                                                        srec[post_dev],
                                                        align_to='pulse',
                                                        require_spike=False)
                    for resp in responses:
                        base_entry = db.Baseline(
                            recording=rec_entries[post_dev],
                            start_index=resp['baseline_start'],
                            stop_index=resp['baseline_stop'],
                            data=resp['baseline'].resample(
                                sample_rate=20000).data,
                            mode=float_mode(resp['baseline'].data),
                        )
                        session.add(base_entry)
                        resp_entry = db.PulseResponse(
                            recording=rec_entries[post_dev],
                            stim_pulse=all_pulse_entries[pre_dev][
                                resp['pulse_n']],
                            baseline=base_entry,
                            start_index=resp['rec_start'],
                            stop_index=resp['rec_stop'],
                            data=resp['response'].resample(
                                sample_rate=20000).data,
                        )
                        session.add(resp_entry)

        return expt