Пример #1
0
    def populate(self):
        self.items_by_pair_id = {}
        
        self.session = db.Session()
        db_expts = db.list_experiments(session=self.session)
        db_expts.sort(key=lambda e: e.acq_timestamp)
        for expt in db_expts:
            date = expt.acq_timestamp
            date_str = datetime.fromtimestamp(date).strftime('%Y-%m-%d')
            slice = expt.slice
            expt_item = pg.TreeWidgetItem(map(str, [date_str, expt.rig_name, slice.species, expt.target_region, slice.genotype, expt.acsf]))
            expt_item.expt = expt
            self.addTopLevelItem(expt_item)

            for pair in expt.pairs:
                if pair.n_ex_test_spikes == 0 and pair.n_in_test_spikes == 0:
                    continue
                cells = '%d => %d' % (pair.pre_cell.ext_id, pair.post_cell.ext_id)
                conn = {True:"syn", False:"-", None:"?"}[pair.synapse]
                types = 'L%s %s => L%s %s' % (pair.pre_cell.target_layer or "?", pair.pre_cell.cre_type, pair.post_cell.target_layer or "?", pair.post_cell.cre_type)
                pair_item = pg.TreeWidgetItem([cells, conn, types])
                expt_item.addChild(pair_item)
                pair_item.pair = pair
                pair_item.expt = expt
                self.items_by_pair_id[pair.id] = pair_item
def update_DB(limit=None,
              expts=None,
              parallel=True,
              workers=6,
              raise_exceptions=False,
              session=None):
    """
    """
    session = db.Session()
    if expts is None:
        experiments = session.query(db.Experiment.acq_timestamp).all()
        expts_done = session.query(db.Experiment.acq_timestamp).join(
            db.Pair).join(AvgFirstPulseFit).all()
        print("Skipping %d already complete experiments" % (len(expts_done)))
        experiments = [e for e in experiments if e not in set(expts_done)]

        if limit > 0:
            np.random.shuffle(experiments)
            experiments = experiments[:limit]

        jobs = [(record.acq_timestamp, index, len(experiments))
                for index, record in enumerate(experiments)]
    else:
        jobs = [(expt, i, len(expts)) for i, expt in enumerate(expts)]
    # if parallel:
    #     pool = multiprocessing.Pool(processes=workers)
    #     pool.map(pair, pairs)
    # else:
    for job in jobs:
        compute_fit(job, raise_exceptions=raise_exceptions)
Пример #3
0
def update_analysis(limit=None):
    s = db.Session()
    q = s.query(db.Pair,
                FirstPulseFeatures).outerjoin(FirstPulseFeatures).filter(
                    FirstPulseFeatures.pair_id == None)
    if limit is not None:
        q = q.limit(limit)
    print("Updating %d pairs.." % q.count())
    records = q.all()
    for i, record in enumerate(records):
        pair = record[0]
        pulse_responses, pulse_ids, pulse_response_amps = filter_pulse_responses(
            pair)
        if len(pulse_responses) > 0:
            results = first_pulse_features(pair, pulse_responses,
                                           pulse_response_amps)
            fpf = FirstPulseFeatures(pair=pair,
                                     n_sweeps=len(pulse_ids),
                                     pulse_ids=pulse_ids,
                                     **results)
            s.add(fpf)
            if i % 10 == 0:
                s.commit()
                print("%d pairs added to the DB of %d" % (i, len(records)))
    s.commit()
    s.close()
def build_detection_limits():
    # silence warnings about fp issues
    np.seterr(all='ignore')

    # read all pair records from DB
    classifier = strength_analysis.get_pair_classifier(seed=0,
                                                       use_vc_features=False)
    conns = strength_analysis.query_all_pairs(classifier)

    # filter
    mask = np.isfinite(conns['ic_deconv_amp_mean'])
    filtered = conns[mask]

    # remove recordings with gain errors
    mask = filtered['ic_deconv_amp_mean'] < 0.02

    # remove recordings with high crosstalk
    mask &= abs(filtered['ic_crosstalk_mean']) < 60e-6

    # remove recordings with low sample count
    mask &= filtered['ic_n_samples'] > 50

    typs = filtered['pre_cre_type']
    mask &= typs == filtered['post_cre_type']

    typ_mask = ((typs == 'sim1') | (typs == 'tlx3') | (typs == 'unknown') |
                (typs == 'rorb') | (typs == 'ntsr1'))
    mask &= typ_mask

    filtered = filtered[mask]

    c_mask = filtered['synapse'] == True
    u_mask = ~c_mask

    signal = filtered['confidence']
    background = filtered['ic_base_deconv_amp_mean']

    session = db.Session()

    # do selected connections first
    count = 0
    for i, rec in enumerate(filtered):
        print("================== %d/%d ===================== " %
              (i, len(filtered)))
        pair = session.query(
            db.Pair).filter(db.Pair.id == rec['pair_id']).all()[0]
        if pair.detection_limit is not None:
            print("    skip!")
            continue
        try:
            measure_limit(pair, session, classifier)
        except Exception:
            sys.excepthook(*sys.exc_info())

        count += 1
        if count > 100:
            print("Bailing out before memory fills up.")
            sys.exit(0)
def query_all_pairs():
    import pandas
    query = ("""
    select """
        # ((DATE_PART('day', experiment.acq_timestamp - '1970-01-01'::timestamp) * 24 + 
        # DATE_PART('hour', experiment.acq_timestamp - '1970-01-01'::timestamp)) * 60 +
        # DATE_PART('minute', experiment.acq_timestamp - '1970-01-01'::timestamp)) * 60 +
        # DATE_PART('second', experiment.acq_timestamp - '1970-01-01'::timestamp) as acq_timestamp,
        """
        connection_strength.*,
        experiment.id as experiment_id,
        experiment.acq_timestamp as acq_timestamp,
        ABS(connection_strength.amp_med) as abs_amp_med,
        ABS(connection_strength.base_amp_med) as abs_base_amp_med,
        ABS(connection_strength.amp_med_minus_base) as abs_amp_med_minus_base,
        ABS(connection_strength.deconv_amp_med) as abs_deconv_amp_med,
        ABS(connection_strength.deconv_base_amp_med) as abs_deconv_base_amp_med,
        ABS(connection_strength.deconv_amp_med_minus_base) as abs_deconv_amp_med_minus_base,
        experiment.rig_name,
        experiment.acsf,
        slice.species,
        slice.genotype,
        slice.age,
        slice.slice_time,
        pre_cell.ext_id as pre_cell_id,
        pre_cell.cre_type as pre_cre_type,
        pre_cell.target_layer as pre_target_layer,
        post_cell.ext_id as post_cell_id,
        post_cell.cre_type as post_cre_type,
        post_cell.target_layer as post_target_layer,
        pair.synapse,
        pair.crosstalk_artifact,
        abs(post_cell.ext_id - pre_cell.ext_id) as electrode_distance
    from connection_strength
    join pair on connection_strength.pair_id=pair.id
    join cell pre_cell on pair.pre_cell_id=pre_cell.id
    join cell post_cell on pair.post_cell_id=post_cell.id
    join experiment on pair.expt_id=experiment.id
    join slice on experiment.slice_id=slice.id
    order by acq_timestamp
    """)
    session = db.Session()
    df = pandas.read_sql(query, session.bind)

    # test out a new metric:
    df['connection_signal'] = pandas.Series(df['deconv_amp_med'] / df['latency_stdev'], index=df.index)
    df['connection_background'] = pandas.Series(df['deconv_base_amp_med'] / df['base_latency_stdev'], index=df.index)

    ts = [datetime_to_timestamp(t) for t in df['acq_timestamp']]
    df['acq_timestamp'] = ts
    recs = df.to_records()
    return recs
Пример #6
0
 def __init__(self):
     pg.QtGui.QWidget.__init__(self)
     self.layout = pg.QtGui.QVBoxLayout()
     self.setLayout(self.layout)
     self.update_button = pg.QtGui.QPushButton("Update Matrix")
     self.layout.addWidget(self.update_button)
     self.project_list = pg.QtGui.QListWidget()
     self.layout.addWidget(self.project_list)
     s = db.Session()
     projects = s.query(db.Experiment.project_name).distinct().all()
     for record in projects:
         project = record[0]
         project_item = pg.QtGui.QListWidgetItem(project)
         project_item.setFlags(project_item.flags() | pg.QtCore.Qt.ItemIsUserCheckable)
         project_item.setCheckState(pg.QtCore.Qt.Unchecked)
         self.project_list.addItem(project_item)
Пример #7
0
    def poll(self):

        self.session = database.Session()
        expts = {}

        print("loading site paths..")
        #site_paths = glob.glob(os.path.join(config.synphys_data, '*', 'slice_*', 'site_*'))

        root_dh = getDirHandle(config.synphys_data)

        print(root_dh.name())
        for site_dh in self.list_expts(root_dh):
            expt = ExperimentMetadata(nas_path=site_dh.name())
            if expt.timestamp in expts:
                continue
            expts[expt.timestamp] = expt
            self.check(expt)
            if self.limit > 0 and len(expts) > self.limit:
                break
Пример #8
0
    def __init__(self):
        pg.QtCore.QObject.__init__(self)

        # global session for querying from DB
        self.session = db.Session()

        win = pg.QtGui.QSplitter()
        win.setOrientation(pg.QtCore.Qt.Horizontal)
        win.resize(1000, 800)
        win.show()
        
        b = ExperimentBrowser()
        win.addWidget(b)
        
        rs_plots = ResponseStrengthPlots(self.session)
        win.addWidget(rs_plots)

        b.itemSelectionChanged.connect(self._selected)            
        b.doubleClicked.connect(self._dbl_clicked)

        self.win = win
        self.rs_plots = rs_plots
        self.browser = b
        self.nwb_viewer = MultipatchNwbViewer()
Пример #9
0
    # strong ex with failures
    expt_id = 1537820585.767
    pre_cell_id = 1
    post_cell_id = 2

    # strong ex, depressing
    expt_id = 1536781898.381
    pre_cell_id = 8
    post_cell_id = 2

    # expt_id = float(sys.argv[1])
    # pre_cell_id = int(sys.argv[2])
    # post_cell_id = int(sys.argv[3])

    session = db.Session()

    expt = db.experiment_from_timestamp(expt_id, session=session)
    pair = expt.pairs[(pre_cell_id, post_cell_id)]

    syn_type = pair.connection_strength.synapse_type

    # 1. Get a list of all presynaptic spike times and the amplitudes of postsynaptic responses

    raw_events = get_amps(session, pair, clamp_mode='ic')
    mask = event_qc(raw_events)
    events = raw_events[mask]

    rec_times = 1e-9 * (events['rec_start_time'].astype(float) -
                        float(events['rec_start_time'][0]))
    spike_times = events['max_dvdt_time'] + rec_times
Пример #10
0
def query_all_pairs(classifier=None):
    columns = [
        "connection_strength.*",
        "experiment.id as experiment_id",
        "experiment.acq_timestamp as acq_timestamp",
        "experiment.rig_name",
        "experiment.acsf",
        "slice.species as donor_species",
        "slice.genotype as donor_genotype",
        "slice.age as donor_age",
        "slice.sex as donor_sex",
        "slice.quality as slice_quality",
        "slice.weight as donor_weight",
        "slice.slice_time",
        "pre_cell.ext_id as pre_cell_id",
        "pre_cell.cre_type as pre_cre_type",
        "pre_cell.target_layer as pre_target_layer",
        "post_cell.ext_id as post_cell_id",
        "post_cell.cre_type as post_cre_type",
        "post_cell.target_layer as post_target_layer",
        "pair.synapse",
        "pair.distance",
        "pair.crosstalk_artifact",
        "abs(post_cell.ext_id - pre_cell.ext_id) as electrode_distance",
    ]
    # columns.extend([
    #     "detection_limit.minimum_amplitude",
    # ])

    joins = [
        "join pair on connection_strength.pair_id=pair.id",
        "join cell pre_cell on pair.pre_cell_id=pre_cell.id",
        "join cell post_cell on pair.post_cell_id=post_cell.id",
        "join experiment on pair.experiment_id=experiment.id",
        "join slice on experiment.slice_id=slice.id",
    ]
    # joins.extend([
    #     "left join detection_limit on detection_limit.pair_id=pair.id",
    # ])


    query = ("""
    select 
    {columns}
    from connection_strength
    {joins}
    order by acq_timestamp
    """).format(
        columns=", ".join(columns), 
        joins=" ".join(joins),
    )

    session = db.Session()
    df = pandas.read_sql(query, session.bind)

    recs = df.to_records()

    if classifier is None:
        return recs

    # Fit classifier and add results of classifier prediction in to records
    classifier.fit(recs)
    prediction = classifier.predict(recs)
    recs = join_struct_arrays([recs, prediction])
    return recs
Пример #11
0
from __future__ import division
import time, datetime
import multipatch_analysis.database.database as db
from neuroanalysis.ui.plot_grid import PlotGrid
s = db.Session()


q = """
    select 
        pcrec.baseline_rms_noise,
        substring(experiment.original_path from 36 for 1),
        recording.device_key,
        recording.start_time
    from 
        patch_clamp_recording pcrec
        join recording on pcrec.recording_id=recording.id
        join sync_rec on recording.sync_rec_id=sync_rec.id
        join experiment on sync_rec.experiment_id=experiment.id
    where
        pcrec.clamp_mode='ic'
        and pcrec.baseline_rms_noise is not null
        and recording.device_key is not null
        and experiment.original_path is not null
"""


rec = s.execute(q)
rows = rec.fetchall()

import pyqtgraph as pg
import numpy as np
Пример #12
0
 def in_database(self):
     session = database.Session()
     expts = session.query(database.Experiment).filter(
         database.Experiment.acq_timestamp == self.datetime).all()
     return len(expts) == 1
def compute_fit(job_info, raise_exceptions=False):

    session = db.Session(readonly=False)  #create session

    expt_id, index, n_jobs = job_info
    print("QUERYING (expt_id=%f): %d/%d" % (expt_id, index, n_jobs))

    #do query
    pre_cell = db.aliased(db.Cell)
    post_cell = db.aliased(db.Cell)
    expt_stuff = session.query(db.Pair, db.Experiment.acq_timestamp, pre_cell.ext_id, post_cell.ext_id,pre_cell.cre_type, post_cell.cre_type)\
                        .join(db.Experiment)\
                        .join(pre_cell, db.Pair.pre_cell_id==pre_cell.id)\
                        .join(post_cell, db.Pair.post_cell_id==post_cell.id).filter(db.Experiment.acq_timestamp==expt_id).all()
    # make sure query returned something
    if len(expt_stuff) <= 0:
        print('No pairs found for expt_id=%f', expt_id)
        return

    processed_count = 0  #index for keeping track of how many cells pairs in experiemnt have been analyzed
    for ii, (pair, uid, pre_cell_id, post_cell_id, pre_cell_cre,
             post_cell_cre) in enumerate(expt_stuff):

        print("Number %i of %i experiment pairs: %0.3f, cell ids:%s %s" %
              (ii, len(expt_stuff), uid, pre_cell_id, post_cell_id))

        # grab syapse from the table
        try:
            excitation = pair.connection_strength.synapse_type
        except:
            print('\tskipping: no pair.connection_strength.synapse_type')
            continue

        if not pair.connection_strength.ic_fit_xoffset:
            print('\tskipping: no latency to do forced latency fitting')
            continue
        xoffset = pair.connection_strength.ic_fit_xoffset

        # -----------fit current clamp data---------------------
        # get pulses
        (pulse_responses_i, pulse_ids_i, psp_amps_measured_i, freq, avg_psp_i,
         measured_relative_amp_i,
         measured_baseline_i) = get_average_pulse_response(pair,
                                                           desired_clamp='ic')

        if pulse_responses_i:
            # weight and fit the trace
            weight_i = np.ones(len(
                avg_psp_i.data)) * 10.  #set everything to ten initially
            weight_i[int((time_before_spike - 3e-3) / avg_psp_i.dt):int(
                time_before_spike / avg_psp_i.dt
            )] = 0.  #area around stim artifact note that since this is spike aligned there will be some blur in where the cross talk is
            weight_i[int((time_before_spike + .0001 + xoffset) /
                         avg_psp_i.dt):int(
                             (time_before_spike + .0001 + xoffset + 4e-3) /
                             avg_psp_i.dt)] = 30.  #area around steep PSP rise
            avg_fit_i = fit_trace(avg_psp_i,
                                  excitation=excitation,
                                  weight=weight_i,
                                  latency=xoffset,
                                  latency_jitter=.5e-3)
            latency_i = avg_fit_i.best_values['xoffset'] - time_before_spike
            amp_i = avg_fit_i.best_values['amp']
            rise_time_i = avg_fit_i.best_values['rise_time']
            decay_tau_i = avg_fit_i.best_values['decay_tau']
            avg_data_waveform_i = avg_psp_i.data
            avg_fit_waveform_i = avg_fit_i.best_fit
            dt_i = avg_psp_i.dt
            nrmse_i = avg_fit_i.nrmse()
        else:
            print(
                '\tskipping: no suitable first pulses found in current clamp')
            weight_i = np.array([0])
            latency_i = None
            amp_i = None
            rise_time_i = None
            decay_tau_i = None
            avg_data_waveform_i = np.array([0])
            avg_fit_waveform_i = np.array([0])
            dt_i = None
            nrmse_i = None
        # --------------fit voltage clamp data---------------------
        # get pulses
        (pulse_responses_v, pulse_ids_v, psp_amps_measured_v, freq_v,
         avg_psp_v, measured_relative_amp_v,
         measured_baseline_v) = get_average_pulse_response(pair,
                                                           desired_clamp='vc')

        if pulse_responses_v:
            # weight and fit the trace
            weight_v = np.ones(len(
                avg_psp_v.data)) * 10.  #set everything to ten initially
            weight_v[int((time_before_spike + .0001 + xoffset) /
                         avg_psp_v.dt):int(
                             (time_before_spike + .0001 + xoffset + 4e-3) /
                             avg_psp_v.dt)] = 30.  #area around steep PSP rise
            avg_fit_v = fit_trace(avg_psp_v,
                                  excitation=excitation,
                                  clamp_mode='vc',
                                  weight=weight_v,
                                  latency=xoffset,
                                  latency_jitter=.5e-3)
            latency_v = avg_fit_v.best_values['xoffset'] - time_before_spike
            amp_v = avg_fit_v.best_values['amp']
            rise_time_v = avg_fit_v.best_values['rise_time']
            decay_tau_v = avg_fit_v.best_values['decay_tau']
            avg_data_waveform_v = avg_psp_v.data
            avg_fit_waveform_v = avg_fit_v.best_fit
            dt_v = avg_psp_v.dt
            nrmse_v = avg_fit_v.nrmse()

        else:
            print(
                '\tskipping: no suitable first pulses found in voltage clamp')
            weight_v = np.array([0])
            latency_v = None
            amp_v = None
            rise_time_v = None
            decay_tau_v = None
            avg_data_waveform_v = np.array([0])
            avg_fit_waveform_v = np.array([0])
            dt_v = None
            nrmse_v = None
        #------------ done with fitting section ------------------------------

        # dictionary for ease of translation into the output table
        out_dict = {
            'ic_amp': amp_i,
            'ic_latency': latency_i,
            'ic_rise_time': rise_time_i,
            'ic_decay_tau': decay_tau_i,
            'ic_avg_psp_data': avg_data_waveform_i,
            'ic_avg_psp_fit': avg_fit_waveform_i,
            'ic_dt': dt_i,
            'ic_pulse_ids': pulse_ids_i,
            'ic_nrmse': nrmse_i,
            'ic_measured_baseline': measured_baseline_i,
            'ic_measured_amp': measured_relative_amp_i,
            'ic_weight': weight_i,
            'vc_amp': amp_v,
            'vc_latency': latency_v,
            'vc_rise_time': rise_time_v,
            'vc_decay_tau': decay_tau_v,
            'vc_avg_psp_data': avg_data_waveform_v,
            'vc_avg_psp_fit': avg_fit_waveform_v,
            'vc_dt': dt_v,
            'vc_pulse_ids': pulse_ids_v,
            'vc_nrmse': nrmse_v,
            'vc_measured_baseline': measured_baseline_v,
            'vc_measured_amp': measured_relative_amp_v,
            'vc_weight': weight_v
        }
        # map to pair table and commit
        afpf = AvgFirstPulseFit(pair=pair, **out_dict)
        if commiting is True:
            session.add(afpf)
            session.commit()

            # TODO: I guess I need to query expt before I do this...
            # expt.meta = expt.meta.copy()  # required by sqlalchemy to flag as modified
            # expt.meta['avg_first_pulse_fit_timestamp'] = time.time()

        print("COMMITED %i pairs from expt_id=%f: %d/%d" %
              (processed_count, expt_id, index, n_jobs))
        #---------------------------------------------------------------------------------------
        processed_count = processed_count + 1
        print('processed', processed_count + 1)