class PairView(QtGui.QWidget): """For analyzing pre/post-synaptic pairs. """ def __init__(self, parent=None): self.sweeps = [] self.channels = [] self.current_event_set = None self.event_sets = [] QtGui.QWidget.__init__(self, parent) self.layout = QtGui.QGridLayout() self.setLayout(self.layout) self.layout.setContentsMargins(0, 0, 0, 0) self.vsplit = QtGui.QSplitter(QtCore.Qt.Vertical) self.layout.addWidget(self.vsplit, 0, 0) self.pre_plot = pg.PlotWidget() self.post_plot = pg.PlotWidget() for plt in (self.pre_plot, self.post_plot): plt.setClipToView(True) plt.setDownsampling(True, True, 'peak') self.post_plot.setXLink(self.pre_plot) self.vsplit.addWidget(self.pre_plot) self.vsplit.addWidget(self.post_plot) self.response_plots = PlotGrid() self.vsplit.addWidget(self.response_plots) self.event_splitter = QtGui.QSplitter(QtCore.Qt.Horizontal) self.vsplit.addWidget(self.event_splitter) self.event_table = pg.TableWidget() self.event_splitter.addWidget(self.event_table) self.event_widget = QtGui.QWidget() self.event_widget_layout = QtGui.QGridLayout() self.event_widget.setLayout(self.event_widget_layout) self.event_splitter.addWidget(self.event_widget) self.event_splitter.setSizes([600, 100]) self.event_set_list = QtGui.QListWidget() self.event_widget_layout.addWidget(self.event_set_list, 0, 0, 1, 2) self.event_set_list.addItem("current") self.event_set_list.itemSelectionChanged.connect( self.event_set_selected) self.add_set_btn = QtGui.QPushButton('add') self.event_widget_layout.addWidget(self.add_set_btn, 1, 0, 1, 1) self.add_set_btn.clicked.connect(self.add_set_clicked) self.remove_set_btn = QtGui.QPushButton('del') self.event_widget_layout.addWidget(self.remove_set_btn, 1, 1, 1, 1) self.remove_set_btn.clicked.connect(self.remove_set_clicked) self.fit_btn = QtGui.QPushButton('fit all') self.event_widget_layout.addWidget(self.fit_btn, 2, 0, 1, 2) self.fit_btn.clicked.connect(self.fit_clicked) self.fit_plot = PlotGrid() self.fit_explorer = None self.artifact_remover = ArtifactRemover(user_width=True) self.baseline_remover = BaselineRemover() self.filter = SignalFilter() self.params = pg.parametertree.Parameter( name='params', type='group', children=[{ 'name': 'pre', 'type': 'list', 'values': [] }, { 'name': 'post', 'type': 'list', 'values': [] }, self.artifact_remover.params, self.baseline_remover.params, self.filter.params, { 'name': 'time constant', 'type': 'float', 'suffix': 's', 'siPrefix': True, 'value': 10e-3, 'dec': True, 'minStep': 100e-6 }]) self.params.sigTreeStateChanged.connect(self._update_plots) def data_selected(self, sweeps, channels): self.sweeps = sweeps self.channels = channels self.params.child('pre').setLimits(channels) self.params.child('post').setLimits(channels) self._update_plots() def _update_plots(self): sweeps = self.sweeps self.current_event_set = None self.event_table.clear() # clear all plots self.pre_plot.clear() self.post_plot.clear() pre = self.params['pre'] post = self.params['post'] # If there are no selected sweeps or channels have not been set, return if len( sweeps ) == 0 or pre == post or pre not in self.channels or post not in self.channels: return pre_mode = sweeps[0][pre].clamp_mode post_mode = sweeps[0][post].clamp_mode for ch, mode, plot in [(pre, pre_mode, self.pre_plot), (post, post_mode, self.post_plot)]: units = 'A' if mode == 'vc' else 'V' plot.setLabels(left=("Channel %d" % ch, units), bottom=("Time", 's')) # Iterate over selected channels of all sweeps, plotting traces one at a time # Collect information about pulses and spikes pulses = [] spikes = [] post_traces = [] for i, sweep in enumerate(sweeps): pre_trace = sweep[pre]['primary'] post_trace = sweep[post]['primary'] # Detect pulse times stim = sweep[pre]['command'].data sdiff = np.diff(stim) on_times = np.argwhere(sdiff > 0)[1:, 0] # 1: skips test pulse off_times = np.argwhere(sdiff < 0)[1:, 0] pulses.append(on_times) # filter data post_filt = self.artifact_remover.process( post_trace, list(on_times) + list(off_times)) post_filt = self.baseline_remover.process(post_filt) post_filt = self.filter.process(post_filt) post_traces.append(post_filt) # plot raw data color = pg.intColor(i, hues=len(sweeps) * 1.3, sat=128) color.setAlpha(128) for trace, plot in [(pre_trace, self.pre_plot), (post_filt, self.post_plot)]: plot.plot(trace.time_values, trace.data, pen=color, antialias=False) # detect spike times spike_inds = [] spike_info = [] for on, off in zip(on_times, off_times): spike = detect_evoked_spike(sweep[pre], [on, off]) spike_info.append(spike) if spike is None: spike_inds.append(None) else: spike_inds.append(spike['rise_index']) spikes.append(spike_info) dt = pre_trace.dt vticks = pg.VTickGroup( [x * dt for x in spike_inds if x is not None], yrange=[0.0, 0.2], pen=color) self.pre_plot.addItem(vticks) # Iterate over spikes, plotting average response all_responses = [] avg_responses = [] fits = [] fit = None npulses = max(map(len, pulses)) self.response_plots.clear() self.response_plots.set_shape(1, npulses + 1) # 1 extra for global average self.response_plots.setYLink(self.response_plots[0, 0]) for i in range(1, npulses + 1): self.response_plots[0, i].hideAxis('left') units = 'A' if post_mode == 'vc' else 'V' self.response_plots[0, 0].setLabels(left=("Averaged events (Channel %d)" % post, units)) fit_pen = {'color': (30, 30, 255), 'width': 2, 'dash': [1, 1]} for i in range(npulses): # get the chunk of each sweep between spikes responses = [] all_responses.append(responses) for j, sweep in enumerate(sweeps): # get the current spike if i >= len(spikes[j]): continue spike = spikes[j][i] if spike is None: continue # find next spike next_spike = None for sp in spikes[j][i + 1:]: if sp is not None: next_spike = sp break # determine time range for response max_len = int(40e-3 / dt) # don't take more than 50ms for any response start = spike['rise_index'] if next_spike is not None: stop = min(start + max_len, next_spike['rise_index']) else: stop = start + max_len # collect data from this trace trace = post_traces[j] d = trace.data[start:stop].copy() responses.append(d) if len(responses) == 0: continue # extend all responses to the same length and take nanmean avg = ragged_mean(responses, method='clip') avg -= float_mode(avg[:int(1e-3 / dt)]) avg_responses.append(avg) # plot average response for this pulse start = np.median( [sp[i]['rise_index'] for sp in spikes if sp[i] is not None]) * dt t = np.arange(len(avg)) * dt self.response_plots[0, i].plot(t, avg, pen='w', antialias=True) # fit! fit = self.fit_psp(avg, t, dt, post_mode) fits.append(fit) # let the user mess with this fit curve = self.response_plots[0, i].plot(t, fit.eval(), pen=fit_pen, antialias=True).curve curve.setClickable(True) curve.fit = fit curve.sigClicked.connect(self.fit_curve_clicked) # display global average global_avg = ragged_mean(avg_responses, method='clip') t = np.arange(len(global_avg)) * dt self.response_plots[0, -1].plot(t, global_avg, pen='w', antialias=True) global_fit = self.fit_psp(global_avg, t, dt, post_mode) self.response_plots[0, -1].plot(t, global_fit.eval(), pen=fit_pen, antialias=True) # display fit parameters in table events = [] for i, f in enumerate(fits + [global_fit]): if f is None: continue if i >= len(fits): vals = OrderedDict([('id', 'avg'), ('spike_time', np.nan), ('spike_stdev', np.nan)]) else: spt = [ s[i]['peak_index'] * dt for s in spikes if s[i] is not None ] vals = OrderedDict([('id', i), ('spike_time', np.mean(spt)), ('spike_stdev', np.std(spt))]) vals.update( OrderedDict([(k, f.best_values[k]) for k in f.params.keys()])) events.append(vals) self.current_event_set = (pre, post, events, sweeps) self.event_set_list.setCurrentRow(0) self.event_set_selected() def fit_psp(self, data, t, dt, clamp_mode): mode = float_mode(data[:int(1e-3 / dt)]) sign = -1 if data.mean() - mode < 0 else 1 params = OrderedDict([ ('xoffset', (2e-3, 3e-4, 5e-3)), ('yoffset', data[0]), ('amp', sign * 10e-12), #('k', (2e-3, 50e-6, 10e-3)), ('rise_time', (2e-3, 50e-6, 10e-3)), ('decay_tau', (4e-3, 500e-6, 50e-3)), ('rise_power', (2.0, 'fixed')), ]) if clamp_mode == 'ic': params['amp'] = sign * 10e-3 #params['k'] = (5e-3, 50e-6, 20e-3) params['rise_time'] = (5e-3, 50e-6, 20e-3) params['decay_tau'] = (15e-3, 500e-6, 150e-3) fit_kws = {'xtol': 1e-3, 'maxfev': 100} psp = fitting.Psp() return psp.fit(data, x=t, fit_kws=fit_kws, params=params) def add_set_clicked(self): if self.current_event_set is None: return ces = self.current_event_set self.event_sets.append(ces) item = QtGui.QListWidgetItem("%d -> %d" % (ces[0], ces[1])) self.event_set_list.addItem(item) item.event_set = ces def remove_set_clicked(self): if self.event_set_list.currentRow() == 0: return sel = self.event_set_list.takeItem(self.event_set_list.currentRow()) self.event_sets.remove(sel.event_set) def event_set_selected(self): sel = self.event_set_list.selectedItems()[0] if sel.text() == "current": self.event_table.setData(self.current_event_set[2]) else: self.event_table.setData(sel.event_set[2]) def fit_clicked(self): from synaptic_release import ReleaseModel self.fit_plot.clear() self.fit_plot.show() n_sets = len(self.event_sets) self.fit_plot.set_shape(n_sets, 1) l = self.fit_plot[0, 0].legend if l is not None: l.scene.removeItem(l) self.fit_plot[0, 0].legend = None self.fit_plot[0, 0].addLegend() for i in range(n_sets): self.fit_plot[i, 0].setXLink(self.fit_plot[0, 0]) spike_sets = [] for i, evset in enumerate(self.event_sets): evset = evset[2][:-1] # select events, skip last row (average) x = np.array([ev['spike_time'] for ev in evset]) y = np.array([ev['amp'] for ev in evset]) x -= x[0] x *= 1000 y /= y[0] spike_sets.append((x, y)) self.fit_plot[i, 0].plot(x / 1000., y, pen=None, symbol='o') model = ReleaseModel() dynamics_types = ['Dep', 'Fac', 'UR', 'SMR', 'DSR'] for k in dynamics_types: model.Dynamics[k] = 0 fit_params = [] with pg.ProgressDialog("Fitting release model..", 0, len(dynamics_types)) as dlg: for k in dynamics_types: model.Dynamics[k] = 1 fit_params.append(model.run_fit(spike_sets)) dlg += 1 if dlg.wasCanceled(): return max_color = len(fit_params) * 1.5 for i, params in enumerate(fit_params): for j, spikes in enumerate(spike_sets): x, y = spikes t = np.linspace(0, x.max(), 1000) output = model.eval(x, params.values()) y = output[:, 1] x = output[:, 0] / 1000. self.fit_plot[j, 0].plot(x, y, pen=(i, max_color), name=dynamics_types[i]) def fit_curve_clicked(self, curve): if self.fit_explorer is None: self.fit_explorer = FitExplorer(curve.fit) else: self.fit_explorer.set_fit(curve.fit) self.fit_explorer.show()
def fit_curve_clicked(self, curve): if self.fit_explorer is None: self.fit_explorer = FitExplorer(curve.fit) else: self.fit_explorer.set_fit(curve.fit) self.fit_explorer.show()
def explore_fit(self, mode, holding): fit = self.pair_analyzer.last_fit[mode, holding] self.fit_explorer = FitExplorer(fit) self.fit_explorer.show()
class PairAnalysisWindow(pg.QtGui.QWidget): def __init__(self, default_session, notes_session): pg.QtGui.QWidget.__init__(self) self.default_session = default_session self.notes_session = notes_session self.layout = pg.QtGui.QGridLayout() self.setLayout(self.layout) self.h_splitter = pg.QtGui.QSplitter() self.h_splitter.setOrientation(pg.QtCore.Qt.Horizontal) self.pair_analyzer = PairAnalysis() self.fit_explorer = None self.ctrl_panel = self.pair_analyzer.ctrl_panel self.user_params = self.ctrl_panel.user_params self.output_params = self.ctrl_panel.output_params self.ptree = ptree.ParameterTree(showHeader=False) self.pair_param = Parameter.create(name='Current Pair', type='str', readonly=True) self.ptree.addParameters(self.pair_param) self.ptree.addParameters(self.user_params, showTop=False) self.fit_ptree = ptree.ParameterTree(showHeader=False) self.fit_ptree.addParameters(self.output_params, showTop=False) self.save_btn = pg.FeedbackButton('Save Analysis') self.expt_btn = pg.QtGui.QPushButton('Set Experiments') self.fit_btn = pg.QtGui.QPushButton('Fit Responses') self.ic_plot = self.pair_analyzer.ic_plot self.vc_plot = self.pair_analyzer.vc_plot self.select_ptree = ptree.ParameterTree(showHeader=False) self.hash_select = Parameter.create(name='Hashtags', type='group', children= [{'name': 'With multiple selected:', 'type': 'list', 'values': ['Include if any appear', 'Include if all appear'], 'value': 'Include if any appear'}]+ [{'name': '#', 'type': 'bool'}] + [{'name': ht, 'type': 'bool'} for ht in comment_hashtag[1:]]) self.rigs = db.query(db.Experiment.rig_name).distinct().all() self.operators = db.query(db.Experiment.operator_name).distinct().all() self.rig_select = Parameter.create(name='Rig', type='group', children=[{'name': rig[0], 'type': 'bool'} for rig in self.rigs]) self.operator_select = Parameter.create(name='Operator', type='group', children=[{'name': operator[0], 'type': 'bool'} for operator in self.operators]) self.data_type = Parameter.create(name='Reduce data to:', type='group', children=[ {'name': 'Pairs with data', 'type': 'bool', 'value': True}, {'name': 'Synapse is None', 'type': 'bool'}]) [self.select_ptree.addParameters(param) for param in [self.data_type, self.rig_select, self.operator_select, self.hash_select]] self.experiment_browser = self.pair_analyzer.experiment_browser self.v_splitter = pg.QtGui.QSplitter() self.v_splitter.setOrientation(pg.QtCore.Qt.Vertical) self.expt_splitter = pg.QtGui.QSplitter() self.expt_splitter.setOrientation(pg.QtCore.Qt.Vertical) self.h_splitter.addWidget(self.expt_splitter) self.expt_splitter.addWidget(self.select_ptree) self.expt_splitter.addWidget(self.expt_btn) self.h_splitter.addWidget(self.v_splitter) self.v_splitter.addWidget(self.experiment_browser) self.v_splitter.addWidget(self.ptree) self.v_splitter.addWidget(self.fit_btn) self.v_splitter.addWidget(self.fit_ptree) self.v_splitter.addWidget(self.save_btn) self.v_splitter.setSizes([200, 20, 20,400, 20]) # self.next_pair_button = pg.QtGui.QPushButton("Load Next Pair") # self.v_splitter.addWidget(self.next_pair_button) self.h_splitter.addWidget(self.vc_plot.grid) self.h_splitter.addWidget(self.ic_plot.grid) self.fit_compare = self.pair_analyzer.fit_compare self.meta_compare = self.pair_analyzer.meta_compare self.v2_splitter = pg.QtGui.QSplitter() self.v2_splitter.setOrientation(pg.QtCore.Qt.Vertical) self.v2_splitter.addWidget(self.fit_compare) self.v2_splitter.addWidget(self.meta_compare) self.h_splitter.addWidget(self.v2_splitter) self.h_splitter.setSizes([100, 200, 200, 200, 500]) self.layout.addWidget(self.h_splitter) self.fit_compare.hide() self.meta_compare.hide() self.setGeometry(280, 130, 1500, 900) self.show() # self.next_pair_button.clicked.connect(self.load_next_pair) self.experiment_browser.itemSelectionChanged.connect(self.selected_pair) self.save_btn.clicked.connect(self.save_to_db) self.expt_btn.clicked.connect(self.get_expts) self.fit_btn.clicked.connect(self.pair_analyzer.fit_response_update) def save_to_db(self): try: self.pair_analyzer.save_to_db() self.save_btn.success() except: self.save_btn.failure() raise def get_expts(self): expt_query = db.query(db.Experiment) synapse_none = self.data_type['Synapse is None'] if synapse_none: subquery = db.query(db.Pair.experiment_id).filter(db.Pair.has_synapse==None).subquery() expt_query = expt_query.filter(db.Experiment.id.in_(subquery)) selected_rigs = [rig.name() for rig in self.rig_select.children() if rig.value() is True] if len(selected_rigs) != 0: expt_query = expt_query.filter(db.Experiment.rig_name.in_(selected_rigs)) selected_operators = [operator.name() for operator in self.operator_select.children() if operator.value()is True] if len(selected_operators) != 0: expt_query = expt_query.filter(db.Experiment.operator_name.in_(selected_operators)) selected_hashtags = [ht.name() for ht in self.hash_select.children()[1:] if ht.value() is True] if len(selected_hashtags) != 0: timestamps = self.get_expts_hashtag(selected_hashtags) expt_query = expt_query.filter(db.Experiment.ext_id.in_(timestamps)) expts = expt_query.all() self.set_expts(expts) def get_expts_hashtag(self, selected_hashtags): q = self.notes_session.query(notes_db.PairNotes) pairs_to_include = [] note_pairs = q.all() note_pairs.sort(key=lambda p: p.expt_id) for p in note_pairs: comments = p.notes.get('comments') if comments is None: continue if len(selected_hashtags) == 1: hashtag = selected_hashtags[0] if hashtag == '#': if hashtag in comments and all([ht not in comments for ht in comment_hashtag[1:]]): print(p.expt_id, p.pre_cell_id, p.post_cell_id, comments) pairs_to_include.append(p) else: if hashtag in comments: print(p.expt_id, p.pre_cell_id, p.post_cell_id, comments) pairs_to_include.append(p) if len(selected_hashtags) > 1: hashtag_present = [ht in comments for ht in selected_hashtags] or_expts = self.hash_select['With multiple selected:'] == 'Include if any appear' and_expts = self.hash_select['With multiple selected:'] == 'Include if all appear' if or_expts and any(hashtag_present): print(p.expt_id, p.pre_cell_id, p.post_cell_id, comments) pairs_to_include.append(p) if and_expts and all(hashtag_present): print(p.expt_id, p.pre_cell_id, p.post_cell_id, comments) pairs_to_include.append(p) return set([pair.expt_id for pair in pairs_to_include]) def set_expts(self, expts): with pg.BusyCursor(): self.experiment_browser.clear() has_data = self.data_type['Pairs with data'] if not has_data: self.experiment_browser.populate(experiments=expts, all_pairs=True) else: self.experiment_browser.populate(experiments=expts) def selected_pair(self): with pg.BusyCursor(): self.fit_compare.hide() self.meta_compare.hide() selected = self.experiment_browser.selectedItems() if len(selected) != 1: return item = selected[0] if hasattr(item, 'pair') is False: return pair = item.pair ## check to see if the pair has already been analyzed expt_id = pair.experiment.ext_id pre_cell_id = pair.pre_cell.ext_id post_cell_id = pair.post_cell.ext_id record = notes_db.get_pair_notes_record(expt_id, pre_cell_id, post_cell_id, session=self.notes_session) self.pair_param.setValue(pair) if record is None: self.pair_analyzer.load_pair(pair, self.default_session) self.pair_analyzer.analyze_responses() # self.pair_analyzer.fit_responses() else: self.pair_analyzer.load_pair(pair, self.default_session, record=record) self.pair_analyzer.analyze_responses() self.pair_analyzer.load_saved_fit(record) def explore_fit(self, mode, holding): fit = self.pair_analyzer.last_fit[mode, holding] self.fit_explorer = FitExplorer(fit) self.fit_explorer.show()
class PairView(QtGui.QWidget): """For analyzing pre/post-synaptic pairs. """ def __init__(self, parent=None): self.sweeps = [] self.channels = [] self.current_event_set = None self.event_sets = [] QtGui.QWidget.__init__(self, parent) self.layout = QtGui.QGridLayout() self.setLayout(self.layout) self.layout.setContentsMargins(0, 0, 0, 0) self.vsplit = QtGui.QSplitter(QtCore.Qt.Vertical) self.layout.addWidget(self.vsplit, 0, 0) self.pre_plot = pg.PlotWidget() self.post_plot = pg.PlotWidget() for plt in (self.pre_plot, self.post_plot): plt.setClipToView(True) plt.setDownsampling(True, True, 'peak') self.post_plot.setXLink(self.pre_plot) self.vsplit.addWidget(self.pre_plot) self.vsplit.addWidget(self.post_plot) self.response_plots = PlotGrid() self.vsplit.addWidget(self.response_plots) self.event_splitter = QtGui.QSplitter(QtCore.Qt.Horizontal) self.vsplit.addWidget(self.event_splitter) self.event_table = pg.TableWidget() self.event_splitter.addWidget(self.event_table) self.event_widget = QtGui.QWidget() self.event_widget_layout = QtGui.QGridLayout() self.event_widget.setLayout(self.event_widget_layout) self.event_splitter.addWidget(self.event_widget) self.event_splitter.setSizes([600, 100]) self.event_set_list = QtGui.QListWidget() self.event_widget_layout.addWidget(self.event_set_list, 0, 0, 1, 2) self.event_set_list.addItem("current") self.event_set_list.itemSelectionChanged.connect(self.event_set_selected) self.add_set_btn = QtGui.QPushButton('add') self.event_widget_layout.addWidget(self.add_set_btn, 1, 0, 1, 1) self.add_set_btn.clicked.connect(self.add_set_clicked) self.remove_set_btn = QtGui.QPushButton('del') self.event_widget_layout.addWidget(self.remove_set_btn, 1, 1, 1, 1) self.remove_set_btn.clicked.connect(self.remove_set_clicked) self.fit_btn = QtGui.QPushButton('fit all') self.event_widget_layout.addWidget(self.fit_btn, 2, 0, 1, 2) self.fit_btn.clicked.connect(self.fit_clicked) self.fit_plot = PlotGrid() self.fit_explorer = None self.artifact_remover = ArtifactRemover(user_width=True) self.baseline_remover = BaselineRemover() self.filter = SignalFilter() self.params = pg.parametertree.Parameter(name='params', type='group', children=[ {'name': 'pre', 'type': 'list', 'values': []}, {'name': 'post', 'type': 'list', 'values': []}, self.artifact_remover.params, self.baseline_remover.params, self.filter.params, {'name': 'time constant', 'type': 'float', 'suffix': 's', 'siPrefix': True, 'value': 10e-3, 'dec': True, 'minStep': 100e-6} ]) self.params.sigTreeStateChanged.connect(self._update_plots) def data_selected(self, sweeps, channels): self.sweeps = sweeps self.channels = channels self.params.child('pre').setLimits(channels) self.params.child('post').setLimits(channels) self._update_plots() def _update_plots(self): sweeps = self.sweeps self.current_event_set = None self.event_table.clear() # clear all plots self.pre_plot.clear() self.post_plot.clear() pre = self.params['pre'] post = self.params['post'] # If there are no selected sweeps or channels have not been set, return if len(sweeps) == 0 or pre == post or pre not in self.channels or post not in self.channels: return pre_mode = sweeps[0][pre].clamp_mode post_mode = sweeps[0][post].clamp_mode for ch, mode, plot in [(pre, pre_mode, self.pre_plot), (post, post_mode, self.post_plot)]: units = 'A' if mode == 'vc' else 'V' plot.setLabels(left=("Channel %d" % ch, units), bottom=("Time", 's')) # Iterate over selected channels of all sweeps, plotting traces one at a time # Collect information about pulses and spikes pulses = [] spikes = [] post_traces = [] for i,sweep in enumerate(sweeps): pre_trace = sweep[pre]['primary'] post_trace = sweep[post]['primary'] # Detect pulse times stim = sweep[pre]['command'].data sdiff = np.diff(stim) on_times = np.argwhere(sdiff > 0)[1:, 0] # 1: skips test pulse off_times = np.argwhere(sdiff < 0)[1:, 0] pulses.append(on_times) # filter data post_filt = self.artifact_remover.process(post_trace, list(on_times) + list(off_times)) post_filt = self.baseline_remover.process(post_filt) post_filt = self.filter.process(post_filt) post_traces.append(post_filt) # plot raw data color = pg.intColor(i, hues=len(sweeps)*1.3, sat=128) color.setAlpha(128) for trace, plot in [(pre_trace, self.pre_plot), (post_filt, self.post_plot)]: plot.plot(trace.time_values, trace.data, pen=color, antialias=False) # detect spike times spike_inds = [] spike_info = [] for on, off in zip(on_times, off_times): spike = detect_evoked_spike(sweep[pre], [on, off]) spike_info.append(spike) if spike is None: spike_inds.append(None) else: spike_inds.append(spike['rise_index']) spikes.append(spike_info) dt = pre_trace.dt vticks = pg.VTickGroup([x * dt for x in spike_inds if x is not None], yrange=[0.0, 0.2], pen=color) self.pre_plot.addItem(vticks) # Iterate over spikes, plotting average response all_responses = [] avg_responses = [] fits = [] fit = None npulses = max(map(len, pulses)) self.response_plots.clear() self.response_plots.set_shape(1, npulses+1) # 1 extra for global average self.response_plots.setYLink(self.response_plots[0,0]) for i in range(1, npulses+1): self.response_plots[0,i].hideAxis('left') units = 'A' if post_mode == 'vc' else 'V' self.response_plots[0, 0].setLabels(left=("Averaged events (Channel %d)" % post, units)) fit_pen = {'color':(30, 30, 255), 'width':2, 'dash': [1, 1]} for i in range(npulses): # get the chunk of each sweep between spikes responses = [] all_responses.append(responses) for j, sweep in enumerate(sweeps): # get the current spike if i >= len(spikes[j]): continue spike = spikes[j][i] if spike is None: continue # find next spike next_spike = None for sp in spikes[j][i+1:]: if sp is not None: next_spike = sp break # determine time range for response max_len = int(40e-3 / dt) # don't take more than 50ms for any response start = spike['rise_index'] if next_spike is not None: stop = min(start + max_len, next_spike['rise_index']) else: stop = start + max_len # collect data from this trace trace = post_traces[j] d = trace.data[start:stop].copy() responses.append(d) if len(responses) == 0: continue # extend all responses to the same length and take nanmean avg = ragged_mean(responses, method='clip') avg -= float_mode(avg[:int(1e-3/dt)]) avg_responses.append(avg) # plot average response for this pulse start = np.median([sp[i]['rise_index'] for sp in spikes if sp[i] is not None]) * dt t = np.arange(len(avg)) * dt self.response_plots[0,i].plot(t, avg, pen='w', antialias=True) # fit! fit = self.fit_psp(avg, t, dt, post_mode) fits.append(fit) # let the user mess with this fit curve = self.response_plots[0,i].plot(t, fit.eval(), pen=fit_pen, antialias=True).curve curve.setClickable(True) curve.fit = fit curve.sigClicked.connect(self.fit_curve_clicked) # display global average global_avg = ragged_mean(avg_responses, method='clip') t = np.arange(len(global_avg)) * dt self.response_plots[0,-1].plot(t, global_avg, pen='w', antialias=True) global_fit = self.fit_psp(global_avg, t, dt, post_mode) self.response_plots[0,-1].plot(t, global_fit.eval(), pen=fit_pen, antialias=True) # display fit parameters in table events = [] for i,f in enumerate(fits + [global_fit]): if f is None: continue if i >= len(fits): vals = OrderedDict([('id', 'avg'), ('spike_time', np.nan), ('spike_stdev', np.nan)]) else: spt = [s[i]['peak_index'] * dt for s in spikes if s[i] is not None] vals = OrderedDict([('id', i), ('spike_time', np.mean(spt)), ('spike_stdev', np.std(spt))]) vals.update(OrderedDict([(k,f.best_values[k]) for k in f.params.keys()])) events.append(vals) self.current_event_set = (pre, post, events, sweeps) self.event_set_list.setCurrentRow(0) self.event_set_selected() def fit_psp(self, data, t, dt, clamp_mode): mode = float_mode(data[:int(1e-3/dt)]) sign = -1 if data.mean() - mode < 0 else 1 params = OrderedDict([ ('xoffset', (2e-3, 3e-4, 5e-3)), ('yoffset', data[0]), ('amp', sign * 10e-12), #('k', (2e-3, 50e-6, 10e-3)), ('rise_time', (2e-3, 50e-6, 10e-3)), ('decay_tau', (4e-3, 500e-6, 50e-3)), ('rise_power', (2.0, 'fixed')), ]) if clamp_mode == 'ic': params['amp'] = sign * 10e-3 #params['k'] = (5e-3, 50e-6, 20e-3) params['rise_time'] = (5e-3, 50e-6, 20e-3) params['decay_tau'] = (15e-3, 500e-6, 150e-3) fit_kws = {'xtol': 1e-3, 'maxfev': 100} psp = fitting.Psp() return psp.fit(data, x=t, fit_kws=fit_kws, params=params) def add_set_clicked(self): if self.current_event_set is None: return ces = self.current_event_set self.event_sets.append(ces) item = QtGui.QListWidgetItem("%d -> %d" % (ces[0], ces[1])) self.event_set_list.addItem(item) item.event_set = ces def remove_set_clicked(self): if self.event_set_list.currentRow() == 0: return sel = self.event_set_list.takeItem(self.event_set_list.currentRow()) self.event_sets.remove(sel.event_set) def event_set_selected(self): sel = self.event_set_list.selectedItems()[0] if sel.text() == "current": self.event_table.setData(self.current_event_set[2]) else: self.event_table.setData(sel.event_set[2]) def fit_clicked(self): from synaptic_release import ReleaseModel self.fit_plot.clear() self.fit_plot.show() n_sets = len(self.event_sets) self.fit_plot.set_shape(n_sets, 1) l = self.fit_plot[0, 0].legend if l is not None: l.scene.removeItem(l) self.fit_plot[0, 0].legend = None self.fit_plot[0, 0].addLegend() for i in range(n_sets): self.fit_plot[i,0].setXLink(self.fit_plot[0, 0]) spike_sets = [] for i,evset in enumerate(self.event_sets): evset = evset[2][:-1] # select events, skip last row (average) x = np.array([ev['spike_time'] for ev in evset]) y = np.array([ev['amp'] for ev in evset]) x -= x[0] x *= 1000 y /= y[0] spike_sets.append((x, y)) self.fit_plot[i,0].plot(x/1000., y, pen=None, symbol='o') model = ReleaseModel() dynamics_types = ['Dep', 'Fac', 'UR', 'SMR', 'DSR'] for k in dynamics_types: model.Dynamics[k] = 0 fit_params = [] with pg.ProgressDialog("Fitting release model..", 0, len(dynamics_types)) as dlg: for k in dynamics_types: model.Dynamics[k] = 1 fit_params.append(model.run_fit(spike_sets)) dlg += 1 if dlg.wasCanceled(): return max_color = len(fit_params)*1.5 for i,params in enumerate(fit_params): for j,spikes in enumerate(spike_sets): x, y = spikes t = np.linspace(0, x.max(), 1000) output = model.eval(x, params.values()) y = output[:,1] x = output[:,0]/1000. self.fit_plot[j,0].plot(x, y, pen=(i,max_color), name=dynamics_types[i]) def fit_curve_clicked(self, curve): if self.fit_explorer is None: self.fit_explorer = FitExplorer(curve.fit) else: self.fit_explorer.set_fit(curve.fit) self.fit_explorer.show()
import sys import pyqtgraph as pg from neuroanalysis.ui.fitting import FitExplorer from multipatch_analysis.database import default_db as db from multipatch_analysis.avg_response_fit import get_pair_avg_fits from multipatch_analysis.ui.avg_response_fit import AvgResponseFitUi app = pg.mkQApp() pg.dbg() session = db.session() expt_id, pre_cell_id, post_cell_id = sys.argv[1:4] pair = db.experiment_from_ext_id(expt_id, session=session).pairs[pre_cell_id, post_cell_id] ui = AvgResponseFitUi() fits = get_pair_avg_fits(pair, session=session, ui=ui) ui.widget.show() fe = FitExplorer(fits['vc', -55]['fit_result']) fe.show() if sys.flags.interactive == 0: app.exec_()
pg.dbg() pg.mkQApp() t = np.arange(2000) * 1e-4 rise_time = 5e-3 decay_tau = 20e-3 n_psp = 4 args = { 'yoffset': (0, 'fixed'), 'xoffset': (2e-3, -1e-3, 5e-3), 'rise_time': (rise_time, rise_time*0.5, rise_time*2), 'decay_tau': (decay_tau, decay_tau*0.5, decay_tau*2), 'rise_power': (2, 'fixed'), } for i in range(n_psp): args['xoffset%d'%i] = (25e-3*i, 'fixed') args['amp%d'%i] = (250e-6, 0, 10e-3) fit_kws = {'xtol': 1e-4, 'maxfev': 1000, 'nan_policy': 'omit'} model = PspTrain(n_psp) args2 = {k:(v[0] if isinstance(v, tuple) else v) for k,v in args.items()} y = np.random.normal(size=len(t), scale=30e-6) + model.eval(x=t, **args2) fit = model.fit(y, x=t, params=args, fit_kws=fit_kws, method='leastsq') ex = FitExplorer(fit) ex.show()
from neuroanalysis.fitting import Psp from neuroanalysis.ui.fitting import FitExplorer pg.mkQApp() pg.dbg() # Load PSP data from the test_data repository if len(sys.argv) == 1: data_file = 'test_data/test_psp_fit/1485904693.10_8_2NOTstacked.json' else: data_file = sys.argv[1] data = json.load(open(data_file)) y = np.array(data['input']['data']) x = np.arange(len(y)) * data['input']['dt'] psp = Psp() params = OrderedDict([ ('xoffset', (10e-3, 10e-3, 15e-3)), ('yoffset', 0), ('amp', 0.1e-3), ('rise_time', (2e-3, 500e-6, 10e-3)), ('decay_tau', (4e-3, 1e-3, 50e-3)), ('rise_power', (2.0, 'fixed')), ]) fit = psp.fit(y, x=x, xtol=1e-3, maxfev=100, params=params) x = FitExplorer(fit=fit) x.show()