def main(nwb_file, output_dir, project, **kwargs):
    nwb = MiesNwb(nwb_file)

    # SPECIFICS FOR EXAMPLE NWB =========

    # Only analyze one channel at a time
    channel = 0

    # We can work out code to automatically extract these based on stimulus names later.
    if_sweep_inds = [39, 45]
    targetv_sweep_inds = [15, 21]

    # END SPECIFICS =====================

    # Assemble all Recordings and convert to Sweeps
    supra_sweep_ids = list(range(*if_sweep_inds))
    sub_sweep_ids = list(range(*targetv_sweep_inds))

    supra_recs = [nwb.contents[i][channel] for i in supra_sweep_ids]
    sub_recs = [nwb.contents[i][channel] for i in sub_sweep_ids]

    # Build sweep sets
    lsq_supra_sweep_list, lsq_supra_dur = recs_to_sweeps(supra_recs)
    lsq_sub_sweep_list, lsq_sub_dur = recs_to_sweeps(sub_recs)
    lsq_supra_sweeps = SweepSet(lsq_supra_sweep_list)
    lsq_sub_sweeps = SweepSet(lsq_sub_sweep_list)

    lsq_supra_start = 0
    lsq_supra_end = lsq_supra_dur
    lsq_sub_start = 0
    lsq_sub_end = lsq_sub_dur

    # Pre-process sweeps
    lsq_supra_spx, lsq_supra_spfx = dsf.extractors_for_sweeps(
        lsq_supra_sweeps, start=lsq_supra_start, end=lsq_supra_end)
    lsq_supra_an = spa.LongSquareAnalysis(lsq_supra_spx,
                                          lsq_supra_spfx,
                                          subthresh_min_amp=-100.,
                                          require_subthreshold=False)
    lsq_supra_features = lsq_supra_an.analyze(lsq_supra_sweeps)

    lsq_sub_spx, lsq_sub_spfx = dsf.extractors_for_sweeps(lsq_sub_sweeps,
                                                          start=lsq_sub_start,
                                                          end=lsq_sub_end)
    lsq_sub_an = spa.LongSquareAnalysis(lsq_sub_spx,
                                        lsq_sub_spfx,
                                        subthresh_min_amp=-100.,
                                        require_suprathreshold=False)
    lsq_sub_features = lsq_sub_an.analyze(lsq_sub_sweeps)

    # Calculate feature vectors
    result = {}
    (subthresh_hyperpol_dict, hyperpol_deflect_dict
     ) = fv.identify_subthreshold_hyperpol_with_amplitudes(
         lsq_sub_features, lsq_sub_sweeps)
    target_amps_for_step_subthresh = [-90, -70, -50, -30, -10]
    result["step_subthresh"] = fv.step_subthreshold(
        subthresh_hyperpol_dict,
        target_amps_for_step_subthresh,
        lsq_sub_start,
        lsq_sub_end,
        amp_tolerance=5)
    result["subthresh_norm"] = fv.subthresh_norm(subthresh_hyperpol_dict,
                                                 hyperpol_deflect_dict,
                                                 lsq_sub_start, lsq_sub_end)

    (subthresh_depol_dict,
     depol_deflect_dict) = fv.identify_subthreshold_depol_with_amplitudes(
         lsq_supra_features, lsq_supra_sweeps)
    result["subthresh_depol_norm"] = fv.subthresh_depol_norm(
        subthresh_depol_dict, depol_deflect_dict, lsq_supra_start,
        lsq_supra_end)
    isi_sweep, isi_sweep_spike_info = fv.identify_sweep_for_isi_shape(
        lsq_supra_sweeps, lsq_supra_features, lsq_supra_end - lsq_supra_start)
    result["isi_shape"] = fv.isi_shape(isi_sweep, isi_sweep_spike_info,
                                       lsq_supra_end)

    # Calculate AP waveform from long squares
    rheo_ind = lsq_supra_features["rheobase_sweep"].name
    sweep = lsq_supra_sweeps.sweeps[rheo_ind]
    lsq_ap_v, lsq_ap_dv = fv.first_ap_vectors(
        [sweep], [lsq_supra_features["spikes_set"][rheo_ind]],
        window_length=ap_window_length)

    result["first_ap_v"] = lsq_ap_v
    result["first_ap_dv"] = lsq_ap_dv

    target_amplitudes = np.arange(0, 120, 20)
    supra_info_list = fv.identify_suprathreshold_sweep_sequence(
        lsq_supra_features, target_amplitudes, shift=10)
    result["psth"] = fv.psth_vector(supra_info_list, lsq_supra_start,
                                    lsq_supra_end)
    result["inst_freq"] = fv.inst_freq_vector(supra_info_list, lsq_supra_start,
                                              lsq_supra_end)
    spike_feature_list = [
        "upstroke_downstroke_ratio",
        "peak_v",
        "fast_trough_v",
        "threshold_v",
        "width",
    ]
    for feature in spike_feature_list:
        result["spiking_" + feature] = fv.spike_feature_vector(
            feature, supra_info_list, lsq_supra_start, lsq_supra_end)

    # Save the results
    specimen_ids = [0]
    results = [result]

    filtered_set = [(i, r) for i, r in zip(specimen_ids, results)
                    if not "error" in r.keys()]
    error_set = [{
        "id": i,
        "error": d
    } for i, d in zip(specimen_ids, results) if "error" in d.keys()]
    if len(filtered_set) == 0:
        logging.info("No specimens had results")
        return

    with open(os.path.join(output_dir, "fv_errors_{:s}.json".format(project)),
              "w") as f:
        json.dump(error_set, f, indent=4)

    used_ids, results = zip(*filtered_set)
    logging.info("Finished with {:d} processed specimens".format(
        len(used_ids)))

    k_sizes = {}
    for k in results[0].keys():
        if k not in k_sizes and results[0][k] is not None:
            k_sizes[k] = len(results[0][k])
        data = np.array([
            r[k] if k in r else np.nan * np.zeros(k_sizes[k]) for r in results
        ])
        if len(data.shape) == 1:  # it'll be 1D if there's just one specimen
            data = np.reshape(data, (1, -1))
        if data.shape[0] < len(used_ids):
            logging.warn("Missing data!")
            missing = np.array([k not in r for r in results])
            print(k, np.array(used_ids)[missing])
        np.save(
            os.path.join(output_dir, "fv_{:s}_{:s}.npy".format(k, project)),
            data)

    np.save(os.path.join(output_dir, "fv_ids_{:s}.npy".format(project)),
            used_ids)
Exemplo n.º 2
0
        self.channels = channels
        self.update_analysis()

    def update_analysis(self, param, changes):
        """Called when the user changes control parameters.
        """
        pass


if __name__ == '__main__':
    import sys
    from pprint import pprint
    pg.dbg()

    filename = sys.argv[1]
    nwb = MiesNwb(filename)
    # sweeps = nwb.sweeps()
    # traces = sweeps[0].traces()
    # # pprint(traces[0].meta())
    # groups = nwb.sweep_groups()
    # for i,g in enumerate(groups):
    #     print "--------", i, g
    #     print g.describe()

    # d = groups[7].data()
    # print d.shape

    app = pg.mkQApp()
    w = MiesNwbViewer(nwb)
    w.show()
    # w.show_group(7)
import sys
import user
import pyqtgraph as pg
pg.dbg()

from neuroanalysis.miesnwb import MiesNwb
from multipatch_analysis.experiment_list import cached_experiments
from multipatch_analysis.connection_detection import plot_response_averages

arg = sys.argv[1]
try:
    expt_ind = arg
    all_expts = cached_experiments()
    expt = all_expts[expt_ind].data
except ValueError:
    expt_file = arg
    expt = MiesNwb(expt_file)

plots = plot_response_averages(expt,
                               show_baseline=True,
                               clamp_mode='ic',
                               min_duration=25e-3,
                               pulse_ids=None)

# detect_connections(expt)

if sys.flags.interactive == 0:
    pg.mkQApp().exec_()
Exemplo n.º 4
0
    def _load_nwb(self, nwb_handle):
        self.nwb_handle = nwb_handle
        self.nwb = MiesNwb(nwb_handle.name())

        # load all recordings
        recs = {}
        for srec in self.nwb.contents:
            for chan in srec.devices:
                recs.setdefault(chan, []).append(srec[chan])

        chans = sorted(recs.keys())

        # find time of first recording
        start_time = min([rec[0].start_time for rec in recs.values()])
        self.start_time = start_time
        end_time = max([rec[-1].start_time for rec in recs.values()])
        self.plots.setXRange(0, (end_time - start_time).seconds)

        # plot all recordings
        for i, chan in enumerate(chans):
            n_recs = len(recs[chan])
            times = np.empty(n_recs)
            i_hold = np.empty(n_recs)
            v_hold = np.empty(n_recs)
            v_noise = np.empty(n_recs)
            i_noise = np.empty(n_recs)

            # load QC metrics for all recordings
            for j, rec in enumerate(recs[chan]):
                dt = (rec.start_time - start_time).seconds
                times[j] = dt
                v_hold[j] = rec.baseline_potential
                i_hold[j] = rec.baseline_current
                if rec.clamp_mode == 'vc':
                    v_noise[j] = np.nan
                    i_noise[j] = rec.baseline_rms_noise
                else:
                    v_noise[j] = rec.baseline_rms_noise
                    i_noise[j] = np.nan

            # scale all qc metrics to the range 0-1
            pass_brush = pg.mkBrush(100, 100, 255, 200)
            fail_brush = pg.mkBrush(255, 0, 0, 200)
            v_hold = (v_hold + 60e-3) / 20e-3
            i_hold = i_hold / 400e-12
            v_noise = v_noise / 5e-3
            i_noise = i_noise / 100e-12

            plt = self.get_channel_plot(chan)
            plt.setLabels(left=("Ch %d" % chan))
            for data, symbol in [(np.zeros_like(times), 'o'), (v_hold, 't'),
                                 (i_hold, 'x'), (v_noise, 't1'),
                                 (i_noise, 'x')]:
                brushes = np.where(np.abs(data) > 1.0, fail_brush, pass_brush)
                plt.plot(times,
                         data,
                         pen=None,
                         symbol=symbol,
                         symbolPen=None,
                         symbolBrush=brushes)

        for i in recs.keys():
            start = (recs[i][0].start_time - start_time).seconds - 1
            stop = (recs[i][-1].start_time - start_time).seconds + 1
            pip_param = self.params.child('Pipette %d' % (i + 1))
            pip_param.set_time_range(start, stop)

            got_data = len(recs[i]) > 2
            pip_param['got data'] = got_data
Exemplo n.º 5
0
def load_nwb(filename):
    global nwb
    nwb = MiesNwb(filename)
    v.set_nwb(nwb)
    def _update_plots(self, auto_range=False):
        sweeps = self.sweeps
        chans = self.channels
        self.plots.clear()
        if len(sweeps) == 0 or len(chans) == 0:
            return
        
        # collect data
        data = MiesNwb.pack_sweep_data(sweeps)
        data, stim = data[...,0], data[...,1]  # unpack stim and recordings
        dt = sweeps[0].recordings[0]['primary'].dt

        # mask for selected channels
        mask = np.array([ch in chans for ch in sweeps[0].devices])
        data = data[:, mask]
        stim = stim[:, mask]
        chans = np.array(sweeps[0].devices)[mask]

        modes = [sweeps[0][ch].clamp_mode for ch in chans]
        
        # get pulse times for each channel
        stim = stim[0]
        diff = stim[:,1:] - stim[:,:-1]
        # note: the [1:] here skips the test pulse
        on_times = [np.argwhere(diff[i] > 0)[1:,0] for i in range(diff.shape[0])]
        off_times = [np.argwhere(diff[i] < 0)[1:,0] for i in range(diff.shape[0])]

        # remove capacitive artifacts from adjacent electrodes
        if self.params['remove artifacts']:
            npts = int(self.params['remove artifacts', 'window'] / dt)
            for i in range(stim.shape[0]):
                for j in range(stim.shape[0]):
                    if i == j:
                        continue
                    
                    # are these headstages adjacent?
                    hs1, hs2 = chans[i], chans[j]
                    if abs(hs2-hs1) > 3:
                        continue
                    
                    # remove artifacts
                    for k in range(len(on_times[i])):
                        on = on_times[i][k]
                        off = off_times[i][k]
                        data[:, j, on:on+npts] = data[:, j, max(0,on-npts):on].mean(axis=1)[:,None]
                        data[:, j, off:off+npts] = data[:, j, max(0,off-npts):off].mean(axis=1)[:,None]

        # lowpass filter
        if self.params['lowpass']:
            data = gaussian_filter(data, (0, 0, self.params['lowpass', 'sigma'] / dt))

        # prepare to plot
        window = int(self.params['window'] / dt)
        n_sweeps = data.shape[0]
        n_channels = data.shape[1]
        self.plots.set_shape(n_channels, n_channels)
        self.plots.setClipToView(True)
        self.plots.setDownsampling(True, True, 'peak')
        self.plots.enableAutoRange(False, False)

        show_sweeps = 'sweeps' in self.params['show']
        show_sweep_avg = 'sweep avg' in self.params['show']
        show_pulse_avg = self.params['show'] == 'pulse avg'

        for i in range(n_channels):
            for j in range(n_channels):
                plt = self.plots[i, j]
                start = on_times[j][self.params['first pulse']] - window
                if start < 0:
                    frontpad = -start
                    start = 0
                else:
                    frontpad = 0
                stop = on_times[j][self.params['last pulse']] + window

                # select the data segment to be displayed in this matrix cell
                # add padding if necessary
                if frontpad == 0:
                    seg = data[:, i, start:stop].copy()
                else:
                    seg = np.empty((data.shape[0], stop + frontpad), data.dtype)
                    seg[:, frontpad:] = data[:, i, start:stop]
                    seg[:, :frontpad] = seg[:, frontpad:frontpad+1]

                # subtract off baseline for each sweep
                if self.params['remove baseline']:
                    seg -= seg[:, :window].mean(axis=1)[:,None]

                if show_sweeps:
                    alpha = 100 if show_sweep_avg else 200
                    color = (255, 255, 255, alpha)
                    t = np.arange(seg.shape[1]) * dt
                    for k in range(n_sweeps):
                        plt.plot(t, seg[k], pen={'color': color, 'width': 1}, antialias=True)

                if show_sweep_avg or show_pulse_avg:
                    # average selected segments over all sweeps
                    segm = seg.mean(axis=0)

                    if show_pulse_avg:
                        # average over all selected pulses
                        pulses = []
                        for k in range(self.params['first pulse'], self.params['last pulse'] + 1):
                            pstart = on_times[j][k] - on_times[j][self.params['first pulse']]
                            pstop = pstart + (window * 2)
                            pulses.append(segm[pstart:pstop])
                        # for p in pulses:
                        #     t = np.arange(p.shape[0]) * dt
                        #     plt.plot(t, p)
                        segm = np.vstack(pulses).mean(axis=0)

                    t = np.arange(segm.shape[0]) * dt

                    if i == j:
                        color = (80, 80, 80)
                    else:
                        dif = segm - segm[:window].mean()
                        qe = 30 * np.clip(dif, 0, 1e20).mean() / segm[:window].std()
                        qi = 30 * np.clip(-dif, 0, 1e20).mean() / segm[:window].std()
                        if modes[i] == 'ic':
                            qi, qe = qe, qi  # invert color metric for current clamp 
                        g = 100
                        r = np.clip(g + max(qi, 0), 0, 255)
                        b = np.clip(g + max(qe, 0), 0, 255)
                        color = (r, g, b)

                    plt.plot(t, segm, pen={'color': color, 'width': 1}, antialias=True)

                if self.params['show ticks']:
                    vt = pg.VTickGroup((on_times[j]-start) * dt, [0, 0.15], pen=0.4)
                    plt.addItem(vt)

                # Link all plots along x axis
                plt.setXLink(self.plots[0, 0])

                if i == j:
                    # link y axes of all diagonal plots
                    plt.setYLink(self.plots[0, 0])
                else:
                    # link y axes of all plots within a row
                    plt.setYLink(self.plots[i, (i+1) % 2])  # (i+1)%2 just avoids linking to 0,0

                if i < n_channels - 1:
                    plt.getAxis('bottom').setVisible(False)
                if j > 0:
                    plt.getAxis('left').setVisible(False)

                if i == n_channels - 1:
                    plt.setLabels(bottom=('CH%d'%chans[j], 's'))
                if j == 0:
                    plt.setLabels(left=('CH%d'%chans[i], 'A' if modes[i] == 'vc' else 'V'))

        if auto_range:
            r = 14e-12 if modes[i] == 'vc' else 5e-3
            self.plots[0, 1].setYRange(-r, r)
            r = 2e-9 if modes[i] == 'vc' else 100e-3
            self.plots[0, 0].setYRange(-r, r)

            self.plots[0, 0].setXRange(t[0], t[-1])
Exemplo n.º 7
0
def load_nwb(filename):
    global nwb
    nwb = MiesNwb(filename)
    v.set_nwb(nwb)
    console.localNamespace['nwb'] = nwb
Exemplo n.º 8
0
from optoanalysis.analyzers import OptoBaselineAnalyzer
from aisynphys.analyzers import MPBaselineAnalyzer
from neuroanalysis.miesnwb import MiesNwb
import pyqtgraph as pg

pg.dbg()


f = "/Users/meganbkratz/Code/ai/example_data/data/2019-06-13_000/slice_000/site_000/2019_06_13_exp1_TH-compressed.nwb"
f2 = "/Users/meganbkratz/Code/ai/example_data/2019_06_24_131623-compressed.nwb"
f3 = "/Users/meganbkratz/Documents/ManisLab/L4Mapping/ExcitationProfileData/2012.11.09_000/slice_000/cell_004"

#hdf = h5py.File(f, 'r')

mies_nwb = Dataset(loader=MiesNwbLoader(f2, baseline_analyzer_class=MPBaselineAnalyzer))
mies_nwb_old = MiesNwb(f2)
opto_nwb = Dataset(loader=MiesNwbLoader(f, baseline_analyzer_class=OptoBaselineAnalyzer))
acq4_dataset = Dataset(loader=Acq4DatasetLoader(f3))

#old = OptoNwb(f)


### for profiling lazy load stimulus vs stimulus
# prof = pg.debug.Profiler(disabled=False)

# for srec in mies_nwb.contents:
#     recs = srec.recordings

# prof('made recordings')

# for srec in mies_nwb.contents:
Exemplo n.º 9
0
    def load_experiment(self, nwb_handle):
        self.nwb_handle = nwb_handle
        self.nwb = MiesNwb(nwb_handle.name())

        # load all recordings
        recs = {}
        for srec in self.nwb.contents:
            for chan in srec.devices:
                recs.setdefault(chan, []).append(srec[chan])

        chans = sorted(recs.keys())
        self.channels = chans
        self.plots.set_shape(len(chans), 1)
        self.plots.setXLink(self.plots[0, 0])

        # find time of first recording
        start_time = min([rec[0].start_time for rec in recs.values()])
        self.start_time = start_time
        end_time = max([rec[-1].start_time for rec in recs.values()])
        self.plots.setXRange(0, (end_time - start_time).seconds)

        # plot all recordings
        for i, chan in enumerate(chans):
            n_recs = len(recs[chan])
            times = np.empty(n_recs)
            i_hold = np.empty(n_recs)
            v_hold = np.empty(n_recs)
            v_noise = np.empty(n_recs)
            i_noise = np.empty(n_recs)

            # load QC metrics for all recordings
            for j, rec in enumerate(recs[chan]):
                dt = (rec.start_time - start_time).seconds
                times[j] = dt
                v_hold[j] = rec.baseline_potential
                i_hold[j] = rec.baseline_current
                if rec.clamp_mode == 'vc':
                    v_noise[j] = np.nan
                    i_noise[j] = rec.baseline_rms_noise
                else:
                    v_noise[j] = rec.baseline_rms_noise
                    i_noise[j] = np.nan

            # scale all qc metrics to the range 0-1
            pass_brush = pg.mkBrush(100, 100, 255, 200)
            fail_brush = pg.mkBrush(255, 0, 0, 200)
            v_hold = (v_hold + 60e-3) / 20e-3
            i_hold = i_hold / 400e-12
            v_noise = v_noise / 5e-3
            i_noise = i_noise / 100e-12

            plt = self.plots[i, 0]
            plt.setLabels(left=("Ch %d" % chan))
            for data, symbol in [(np.zeros_like(times), 'o'), (v_hold, 't'),
                                 (i_hold, 'x'), (v_noise, 't1'),
                                 (i_noise, 'x')]:
                brushes = np.where(np.abs(data) > 1.0, fail_brush, pass_brush)
                plt.plot(times,
                         data,
                         pen=None,
                         symbol=symbol,
                         symbolPen=None,
                         symbolBrush=brushes)

        # automatically select electrode regions
        self.remove_pipettes()
        site_info = self.nwb_handle.parent().info()
        for i in self.channels:
            hs_state = site_info.get('Headstage %d' % i, None)
            if hs_state is None:
                continue
            status = {
                'NS': 'No seal',
                'LS': 'Low seal',
                'GS': 'GOhm seal',
                'TF': 'Technical failure',
            }[hs_state]
            start = (recs[i][0].start_time - start_time).seconds - 1
            stop = (recs[i][-1].start_time - start_time).seconds + 1

            # assume if we got more than two recordings, then a cell was present.
            got_cell = len(recs[i]) > 2

            self.add_pipette(i, start, stop, status=status, got_cell=got_cell)