def __init__(self, kwik_path, config_dir=None, **kwargs): super(KwikController, self).__init__() kwik_path = op.realpath(kwik_path) _backup(kwik_path) self.model = KwikModel(kwik_path, **kwargs) m = self.model self.channel_vertical_order = np.argsort(m.channel_positions[:, 1]) self.distance_max = _get_distance_max(self.model.channel_positions) self.cache_dir = op.join(op.dirname(kwik_path), '.phy') cg = kwargs.get('channel_group', None) if cg is not None: self.cache_dir = op.join(self.cache_dir, str(cg)) self.context = Context(self.cache_dir) self.config_dir = config_dir self._set_cache() self.supervisor = self._set_supervisor() self.selector = self._set_selector() self.color_selector = ColorSelector() self._show_all_spikes = False attach_plugins(self, plugins=kwargs.get('plugins', None), config_dir=config_dir)
def set_kwik_file(self, kpath): """! @brief Defines the corresponding kwik file Author: Nivaldo A P de Vasconcelos Date: 2018.Feb.02 """ self.kwik_model = KwikModel(kpath) self.name = self.kwik_model.name self.kpath = kpath
def __init__(self, kpath=None, name=None): self.kwik_model = None self.name = name self.kpath = None if (kpath is not None): self.kwik_model = KwikModel(kpath) self.kpath = kpath if (name is None): self.name = self.kwik_model.name print("Created class on = %s !" % kpath) else: print("It still with no path:(")
def __init__(self, path, channel_group=None, clustering=None): path = op.realpath(op.expanduser(path)) _backup(path) self.path = path self.cache_dir = op.join(op.dirname(path), '.phy', str(clustering or 'main'), str(channel_group or 'default'), ) self.model = KwikModel(path, channel_group=channel_group, clustering=None, ) super(KwikController, self).__init__()
def describe(path, channel_group=0, clustering='main'): """Describe a Kwik file.""" KwikModel( path, channel_group=channel_group, clustering=clustering, ).describe()
def autoAnalysis(kwik_fullpath): model = KwikModel(kwik_fullpath) kwikPath,kwikFile = os.path.split(kwik_fullpath) allsamples = [] allspikes = [] for j in model.channel_groups: # for statement walks through different channel_groups (i.e., those on different shanks) model.channel_group = j samples = model.spike_samples spikes = model.spike_clusters for i,n in enumerate(spikes): allsamples.append(samples[i]) allspikes.append(spikes[i]) alltimes = np.array(allsamples)/model.sample_rate allspikes = np.array(allspikes) #plt.figure(figsize=(15,5)) #plt.plot(alltimes, allspikes, '|',mew=.5,color=[.5,.5,.5]) #plt.savefig(kwikPath+'\\allSpikes.png') [di, ai] = liic.load_intan_input_channels() intan_trigger = di['0'][:] intan_camera = di['1'][:] intan_transitions = np.where(intan_trigger[:-1] != intan_trigger[1:])[0] adjustmentAmount = (intan_transitions[0]) [matlab_trigger, matlab_triggerCMD, matlab_brush] = loadMatlabFile(adjustmentAmount) plt.figure(figsize=(15,10)) plt.subplot(211) plt.plot(alltimes, allspikes, '|',color =[.5,.5,.5],markersize=5,mew=.4) plt.ylabel('unit') topxlim = plt.xlim() plt.xlim(topxlim) plt.subplot(212) plt.plot(np.arange(0,len(matlab_brush)/20000,1/20000),matlab_brush,color = [.5,.5,.5],linewidth=.5) plt.savefig(kwikPath+'\\brushSpikes.png')
def __init__(self, path, channel_group=None, clustering=None): path = op.realpath(op.expanduser(path)) _backup(path) self.path = path # The cache directory depends on the filename, clustering, and # channel_group to avoid conflicts. self.cache_dir = op.join(op.dirname(path), '.phy', op.splitext(op.basename(path))[0], str(clustering or 'main'), str(channel_group or 0), ) self.model = KwikModel(path, channel_group=channel_group, clustering=None, ) super(KwikController, self).__init__()
class KwikFile: """! @brief Model for Kwik file, strongly based on KwikModel from phy project The main purpose of this class is provide an abstraction for kwik files provided by phy project. The current version contains a basic set of fundamental methods used in kwik file @author: Nivaldo A P de Vasconcelos @date: 2018.Feb.02 """ #get_path def __init__(self, kpath=None, name=None): self.kwik_model = None self.name = name self.kpath = None if (kpath is not None): self.kwik_model = KwikModel(kpath) self.kpath = kpath if (name is None): self.name = self.kwik_model.name print("Created class on = %s !" % kpath) else: print("It still with no path:(") def get_name(self): """! @brief Returns the found in name field in kwik file. Author: Nivaldo A P de Vasconcelos Date: 2018.Feb.02 """ return (self.name) def set_kwik_file(self, kpath): """! @brief Defines the corresponding kwik file Author: Nivaldo A P de Vasconcelos Date: 2018.Feb.02 """ self.kwik_model = KwikModel(kpath) self.name = self.kwik_model.name self.kpath = kpath def sampling_rate(self): """! @brief Returns the sampling rate used during the recordings Author: Nivaldo A P de Vasconcelos Date: 2018.Feb.02 """ return (self.kwik_model.sample_rate) def shank(self): """! @brief Returns the shank/population's id used to group the recordings. Author: Nivaldo A P de Vasconcelos Date: 2018.Feb.02 """ return (self.kwik_model.name) def get_spike_samples(self): """! @brief Returns the spike's samples on the recordings. Author: Nivaldo A P de Vasconcelos Date: 2018.Feb.02 """ return (self.kwik_model.spike_samples) def get_spike_clusters(self): """! @brief Returns the corresponding spike's clusters on the recordings. Author: Nivaldo A P de Vasconcelos Date: 2018.Feb.02 """ return (self.kwik_model.spike_clusters) def describe(self): """! @brief Describes the kwik file It calls the describe method in KwikModel Author: Nivaldo A P de Vasconcelos Date: 2018.Feb.02 """ self.kwik_model.describe() def close(self): """! @brief Closes the corresponding kwik model It calls the close method in KwikModel Author: Nivaldo A P de Vasconcelos Date: 2018.Feb.02 """ self.kwik_model.close() def list_of_groups(self): """! @brief Returns the list of groups found in kwik file The result has a list's form. Author: Nivaldo A P de Vasconcelos Date: 2018.Feb.02 """ lgroups = list(self.groups().values()) lgroups = list(set(lgroups)) return (lgroups) def list_of_non_noisy_groups(self): """! @brief Returns the list of groups found in kwik file which are not called noise The result has a list's form. Author: Nivaldo A P de Vasconcelos Date: 2018.Feb.02 """ lgroups = list(self.groups().values()) lgroups = list(set(lgroups) - set([ 'noise', ])) return (lgroups) def all_clusters(self): """! @brief Returns the list of all clusters in kwik file The result has a list's form. Author: Nivaldo A P de Vasconcelos Date: 2018.Feb.02 """ llabels = list(self.groups().keys()) llabels = list(set(llabels)) return (llabels) def groups(self): """! @brief Returns a dict with cluster label and its respective group Author: Nivaldo A P de Vasconcelos Date: 2018.Feb.02 """ if not (isinstance(self.kwik_model, KwikModel)): raise ValueError("There is no KwikModel assigned for this object.") return (self.kwik_model.cluster_groups) def clusters(self, group_name=None): """! @brief Returns the list of clusters on kwik file It can be used to get the list of clusters for a given group by pproviding this information the group_name. Author: Nivaldo A P de Vasconcelos Date: 2018.Feb.02 """ if (group_name is None): return (self.all_clusters()) if not (group_name in self.list_of_groups()): raise ValueError("\nThis group was not found in kwik file: %s\n" % group_name) group = self.groups() clusters = [] for c in self.all_clusters(): if (group[c] == group_name): clusters.append(c) clusters.sort() return (clusters) def all_spikes_on_groups(self, group_names): """! @brief Returns the all spike samples within a list of groups Usually the clusters are organized in groups. Ex: noise, mua, sua, unsorted This method returns, in a single list of spike samples, all spikes found in a lists of groups (group_names). Parameters: group_names: list of group names, where the spikes will be searched. Author: Nivaldo A P de Vasconcelos Date: 2018.Feb.02 """ spikes = [] all_spikes = self.get_spike_samples() all_labels = self.get_spike_clusters() if not (isinstance(group_names, list)): raise ValueError("\nThe argument must be a list.") for group_name in group_names: if not (group_name in self.list_of_groups()): raise ValueError( "\nThis group was not found in kwik file: %s\n" % group_name) for c in self.clusters(group_name=group_name): spikes = spikes + list(all_spikes[all_labels == c]) spikes.sort() return (spikes) def all_spike_id_on_groups(self, group_names): """! @brief Returns the all spike id within a list of groups Usually the clusters are organized in groups. Ex: noise, mua, sua, unsorted This method returns, in a single list of spike samples, all spikes found in a lists of groups (group_names). Parameters: group_names: list of group names, where the spike ids will be searched. Author: Nivaldo A P de Vasconcelos Date: 2018.Jun.05 """ spk_id = [] all_spk_id = self.kwik_model.spike_ids all_labels = self.get_spike_clusters() if not (isinstance(group_names, list)): raise ValueError("\nThe argument must be a list.") for group_name in group_names: if not (group_name in self.list_of_groups()): raise ValueError( "\nThis group was not found in kwik file: %s\n" % group_name) for c in self.clusters(group_name=group_name): spk_id = spk_id + list(all_spk_id[all_labels == c]) spk_id.sort() return (spk_id) def all_spike_id_on_cluster(self, cluster_id): if not (cluster_id in self.all_clusters()): raise ValueError( "\nThis cluster was not found in kwik file: %s\n" % cluster_id) all_spk_id = self.kwik_model.spike_ids all_labels = self.get_spike_clusters() spk_id = list(all_spk_id[all_labels == cluster_id]) spk_id.sort() return (spk_id) def group_firing_rate(self, group_names=None, a=None, b=None): """! @brief Returns firing rate in a given set of groups found in kwik file. Usually, the clusters are organized in groups. Ex: noise, mua, sua, unsorted. This method returns, in a doubled dictionary, the firing rate for each cluster, organized by groups. Parameters: group_names: list of group names, where the spikes will be searched. When this input is 'None' all groups are taken. The resulting dictionary has the first keys as groups, and the second keys as the respective cluster id's, whereas the value, is the corresponding firing rate within [a,b]. Please refer to the method cluster_firing_rate in order to get more details about the firing calculation. Author: Nivaldo A P de Vasconcelos Date: 2018.Feb.02 """ if not (isinstance(group_names, list)) and not (group_names is None): raise ValueError("\nThe argument must be a list or a None.") spk = dict() if group_names is None: group_names = self.list_of_non_noisy_groups() for group_name in group_names: if not (group_name in self.list_of_groups()): raise ValueError( "\nThis group was not found in kwik file: %s\n" % group_name) spk[group_name] = dict() for c in self.clusters(group_name=group_name): spk[group_name][c] = self.cluster_firing_rate(c, a=a, b=b) return (spk) def cluster_firing_rate(self, cluster_id, a=None, b=None): """! @brief Returns firing rate in a given cluster_id found in the kwik file In the kwik file, a cluster stores the spike times sorted for a given neuronal unit. The firing rate here is calculated by dividing the number of spike times by the number of seconds of the time period definedd by [a,b]. If a is 'None' a is assingned to zero; if b is 'None', it is assigned to the time of the last spike within the cluster. Parameters: cluster_id: id which identifies the cluster. a,b: limits of the time period where the firing rate must be calculated. Author: Nivaldo A P de Vasconcelos Date: 2018.Feb.02 """ sr = self.sampling_rate() spikes = np.array(self.spikes_on_cluster(cluster_id)) / sr if a is None: a = 0 if b is None: b = spikes[-1] if (a == b): raise ValueError("\nThe limits of the time interval are equal\n") piece = spikes[(spikes >= a)] piece = piece[piece <= b] return (len(piece) / (b - a)) def spikes_on_cluster(self, cluster_id): """! @brief Returns the all spike samples within a single cluster Parameters: cluster_id: id used to indentify the cluster. Author: Nivaldo A P de Vasconcelos Date: 2018.Feb.02 """ if not (cluster_id in self.all_clusters()): raise ValueError( "\nThis cluster was not found in kwik file: %s\n" % cluster_id) all_spikes = self.get_spike_samples() all_labels = self.get_spike_clusters() spikes = list(all_spikes[all_labels == cluster_id]) spikes.sort() return (spikes) def group_firing_rate_to_dataframe(self, group_names=None, a=None, b=None): """! @brief Exports the group's firing rate into a pandas dataframe Usually, the clusters are organized in groups. Ex: noise, mua, sua, unsorted. This method returns, in a pandas dataframe, which contains the following information for each unit: 'shank', 'group', 'label', and 'fr'; Parameters: group_names: list of group names, where the spikes will be searched. When this input is 'None' all groups are taken. The resulting dictionary has the first keys as groups, and the second keys as the respective cluster id's, whereas the value, is the corresponding firing rate within [a,b]. Please refer to the method cluster_firing_rate in order to get more details about the firing calculation. Author: Nivaldo A P de Vasconcelos Date: 2018.Feb.02 """ d = self.group_firing_rate(group_names=group_names, a=a, b=b) shank_id = self.name group_names = d.keys() data = [] for group_name in group_names: for label in d[group_name].keys(): fr = d[group_name][label] data.append({ "shank_id": shank_id, "group": group_name, "label": label, "fr": fr }) return (pd.DataFrame(data))
class KwikController(Controller): gui_name = 'KwikGUI' def __init__(self, path, channel_group=None, clustering=None): path = op.realpath(op.expanduser(path)) _backup(path) self.path = path self.cache_dir = op.join( op.dirname(path), '.phy', str(clustering or 'main'), str(channel_group or 'default'), ) self.model = KwikModel( path, channel_group=channel_group, clustering=None, ) super(KwikController, self).__init__() def _init_data(self): m = self.model self.spike_times = m.spike_times self.spike_clusters = m.spike_clusters self.cluster_groups = m.cluster_groups self.cluster_ids = m.cluster_ids self.channel_positions = m.channel_positions self.n_samples_waveforms = m.n_samples_waveforms self.n_channels = m.n_channels self.n_features_per_channel = m.n_features_per_channel self.sample_rate = m.sample_rate self.duration = m.duration self.all_masks = m.all_masks self.all_waveforms = m.all_waveforms self.all_features = m.all_features # WARNING: m.all_traces contains the dead channels, m.traces doesn't. # Also, m.traces has the reordered channels as per the prb file. self.all_traces = m.traces def create_gui(self, config_dir=None): """Create the kwik GUI.""" f = super(KwikController, self).create_gui gui = f( name=self.gui_name, subtitle=self.path, config_dir=config_dir, ) @self.manual_clustering.actions.add def recluster(): """Relaunch KlustaKwik on the selected clusters.""" # Selected clusters. cluster_ids = self.manual_clustering.selected spike_ids = self.selector.select_spikes(cluster_ids) logger.info("Running KlustaKwik on %d spikes.", len(spike_ids)) spike_clusters, metadata = cluster( self.model, spike_ids, num_starting_clusters=10, ) self.manual_clustering.split(spike_ids, spike_clusters) # Save. @gui.connect_ def on_request_save(spike_clusters, groups): groups = {c: g.title() for c, g in groups.items()} self.model.save(spike_clusters, groups) return gui
def kwik_describe(path, channel_group=None, clustering=None): """Describe a template dataset.""" assert path KwikModel(path, channel_group=channel_group, clustering=clustering).describe()
class KwikController(Controller): gui_name = 'KwikGUI' def __init__(self, path, channel_group=None, clustering=None): path = op.realpath(op.expanduser(path)) _backup(path) self.path = path # The cache directory depends on the filename, clustering, and # channel_group to avoid conflicts. self.cache_dir = op.join(op.dirname(path), '.phy', op.splitext(op.basename(path))[0], str(clustering or 'main'), str(channel_group or 0), ) self.model = KwikModel(path, channel_group=channel_group, clustering=None, ) super(KwikController, self).__init__() def _init_data(self): m = self.model self.spike_times = m.spike_times self.spike_clusters = m.spike_clusters self.cluster_groups = m.cluster_groups self.cluster_ids = m.cluster_ids self.channel_positions = m.channel_positions self.channel_order = m.channel_order self.n_samples_waveforms = m.n_samples_waveforms self.n_channels = m.n_channels self.n_features_per_channel = m.n_features_per_channel self.sample_rate = m.sample_rate self.duration = m.duration self.all_masks = m.all_masks self.all_waveforms = m.all_waveforms self.all_features = m.all_features # WARNING: m.all_traces contains the dead channels, m.traces doesn't. # Also, m.traces has the reordered channels as per the prb file. self.all_traces = m.traces def create_gui(self, config_dir=None): """Create the kwik GUI.""" f = super(KwikController, self).create_gui gui = f(name=self.gui_name, subtitle=self.path, config_dir=config_dir, ) @self.manual_clustering.actions.add def recluster(): """Relaunch KlustaKwik on the selected clusters.""" # Selected clusters. cluster_ids = self.manual_clustering.selected spike_ids = self.selector.select_spikes(cluster_ids) logger.info("Running KlustaKwik on %d spikes.", len(spike_ids)) # Run KK2 in a temporary directory to avoid side effects. with TemporaryDirectory() as tempdir: spike_clusters, metadata = cluster(self.model, spike_ids, num_starting_clusters=10, tempdir=tempdir, ) self.manual_clustering.split(spike_ids, spike_clusters) # Save. @gui.connect_ def on_request_save(spike_clusters, groups): groups = {c: g.title() for c, g in groups.items()} self.model.save(spike_clusters, groups) return gui
class KwikController(EventEmitter): gui_name = 'KwikGUI' n_spikes_waveforms = 100 batch_size_waveforms = 10 n_spikes_features = 10000 n_spikes_amplitudes = 10000 n_spikes_close_clusters = 100 n_closest_channels = 16 def __init__(self, kwik_path, config_dir=None, **kwargs): super(KwikController, self).__init__() kwik_path = op.realpath(kwik_path) _backup(kwik_path) self.model = KwikModel(kwik_path, **kwargs) m = self.model self.channel_vertical_order = np.argsort(m.channel_positions[:, 1]) self.distance_max = _get_distance_max(self.model.channel_positions) self.cache_dir = op.join(op.dirname(kwik_path), '.phy') cg = kwargs.get('channel_group', None) if cg is not None: self.cache_dir = op.join(self.cache_dir, str(cg)) self.context = Context(self.cache_dir) self.config_dir = config_dir self._set_cache() self.supervisor = self._set_supervisor() self.selector = self._set_selector() self.color_selector = ColorSelector() self._show_all_spikes = False attach_plugins(self, plugins=kwargs.get('plugins', None), config_dir=config_dir) # Internal methods # ------------------------------------------------------------------------- def _set_cache(self): memcached = ( 'get_best_channels', 'get_probe_depth', '_get_mean_masks', '_get_mean_waveforms', ) cached = ( '_get_waveforms', '_get_features', '_get_masks', ) _cache_methods(self, memcached, cached) def _set_supervisor(self): # Load the new cluster id. new_cluster_id = self.context.load('new_cluster_id'). \ get('new_cluster_id', None) cluster_groups = self.model.cluster_groups supervisor = Supervisor( self.model.spike_clusters, similarity=self.similarity, cluster_groups=cluster_groups, new_cluster_id=new_cluster_id, context=self.context, ) @supervisor.connect def on_create_cluster_views(): supervisor.add_column(self.get_best_channel, name='channel') supervisor.add_column(self.get_probe_depth, name='depth') @supervisor.actions.add def recluster(): """Relaunch KlustaKwik on the selected clusters.""" # Selected clusters. cluster_ids = supervisor.selected spike_ids = self.selector.select_spikes(cluster_ids) logger.info("Running KlustaKwik on %d spikes.", len(spike_ids)) # Run KK2 in a temporary directory to avoid side effects. n = 10 with TemporaryDirectory() as tempdir: spike_clusters, metadata = cluster( self.model, spike_ids, num_starting_clusters=n, tempdir=tempdir, ) self.supervisor.split(spike_ids, spike_clusters) # Save. @supervisor.connect def on_request_save(spike_clusters, groups, *labels): """Save the modified data.""" groups = {c: g.title() for c, g in groups.items()} self.model.save(spike_clusters, groups) return supervisor def _set_selector(self): def spikes_per_cluster(cluster_id): return self.supervisor.clustering.spikes_per_cluster[cluster_id] return Selector(spikes_per_cluster) def _add_view(self, gui, view): view.attach(gui) self.emit('add_view', gui, view) return view # Model methods # ------------------------------------------------------------------------- def get_best_channel(self, cluster_id): return self.get_best_channels(cluster_id)[0] def get_best_channels(self, cluster_id): """Only used in the trace view.""" mm = self._get_mean_masks(cluster_id) channel_ids = np.argsort(mm)[::-1] ind = mm[channel_ids] > .1 if np.sum(ind) > 0: channel_ids = channel_ids[ind] else: channel_ids = channel_ids[:4] return channel_ids def get_cluster_position(self, cluster_id): channel_id = self.get_best_channel(cluster_id) return self.model.channel_positions[channel_id] def get_probe_depth(self, cluster_id): return self.get_cluster_position(cluster_id)[1] def similarity(self, cluster_id): """Return the list of similar clusters to a given cluster.""" pos_i = self.get_cluster_position(cluster_id) assert len(pos_i) == 2 def _sim_ij(cj): """Distance between channel position of clusters i and j.""" pos_j = self.get_cluster_position(cj) assert len(pos_j) == 2 d = np.sqrt(np.sum((pos_j - pos_i)**2)) return self.distance_max - d out = [(cj, _sim_ij(cj)) for cj in self.supervisor.clustering.cluster_ids] return sorted(out, key=itemgetter(1), reverse=True) # Waveforms # ------------------------------------------------------------------------- def _get_masks(self, cluster_id): spike_ids = self.selector.select_spikes( [cluster_id], self.n_spikes_waveforms, self.batch_size_waveforms, ) return self.model.all_masks[spike_ids] def _get_mean_masks(self, cluster_id): return np.mean(self._get_masks(cluster_id), axis=0) def _get_waveforms(self, cluster_id): """Return a selection of waveforms for a cluster.""" pos = self.model.channel_positions spike_ids = self.selector.select_spikes( [cluster_id], self.n_spikes_waveforms, self.batch_size_waveforms, ) data = self.model.all_waveforms[spike_ids] mm = self._get_mean_masks(cluster_id) mw = np.mean(data, axis=0) amp = get_waveform_amplitude(mm, mw) masks = self._get_masks(cluster_id) # Find the best channels. channel_ids = np.argsort(amp)[::-1] return Bunch( data=data[..., channel_ids], channel_ids=channel_ids, channel_positions=pos[channel_ids], masks=masks[:, channel_ids], ) def _get_mean_waveforms(self, cluster_id): b = self._get_waveforms(cluster_id).copy() b.data = np.mean(b.data, axis=0)[np.newaxis, ...] b.masks = np.mean(b.masks, axis=0)[np.newaxis, ...]**.1 b['alpha'] = 1. return b def add_waveform_view(self, gui): v = WaveformView(waveforms=self._get_waveforms, ) v = self._add_view(gui, v) v.actions.separator() @v.actions.add(shortcut='m') def toggle_mean_waveforms(): f, g = self._get_waveforms, self._get_mean_waveforms v.waveforms = f if v.waveforms == g else g v.on_select() return v # Features # ------------------------------------------------------------------------- def _get_spike_ids(self, cluster_id=None, load_all=None): nsf = self.n_spikes_features if cluster_id is None: # Background points. ns = self.model.n_spikes return np.arange(0, ns, max(1, ns // nsf)) else: # Load all spikes from the cluster if load_all is True. n = nsf if not load_all else None return self.selector.select_spikes([cluster_id], n) def _get_spike_times(self, cluster_id=None, load_all=None): spike_ids = self._get_spike_ids(cluster_id, load_all=load_all) return Bunch(data=self.model.spike_times[spike_ids], lim=(0., self.model.duration)) def _get_features(self, cluster_id=None, channel_ids=None, load_all=None): spike_ids = self._get_spike_ids(cluster_id, load_all=load_all) # Use the best channels only if a cluster is specified and # channels are not specified. if cluster_id is not None and channel_ids is None: channel_ids = self.get_best_channels(cluster_id) f = self.model.all_features[spike_ids][:, channel_ids] m = self.model.all_masks[spike_ids][:, channel_ids] return Bunch( data=f, masks=m, spike_ids=spike_ids, channel_ids=channel_ids, ) def add_feature_view(self, gui): v = FeatureView(features=self._get_features, attributes={'time': self._get_spike_times}) return self._add_view(gui, v) # Traces # ------------------------------------------------------------------------- def _get_traces(self, interval): """Get traces and spike waveforms.""" ns = self.model.n_samples_waveforms m = self.model c = self.channel_vertical_order traces_interval = select_traces(m.traces, interval, sample_rate=m.sample_rate) # Reorder vertically. traces_interval = traces_interval[:, c] def gbc(cluster_id): ch = self.get_best_channels(cluster_id) return ch out = Bunch(data=traces_interval) out.waveforms = [] for b in _iter_spike_waveforms( interval=interval, traces_interval=traces_interval, model=self.model, supervisor=self.supervisor, color_selector=self.color_selector, n_samples_waveforms=ns, get_best_channels=gbc, show_all_spikes=self._show_all_spikes, ): b.channel_labels = m.channel_order[b.channel_ids] out.waveforms.append(b) return out def _jump_to_spike(self, view, delta=+1): """Jump to next or previous spike from the selected clusters.""" m = self.model cluster_ids = self.supervisor.selected if len(cluster_ids) == 0: return spc = self.supervisor.clustering.spikes_per_cluster spike_ids = spc[cluster_ids[0]] spike_times = m.spike_times[spike_ids] ind = np.searchsorted(spike_times, view.time) n = len(spike_times) view.go_to(spike_times[(ind + delta) % n]) def add_trace_view(self, gui): m = self.model v = TraceView( traces=self._get_traces, n_channels=m.n_channels, sample_rate=m.sample_rate, duration=m.duration, channel_vertical_order=self.channel_vertical_order, ) self._add_view(gui, v) v.actions.separator() @v.actions.add(shortcut='alt+pgdown') def go_to_next_spike(): """Jump to the next spike from the first selected cluster.""" self._jump_to_spike(v, +1) @v.actions.add(shortcut='alt+pgup') def go_to_previous_spike(): """Jump to the previous spike from the first selected cluster.""" self._jump_to_spike(v, -1) v.actions.separator() @v.actions.add(shortcut='alt+s') def toggle_highlighted_spikes(): """Toggle between showing all spikes or selected spikes.""" self._show_all_spikes = not self._show_all_spikes v.set_interval(force_update=True) @gui.connect_ def on_spike_click(channel_id=None, spike_id=None, cluster_id=None): # Select the corresponding cluster. self.supervisor.select([cluster_id]) # Update the trace view. v.on_select([cluster_id], force_update=True) return v # Correlograms # ------------------------------------------------------------------------- def _get_correlograms(self, cluster_ids, bin_size, window_size): spike_ids = self.selector.select_spikes(cluster_ids, 100000) st = self.model.spike_times[spike_ids] sc = self.supervisor.clustering.spike_clusters[spike_ids] return correlograms( st, sc, sample_rate=self.model.sample_rate, cluster_ids=cluster_ids, bin_size=bin_size, window_size=window_size, ) def add_correlogram_view(self, gui): m = self.model v = CorrelogramView( correlograms=self._get_correlograms, sample_rate=m.sample_rate, ) return self._add_view(gui, v) # GUI # ------------------------------------------------------------------------- def create_gui(self, **kwargs): gui = GUI(name=self.gui_name, subtitle=self.model.kwik_path, config_dir=self.config_dir, **kwargs) self.supervisor.attach(gui) self.add_waveform_view(gui) if self.model.traces is not None: self.add_trace_view(gui) self.add_feature_view(gui) self.add_correlogram_view(gui) # Save the memcache when closing the GUI. @gui.connect_ def on_close(): self.context.save_memcache() self.emit('gui_ready', gui) return gui
class KwikController(Controller): gui_name = 'KwikGUI' def __init__(self, path, channel_group=None, clustering=None): path = op.realpath(op.expanduser(path)) _backup(path) self.path = path self.cache_dir = op.join(op.dirname(path), '.phy', str(clustering or 'main'), str(channel_group or 'default'), ) self.model = KwikModel(path, channel_group=channel_group, clustering=None, ) super(KwikController, self).__init__() def _init_data(self): m = self.model self.spike_times = m.spike_times self.spike_clusters = m.spike_clusters self.cluster_groups = m.cluster_groups self.cluster_ids = m.cluster_ids self.channel_positions = m.channel_positions self.n_samples_waveforms = m.n_samples_waveforms self.n_channels = m.n_channels self.n_features_per_channel = m.n_features_per_channel self.sample_rate = m.sample_rate self.duration = m.duration self.all_masks = m.all_masks self.all_waveforms = m.all_waveforms self.all_features = m.all_features self.all_traces = m.all_traces def create_gui(self, config_dir=None): """Create the kwik GUI.""" f = super(KwikController, self).create_gui gui = f(name=self.gui_name, subtitle=self.path, config_dir=config_dir, ) @self.manual_clustering.actions.add def recluster(): """Relaunch KlustaKwik on the selected clusters.""" # Selected clusters. cluster_ids = self.manual_clustering.selected spike_ids = self.selector.select_spikes(cluster_ids) logger.info("Running KlustaKwik on %d spikes.", len(spike_ids)) spike_clusters, metadata = cluster(self.model, spike_ids, num_starting_clusters=10, ) self.manual_clustering.split(spike_ids, spike_clusters) # Save. @gui.connect_ def on_request_save(spike_clusters, groups): groups = {c: g.title() for c, g in groups.items()} self.model.save(spike_clusters, groups) return gui