def format_trace(trace, baseline_win, x_offset, align='spike'):
    # align can be to the pre-synaptic spike (default) or the onset of the PSP ('psp')
    baseline = float_mode(trace[0:baseline_win])
    trace = Trace(data=(trace-baseline), sample_rate=db.default_sample_rate)
    if align == 'psp':
        trace.t0 = -x_offset
    return trace
Exemplo n.º 2
0
def cache_response(expt, pre, post, cache_file, cache):
    key = (expt.uid, pre, post)
    if key in cache:
        responses = cache[key]
        n_trials = len(responses['data'])
        response = []
        if n_trials != 0:
            for trial in range(n_trials):
                response.append(
                    Trace(data=responses['data'][trial],
                          dt=responses['dt'][trial],
                          stim_param=[responses['stim_param'][trial]]))
        return response

    response = get_response(expt, pre, post)
    cache[key] = response

    data = pickle.dumps(cache)
    open(cache_file, 'wb').write(data)
    print("cached connection %s, %d -> %d" % (key[0], key[1], key[2]))
    n_trials = len(response['data'])
    response2 = []
    if n_trials != 0:
        for trial in range(n_trials):
            response2.append(
                Trace(data=response['data'][trial],
                      dt=response['dt'][trial],
                      stim_param=[response['stim_param'][trial]]))
    return response2
Exemplo n.º 3
0
def format_trace(trace, baseline_win, x_offset, align='spike'):
    # align can be to the pre-synaptic spike (default) or the onset of the PSP ('psp')
    baseline = float_mode(trace[0:baseline_win])
    trace = Trace(data=(trace - baseline), sample_rate=db.default_sample_rate)
    if align == 'psp':
        trace.t0 = -x_offset
    return trace
def analyze_response_strength(rec, source, remove_artifacts=False, deconvolve=True, lpf=True, bsub=True, lowpass=1000):
    """Perform a standardized strength analysis on a record selected by response_query or baseline_query.

    1. Determine timing of presynaptic stimulus pulse edges and spike
    2. Measure peak deflection on raw trace
    3. Apply deconvolution / artifact removal / lpf
    4. Measure peak deflection on deconvolved trace
    """
    data = Trace(rec.data, sample_rate=db.default_sample_rate)
    if source == 'pulse_response':
        # Find stimulus pulse edges for artifact removal
        start = rec.pulse_start - rec.rec_start
        pulse_times = [start, start + rec.pulse_dur]
        if rec.spike_time is None:
            # these pulses failed QC, but we analyze them anyway to make all data visible
            spike_time = 11e-3
        else:
            spike_time = rec.spike_time - rec.rec_start
    elif source == 'baseline':
        # Fake stimulus information to ensure that background data receives
        # the same filtering / windowing treatment
        pulse_times = [10e-3, 12e-3]
        spike_time = 11e-3
    else:
        raise ValueError("Invalid source %s" % source)

    results = {}

    results['raw_trace'] = data
    results['pulse_times'] = pulse_times
    results['spike_time'] = spike_time

    # Measure crosstalk from pulse onset
    p1 = data.time_slice(pulse_times[0]-200e-6, pulse_times[0]).median()
    p2 = data.time_slice(pulse_times[0], pulse_times[0]+200e-6).median()
    results['crosstalk'] = p2 - p1

    # crosstalk artifacts in VC are removed before deconvolution
    if rec.clamp_mode == 'vc' and remove_artifacts is True:
        data = remove_crosstalk_artifacts(data, pulse_times)
        remove_artifacts = False

    # Measure deflection on raw data
    results['pos_amp'], _ = measure_peak(data, '+', spike_time, pulse_times)
    results['neg_amp'], _ = measure_peak(data, '-', spike_time, pulse_times)

    # Deconvolution / artifact removal / filtering
    if deconvolve:
        tau = 15e-3 if rec.clamp_mode == 'ic' else 5e-3
    else:
        tau = None
    dec_data = deconv_filter(data, pulse_times, tau=tau, lpf=lpf, remove_artifacts=remove_artifacts, bsub=bsub, lowpass=lowpass)
    results['dec_trace'] = dec_data

    # Measure deflection on deconvolved data
    results['pos_dec_amp'], results['pos_dec_latency'] = measure_peak(dec_data, '+', spike_time, pulse_times)
    results['neg_dec_amp'], results['neg_dec_latency'] = measure_peak(dec_data, '-', spike_time, pulse_times)
    
    return results
Exemplo n.º 5
0
def analyze_response_strength(rec, source, remove_artifacts=False, lpf=True, bsub=True, lowpass=1000):
    """Perform a standardized strength analysis on a record selected by response_query or baseline_query.

    1. Determine timing of presynaptic stimulus pulse edges and spike
    2. Measure peak deflection on raw trace
    3. Apply deconvolution / artifact removal / lpf
    4. Measure peak deflection on deconvolved trace
    """
    data = Trace(rec.data, sample_rate=db.default_sample_rate)
    if source == 'pulse_response':
        # Find stimulus pulse edges for artifact removal
        start = rec.pulse_start - rec.rec_start
        pulse_times = [start, start + rec.pulse_dur]
        if rec.spike_time is None:
            # these pulses failed QC, but we analyze them anyway to make all data visible
            spike_time = 11e-3
        else:
            spike_time = rec.spike_time - rec.rec_start
    elif source == 'baseline':
        # Fake stimulus information to ensure that background data receives
        # the same filtering / windowing treatment
        pulse_times = [10e-3, 12e-3]
        spike_time = 11e-3
    else:
        raise ValueError("Invalid source %s" % source)

    results = {}

    results['raw_trace'] = data
    results['pulse_times'] = pulse_times
    results['spike_time'] = spike_time

    # Measure crosstalk from pulse onset
    p1 = data.time_slice(pulse_times[0]-200e-6, pulse_times[0]).median()
    p2 = data.time_slice(pulse_times[0], pulse_times[0]+200e-6).median()
    results['crosstalk'] = p2 - p1

    # crosstalk artifacts in VC are removed before deconvolution
    if rec.clamp_mode == 'vc' and remove_artifacts is True:
        data = remove_crosstalk_artifacts(data, pulse_times)
        remove_artifacts = False

    # Measure deflection on raw data
    results['pos_amp'], _ = measure_peak(data, '+', spike_time, pulse_times)
    results['neg_amp'], _ = measure_peak(data, '-', spike_time, pulse_times)

    # Deconvolution / artifact removal / filtering
    tau = 15e-3 if rec.clamp_mode == 'ic' else 5e-3
    dec_data = deconv_filter(data, pulse_times, tau=tau, lpf=lpf, remove_artifacts=remove_artifacts, bsub=bsub, lowpass=lowpass)
    results['dec_trace'] = dec_data

    # Measure deflection on deconvolved data
    results['pos_dec_amp'], results['pos_dec_latency'] = measure_peak(dec_data, '+', spike_time, pulse_times)
    results['neg_dec_amp'], results['neg_dec_latency'] = measure_peak(dec_data, '-', spike_time, pulse_times)
    
    return results
Exemplo n.º 6
0
def simulate_response(fg_recs, bg_results, amp, rtime, seed=None):
    if seed is not None:
        np.random.seed(seed)

    dt = 1.0 / db.default_sample_rate
    t = np.arange(0, 15e-3, dt)
    template = Psp.psp_func(t, xoffset=0, yoffset=0, rise_time=rtime, decay_tau=15e-3, amp=1, rise_power=2)

    r_amps = scipy.stats.binom.rvs(p=0.2, n=24, size=len(fg_recs)) * scipy.stats.norm.rvs(scale=0.3, loc=1, size=len(fg_recs))
    r_amps *= amp / r_amps.mean()
    r_latency = np.random.normal(size=len(fg_recs), scale=200e-6, loc=13e-3)
    fg_results = []
    traces = []
    fg_recs = [RecordWrapper(rec) for rec in fg_recs]  # can't modify fg_recs, so we wrap records with a mutable shell
    for k,rec in enumerate(fg_recs):
        rec.data = rec.data.copy()
        start = int(r_latency[k] * db.default_sample_rate)
        length = len(rec.data) - start
        rec.data[start:] += template[:length] * r_amps[k]

        fg_result = analyze_response_strength(rec, 'baseline')
        fg_results.append(fg_result)

        traces.append(Trace(rec.data, sample_rate=db.default_sample_rate))
        traces[-1].amp = r_amps[k]
    fg_results = str_analysis_result_table(fg_results, fg_recs)
    conn_result = analyze_pair_connectivity({('ic', 'fg'): fg_results, ('ic', 'bg'): bg_results, ('vc', 'fg'): [], ('vc', 'bg'): []}, sign=1)
    return conn_result, traces
Exemplo n.º 7
0
def create_test_pulse(start=5 * ms,
                      pdur=10 * ms,
                      pamp=-10 * pA,
                      mode='ic',
                      dt=10 * us,
                      r_access=10 * MOhm,
                      c_soma=5 * pF,
                      noise=5 * pA):
    # update patch pipette access resistance
    model_cell.clamp.ra = r_access

    # update noise amplitude
    model_cell.mechs['noise'].stdev = noise

    # make pulse array
    duration = start + pdur * 3
    pulse = np.zeros(int(duration / dt))
    pstart = int(start / dt)
    pstop = pstart + int(pdur / dt)
    pulse[pstart:pstop] = pamp

    # simulate response
    result = model_cell.test(Trace(pulse, dt), mode)

    return result
Exemplo n.º 8
0
def responses(expt, pre, post):
    key = (expt.nwb_file, pre, post)
    result_cache = load_cache(result_cache_file)
    if key in result_cache:
        res = result_cache[key]
        if 'avg_est' not in res:
            return None, None, None
        avg_est = res['avg_est']
        avg_amp = Trace(data=res['data'], dt=res['dt'])
        n_sweeps = res['n_sweeps']
        return avg_est, avg_amp, n_sweeps

    analyzer = DynamicsAnalyzer(expt, pre, post, align_to='spike')
    avg_est, _, avg_amp, _, n_sweeps = analyzer.estimate_amplitude(plot=False)
    if n_sweeps == 0:
        result_cache[key] = {}
        ret = None, None, n_sweeps
    else:
        result_cache[key] = {'avg_est': avg_est, 'data': avg_amp.data, 'dt': avg_amp.dt, 'n_sweeps': n_sweeps}
        ret = avg_est, avg_amp, n_sweeps

    data = pickle.dumps(result_cache)
    open(result_cache_file, 'wb').write(data)
    print (key)
    return ret
def responses(expt, pre, post, thresh, filter=None):
    key = (expt.nwb_file, pre, post)
    result_cache = load_cache(result_cache_file)
    if key in result_cache:
        res = result_cache[key]
        if 'avg_amp' not in res:
            return None, None, None
        avg_amp = res['avg_amp']
        avg_trace = Trace(data=res['data'], dt=res['dt'])
        n_sweeps = res['n_sweeps']
        return avg_amp, avg_trace, n_sweeps

    if filter is not None:
        responses, artifact = get_response(expt, pre, post, type='pulse')
        response_subset = response_filter(responses, freq_range=filter[0], holding_range=filter[1], pulse=True)
        if len(response_subset) > 0:
            avg_trace, avg_amp, _, _ = get_amplitude(response_subset)
        n_sweeps = len(response_subset)
    else:
        analyzer = DynamicsAnalyzer(expt, pre, post, align_to='spike')
        avg_amp, _, avg_trace, _, n_sweeps = analyzer.estimate_amplitude(plot=False)
        artifact = analyzer.cross_talk()
    if n_sweeps == 0 or artifact > thresh:
        result_cache[key] = {}
        ret = None, None, n_sweeps
    else:
        result_cache[key] = {'avg_amp': avg_amp, 'data': avg_trace.data, 'dt': avg_trace.dt, 'n_sweeps': n_sweeps}
        ret = avg_amp, avg_trace, n_sweeps

    data = pickle.dumps(result_cache)
    open(result_cache_file, 'wb').write(data)
    print (key)
    return ret
def format_responses(responses):
    n_trials = len(responses['data'])
    response = {}
    if n_trials != 0:
        for trial in range(n_trials):
            stim_params = responses['stim_param'][trial]
            if stim_params not in response:
                response[stim_params] = []
            response[stim_params].append(Trace(data=responses['data'][trial], dt=responses['dt'][trial],
                                    stim_param=[responses['stim_param'][trial]]))
    return response
Exemplo n.º 11
0
def load_next():
    global all_pulses, ui, last_result
    try:
        (expt_id, pre_cell_id, post_cell_id, sweep, pr) = next(all_pulses)
    except StopIteration:
        ui.widget.hide()
        return
    print(expt_id, pre_cell_id, post_cell_id, sweep, pr['pulse_n'])

    # run psp fit on each chunk
    spiketime = pr['spikes'][0]['max_slope_time']
    if spiketime is None:
        print("  skip - unknown spike time")
        load_next()
        return

    rec = pr['response']
    kwds = {
        'data': rec['primary'],
        'search_window': (spiketime, spiketime + 8e-3),
        'clamp_mode': rec.clamp_mode,
    }
    fit = fit_psp(ui=ui, **kwds)
    ui.show_result(fit)

    # copy just the necessary parts of recording data for export to file
    data = kwds['data']
    export_data = Trace(data.data,
                        t0=data.t0,
                        dt=data.meta['dt'],
                        sample_rate=data.meta['sample_rate'])
    kwds['data'] = export_data

    # construct test case
    tc = PspFitTestCase()
    tc._meta = {
        'expt_id': expt_id,
        'pre_cell_id': pre_cell_id,
        'post_cell_id': post_cell_id,
        'sweep_id': sweep.key,
        'pulse_n': pr['pulse_n'],
    }
    tc._input_args = kwds
    last_result = tc
def analyze_response_strength(rec, source, remove_artifacts=False, lpf=True, bsub=True, lowpass=1000):
    """Perform a standardized strength analysis on a record selected by response_query or baseline_query.

    1. Determine timing of presynaptic stimulus pulse edges and spike
    2. Measure peak deflection on raw trace
    3. Apply deconvolution / artifact removal / lpf
    4. Measure peak deflection on deconvolved trace
    """
    data = Trace(rec.data, sample_rate=db.default_sample_rate)
    if source == 'pulse_response':
        # Find stimulus pulse edges for artifact removal
        start = rec.pulse_start - rec.rec_start
        pulse_times = [start, start + rec.pulse_dur]
        spike_time = rec.spike_time - rec.rec_start
    elif source == 'baseline':
        # Fake stimulus information to ensure that background data receives
        # the same filtering / windowing treatment
        pulse_times = [10e-3, 12e-3]
        spike_time = 11e-3
    else:
        raise ValueError("Invalid source %s" % source)

    results = {}

    results['raw_trace'] = data
    results['pulse_times'] = pulse_times
    results['spike_time'] = spike_time

    # Measure deflection on raw data
    results['pos_amp'], _ = measure_peak(data, '+', spike_time, pulse_times)
    results['neg_amp'], _ = measure_peak(data, '-', spike_time, pulse_times)

    # Deconvolution / artifact removal / filtering
    dec_data = deconv_filter(data, pulse_times, lpf=lpf, remove_artifacts=remove_artifacts, bsub=bsub, lowpass=lowpass)
    results['dec_trace'] = dec_data

    # Measure deflection on deconvolved data
    results['pos_dec_amp'], results['pos_dec_latency'] = measure_peak(dec_data, '+', spike_time, pulse_times)
    results['neg_dec_amp'], results['neg_dec_latency'] = measure_peak(dec_data, '-', spike_time, pulse_times)
    
    return results
Exemplo n.º 13
0
def plot_response_averages(expt, show_baseline=False, **kwds):
    analyzer = MultiPatchExperimentAnalyzer.get(expt)
    devs = analyzer.list_devs()

    # First get average evoked responses for all pre/post pairs
    responses, rows, cols = analyzer.get_evoked_response_matrix(**kwds)

    # resize plot grid accordingly
    plots = PlotGrid()
    plots.set_shape(len(rows), len(cols))
    plots.show() 
    
    ranges = [([], []), ([], [])]
    points = []

    # Plot each matrix element with PSP fit
    for i, dev1 in enumerate(rows):
        for j, dev2 in enumerate(cols):
            # select plot and hide axes
            plt = plots[i, j]
            if i < len(devs) - 1:
                plt.getAxis('bottom').setVisible(False)
            if j > 0:
                plt.getAxis('left').setVisible(False)

            if dev1 == dev2:
                plt.getAxis('bottom').setVisible(False)
                plt.getAxis('left').setVisible(False)
                continue
            
            # adjust axes / labels
            plt.setXLink(plots[0, 0])
            plt.setYLink(plots[0, 0])
            plt.addLine(x=10e-3, pen=0.3)
            plt.addLine(y=0, pen=0.3)
            plt.setLabels(bottom=(str(dev2), 's'))
            if kwds.get('clamp_mode', 'ic') == 'ic':
                plt.setLabels(left=('%s' % dev1, 'V'))
            else:
                plt.setLabels(left=('%s' % dev1, 'A'))

            
            # print "==========", dev1, dev2
            avg_response = responses[(dev1, dev2)].bsub_mean()
            if avg_response is not None:
                avg_response.t0 = 0
                t = avg_response.time_values
                y = bessel_filter(Trace(avg_response.data, dt=avg_response.dt), 2e3).data
                plt.plot(t, y, antialias=True)

                # fit!                
                #fit = responses[(dev1, dev2)].fit_psp(yoffset=0, mask_stim_artifact=(abs(dev1-dev2) < 3))
                #lsnr = np.log(fit.snr)
                #lerr = np.log(fit.nrmse())
                
                #color = (
                    #np.clip(255 * (-lerr/3.), 0, 255),
                    #np.clip(50 * lsnr, 0, 255),
                    #np.clip(255 * (1+lerr/3.), 0, 255)
                #)

                #plt.plot(t, fit.best_fit, pen=color)
                ## plt.plot(t, fit.init_fit, pen='y')

                #points.append({'x': lerr, 'y': lsnr, 'brush': color})

                #if show_baseline:
                    ## plot baseline for reference
                    #bl = avg_response.meta['baseline'] - avg_response.meta['baseline_med']
                    #plt.plot(np.arange(len(bl)) * avg_response.dt, bl, pen=(0, 100, 0), antialias=True)

                # keep track of data range across all plots
                ranges[0][0].append(y.min())
                ranges[0][1].append(y.max())
                ranges[1][0].append(t[0])
                ranges[1][1].append(t[-1])

    plots[0,0].setYRange(min(ranges[0][0]), max(ranges[0][1]))
    plots[0,0].setXRange(min(ranges[1][0]), max(ranges[1][1]))

    # scatter plot of SNR vs NRMSE
    plt = pg.plot()
    plt.setLabels(left='ln(SNR)', bottom='ln(NRMSE)')
    plt.plot([p['x'] for p in points], [p['y'] for p in points], pen=None, symbol='o', symbolBrush=[pg.mkBrush(p['brush']) for p in points])
    # show threshold line
    line = pg.InfiniteLine(pos=[0, 6], angle=180/np.pi * np.arctan(1))
    plt.addItem(line, ignoreBounds=True)

    return plots
Exemplo n.º 14
0
import pyqtgraph as pg
from pyqtgraph.Qt import QtGui, QtCore
import numpy as np
from neuroanalysis.data import Trace
from neuroanalysis.ui.event_detection import EventDetector
from neuroanalysis.ui.plot_grid import PlotGrid

pg.mkQApp()

data = np.load("test_data/synaptic_events/events1.npz")
trace_names = sorted([x for x in data.keys() if x.startswith('trace')])
traces = {
    n: Trace(data[n], dt=1.0 / data['sample_rates'][i])
    for i, n in enumerate(trace_names)
}

evd = EventDetector()
evd.params['threshold'] = 5e-10

hs = QtGui.QSplitter(QtCore.Qt.Horizontal)
pt = pg.parametertree.ParameterTree(showHeader=False)

params = pg.parametertree.Parameter.create(name='params',
                                           type='group',
                                           children=[
                                               dict(name='data',
                                                    type='list',
                                                    values=trace_names),
                                               evd.params,
                                           ])
def analyze_pair_connectivity(amps, sign=None):
    """Given response strength records for a single pair, generate summary
    statistics characterizing strength, latency, and connectivity.
    
    Parameters
    ----------
    amps : dict
        Contains foreground and background strength analysis records
        (see input format below)
    sign : None, -1, or +1
        If None, then automatically determine whether to treat this connection as
        inhibitory or excitatory.

    Input must have the following structure::
    
        amps = {
            ('ic', 'fg'): recs, 
            ('ic', 'bg'): recs,
            ('vc', 'fg'): recs, 
            ('vc', 'bg'): recs,
        }
        
    Where each *recs* must be a structured array containing fields as returned
    by get_amps() and get_baseline_amps().
    
    The overall strategy here is:
    
    1. Make an initial decision on whether to treat this pair as excitatory or
       inhibitory, based on differences between foreground and background amplitude
       measurements
    2. Generate mean and stdev for amplitudes, deconvolved amplitudes, and deconvolved
       latencies
    3. Generate KS test p values describing the differences between foreground
       and background distributions for amplitude, deconvolved amplitude, and
       deconvolved latency    
    """
    requested_sign = sign
    fields = {}  # used to fill the new DB record
    
    # Use KS p value to check for differences between foreground and background
    qc_amps = {}
    ks_pvals = {}
    amp_means = {}
    amp_diffs = {}
    for clamp_mode in ('ic', 'vc'):
        clamp_mode_fg = amps[clamp_mode, 'fg']
        clamp_mode_bg = amps[clamp_mode, 'bg']
        if (len(clamp_mode_fg) == 0 or len(clamp_mode_bg) == 0):
            continue
        for sign in ('pos', 'neg'):
            # Separate into positive/negative tests and filter out responses that failed qc
            qc_field = {'vc': {'pos': 'in_qc_pass', 'neg': 'ex_qc_pass'}, 'ic': {'pos': 'ex_qc_pass', 'neg': 'in_qc_pass'}}[clamp_mode][sign]
            fg = clamp_mode_fg[clamp_mode_fg[qc_field]]
            bg = clamp_mode_bg[clamp_mode_bg[qc_field]]
            qc_amps[sign, clamp_mode, 'fg'] = fg
            qc_amps[sign, clamp_mode, 'bg'] = bg
            if (len(fg) == 0 or len(bg) == 0):
                continue
            
            # Measure some statistics from these records
            fg = fg[sign + '_dec_amp']
            bg = bg[sign + '_dec_amp']
            pval = scipy.stats.ks_2samp(fg, bg).pvalue
            ks_pvals[(sign, clamp_mode)] = pval
            # we could ensure that the average amplitude is in the right direction:
            fg_mean = np.mean(fg)
            bg_mean = np.mean(bg)
            amp_means[sign, clamp_mode] = {'fg': fg_mean, 'bg': bg_mean}
            amp_diffs[sign, clamp_mode] = fg_mean - bg_mean

    if requested_sign is None:
        # Decide whether to treat this connection as excitatory or inhibitory.
        #   strategy: accumulate evidence for either possibility by checking
        #   the ks p-values for each sign/clamp mode and the direction of the deflection
        is_exc = 0
        # print(expt.acq_timestamp, pair.pre_cell.ext_id, pair.post_cell.ext_id)
        for sign in ('pos', 'neg'):
            for mode in ('ic', 'vc'):
                ks = ks_pvals.get((sign, mode), None)
                if ks is None:
                    continue
                # turn p value into a reasonable scale factor
                ks = norm_pvalue(ks)
                dif_sign = 1 if amp_diffs[sign, mode] > 0 else -1
                if mode == 'vc':
                    dif_sign *= -1
                is_exc += dif_sign * ks
                # print("    ", sign, mode, is_exc, dif_sign * ks)
    else:
        is_exc = requested_sign

    if is_exc > 0:
        fields['synapse_type'] = 'ex'
        signs = {'ic':'pos', 'vc':'neg'}
    else:
        fields['synapse_type'] = 'in'
        signs = {'ic':'neg', 'vc':'pos'}

    # compute the rest of statistics for only positive or negative deflections
    for clamp_mode in ('ic', 'vc'):
        sign = signs[clamp_mode]
        fg = qc_amps.get((sign, clamp_mode, 'fg'))
        bg = qc_amps.get((sign, clamp_mode, 'bg'))
        if fg is None or bg is None or len(fg) == 0 or len(bg) == 0:
            fields[clamp_mode + '_n_samples'] = 0
            continue
        
        fields[clamp_mode + '_n_samples'] = len(fg)
        fields[clamp_mode + '_crosstalk_mean'] = np.mean(fg['crosstalk'])
        fields[clamp_mode + '_base_crosstalk_mean'] = np.mean(bg['crosstalk'])
        
        # measure mean, stdev, and statistical differences between
        # fg and bg for each measurement
        for val, field in [('amp', 'amp'), ('deconv_amp', 'dec_amp'), ('latency', 'dec_latency')]:
            f = fg[sign + '_' + field]
            b = bg[sign + '_' + field]
            fields[clamp_mode + '_' + val + '_mean'] = np.mean(f)
            fields[clamp_mode + '_' + val + '_stdev'] = np.std(f)
            fields[clamp_mode + '_base_' + val + '_mean'] = np.mean(b)
            fields[clamp_mode + '_base_' + val + '_stdev'] = np.std(b)
            # statistical tests comparing fg vs bg
            # Note: we use log(1-log(pval)) because it's nicer to plot and easier to
            # use as a classifier input
            tt_pval = scipy.stats.ttest_ind(f, b, equal_var=False).pvalue
            ks_pval = scipy.stats.ks_2samp(f, b).pvalue
            fields[clamp_mode + '_' + val + '_ttest'] = norm_pvalue(tt_pval)
            fields[clamp_mode + '_' + val + '_ks2samp'] = norm_pvalue(ks_pval)


        ### generate the average response and psp fit
        
        # collect all bg and fg traces
        # bg_traces = TraceList([Trace(data, sample_rate=db.default_sample_rate) for data in amps[clamp_mode, 'bg']['data']])
        fg_traces = TraceList()
        for rec in fg:
            t0 = rec['response_start_time'] - rec['max_dvdt_time']   # time-align to presynaptic spike
            trace = Trace(rec['data'], sample_rate=db.default_sample_rate, t0=t0)
            fg_traces.append(trace)
        
        # get averages
        # bg_avg = bg_traces.mean()        
        fg_avg = fg_traces.mean()
        base_rgn = fg_avg.time_slice(-6e-3, 0)
        base = float_mode(base_rgn.data)
        fields[clamp_mode + '_average_response'] = fg_avg.data
        fields[clamp_mode + '_average_response_t0'] = fg_avg.t0
        fields[clamp_mode + '_average_base_stdev'] = base_rgn.std()

        sign = {'pos':'+', 'neg':'-'}[signs[clamp_mode]]
        fg_bsub = fg_avg.copy(data=fg_avg.data - base)  # remove base to help fitting
        try:
            fit = fit_psp(fg_bsub, mode=clamp_mode, sign=sign, xoffset=(1e-3, 0, 6e-3), yoffset=(0, None, None), rise_time_mult_factor=4)              
            for param, val in fit.best_values.items():
                fields['%s_fit_%s' % (clamp_mode, param)] = val
            fields[clamp_mode + '_fit_yoffset'] = fit.best_values['yoffset'] + base
            fields[clamp_mode + '_fit_nrmse'] = fit.nrmse()
        except:
            print("Error in PSP fit:")
            sys.excepthook(*sys.exc_info())
            continue
        
        #global fit_plot
        #if fit_plot is None:
            #fit_plot = FitExplorer(fit)
            #fit_plot.show()
        #else:
            #fit_plot.set_fit(fit)
        #raw_input("Waiting to continue..")

    return fields
    def add_connection_plots(i, name, timestamp, pre_id, post_id):
        global session, win, filtered
        p = pg.debug.Profiler(disabled=True, delayed=False)
        trace_plot = win.addPlot(i, 1)
        trace_plots.append(trace_plot)
        deconv_plot = win.addPlot(i, 2)
        deconv_plots.append(deconv_plot)
        hist_plot = win.addPlot(i, 3)
        hist_plots.append(hist_plot)
        limit_plot = win.addPlot(i, 4)
        limit_plot.addLegend()
        limit_plot.setLogMode(True, True)
        # Find this connection in the pair list
        idx = np.argwhere((abs(filtered['acq_timestamp'] - timestamp) < 1) & (filtered['pre_cell_id'] == pre_id) & (filtered['post_cell_id'] == post_id))
        if idx.size == 0:
            print("not in filtered connections")
            return
        idx = idx[0,0]
        p()

        # Mark the point in scatter plot
        scatter_plot.plot([background[idx]], [signal[idx]], pen='k', symbol='o', size=10, symbolBrush='r', symbolPen=None)
            
        # Plot example traces and histograms
        for plts in [trace_plots, deconv_plots]:
            plt = plts[-1]
            plt.setXLink(plts[0])
            plt.setYLink(plts[0])
            plt.setXRange(-10e-3, 17e-3, padding=0)
            plt.hideAxis('left')
            plt.hideAxis('bottom')
            plt.addLine(x=0)
            plt.setDownsampling(auto=True, mode='peak')
            plt.setClipToView(True)
            hbar = pg.QtGui.QGraphicsLineItem(0, 0, 2e-3, 0)
            hbar.setPen(pg.mkPen(color='k', width=5))
            plt.addItem(hbar)
            vbar = pg.QtGui.QGraphicsLineItem(0, 0, 0, 100e-6)
            vbar.setPen(pg.mkPen(color='k', width=5))
            plt.addItem(vbar)


        hist_plot.setXLink(hist_plots[0])
        
        pair = session.query(db.Pair).filter(db.Pair.id==filtered[idx]['pair_id']).all()[0]
        p()
        amps = strength_analysis.get_amps(session, pair)
        p()
        base_amps = strength_analysis.get_baseline_amps(session, pair)
        p()
        
        q = strength_analysis.response_query(session)
        p()
        q = q.join(strength_analysis.PulseResponseStrength)
        q = q.filter(strength_analysis.PulseResponseStrength.id.in_(amps['id']))
        q = q.join(db.Recording, db.Recording.id==db.PulseResponse.recording_id).join(db.PatchClampRecording).join(db.MultiPatchProbe)
        q = q.filter(db.MultiPatchProbe.induction_frequency < 100)
        # pre_cell = db.aliased(db.Cell)
        # post_cell = db.aliased(db.Cell)
        # q = q.join(db.Pair).join(db.Experiment).join(pre_cell, db.Pair.pre_cell_id==pre_cell.id).join(post_cell, db.Pair.post_cell_id==post_cell.id)
        # q = q.filter(db.Experiment.id==filtered[idx]['experiment_id'])
        # q = q.filter(pre_cell.ext_id==pre_id)
        # q = q.filter(post_cell.ext_id==post_id)

        fg_recs = q.all()
        p()

        traces = []
        deconvs = []
        for rec in fg_recs[:100]:
            result = strength_analysis.analyze_response_strength(rec, source='pulse_response', lpf=True, lowpass=2000,
                                                remove_artifacts=False, bsub=True)
            trace = result['raw_trace']
            trace.t0 = -result['spike_time']
            trace = trace - np.median(trace.time_slice(-0.5e-3, 0.5e-3).data)
            traces.append(trace)            
            trace_plot.plot(trace.time_values, trace.data, pen=(0, 0, 0, 20))

            trace = result['dec_trace']
            trace.t0 = -result['spike_time']
            trace = trace - np.median(trace.time_slice(-0.5e-3, 0.5e-3).data)
            deconvs.append(trace)            
            deconv_plot.plot(trace.time_values, trace.data, pen=(0, 0, 0, 20))

        # plot average trace
        mean = TraceList(traces).mean()
        trace_plot.plot(mean.time_values, mean.data, pen={'color':'g', 'width': 2}, shadowPen={'color':'k', 'width': 3}, antialias=True)
        mean = TraceList(deconvs).mean()
        deconv_plot.plot(mean.time_values, mean.data, pen={'color':'g', 'width': 2}, shadowPen={'color':'k', 'width': 3}, antialias=True)

        # add label
        label = pg.LabelItem(name)
        label.setParentItem(trace_plot)


        p("analyze_response_strength")

        # bins = np.arange(-0.0005, 0.002, 0.0001) 
        # field = 'pos_amp'
        bins = np.arange(-0.001, 0.015, 0.0005) 
        field = 'pos_dec_amp'
        n = min(len(amps), len(base_amps))
        hist_y, hist_bins = np.histogram(base_amps[:n][field], bins=bins)
        hist_plot.plot(hist_bins, hist_y, stepMode=True, pen=None, brush=(200, 0, 0, 150), fillLevel=0)
        hist_y, hist_bins = np.histogram(amps[:n][field], bins=bins)
        hist_plot.plot(hist_bins, hist_y, stepMode=True, pen='k', brush=(0, 150, 150, 100), fillLevel=0)
        p()

        pg.QtGui.QApplication.processEvents()


        # Plot detectability analysis
        q = strength_analysis.baseline_query(session)
        q = q.join(strength_analysis.BaselineResponseStrength)
        q = q.filter(strength_analysis.BaselineResponseStrength.id.in_(base_amps['id']))
        # q = q.limit(100)
        bg_recs = q.all()

        def clicked(sp, pts):
            traces = pts[0].data()['traces']
            print([t.amp for t in traces])
            plt = pg.plot()
            bsub = [t.copy(data=t.data - np.median(t.time_slice(0, 1e-3).data)) for t in traces]
            for t in bsub:
                plt.plot(t.time_values, t.data, pen=(0, 0, 0, 50))
            mean = TraceList(bsub).mean()
            plt.plot(mean.time_values, mean.data, pen='g')

        # first measure background a few times
        N = len(fg_recs)
        N = 50  # temporary for testing
        print("Testing %d trials" % N)


        bg_results = []
        M = 500
        print("  Grinding on %d background trials" % len(bg_recs))
        for i in range(M):
            amps = base_amps.copy()
            np.random.shuffle(amps)
            bg_results.append(np.median(amps[:N]['pos_dec_amp']) / np.std(amps[:N]['pos_dec_latency']))
            print("    %d/%d      \r" % (i, M),)
        print("    done.            ")
        print("    ", bg_results)


        # now measure foreground simulated under different conditions
        amps = 5e-6 * 2**np.arange(6)
        amps[0] = 0
        rtimes = 1e-3 * 1.71**np.arange(4)
        dt = 1 / db.default_sample_rate
        results = np.empty((len(amps), len(rtimes)), dtype=[('pos_dec_amp', float), ('latency_stdev', float), ('result', float), ('percentile', float), ('traces', object)])
        print("  Simulating synaptic events..")
        for j,rtime in enumerate(rtimes):
            for i,amp in enumerate(amps):
                trial_results = []
                t = np.arange(0, 15e-3, dt)
                template = Psp.psp_func(t, xoffset=0, yoffset=0, rise_time=rtime, decay_tau=15e-3, amp=1, rise_power=2)

                for l in range(20):
                    print("    %d/%d  %d/%d      \r" % (i,len(amps),j,len(rtimes)),)
                    r_amps = amp * 2**np.random.normal(size=N, scale=0.5)
                    r_latency = np.random.normal(size=N, scale=600e-6, loc=12.5e-3)
                    fg_results = []
                    traces = []
                    np.random.shuffle(bg_recs)
                    for k,rec in enumerate(bg_recs[:N]):
                        data = rec.data.copy()
                        start = int(r_latency[k] / dt)
                        length = len(rec.data) - start
                        rec.data[start:] += template[:length] * r_amps[k]

                        fg_result = strength_analysis.analyze_response_strength(rec, 'baseline')
                        fg_results.append((fg_result['pos_dec_amp'], fg_result['pos_dec_latency']))

                        traces.append(Trace(rec.data.copy(), dt=dt))
                        traces[-1].amp = r_amps[k]
                        rec.data[:] = data  # can't modify rec, so we have to muck with the array (and clean up afterward) instead
                    
                    fg_amp = np.array([r[0] for r in fg_results])
                    fg_latency = np.array([r[1] for r in fg_results])
                    trial_results.append(np.median(fg_amp) / np.std(fg_latency))
                results[i,j]['result'] = np.median(trial_results) / np.median(bg_results)
                results[i,j]['percentile'] = stats.percentileofscore(bg_results, results[i,j]['result'])
                results[i,j]['traces'] = traces

            assert all(np.isfinite(results[i]['pos_dec_amp']))
            print(i, results[i]['result'])
            print(i, results[i]['percentile'])
            

            # c = limit_plot.plot(rtimes, results[i]['result'], pen=(i, len(amps)*1.3), symbol='o', antialias=True, name="%duV"%(amp*1e6), data=results[i], symbolSize=4)
            # c.scatter.sigClicked.connect(clicked)
            # pg.QtGui.QApplication.processEvents()
            c = limit_plot.plot(amps, results[:,j]['result'], pen=(j, len(rtimes)*1.3), symbol='o', antialias=True, name="%dus"%(rtime*1e6), data=results[:,j], symbolSize=4)
            c.scatter.sigClicked.connect(clicked)
            pg.QtGui.QApplication.processEvents()

                
        pg.QtGui.QApplication.processEvents()
Exemplo n.º 17
0
def filter_pulse_responses(pair):
    ### get first pulse response if it passes qc for excitatory or inhibitory analysis

    # TODO: learn how to do what's below in one query
    # s = db.Session()
    # q = s.query(db.PulseResponse.data, db.StimSpike, db.PatchClampRecording)
    # q = q.join(db.StimPulse).join(db.StimSpike).join(db.PatchClampRecording)
    # filters = [
    #     (db.Pair == pair)
    #     (db.StimPulse.pulse_number == 1),
    #     (db.StimPulse.n_spikes == 1),
    #     (db.StimSpike.max_dvdt_time != None),
    #     (db.PulseResponse.ex_qc_pass == True)
    #     (db.PatchClampRecording.clamp_mode == 'ic')
    # ]
    #
    # for filter_arg in filters:
    #     q = q.filter(*filter_arg)

    synapse_type = pair.connection_strength.synapse_type
    pulse_responses = []
    pulse_response_amps = []
    pulse_ids = []
    for pr in pair.pulse_responses:
        stim_pulse = pr.stim_pulse
        n_spikes = stim_pulse.n_spikes
        pulse_number = stim_pulse.pulse_number
        pulse_id = pr.stim_pulse_id
        ex_qc_pass = pr.ex_qc_pass
        in_qc_pass = pr.in_qc_pass
        pcr = stim_pulse.recording.patch_clamp_recording
        stim_freq = pcr.multi_patch_probe[0].induction_frequency
        clamp_mode = pcr.clamp_mode
        # current clamp
        if clamp_mode != 'ic':
            continue
        # ensure that there was only 1 presynaptic spike
        if n_spikes != 1:
            continue
        # we only want the first pulse of the train
        if pulse_number != 1:
            continue
        # only include frequencies up to 50Hz
        if stim_freq > 50:
            continue

        data = pr.data
        start_time = pr.start_time
        spike_time = stim_pulse.spikes[0].max_dvdt_time
        data_trace = Trace(data=data,
                           t0=start_time - spike_time,
                           sample_rate=db.default_sample_rate)

        if synapse_type == 'ex' and ex_qc_pass is True:
            pulse_responses.append(data_trace)
            pulse_ids.append(pulse_id)
            pulse_response_amps.append(pr.pulse_response_strength.pos_amp)
        if synapse_type == 'in' and in_qc_pass is True:
            pulse_responses.append(data_trace)
            pulse_ids.append(pulse_id)
            pulse_response_amps.append(pr.pulse_response_strength.neg_amp)

    return pulse_responses, pulse_ids, pulse_response_amps
Exemplo n.º 18
0
def save_fit_psp_test_set():
    """NOTE THIS CODE DOES NOT WORK BUT IS HERE FOR DOCUMENTATION PURPOSES SO 
    THAT WE CAN TRACE BACK HOW THE TEST DATA WAS CREATED IF NEEDED.
    Create a test set of data for testing the fit_psp function.  Uses Steph's 
    original first_puls_feature.py code to filter out error causing data.
    
    Example run statement
    python save save_fit_psp_test_set.py --organism mouse --connection ee
    
    Comment in the code that does the saving at the bottom
    """

    import pyqtgraph as pg
    import numpy as np
    import csv
    import sys
    import argparse
    from multipatch_analysis.experiment_list import cached_experiments
    from manuscript_figures import get_response, get_amplitude, response_filter, feature_anova, write_cache, trace_plot, \
        colors_human, colors_mouse, fail_rate, pulse_qc, feature_kw
    from synapse_comparison import load_cache, summary_plot_pulse
    from neuroanalysis.data import TraceList, Trace
    from neuroanalysis.ui.plot_grid import PlotGrid
    from multipatch_analysis.connection_detection import fit_psp
    from rep_connections import ee_connections, human_connections, no_include, all_connections, ie_connections, ii_connections, ei_connections
    from multipatch_analysis.synaptic_dynamics import DynamicsAnalyzer
    from scipy import stats
    import time
    import pandas as pd
    import json
    import os

    app = pg.mkQApp()
    pg.dbg()
    pg.setConfigOption('background', 'w')
    pg.setConfigOption('foreground', 'k')

    parser = argparse.ArgumentParser(
        description=
        'Enter organism and type of connection you"d like to analyze ex: mouse ee (all mouse excitatory-'
        'excitatory). Alternatively enter a cre-type connection ex: sim1-sim1')
    parser.add_argument('--organism',
                        dest='organism',
                        help='Select mouse or human')
    parser.add_argument('--connection',
                        dest='connection',
                        help='Specify connections to analyze')
    args = vars(parser.parse_args(sys.argv[1:]))

    all_expts = cached_experiments()
    manifest = {
        'Type': [],
        'Connection': [],
        'amp': [],
        'latency': [],
        'rise': [],
        'rise2080': [],
        'rise1090': [],
        'rise1080': [],
        'decay': [],
        'nrmse': [],
        'CV': []
    }
    fit_qc = {'nrmse': 8, 'decay': 499e-3}

    if args['organism'] == 'mouse':
        color_palette = colors_mouse
        calcium = 'high'
        age = '40-60'
        sweep_threshold = 3
        threshold = 0.03e-3
        connection = args['connection']
        if connection == 'ee':
            connection_types = ee_connections.keys()
        elif connection == 'ii':
            connection_types = ii_connections.keys()
        elif connection == 'ei':
            connection_types = ei_connections.keys()
        elif connection == 'ie':
            connection_types == ie_connections.keys()
        elif connection == 'all':
            connection_types = all_connections.keys()
        elif len(connection.split('-')) == 2:
            c_type = connection.split('-')
            if c_type[0] == '2/3':
                pre_type = ('2/3', 'unknown')
            else:
                pre_type = (None, c_type[0])
            if c_type[1] == '2/3':
                post_type = ('2/3', 'unknown')
            else:
                post_type = (None, c_type[0])
            connection_types = [(pre_type, post_type)]
    elif args['organism'] == 'human':
        color_palette = colors_human
        calcium = None
        age = None
        sweep_threshold = 5
        threshold = None
        connection = args['connection']
        if connection == 'ee':
            connection_types = human_connections.keys()
        else:
            c_type = connection.split('-')
            connection_types = [((c_type[0], 'unknown'), (c_type[1],
                                                          'unknown'))]

    plt = pg.plot()

    scale_offset = (-20, -20)
    scale_anchor = (0.4, 1)
    holding = [-65, -75]
    qc_plot = pg.plot()
    grand_response = {}
    expt_ids = {}
    feature_plot = None
    feature2_plot = PlotGrid()
    feature2_plot.set_shape(5, 1)
    feature2_plot.show()
    feature3_plot = PlotGrid()
    feature3_plot.set_shape(1, 3)
    feature3_plot.show()
    amp_plot = pg.plot()
    synapse_plot = PlotGrid()
    synapse_plot.set_shape(len(connection_types), 1)
    synapse_plot.show()
    for c in range(len(connection_types)):
        cre_type = (connection_types[c][0][1], connection_types[c][1][1])
        target_layer = (connection_types[c][0][0], connection_types[c][1][0])
        conn_type = connection_types[c]
        expt_list = all_expts.select(cre_type=cre_type,
                                     target_layer=target_layer,
                                     calcium=calcium,
                                     age=age)
        color = color_palette[c]
        grand_response[conn_type[0]] = {
            'trace': [],
            'amp': [],
            'latency': [],
            'rise': [],
            'dist': [],
            'decay': [],
            'CV': [],
            'amp_measured': []
        }
        expt_ids[conn_type[0]] = []
        synapse_plot[c, 0].addLegend()
        for expt in expt_list:
            for pre, post in expt.connections:
                if [expt.uid, pre, post] in no_include:
                    continue
                cre_check = expt.cells[pre].cre_type == cre_type[
                    0] and expt.cells[post].cre_type == cre_type[1]
                layer_check = expt.cells[pre].target_layer == target_layer[
                    0] and expt.cells[post].target_layer == target_layer[1]
                if cre_check is True and layer_check is True:
                    pulse_response, artifact = get_response(
                        expt, pre, post, analysis_type='pulse')
                    if threshold is not None and artifact > threshold:
                        continue
                    response_subset, hold = response_filter(
                        pulse_response,
                        freq_range=[0, 50],
                        holding_range=holding,
                        pulse=True)
                    if len(response_subset) >= sweep_threshold:
                        qc_plot.clear()
                        qc_list = pulse_qc(response_subset,
                                           baseline=1.5,
                                           pulse=None,
                                           plot=qc_plot)
                        if len(qc_list) >= sweep_threshold:
                            avg_trace, avg_amp, amp_sign, peak_t = get_amplitude(
                                qc_list)
                            #                        if amp_sign is '-':
                            #                            continue
                            #                        #print ('%s, %0.0f' %((expt.uid, pre, post), hold, ))
                            #                        all_amps = fail_rate(response_subset, '+', peak_t)
                            #                        cv = np.std(all_amps)/np.mean(all_amps)
                            #
                            #                        # weight parts of the trace during fitting
                            dt = avg_trace.dt
                            weight = np.ones(
                                len(avg_trace.data
                                    )) * 10.  #set everything to ten initially
                            weight[int(10e-3 / dt):int(
                                12e-3 / dt)] = 0.  #area around stim artifact
                            weight[int(12e-3 / dt):int(
                                19e-3 / dt)] = 30.  #area around steep PSP rise

                            # check if the test data dir is there and if not create it
                            test_data_dir = 'test_psp_fit'
                            if not os.path.isdir(test_data_dir):
                                os.mkdir(test_data_dir)

                            save_dict = {}
                            save_dict['input'] = {
                                'data': avg_trace.data.tolist(),
                                'dtype': str(avg_trace.data.dtype),
                                'dt': float(avg_trace.dt),
                                'amp_sign': amp_sign,
                                'yoffset': 0,
                                'xoffset': 14e-3,
                                'avg_amp': float(avg_amp),
                                'method': 'leastsq',
                                'stacked': False,
                                'rise_time_mult_factor': 10.,
                                'weight': weight.tolist()
                            }

                            # need to remake trace because different output is created
                            avg_trace_simple = Trace(
                                data=np.array(save_dict['input']['data']),
                                dt=save_dict['input']
                                ['dt'])  # create Trace object

                            psp_fits_original = fit_psp(
                                avg_trace,
                                sign=save_dict['input']['amp_sign'],
                                yoffset=save_dict['input']['yoffset'],
                                xoffset=save_dict['input']['xoffset'],
                                amp=save_dict['input']['avg_amp'],
                                method=save_dict['input']['method'],
                                stacked=save_dict['input']['stacked'],
                                rise_time_mult_factor=save_dict['input']
                                ['rise_time_mult_factor'],
                                fit_kws={
                                    'weights': save_dict['input']['weight']
                                })

                            psp_fits_simple = fit_psp(
                                avg_trace_simple,
                                sign=save_dict['input']['amp_sign'],
                                yoffset=save_dict['input']['yoffset'],
                                xoffset=save_dict['input']['xoffset'],
                                amp=save_dict['input']['avg_amp'],
                                method=save_dict['input']['method'],
                                stacked=save_dict['input']['stacked'],
                                rise_time_mult_factor=save_dict['input']
                                ['rise_time_mult_factor'],
                                fit_kws={
                                    'weights': save_dict['input']['weight']
                                })
                            print expt.uid, pre, post
                            if psp_fits_original.nrmse(
                            ) != psp_fits_simple.nrmse():
                                print '  the nrmse values dont match'
                                print '\toriginal', psp_fits_original.nrmse()
                                print '\tsimple', psp_fits_simple.nrmse()
Exemplo n.º 19
0
def test_trace_timing():
    # Make sure sample timing is handled exactly--need to avoid fp error here
    a = np.random.normal(size=300)
    sr = 50000
    dt = 2e-5
    t = np.arange(len(a)) * dt
    
    # trace with no timing information 
    tr = Trace(a)
    assert not tr.has_timing
    assert not tr.has_time_values
    with raises(TypeError):
        tr.dt
    with raises(TypeError):
        tr.sample_rate
    with raises(TypeError):
        tr.time_values
        
    view = tr[100:200]
    assert not tr.has_timing
    assert not tr.has_time_values

    # invalid data
    with raises(ValueError):
        Trace(data=np.zeros((10, 10)))

    # invalid timing information
    with raises(TypeError):
        Trace(data=a, dt=dt, time_values=t)
    with raises(TypeError):
        Trace(data=a, sample_rate=sr, time_values=t)
    with raises(TypeError):
        Trace(data=a, dt=dt, t0=0, time_values=t)
    with raises(TypeError):
        Trace(data=a, dt=dt, t0=0, sample_rate=sr)
    with raises(ValueError):
        Trace(data=a, time_values=t[:-1])

    # trace with only dt
    tr = Trace(a, dt=dt)
    assert tr.dt == dt
    assert np.allclose(tr.sample_rate, sr)
    assert np.all(tr.time_values == t)
    assert tr.has_timing
    assert not tr.has_time_values
    assert tr.regularly_sampled

    # test view
    view = tr[100:200]
    assert view.t0 == tr.time_values[100]
    assert view.time_values[0] == view.t0
    assert view.dt == tr.dt
    assert view._meta['sample_rate'] is None
    assert not view.has_time_values

    view = tr.time_slice(100*dt, 200*dt)
    assert view.t0 == tr.time_values[100]
    assert view.time_values[0] == view.t0
    assert view.dt == tr.dt
    assert view._meta['sample_rate'] is None
    assert not view.has_time_values
    
    # test nested view
    view2 = view.time_slice(view.t0 + 20*dt, view.t0 + 50*dt)
    assert view2.t0 == view.time_values[20] == tr.time_values[120]
    
    # trace with only sample_rate
    tr = Trace(a, sample_rate=sr)
    assert tr.dt == dt
    assert tr.sample_rate == sr
    assert np.all(tr.time_values == t)
    assert tr.has_timing
    assert not tr.has_time_values
    assert tr.regularly_sampled

    # test view
    view = tr[100:200]
    assert view.t0 == tr.time_values[100]
    assert view.time_values[0] == view.t0
    assert view.sample_rate == tr.sample_rate
    assert view._meta['dt'] is None
    assert not view.has_time_values
    
    
    # trace with only regularly-sampled time_values
    tr = Trace(a, time_values=t)
    assert tr.dt == dt
    assert np.allclose(tr.sample_rate, sr)
    assert np.all(tr.time_values == t)
    assert tr.has_timing
    assert tr.has_time_values
    assert tr.regularly_sampled
    
    # test view
    view = tr[100:200]
    assert view.t0 == tr.time_values[100]
    assert view.time_values[0] == view.t0
    assert view._meta['dt'] is None
    assert view._meta['sample_rate'] is None
    assert view.has_time_values
    assert view.regularly_sampled
    

    # trace with irregularly-sampled time values
    t1 = np.cumsum(np.random.normal(loc=1, scale=0.02, size=a.shape))
    tr = Trace(a, time_values=t1)
    assert tr.dt == t1[1] - t1[0]
    assert np.all(tr.time_values == t1)
    assert tr.has_timing
    assert tr.has_time_values
    assert not tr.regularly_sampled

    # test view
    view = tr[100:200]
    assert view.t0 == tr.time_values[100]
    assert view.time_values[0] == view.t0
    assert view._meta['dt'] is None
    assert view._meta['sample_rate'] is None
    assert view.has_time_values
    assert not view.regularly_sampled
    
Exemplo n.º 20
0
def fit_single_first_pulse(pr, pair):
    #TODO: HAS THE APPROPRIATE QC HAPPENED?
    message = None #initialize error message for downstream processing 
    # excitatory or inhibitory?
    excitation = pair.connection_strength.synapse_type
    if not excitation:
        raise Exception('there is no synapse_type in connection_strength')

    if excitation == 'in':
        if not pr.in_qc_pass:
            return {'error': 'this pulse does not pass inhibitory qc'}    
    if excitation == 'ex':
        if not pr.ex_qc_pass:
            return {'error': 'this pulse does not pass excitatory qc'}

    # get response latency from average first pulse table
    if not pair.avg_first_pulse_fit:
        return {'error': 'no entry in avg_first_pulse_fit table for this pair'}
        

    if pr.clamp_mode == 'vc':
        weight_i = np.array([0])
        latency_i = None
        amp_i = None
        rise_time_i = None
        decay_tau_i = None
        data_waveform_i = np.array([0])
        fit_waveform_i = np.array([0])
        dt_i = None
        nrmse_i = None
        if pair.avg_first_pulse_fit.vc_latency:
            data_trace = Trace(data=pr.data, 
                t0= pr.response_start_time - pr.spike_time + time_before_spike, 
                sample_rate=db.default_sample_rate).time_slice(start=0, stop=None)
            xoffset = pair.avg_first_pulse_fit.vc_latency
            # weight and fit the trace    
            weight_v = np.ones(len(data_trace.data))*10.  #set everything to ten initially
            weight_v[int((time_before_spike+.0001+xoffset)/data_trace.dt):int((time_before_spike+.0001+xoffset+4e-3)/data_trace.dt)] = 30.  #area around steep PSP rise 
            fit_v = fit_trace(data_trace, excitation=excitation, clamp_mode = 'vc', weight=weight_v, latency=xoffset, latency_jitter=.5e-3)
            latency_v = fit_v.best_values['xoffset'] - time_before_spike
            amp_v = fit_v.best_values['amp']
            rise_time_v = fit_v.best_values['rise_time']
            decay_tau_v = fit_v.best_values['decay_tau']
            data_waveform_v = data_trace.data
            fit_waveform_v = fit_v.best_fit
            dt_v = data_trace.dt
            nrmse_v = fit_v.nrmse()

        else:
            return {'error': 'no vc_latency available from avg_first_pulse_fit table'} #no row will be made in the table because the error message is not none               

    elif pr.clamp_mode == 'ic':
        # set voltage to none since this is current clamp
        weight_v = np.array([0])
        latency_v = None
        amp_v = None
        rise_time_v = None
        decay_tau_v = None
        data_waveform_v = np.array([0])
        fit_waveform_v = np.array([0])
        dt_v = None
        nrmse_v = None
        if pair.avg_first_pulse_fit.ic_latency:
            data_trace = Trace(data=pr.data, 
                t0= pr.response_start_time - pr.spike_time + time_before_spike, 
                sample_rate=db.default_sample_rate).time_slice(start=0, stop=None)  #TODO: annoys me that this is repetitive in vc code above.
            xoffset = pair.avg_first_pulse_fit.ic_latency
            # weight and fit the trace
            weight_i = np.ones(len(data_trace.data)) * 10.  #set everything to ten initially
            weight_i[int((time_before_spike-3e-3) / data_trace.dt):int(time_before_spike / data_trace.dt)] = 0.   #area around stim artifact note that since this is spike aligned there will be some blur in where the cross talk is
            weight_i[int((time_before_spike+.0001+xoffset) / data_trace.dt):int((time_before_spike+.0001+xoffset+4e-3) / data_trace.dt)] = 30.  #area around steep PSP rise 
            fit_i = fit_trace(data_trace, excitation=excitation, weight=weight_i, latency=xoffset, latency_jitter=.5e-3)
            latency_i = fit_i.best_values['xoffset'] - time_before_spike
            amp_i = fit_i.best_values['amp']
            rise_time_i = fit_i.best_values['rise_time']
            decay_tau_i = fit_i.best_values['decay_tau']
            data_waveform_i = data_trace.data
            fit_waveform_i = fit_i.best_fit
            dt_i = data_trace.dt
            nrmse_i = fit_i.nrmse()

        else:
            return {'error': 'no ic_latency available from avg_first_pulse_fit table'} #no row will be made in the table because the error message is not none

    else:
        raise Exception('There is no clamp mode associated with this pulse')

    #------------ done with fitting section ------------------------------

    # dictionary for ease of translation into the output table
    out_dict = {
        'ic_amp': amp_i,
        'ic_latency': latency_i,
        'ic_rise_time': rise_time_i,
        'ic_decay_tau': decay_tau_i,
        'ic_psp_data': data_waveform_i,
        'ic_psp_fit': fit_waveform_i,
        'ic_dt': dt_i,
        'ic_nrmse': nrmse_i,

        'vc_amp': amp_v,
        'vc_latency': latency_v,
        'vc_rise_time': rise_time_v,
        'vc_decay_tau': decay_tau_v,
        'vc_psp_data':data_waveform_v,
        'vc_psp_fit': fit_waveform_v,
        'vc_dt': dt_v,
        'vc_nrmse': nrmse_v,

        'error': message
    } 
    
    return out_dict
Exemplo n.º 21
0
 def recorded_tseries(self):
     if self._rec_tseries is None:
         self._rec_tseries = Trace(self.data, sample_rate=default_sample_rate, t0=self.data_start_time)
     return self._rec_tseries
Exemplo n.º 22
0
 def post_tseries(self):
     if self._post_tseries is None:
         self._post_tseries = Trace(self.data, sample_rate=default_sample_rate, t0=self.start_time)
     return self._post_tseries
Exemplo n.º 23
0
        pulse_response
        join stim_pulse on pulse_response.pulse_id=stim_pulse.id
        join recording post_rec on pulse_response.recording_id=post_rec.id
        join patch_clamp_recording post_pcrec on post_pcrec.recording_id=post_rec.id
        join multi_patch_probe on multi_patch_probe.patch_clamp_recording_id=post_pcrec.id
        join recording pre_rec on stim_pulse.recording_id=pre_rec.id
        join sync_rec on post_rec.sync_rec_id=sync_rec.id
        join experiment on sync_rec.experiment_id=experiment.id
    where
        {conditions}
""".format(conditions='\n        and '.join(conditions))

print(query)

rp = session.execute(query)

recs = rp.fetchall()
data = [np.load(io.BytesIO(rec[0])) for rec in recs]
print("\n\nloaded %d records" % len(data))

plt = pg.plot(labels={'left': ('Vm', 'V')})
traces = TraceList()
for i, x in enumerate(data):
    trace = Trace(x - np.median(x[:100]), sample_rate=20000)
    traces.append(trace)
    if i < 100:
        plt.plot(trace.time_values, trace.data, pen=(255, 255, 255, 100))

avg = traces.mean()
plt.plot(avg.time_values, avg.data, pen='g')
Exemplo n.º 24
0
    on_times = np.argwhere(diff > 0)[:, 0]
    off_times = np.argwhere(diff < 0)[:, 0]

    # decide on the region of the trace to focus on
    start = on_times[1] - 1000
    stop = off_times[8] + 1000
    chunk = trace[start:stop]

    # plot the selected chunk
    t = np.arange(chunk.shape[0]) * dt
    plot.plot(t[:-1], np.diff(ndi.gaussian_filter(chunk, sigma)), pen=0.5)
    plot.plot(t, chunk)

    # detect spike times
    peak_inds = []
    rise_inds = []
    for j in range(8):  # loop over pulses
        pstart = on_times[j + 1] - start
        pstop = off_times[j + 1] - start
        spike_info = detect_vc_evoked_spike(Trace(chunk, dt=dt),
                                            pulse_edges=(pstart, pstop))
        if spike_info is not None:
            peak_inds.append(spike_info['peak_index'])
            rise_inds.append(spike_info['rise_index'])

    # display spike rise and peak times as ticks
    pticks = pg.VTickGroup(np.array(peak_inds) * dt, yrange=[0, 0.3], pen='r')
    rticks = pg.VTickGroup(np.array(rise_inds) * dt, yrange=[0, 0.3], pen='y')
    plot.addItem(pticks)
    plot.addItem(rticks)
Exemplo n.º 25
0
def extract_first_pulse_info_from_Pair_object(pair, desired_clamp='ic'):
    """Extract first pulse responses and relevant information 
    from entry in the pair database. Screen out pulses that are
    not current clamp or do not pass the corresponding
    inhibitory or excitatory qc.
    
    Input
    -----
    pair: multipatch_analysis.database.database.Pair object
    desired_clamp: string
        Specifies whether current or voltage clamp sweeps are desired.
        Options are:
            'ic': current clamp
            'vc': voltage clamp
    
    Return
    ------
    pulse_responses: TraceList 
        traces where the start of each trace is 10 ms before the spike 
    pulse_ids: list of ints
        pulse ids of *pulse_responses*
    psp_amps_measured: list of floats
        amplitude of *pulse_responses* from the *pulse_response* table
    stim_freq: list of floats
        the stimulation frequency corresponding to the *pulse_responses* 
    """

    if pair.connection_strength is None:
        # print ("\t\tSKIPPING: pair_id %s, is not yielding pair.connection_strength" % pair.id)
        return [], [], [], []
    if pair.connection_strength.synapse_type is None:
        # print ("\t\tSKIPPING: pair_id %s, is not yielding pair.connection_strength.synapse_type" % pair.id)
        return [], [], [], []
    synapse_type = pair.connection_strength.synapse_type
    pulse_responses = []
    psp_amps_measured = []
    pulse_ids = []
    stim_freqs = []
    if len(pair.pulse_responses)==0:
        # print ("\t\tSKIPPING: pair_id %s, no pulse responses in pair table" % (pair.id))
        return [], [], [], []
    for pr in pair.pulse_responses:
        stim_pulse = pr.stim_pulse
        n_spikes = stim_pulse.n_spikes
        pulse_number = stim_pulse.pulse_number
        pulse_id = pr.stim_pulse_id
        ex_qc_pass = pr.ex_qc_pass
        in_qc_pass = pr.in_qc_pass
        pcr = stim_pulse.recording.patch_clamp_recording
        stim_freq = pcr.multi_patch_probe[0].induction_frequency
        clamp_mode = pcr.clamp_mode
        # current clamp
        if clamp_mode != desired_clamp:
            continue
        # ensure that there was only 1 presynaptic spike
        if n_spikes != 1:
            continue
        # we only want the first pulse of the train
        if pulse_number != 1:
            continue

        data = pr.data
        start_time = pr.start_time
        spike_time = stim_pulse.spikes[0].max_dvdt_time        
        data_trace = Trace(data=data, t0= start_time-spike_time+time_before_spike, sample_rate=db.default_sample_rate).time_slice(start=0, stop=None) #start of the data is the spike time

        # append to output lists if neurons pass qc
        if (synapse_type == 'ex' and ex_qc_pass is True) or (synapse_type == 'in' and in_qc_pass is True):
            pulse_responses.append(data_trace)
            pulse_ids.append(pulse_id)
            stim_freqs.append(stim_freq)        
        if synapse_type == 'in' and in_qc_pass is True:
            psp_amps_measured.append(pr.pulse_response_strength.neg_amp)
        if synapse_type == 'ex' and ex_qc_pass is True:
            psp_amps_measured.append(pr.pulse_response_strength.pos_amp)

    return pulse_responses, pulse_ids, psp_amps_measured, stim_freq