def __init__(self, kwik_path, config_dir=None, **kwargs): super(KwikController, self).__init__() kwik_path = op.realpath(kwik_path) _backup(kwik_path) self.model = KwikModel(kwik_path, **kwargs) m = self.model self.channel_vertical_order = np.argsort(m.channel_positions[:, 1]) self.distance_max = _get_distance_max(self.model.channel_positions) self.cache_dir = op.join(op.dirname(kwik_path), '.phy') cg = kwargs.get('channel_group', None) if cg is not None: self.cache_dir = op.join(self.cache_dir, str(cg)) self.context = Context(self.cache_dir) self.config_dir = config_dir self._set_cache() self.supervisor = self._set_supervisor() self.selector = self._set_selector() self.color_selector = ColorSelector() self._show_all_spikes = False attach_plugins(self, plugins=kwargs.get('plugins', None), config_dir=config_dir)
def __init__(self, dat_path=None, proc_path=None, config_dir=None, model=None, **kwargs): super(TemplateController, self).__init__() if model is None: assert dat_path dat_path = op.abspath(dat_path) self.model = TemplateModel(dat_path=dat_path, proc_path=proc_path, **kwargs) else: self.model = model self.cache_dir = op.join(self.model.dir_path, '.phy') self.context = Context(self.cache_dir) self.config_dir = config_dir self._set_cache() self.supervisor = self._set_supervisor() self.selector = self._set_selector() self.color_selector = ColorSelector() self._show_all_spikes = False attach_plugins(self, plugins=kwargs.get('plugins', None), config_dir=config_dir)
def __init__(self, data_path, config_dir=None, **kwargs): super(NeoController, self).__init__() self.model = NeoModel(data_path, **kwargs) self.distance_max = _get_distance_max(self.model.channel_positions) self.cache_dir = op.join(self.model.output_dir, '.phy') cg = kwargs.get('channel_group', None) cg = cg or 0 self.cache_dir = op.join(self.cache_dir, 'channel_group_' + str(cg)) self.context = Context(self.cache_dir) self.config_dir = config_dir self._set_cache() self.supervisor = self._set_supervisor() self.selector = self._set_selector() self.color_selector = ColorSelector() self._show_all_spikes = False attach_plugins(self, plugins=kwargs.get('plugins', None), config_dir=config_dir)
class TemplateController(EventEmitter): gui_name = 'TemplateGUI' n_spikes_waveforms = 100 batch_size_waveforms = 10 n_spikes_features = 10000 n_spikes_amplitudes = 10000 n_spikes_correlograms = 100000 def __init__(self, dat_path=None, config_dir=None, model=None, **kwargs): super(TemplateController, self).__init__() if model is None: assert dat_path dat_path = op.abspath(dat_path) self.model = TemplateModel(dat_path, **kwargs) else: self.model = model self.cache_dir = op.join(self.model.dir_path, '.phy') self.context = Context(self.cache_dir) self.config_dir = config_dir self._set_cache() self.supervisor = self._set_supervisor() self.selector = self._set_selector() self.color_selector = ColorSelector() self._show_all_spikes = False attach_plugins(self, plugins=kwargs.get('plugins', None), config_dir=config_dir) # Internal methods # ------------------------------------------------------------------------- def _set_cache(self): memcached = ('get_template_counts', 'get_template_for_cluster', 'get_best_channel', 'get_best_channels', 'get_probe_depth', 'get_real_amplitude', 'get_violation_rate', 'get_noise_ratioz', 'get_ratio', 'get_firing_rate', ) cached = ('_get_waveforms', '_get_template_waveforms', '_get_features', '_get_template_features', '_get_amplitudes', '_get_correlograms', ) _cache_methods(self, memcached, cached) def _set_supervisor(self): # Load the new cluster id. new_cluster_id = self.context.load('new_cluster_id'). \ get('new_cluster_id', None) cluster_groups = self.model.get_metadata('group') supervisor = Supervisor(self.model.spike_clusters, similarity=self.similarity, cluster_groups=cluster_groups, new_cluster_id=new_cluster_id, context=self.context, ) # Load the non-group metadata from the model to the cluster_meta. for name in self.model.metadata_fields: if name == 'group': continue values = self.model.get_metadata(name) for cluster_id, value in values.items(): supervisor.cluster_meta.set(name, [cluster_id], value, add_to_stack=False) @supervisor.connect def on_create_cluster_views(): supervisor.add_column(self.get_best_channel, name='channel') supervisor.add_column(self.get_probe_depth, name='depth') supervisor.add_column(self.get_real_amplitude,name='amp') #supervisor.add_column(self.get_noise_estimate,name='N') supervisor.add_column(self.get_noise_ratioz,name='Amp.o.N') supervisor.add_column(self.get_ratio,name='Bin1.o.M') supervisor.add_column(self.get_firing_rate,name='FR') #supervisor.add_column(self.get_violation_rate, name='vioRate') @supervisor.actions.add(shortcut='shift+ctrl+k') def split_init(cluster_ids=None): """Split a cluster according to the original templates.""" if cluster_ids is None: cluster_ids = supervisor.selected s = supervisor.clustering.spikes_in_clusters(cluster_ids) supervisor.split(s, self.model.spike_templates[s]) # Save. @supervisor.connect def on_request_save(spike_clusters, groups, *labels): """Save the modified data.""" # Save the clusters. self.model.save_spike_clusters(spike_clusters) # Save cluster metadata. for name, values in labels: self.model.save_metadata(name, values) return supervisor def _set_selector(self): def spikes_per_cluster(cluster_id): return self.supervisor.clustering.spikes_per_cluster[cluster_id] return Selector(spikes_per_cluster) def _add_view(self, gui, view): view.attach(gui) self.emit('add_view', gui, view) return view # Model methods # ------------------------------------------------------------------------- def get_template_counts(self, cluster_id): """Return a histogram of the number of spikes in each template for a given cluster.""" spike_ids = self.supervisor.clustering.spikes_per_cluster[cluster_id] st = self.model.spike_templates[spike_ids] return np.bincount(st, minlength=self.model.n_templates) def get_template_for_cluster(self, cluster_id): """Return the largest template associated to a cluster.""" spike_ids = self.supervisor.clustering.spikes_per_cluster[cluster_id] st = self.model.spike_templates[spike_ids] template_ids, counts = np.unique(st, return_counts=True) ind = np.argmax(counts) return template_ids[ind] def similarity(self, cluster_id): """Return the list of similar clusters to a given cluster.""" # Templates of the cluster. temp_i = np.nonzero(self.get_template_counts(cluster_id))[0] # The similarity of the cluster with each template. sims = np.max(self.model.similar_templates[temp_i, :], axis=0) def _sim_ij(cj): # Templates of the cluster. if cj < self.model.n_templates: return float(sims[cj]) temp_j = np.nonzero(self.get_template_counts(cj))[0] return float(np.max(sims[temp_j])) out = [(cj, _sim_ij(cj)) for cj in self.supervisor.clustering.cluster_ids] # NOTE: hard-limit to 100 for performance reasons. return sorted(out, key=itemgetter(1), reverse=True)[:100] def get_best_channel(self, cluster_id): """Return the best channel of a given cluster.""" template_id = self.get_template_for_cluster(cluster_id) return self.model.get_template(template_id).best_channel def get_real_amplitude(self,cluster_id): waveforms=self._get_mean_waveforms(cluster_id) amp=waveforms.data[0,:,0].max()-waveforms.data[0,:,0].min() return amp def get_violation_rate(self,cluster_id): thresh=0.002 spt=self._get_spike_times(cluster_id=cluster_id, load_all=True) isi=np.diff(spt.data) violation=np.count_nonzero(isi<=thresh) try: rate=violation*1.0/len(isi) except: print('cannot calculate vio rate for cluster: ' + str(cluster_id)) rate = float('nan') return rate def get_violation_rate_threshold(self,cluster_id): """based on a ref period of 0.002s, formula from Hill et al 2011 """ ref_period=0.002 fp_rate=0.05 T=self.model.duration N=self.supervisor.n_spikes(cluster_id) R=2*ref_period*N*(1-fp_rate)*fp_rate/T return R def get_noise_estimate(self,cluster_id): """""" aa=self._get_waveforms(cluster_id) pre=aa.data[:,0:25,0] #bc best channel is always first; pre=pre.flatten() return pre.std() def get_noise_ratioz(self,cluster_id): ratio = self.get_real_amplitude(cluster_id)/self.get_noise_estimate(cluster_id) #str(f'{ratio:.1f}') return np.around(ratio,1) def get_firing_rate(self,cluster_id): T=self.model.duration N=self.supervisor.n_spikes(cluster_id) R = N/T return R def get_ratio(self,cluster_id): cgr=self._get_correlograms([cluster_id],0.001,0.05) vio_bin = range(26,27) bin1=cgr[0,0,vio_bin].sum() ma=cgr.max() R=bin1/ma return R def get_best_channels(self, cluster_id): """Return the best channels of a given cluster.""" template_id = self.get_template_for_cluster(cluster_id) return self.model.get_template(template_id).channel_ids def get_probe_depth(self, cluster_id): """Return the depth of a cluster.""" channel_id = self.get_best_channel(cluster_id) return self.model.channel_positions[channel_id][1] # Waveforms # ------------------------------------------------------------------------- def _get_waveforms(self, cluster_id): """Return a selection of waveforms for a cluster.""" pos = self.model.channel_positions spike_ids = self.selector.select_spikes([cluster_id], self.n_spikes_waveforms, self.batch_size_waveforms, ) channel_ids = self.get_best_channels(cluster_id) data = self.model.get_waveforms(spike_ids, channel_ids) data = data - data.mean() return Bunch(data=data, channel_ids=channel_ids, channel_positions=pos[channel_ids], ) def _get_mean_waveforms(self, cluster_id): b = self._get_waveforms(cluster_id) b.data = b.data.mean(axis=0)[np.newaxis, ...] b['alpha'] = 1. return b def _get_template_waveforms(self, cluster_id): """Return the waveforms of the templates corresponding to a cluster.""" pos = self.model.channel_positions count = self.get_template_counts(cluster_id) template_ids = np.nonzero(count)[0] count = count[template_ids] # Get local channels. channel_ids = self.get_best_channels(cluster_id) # Get masks. masks = count / float(count.max()) masks = np.tile(masks.reshape((-1, 1)), (1, len(channel_ids))) # Get the mean amplitude for the cluster. mean_amp = self._get_amplitudes(cluster_id).y.mean() # Get all templates from which this cluster stems from. templates = [self.model.get_template(template_id) for template_id in template_ids] data = np.stack([b.template * mean_amp for b in templates], axis=0) cols = np.stack([b.channel_ids for b in templates], axis=0) # NOTE: transposition because the channels should be in the second # dimension for from_sparse. data = data.transpose((0, 2, 1)) assert data.ndim == 3 assert data.shape[1] == cols.shape[1] waveforms = from_sparse(data, cols, channel_ids) # Transpose back. waveforms = waveforms.transpose((0, 2, 1)) return Bunch(data=waveforms, channel_ids=channel_ids, channel_positions=pos[channel_ids], masks=masks, alpha=1., ) def add_waveform_view(self, gui): f = (self._get_waveforms if self.model.traces is not None else self._get_template_waveforms) v = WaveformView(waveforms=f, ) v = self._add_view(gui, v) v.actions.separator() @v.actions.add(shortcut='w') def toggle_templates(): f, g = self._get_waveforms, self._get_template_waveforms if self.model.traces is None: return v.waveforms = f if v.waveforms == g else g v.on_select() @v.actions.add(shortcut='m') def toggle_mean_waveforms(): f, g = self._get_waveforms, self._get_mean_waveforms v.waveforms = f if v.waveforms == g else g v.on_select() return v # Features # ------------------------------------------------------------------------- def _get_spike_ids(self, cluster_id=None, load_all=None): nsf = self.n_spikes_features if cluster_id is None: # Background points. ns = self.model.n_spikes spike_ids = np.arange(0, ns, max(1, ns // nsf)) else: # Load all spikes from the cluster if load_all is True. n = nsf if not load_all else None spike_ids = self.selector.select_spikes([cluster_id], n) # Remove spike_ids that do not belong to model.features_rows if self.model.features_rows is not None: spike_ids = np.intersect1d(spike_ids, self.model.features_rows) return spike_ids def _get_spike_times(self, cluster_id=None, load_all=None): spike_ids = self._get_spike_ids(cluster_id, load_all=load_all) return Bunch(data=self.model.spike_times[spike_ids], spike_ids=spike_ids, lim=(0., self.model.duration)) def _get_features(self, cluster_id=None, channel_ids=None, load_all=None): spike_ids = self._get_spike_ids(cluster_id, load_all=load_all) # Use the best channels only if a cluster is specified and # channels are not specified. if cluster_id is not None and channel_ids is None: channel_ids = self.get_best_channels(cluster_id) data = self.model.get_features(spike_ids, channel_ids) assert data.shape[:2] == (len(spike_ids), len(channel_ids)) # Remove rows with at least one nan value. nan = np.unique(np.nonzero(np.isnan(data))[0]) nonnan = np.setdiff1d(np.arange(len(spike_ids)), nan) data = data[nonnan, ...] spike_ids = spike_ids[nonnan] assert data.shape[:2] == (len(spike_ids), len(channel_ids)) assert np.isnan(data).sum() == 0 return Bunch(data=data, spike_ids=spike_ids, channel_ids=channel_ids, ) def add_feature_view(self, gui): v = FeatureView(features=self._get_features, attributes={'time': self._get_spike_times} ) return self._add_view(gui, v) # Template features # ------------------------------------------------------------------------- def _get_template_features(self, cluster_ids): assert len(cluster_ids) == 2 clu0, clu1 = cluster_ids s0 = self._get_spike_ids(clu0) s1 = self._get_spike_ids(clu1) n0 = self.get_template_counts(clu0) n1 = self.get_template_counts(clu1) t0 = self.model.get_template_features(s0) t1 = self.model.get_template_features(s1) x0 = np.average(t0, weights=n0, axis=1) y0 = np.average(t0, weights=n1, axis=1) x1 = np.average(t1, weights=n0, axis=1) y1 = np.average(t1, weights=n1, axis=1) return Bunch(x0=x0, y0=y0, x1=x1, y1=y1, data_bounds=(min(x0.min(), x1.min()), min(y0.min(), y1.min()), max(y0.max(), y1.max()), max(y0.max(), y1.max()), ), ) def add_template_feature_view(self, gui): v = TemplateFeatureView(coords=self._get_template_features, ) return self._add_view(gui, v) # Traces # ------------------------------------------------------------------------- def _get_traces(self, interval): """Get traces and spike waveforms.""" k = self.model.n_samples_templates m = self.model traces_interval = select_traces(m.traces, interval, sample_rate=m.sample_rate) # Reorder vertically. out = Bunch(data=traces_interval) out.waveforms = [] def gbc(cluster_id): return self.get_best_channels(cluster_id) for b in _iter_spike_waveforms(interval=interval, traces_interval=traces_interval, model=self.model, supervisor=self.supervisor, color_selector=self.color_selector, n_samples_waveforms=k, get_best_channels=gbc, show_all_spikes=self._show_all_spikes, ): i = b.spike_id # Compute the residual: waveform - amplitude * template. residual = b.copy() template_id = m.spike_templates[i] template = m.get_template(template_id).template amplitude = m.amplitudes[i] residual.data = residual.data - amplitude * template out.waveforms.extend([b, residual]) return out def _jump_to_spike(self, view, delta=+1): """Jump to next or previous spike from the selected clusters.""" m = self.model cluster_ids = self.supervisor.selected if len(cluster_ids) == 0: return spc = self.supervisor.clustering.spikes_per_cluster spike_ids = spc[cluster_ids[0]] spike_times = m.spike_times[spike_ids] ind = np.searchsorted(spike_times, view.time) n = len(spike_times) view.go_to(spike_times[(ind + delta) % n]) def add_trace_view(self, gui): m = self.model v = TraceView(traces=self._get_traces, n_channels=m.n_channels, sample_rate=m.sample_rate, duration=m.duration, channel_vertical_order=m.channel_vertical_order, ) self._add_view(gui, v) v.actions.separator() @v.actions.add(shortcut='alt+pgdown') def go_to_next_spike(): """Jump to the next spike from the first selected cluster.""" self._jump_to_spike(v, +1) @v.actions.add(shortcut='alt+pgup') def go_to_previous_spike(): """Jump to the previous spike from the first selected cluster.""" self._jump_to_spike(v, -1) v.actions.separator() @v.actions.add(shortcut='alt+s') def toggle_highlighted_spikes(): """Toggle between showing all spikes or selected spikes.""" self._show_all_spikes = not self._show_all_spikes v.set_interval(force_update=True) @gui.connect_ def on_spike_click(channel_id=None, spike_id=None, cluster_id=None): # Select the corresponding cluster. self.supervisor.select([cluster_id]) # Update the trace view. v.on_select([cluster_id], force_update=True) return v # Correlograms # ------------------------------------------------------------------------- def _get_correlograms(self, cluster_ids, bin_size, window_size): spike_ids = self.selector.select_spikes(cluster_ids, self.n_spikes_correlograms, subset='random', ) st = self.model.spike_times[spike_ids] sc = self.supervisor.clustering.spike_clusters[spike_ids] return correlograms(st, sc, sample_rate=self.model.sample_rate, cluster_ids=cluster_ids, bin_size=bin_size, window_size=window_size, ) def add_correlogram_view(self, gui): m = self.model v = CorrelogramView(correlograms=self._get_correlograms, sample_rate=m.sample_rate, ) return self._add_view(gui, v) # Amplitudes # ------------------------------------------------------------------------- def _get_amplitudes(self, cluster_id): n = self.n_spikes_amplitudes m = self.model spike_ids = self.selector.select_spikes([cluster_id], n) x = m.spike_times[spike_ids] y = m.amplitudes[spike_ids] return Bunch(x=x, y=y, data_bounds=(0., 0., m.duration, y.max())) def add_amplitude_view(self, gui): v = AmplitudeView(coords=self._get_amplitudes, ) return self._add_view(gui, v) # Probe view # ------------------------------------------------------------------------- def add_probe_view(self, gui): v = ProbeView(positions=self.model.channel_positions, best_channels=self.get_best_channels, ) v.attach(gui) return v # GUI # ------------------------------------------------------------------------- def create_gui(self, **kwargs): gui = GUI(name=self.gui_name, subtitle=self.model.dat_path, config_dir=self.config_dir, **kwargs) self.supervisor.attach(gui) self.add_waveform_view(gui) if self.model.traces is not None: self.add_trace_view(gui) if self.model.features is not None: self.add_feature_view(gui) if self.model.template_features is not None: self.add_template_feature_view(gui) self.add_correlogram_view(gui) if self.model.amplitudes is not None: self.add_amplitude_view(gui) self.add_probe_view(gui) # Save the memcache when closing the GUI. @gui.connect_ def on_close(): self.context.save_memcache() self.emit('gui_ready', gui) return gui
class KwikController(EventEmitter): gui_name = 'KwikGUI' n_spikes_waveforms = 100 batch_size_waveforms = 10 n_spikes_features = 10000 n_spikes_amplitudes = 10000 n_spikes_close_clusters = 100 n_closest_channels = 16 def __init__(self, kwik_path, config_dir=None, **kwargs): super(KwikController, self).__init__() kwik_path = op.realpath(kwik_path) _backup(kwik_path) self.model = KwikModel(kwik_path, **kwargs) m = self.model self.channel_vertical_order = np.argsort(m.channel_positions[:, 1]) self.distance_max = _get_distance_max(self.model.channel_positions) self.cache_dir = op.join(op.dirname(kwik_path), '.phy') cg = kwargs.get('channel_group', None) if cg is not None: self.cache_dir = op.join(self.cache_dir, str(cg)) self.context = Context(self.cache_dir) self.config_dir = config_dir self._set_cache() self.supervisor = self._set_supervisor() self.selector = self._set_selector() self.color_selector = ColorSelector() self._show_all_spikes = False attach_plugins(self, plugins=kwargs.get('plugins', None), config_dir=config_dir) # Internal methods # ------------------------------------------------------------------------- def _set_cache(self): memcached = ( 'get_best_channels', 'get_probe_depth', '_get_mean_masks', '_get_mean_waveforms', ) cached = ( '_get_waveforms', '_get_features', '_get_masks', ) _cache_methods(self, memcached, cached) def _set_supervisor(self): # Load the new cluster id. new_cluster_id = self.context.load('new_cluster_id'). \ get('new_cluster_id', None) cluster_groups = self.model.cluster_groups supervisor = Supervisor( self.model.spike_clusters, similarity=self.similarity, cluster_groups=cluster_groups, new_cluster_id=new_cluster_id, context=self.context, ) @supervisor.connect def on_create_cluster_views(): supervisor.add_column(self.get_best_channel, name='channel') supervisor.add_column(self.get_probe_depth, name='depth') @supervisor.actions.add def recluster(): """Relaunch KlustaKwik on the selected clusters.""" # Selected clusters. cluster_ids = supervisor.selected spike_ids = self.selector.select_spikes(cluster_ids) logger.info("Running KlustaKwik on %d spikes.", len(spike_ids)) # Run KK2 in a temporary directory to avoid side effects. n = 10 with TemporaryDirectory() as tempdir: spike_clusters, metadata = cluster( self.model, spike_ids, num_starting_clusters=n, tempdir=tempdir, ) self.supervisor.split(spike_ids, spike_clusters) # Save. @supervisor.connect def on_request_save(spike_clusters, groups, *labels): """Save the modified data.""" groups = {c: g.title() for c, g in groups.items()} self.model.save(spike_clusters, groups) return supervisor def _set_selector(self): def spikes_per_cluster(cluster_id): return self.supervisor.clustering.spikes_per_cluster[cluster_id] return Selector(spikes_per_cluster) def _add_view(self, gui, view): view.attach(gui) self.emit('add_view', gui, view) return view # Model methods # ------------------------------------------------------------------------- def get_best_channel(self, cluster_id): return self.get_best_channels(cluster_id)[0] def get_best_channels(self, cluster_id): """Only used in the trace view.""" mm = self._get_mean_masks(cluster_id) channel_ids = np.argsort(mm)[::-1] ind = mm[channel_ids] > .1 if np.sum(ind) > 0: channel_ids = channel_ids[ind] else: channel_ids = channel_ids[:4] return channel_ids def get_cluster_position(self, cluster_id): channel_id = self.get_best_channel(cluster_id) return self.model.channel_positions[channel_id] def get_probe_depth(self, cluster_id): return self.get_cluster_position(cluster_id)[1] def similarity(self, cluster_id): """Return the list of similar clusters to a given cluster.""" pos_i = self.get_cluster_position(cluster_id) assert len(pos_i) == 2 def _sim_ij(cj): """Distance between channel position of clusters i and j.""" pos_j = self.get_cluster_position(cj) assert len(pos_j) == 2 d = np.sqrt(np.sum((pos_j - pos_i)**2)) return self.distance_max - d out = [(cj, _sim_ij(cj)) for cj in self.supervisor.clustering.cluster_ids] return sorted(out, key=itemgetter(1), reverse=True) # Waveforms # ------------------------------------------------------------------------- def _get_masks(self, cluster_id): spike_ids = self.selector.select_spikes( [cluster_id], self.n_spikes_waveforms, self.batch_size_waveforms, ) return self.model.all_masks[spike_ids] def _get_mean_masks(self, cluster_id): return np.mean(self._get_masks(cluster_id), axis=0) def _get_waveforms(self, cluster_id): """Return a selection of waveforms for a cluster.""" pos = self.model.channel_positions spike_ids = self.selector.select_spikes( [cluster_id], self.n_spikes_waveforms, self.batch_size_waveforms, ) data = self.model.all_waveforms[spike_ids] mm = self._get_mean_masks(cluster_id) mw = np.mean(data, axis=0) amp = get_waveform_amplitude(mm, mw) masks = self._get_masks(cluster_id) # Find the best channels. channel_ids = np.argsort(amp)[::-1] return Bunch( data=data[..., channel_ids], channel_ids=channel_ids, channel_positions=pos[channel_ids], masks=masks[:, channel_ids], ) def _get_mean_waveforms(self, cluster_id): b = self._get_waveforms(cluster_id).copy() b.data = np.mean(b.data, axis=0)[np.newaxis, ...] b.masks = np.mean(b.masks, axis=0)[np.newaxis, ...]**.1 b['alpha'] = 1. return b def add_waveform_view(self, gui): v = WaveformView(waveforms=self._get_waveforms, ) v = self._add_view(gui, v) v.actions.separator() @v.actions.add(shortcut='m') def toggle_mean_waveforms(): f, g = self._get_waveforms, self._get_mean_waveforms v.waveforms = f if v.waveforms == g else g v.on_select() return v # Features # ------------------------------------------------------------------------- def _get_spike_ids(self, cluster_id=None, load_all=None): nsf = self.n_spikes_features if cluster_id is None: # Background points. ns = self.model.n_spikes return np.arange(0, ns, max(1, ns // nsf)) else: # Load all spikes from the cluster if load_all is True. n = nsf if not load_all else None return self.selector.select_spikes([cluster_id], n) def _get_spike_times(self, cluster_id=None, load_all=None): spike_ids = self._get_spike_ids(cluster_id, load_all=load_all) return Bunch(data=self.model.spike_times[spike_ids], lim=(0., self.model.duration)) def _get_features(self, cluster_id=None, channel_ids=None, load_all=None): spike_ids = self._get_spike_ids(cluster_id, load_all=load_all) # Use the best channels only if a cluster is specified and # channels are not specified. if cluster_id is not None and channel_ids is None: channel_ids = self.get_best_channels(cluster_id) f = self.model.all_features[spike_ids][:, channel_ids] m = self.model.all_masks[spike_ids][:, channel_ids] return Bunch( data=f, masks=m, spike_ids=spike_ids, channel_ids=channel_ids, ) def add_feature_view(self, gui): v = FeatureView(features=self._get_features, attributes={'time': self._get_spike_times}) return self._add_view(gui, v) # Traces # ------------------------------------------------------------------------- def _get_traces(self, interval): """Get traces and spike waveforms.""" ns = self.model.n_samples_waveforms m = self.model c = self.channel_vertical_order traces_interval = select_traces(m.traces, interval, sample_rate=m.sample_rate) # Reorder vertically. traces_interval = traces_interval[:, c] def gbc(cluster_id): ch = self.get_best_channels(cluster_id) return ch out = Bunch(data=traces_interval) out.waveforms = [] for b in _iter_spike_waveforms( interval=interval, traces_interval=traces_interval, model=self.model, supervisor=self.supervisor, color_selector=self.color_selector, n_samples_waveforms=ns, get_best_channels=gbc, show_all_spikes=self._show_all_spikes, ): b.channel_labels = m.channel_order[b.channel_ids] out.waveforms.append(b) return out def _jump_to_spike(self, view, delta=+1): """Jump to next or previous spike from the selected clusters.""" m = self.model cluster_ids = self.supervisor.selected if len(cluster_ids) == 0: return spc = self.supervisor.clustering.spikes_per_cluster spike_ids = spc[cluster_ids[0]] spike_times = m.spike_times[spike_ids] ind = np.searchsorted(spike_times, view.time) n = len(spike_times) view.go_to(spike_times[(ind + delta) % n]) def add_trace_view(self, gui): m = self.model v = TraceView( traces=self._get_traces, n_channels=m.n_channels, sample_rate=m.sample_rate, duration=m.duration, channel_vertical_order=self.channel_vertical_order, ) self._add_view(gui, v) v.actions.separator() @v.actions.add(shortcut='alt+pgdown') def go_to_next_spike(): """Jump to the next spike from the first selected cluster.""" self._jump_to_spike(v, +1) @v.actions.add(shortcut='alt+pgup') def go_to_previous_spike(): """Jump to the previous spike from the first selected cluster.""" self._jump_to_spike(v, -1) v.actions.separator() @v.actions.add(shortcut='alt+s') def toggle_highlighted_spikes(): """Toggle between showing all spikes or selected spikes.""" self._show_all_spikes = not self._show_all_spikes v.set_interval(force_update=True) @gui.connect_ def on_spike_click(channel_id=None, spike_id=None, cluster_id=None): # Select the corresponding cluster. self.supervisor.select([cluster_id]) # Update the trace view. v.on_select([cluster_id], force_update=True) return v # Correlograms # ------------------------------------------------------------------------- def _get_correlograms(self, cluster_ids, bin_size, window_size): spike_ids = self.selector.select_spikes(cluster_ids, 100000) st = self.model.spike_times[spike_ids] sc = self.supervisor.clustering.spike_clusters[spike_ids] return correlograms( st, sc, sample_rate=self.model.sample_rate, cluster_ids=cluster_ids, bin_size=bin_size, window_size=window_size, ) def add_correlogram_view(self, gui): m = self.model v = CorrelogramView( correlograms=self._get_correlograms, sample_rate=m.sample_rate, ) return self._add_view(gui, v) # GUI # ------------------------------------------------------------------------- def create_gui(self, **kwargs): gui = GUI(name=self.gui_name, subtitle=self.model.kwik_path, config_dir=self.config_dir, **kwargs) self.supervisor.attach(gui) self.add_waveform_view(gui) if self.model.traces is not None: self.add_trace_view(gui) self.add_feature_view(gui) self.add_correlogram_view(gui) # Save the memcache when closing the GUI. @gui.connect_ def on_close(): self.context.save_memcache() self.emit('gui_ready', gui) return gui
class NeoController(EventEmitter): gui_name = 'NeoGUI' n_spikes_waveforms = 200 batch_size_waveforms = 200 n_spikes_features = 10000 n_spikes_amplitudes = 10000 n_spikes_correlograms = 100000 def __init__(self, data_path, config_dir=None, **kwargs): super(NeoController, self).__init__() self.model = NeoModel(data_path, **kwargs) self.distance_max = _get_distance_max(self.model.channel_positions) self.cache_dir = op.join(self.model.output_dir, '.phy') cg = kwargs.get('channel_group', None) cg = cg or 0 self.cache_dir = op.join(self.cache_dir, 'channel_group_' + str(cg)) self.context = Context(self.cache_dir) self.config_dir = config_dir self._set_cache() self.supervisor = self._set_supervisor() self.selector = self._set_selector() self.color_selector = ColorSelector() self._show_all_spikes = False attach_plugins(self, plugins=kwargs.get('plugins', None), config_dir=config_dir) # Internal methods # ------------------------------------------------------------------------- def _set_cache(self): memcached = ( 'get_best_channels', #'get_probe_depth', '_get_mean_waveforms', ) cached = ( '_get_waveforms', '_get_features', ) _cache_methods(self, memcached, cached) def _set_supervisor(self): # Load the new cluster id. new_cluster_id = self.context.load('new_cluster_id'). \ get('new_cluster_id', None) cluster_groups = self.model.cluster_groups supervisor = Supervisor( self.model.spike_clusters, similarity=self.similarity, cluster_groups=cluster_groups, new_cluster_id=new_cluster_id, context=self.context, ) @supervisor.connect def on_create_cluster_views(): supervisor.add_column(self.get_best_channel, name='channel') supervisor.add_column(self.get_probe_depth, name='depth') @supervisor.actions.add def recluster(): """Relaunch KlustaKwik on the selected clusters.""" # Selected clusters. cluster_ids = supervisor.selected # TODO can you have multiselect here? spike_ids = self.selector.select_spikes(cluster_ids) logger.info("Running KlustaKwik on %d spikes.", len(spike_ids)) print('***********************') print('Fix this wierd fix, cant send in list (cluster_ids)') channel_ids = self.get_best_channels( cluster_ids[0] ) # TODO sending several cluster_ids to get best channels ? spike_clusters = self.model.cluster(spike_ids, channel_ids) self.supervisor.split(spike_ids, spike_clusters) # Save. @supervisor.connect def on_request_save(spike_clusters, groups, *labels): """Save the modified data.""" # Save the clusters. groups = {c: g.title() for c, g in groups.items()} self.model.save(spike_clusters, groups, *labels) return supervisor def _set_selector(self): def spikes_per_cluster(cluster_id): return self.supervisor.clustering.spikes_per_cluster[cluster_id] return Selector(spikes_per_cluster) def _add_view(self, gui, view): view.attach(gui) self.emit('add_view', gui, view) return view # Model methods # ------------------------------------------------------------------------- def get_best_channel(self, cluster_id): channel_ids = self.get_best_channels(cluster_id) amps = self._get_mean_waveforms(cluster_id).data[0].min(axis=0) channel_id = channel_ids[np.argmin(amps)] return channel_id def get_best_channels(self, cluster_ids): # TODO return np.arange(self.model.n_chans) def get_cluster_position(self, cluster_id): channel_id = self.get_best_channel(cluster_id) return self.model.channel_positions[channel_id] def get_probe_depth(self, cluster_id): channel_id = self.get_best_channel(cluster_id) return self.model.channel_positions[channel_id][1] def similarity(self, cluster_id): """Return the list of similar clusters to a given cluster.""" pos_i = self.get_cluster_position(cluster_id) def _sim_ij(cj): """Distance between channel position of clusters i and j.""" pos_j = self.get_cluster_position(cj) d = np.sqrt(np.sum((pos_j - pos_i)**2)) return self.distance_max - d out = [(cj, _sim_ij(cj)) for cj in self.supervisor.clustering.cluster_ids] return sorted(out, key=itemgetter(1), reverse=True) # Waveforms # ------------------------------------------------------------------------- def _get_waveforms(self, cluster_id): """Return a selection of waveforms for a cluster.""" pos = self.model.channel_positions spike_ids = self.selector.select_spikes( [cluster_id], self.n_spikes_waveforms, self.batch_size_waveforms, ) channel_ids = self.get_best_channels(cluster_id) data = self.model.get_waveforms(spike_ids, channel_ids) return Bunch(data=data, channel_ids=channel_ids, channel_positions=pos[channel_ids], alpha=0.25) def _get_mean_waveforms(self, cluster_id): b = self._get_waveforms(cluster_id) b.data = b.data.mean(axis=0)[np.newaxis, ...] b['alpha'] = 1. return b def add_waveform_view(self, gui): v = WaveformView(waveforms=self._get_waveforms, ) v = self._add_view(gui, v) v.actions.separator() @v.actions.add(shortcut='m') def toggle_mean_waveforms(): f, g = self._get_waveforms, self._get_mean_waveforms v.waveforms = f if v.waveforms == g else g v.on_select() return v # Features # ------------------------------------------------------------------------- def _get_spike_ids(self, cluster_id=None, load_all=None): nsf = self.n_spikes_features if cluster_id is None: # Background points. ns = self.model.n_spikes return np.arange(0, ns, max(1, ns // nsf)) else: # Load all spikes from the cluster if load_all is True. n = nsf if not load_all else None return self.selector.select_spikes([cluster_id], n) def _get_spike_times(self, cluster_id=None, load_all=None): spike_ids = self._get_spike_ids(cluster_id) return Bunch(data=self.model.spike_times[spike_ids], lim=(0., self.model.duration)) def _get_features(self, cluster_id=None, channel_ids=None, load_all=None): spike_ids = self._get_spike_ids(cluster_id, load_all=load_all) # Use the best channels only if a cluster is specified and # channels are not specified. if cluster_id is not None and channel_ids is None: channel_ids = self.get_best_channels(cluster_id) f = self.model.features[spike_ids][:, channel_ids] m = self.model.masks[spike_ids][:, channel_ids] return Bunch( data=f, masks=m, spike_ids=spike_ids, channel_ids=channel_ids, ) def add_feature_view(self, gui): v = FeatureView(features=self._get_features, attributes={'time': self._get_spike_times}) return self._add_view(gui, v) # Correlograms # ------------------------------------------------------------------------- def _get_correlograms(self, cluster_ids, bin_size, window_size): spike_ids = self.selector.select_spikes( cluster_ids, self.n_spikes_correlograms, subset='random', ) st = self.model.spike_times[spike_ids] sc = self.supervisor.clustering.spike_clusters[spike_ids] return correlograms( st, sc, sample_rate=self.model.sample_rate, cluster_ids=cluster_ids, bin_size=bin_size, window_size=window_size, ) def add_correlogram_view(self, gui): m = self.model v = CorrelogramView( correlograms=self._get_correlograms, sample_rate=m.sample_rate, ) return self._add_view(gui, v) # Amplitudes # ------------------------------------------------------------------------- def _get_amplitudes(self, cluster_id): n = self.n_spikes_amplitudes m = self.model spike_ids = self.selector.select_spikes([cluster_id], n) channel_id = self.get_best_channel(cluster_id) x = m.spike_times[spike_ids] y = m.amplitudes[spike_ids, channel_id] return Bunch(x=x, y=y, data_bounds=(0., y.min(), m.duration, y.max())) def add_amplitude_view(self, gui): v = AmplitudeView(coords=self._get_amplitudes, ) return self._add_view(gui, v) # GUI # ------------------------------------------------------------------------- def create_gui(self, **kwargs): gui = GUI(name=self.gui_name, subtitle=self.model.data_path, config_dir=self.config_dir, **kwargs) self.supervisor.attach(gui) self.add_waveform_view(gui) if self.model.features is not None: self.add_feature_view(gui) self.add_correlogram_view(gui) if self.model.amplitudes is not None: self.add_amplitude_view(gui) # Save the memcache when closing the GUI. @gui.connect_ def on_close(): self.context.save_memcache() self.emit('gui_ready', gui) return gui