Example #1
0
    def __init__(self, kwik_path, config_dir=None, **kwargs):
        super(KwikController, self).__init__()
        kwik_path = op.realpath(kwik_path)
        _backup(kwik_path)
        self.model = KwikModel(kwik_path, **kwargs)
        m = self.model
        self.channel_vertical_order = np.argsort(m.channel_positions[:, 1])
        self.distance_max = _get_distance_max(self.model.channel_positions)
        self.cache_dir = op.join(op.dirname(kwik_path), '.phy')
        cg = kwargs.get('channel_group', None)
        if cg is not None:
            self.cache_dir = op.join(self.cache_dir, str(cg))
        self.context = Context(self.cache_dir)
        self.config_dir = config_dir

        self._set_cache()
        self.supervisor = self._set_supervisor()
        self.selector = self._set_selector()
        self.color_selector = ColorSelector()

        self._show_all_spikes = False

        attach_plugins(self,
                       plugins=kwargs.get('plugins', None),
                       config_dir=config_dir)
Example #2
0
    def __init__(self,
                 dat_path=None,
                 proc_path=None,
                 config_dir=None,
                 model=None,
                 **kwargs):
        super(TemplateController, self).__init__()
        if model is None:
            assert dat_path
            dat_path = op.abspath(dat_path)
            self.model = TemplateModel(dat_path=dat_path,
                                       proc_path=proc_path,
                                       **kwargs)
        else:
            self.model = model
        self.cache_dir = op.join(self.model.dir_path, '.phy')
        self.context = Context(self.cache_dir)
        self.config_dir = config_dir

        self._set_cache()
        self.supervisor = self._set_supervisor()
        self.selector = self._set_selector()
        self.color_selector = ColorSelector()

        self._show_all_spikes = False

        attach_plugins(self,
                       plugins=kwargs.get('plugins', None),
                       config_dir=config_dir)
Example #3
0
    def __init__(self, data_path, config_dir=None, **kwargs):
        super(NeoController, self).__init__()
        self.model = NeoModel(data_path, **kwargs)
        self.distance_max = _get_distance_max(self.model.channel_positions)
        self.cache_dir = op.join(self.model.output_dir, '.phy')
        cg = kwargs.get('channel_group', None)
        cg = cg or 0
        self.cache_dir = op.join(self.cache_dir, 'channel_group_' + str(cg))
        self.context = Context(self.cache_dir)
        self.config_dir = config_dir
        self._set_cache()
        self.supervisor = self._set_supervisor()
        self.selector = self._set_selector()
        self.color_selector = ColorSelector()

        self._show_all_spikes = False

        attach_plugins(self,
                       plugins=kwargs.get('plugins', None),
                       config_dir=config_dir)
Example #4
0
class TemplateController(EventEmitter):
    gui_name = 'TemplateGUI'

    n_spikes_waveforms = 100
    batch_size_waveforms = 10

    n_spikes_features = 10000
    n_spikes_amplitudes = 10000
    n_spikes_correlograms = 100000

    def __init__(self, dat_path=None, config_dir=None, model=None, **kwargs):
        super(TemplateController, self).__init__()
        if model is None:
            assert dat_path
            dat_path = op.abspath(dat_path)
            self.model = TemplateModel(dat_path, **kwargs)
        else:
            self.model = model
        self.cache_dir = op.join(self.model.dir_path, '.phy')
        self.context = Context(self.cache_dir)
        self.config_dir = config_dir

        self._set_cache()
        self.supervisor = self._set_supervisor()
        self.selector = self._set_selector()
        self.color_selector = ColorSelector()

        self._show_all_spikes = False

        attach_plugins(self, plugins=kwargs.get('plugins', None),
                       config_dir=config_dir)

    # Internal methods
    # -------------------------------------------------------------------------

    def _set_cache(self):
        memcached = ('get_template_counts',
                     'get_template_for_cluster',
                     'get_best_channel',
                     'get_best_channels',
                     'get_probe_depth',
                     'get_real_amplitude',
                     'get_violation_rate',
                     'get_noise_ratioz',
                     'get_ratio',
                     'get_firing_rate',
                     )
        cached = ('_get_waveforms',
                  '_get_template_waveforms',
                  '_get_features',
                  '_get_template_features',
                  '_get_amplitudes',
                  '_get_correlograms',
                  )
        _cache_methods(self, memcached, cached)

    def _set_supervisor(self):
        # Load the new cluster id.
        new_cluster_id = self.context.load('new_cluster_id'). \
            get('new_cluster_id', None)
        cluster_groups = self.model.get_metadata('group')
        supervisor = Supervisor(self.model.spike_clusters,
                                similarity=self.similarity,
                                cluster_groups=cluster_groups,
                                new_cluster_id=new_cluster_id,
                                context=self.context,
                                )

        # Load the non-group metadata from the model to the cluster_meta.
        for name in self.model.metadata_fields:
            if name == 'group':
                continue
            values = self.model.get_metadata(name)
            for cluster_id, value in values.items():
                supervisor.cluster_meta.set(name, [cluster_id], value,
                                            add_to_stack=False)

        @supervisor.connect
        def on_create_cluster_views():
            supervisor.add_column(self.get_best_channel, name='channel')
            supervisor.add_column(self.get_probe_depth, name='depth')
            supervisor.add_column(self.get_real_amplitude,name='amp')
            #supervisor.add_column(self.get_noise_estimate,name='N')
            supervisor.add_column(self.get_noise_ratioz,name='Amp.o.N')
            supervisor.add_column(self.get_ratio,name='Bin1.o.M')
            supervisor.add_column(self.get_firing_rate,name='FR')
            #supervisor.add_column(self.get_violation_rate, name='vioRate')

            @supervisor.actions.add(shortcut='shift+ctrl+k')
            def split_init(cluster_ids=None):
                """Split a cluster according to the original templates."""
                if cluster_ids is None:
                    cluster_ids = supervisor.selected
                s = supervisor.clustering.spikes_in_clusters(cluster_ids)
                supervisor.split(s, self.model.spike_templates[s])

        # Save.
        @supervisor.connect
        def on_request_save(spike_clusters, groups, *labels):
            """Save the modified data."""
            # Save the clusters.
            self.model.save_spike_clusters(spike_clusters)
            # Save cluster metadata.
            for name, values in labels:
                self.model.save_metadata(name, values)

        return supervisor

    def _set_selector(self):
        def spikes_per_cluster(cluster_id):
            return self.supervisor.clustering.spikes_per_cluster[cluster_id]
        return Selector(spikes_per_cluster)

    def _add_view(self, gui, view):
        view.attach(gui)
        self.emit('add_view', gui, view)
        return view

    # Model methods
    # -------------------------------------------------------------------------

    def get_template_counts(self, cluster_id):
        """Return a histogram of the number of spikes in each template for
        a given cluster."""
        spike_ids = self.supervisor.clustering.spikes_per_cluster[cluster_id]
        st = self.model.spike_templates[spike_ids]
        return np.bincount(st, minlength=self.model.n_templates)

    def get_template_for_cluster(self, cluster_id):
        """Return the largest template associated to a cluster."""
        spike_ids = self.supervisor.clustering.spikes_per_cluster[cluster_id]
        st = self.model.spike_templates[spike_ids]
        template_ids, counts = np.unique(st, return_counts=True)
        ind = np.argmax(counts)
        return template_ids[ind]

    def similarity(self, cluster_id):
        """Return the list of similar clusters to a given cluster."""
        # Templates of the cluster.
        temp_i = np.nonzero(self.get_template_counts(cluster_id))[0]
        # The similarity of the cluster with each template.
        sims = np.max(self.model.similar_templates[temp_i, :], axis=0)

        def _sim_ij(cj):
            # Templates of the cluster.
            if cj < self.model.n_templates:
                return float(sims[cj])
            temp_j = np.nonzero(self.get_template_counts(cj))[0]
            return float(np.max(sims[temp_j]))

        out = [(cj, _sim_ij(cj))
               for cj in self.supervisor.clustering.cluster_ids]
        # NOTE: hard-limit to 100 for performance reasons.
        return sorted(out, key=itemgetter(1), reverse=True)[:100]

    def get_best_channel(self, cluster_id):
        """Return the best channel of a given cluster."""
        template_id = self.get_template_for_cluster(cluster_id)
        return self.model.get_template(template_id).best_channel
    
    def get_real_amplitude(self,cluster_id):
        waveforms=self._get_mean_waveforms(cluster_id)
        amp=waveforms.data[0,:,0].max()-waveforms.data[0,:,0].min()
        return amp
    
    def get_violation_rate(self,cluster_id):
        thresh=0.002
        spt=self._get_spike_times(cluster_id=cluster_id, load_all=True)
        isi=np.diff(spt.data)
        violation=np.count_nonzero(isi<=thresh)
        try:
            rate=violation*1.0/len(isi)
        except:
            print('cannot calculate vio rate for cluster: ' + str(cluster_id))
            rate = float('nan')
        return rate
    
    def get_violation_rate_threshold(self,cluster_id):
        """based on a ref period of 0.002s, formula from Hill et al 2011 """
        ref_period=0.002
        fp_rate=0.05
        T=self.model.duration
        N=self.supervisor.n_spikes(cluster_id)
        R=2*ref_period*N*(1-fp_rate)*fp_rate/T
        return R
    
    def get_noise_estimate(self,cluster_id):
        """"""
        
        aa=self._get_waveforms(cluster_id)
        pre=aa.data[:,0:25,0] #bc best channel is always first;
        pre=pre.flatten()
        
        return pre.std()
    def get_noise_ratioz(self,cluster_id):
        ratio = self.get_real_amplitude(cluster_id)/self.get_noise_estimate(cluster_id)
        #str(f'{ratio:.1f}')
        return np.around(ratio,1)
    
    def get_firing_rate(self,cluster_id):
        T=self.model.duration
        N=self.supervisor.n_spikes(cluster_id)
        R = N/T
        return R
    
    def get_ratio(self,cluster_id):
        cgr=self._get_correlograms([cluster_id],0.001,0.05)
        vio_bin = range(26,27) 
        bin1=cgr[0,0,vio_bin].sum()
        ma=cgr.max()
        R=bin1/ma
        return R 

    def get_best_channels(self, cluster_id):
        """Return the best channels of a given cluster."""
        template_id = self.get_template_for_cluster(cluster_id)
        return self.model.get_template(template_id).channel_ids

    def get_probe_depth(self, cluster_id):
        """Return the depth of a cluster."""
        channel_id = self.get_best_channel(cluster_id)
        return self.model.channel_positions[channel_id][1]

    # Waveforms
    # -------------------------------------------------------------------------

    def _get_waveforms(self, cluster_id):
        """Return a selection of waveforms for a cluster."""
        pos = self.model.channel_positions
        spike_ids = self.selector.select_spikes([cluster_id],
                                                self.n_spikes_waveforms,
                                                self.batch_size_waveforms,
                                                )
        channel_ids = self.get_best_channels(cluster_id)
        data = self.model.get_waveforms(spike_ids, channel_ids)
        data = data - data.mean()
        return Bunch(data=data,
                     channel_ids=channel_ids,
                     channel_positions=pos[channel_ids],
                     )

    def _get_mean_waveforms(self, cluster_id):
        b = self._get_waveforms(cluster_id)
        b.data = b.data.mean(axis=0)[np.newaxis, ...]
        b['alpha'] = 1.
        return b

    def _get_template_waveforms(self, cluster_id):
        """Return the waveforms of the templates corresponding to a cluster."""
        pos = self.model.channel_positions
        count = self.get_template_counts(cluster_id)
        template_ids = np.nonzero(count)[0]
        count = count[template_ids]
        # Get local channels.
        channel_ids = self.get_best_channels(cluster_id)
        # Get masks.
        masks = count / float(count.max())
        masks = np.tile(masks.reshape((-1, 1)), (1, len(channel_ids)))
        # Get the mean amplitude for the cluster.
        mean_amp = self._get_amplitudes(cluster_id).y.mean()
        # Get all templates from which this cluster stems from.
        templates = [self.model.get_template(template_id)
                     for template_id in template_ids]
        data = np.stack([b.template * mean_amp for b in templates], axis=0)
        cols = np.stack([b.channel_ids for b in templates], axis=0)
        # NOTE: transposition because the channels should be in the second
        # dimension for from_sparse.
        data = data.transpose((0, 2, 1))
        assert data.ndim == 3
        assert data.shape[1] == cols.shape[1]
        waveforms = from_sparse(data, cols, channel_ids)
        # Transpose back.
        waveforms = waveforms.transpose((0, 2, 1))
        return Bunch(data=waveforms,
                     channel_ids=channel_ids,
                     channel_positions=pos[channel_ids],
                     masks=masks,
                     alpha=1.,
                     )

    def add_waveform_view(self, gui):
        f = (self._get_waveforms if self.model.traces is not None
             else self._get_template_waveforms)
        v = WaveformView(waveforms=f,
                         )
        v = self._add_view(gui, v)

        v.actions.separator()

        @v.actions.add(shortcut='w')
        def toggle_templates():
            f, g = self._get_waveforms, self._get_template_waveforms
            if self.model.traces is None:
                return
            v.waveforms = f if v.waveforms == g else g
            v.on_select()

        @v.actions.add(shortcut='m')
        def toggle_mean_waveforms():
            f, g = self._get_waveforms, self._get_mean_waveforms
            v.waveforms = f if v.waveforms == g else g
            v.on_select()

        return v

    # Features
    # -------------------------------------------------------------------------

    def _get_spike_ids(self, cluster_id=None, load_all=None):
        nsf = self.n_spikes_features
        if cluster_id is None:
            # Background points.
            ns = self.model.n_spikes
            spike_ids = np.arange(0, ns, max(1, ns // nsf))
        else:
            # Load all spikes from the cluster if load_all is True.
            n = nsf if not load_all else None
            spike_ids = self.selector.select_spikes([cluster_id], n)
        # Remove spike_ids that do not belong to model.features_rows
        if self.model.features_rows is not None:
            spike_ids = np.intersect1d(spike_ids, self.model.features_rows)
        return spike_ids

    def _get_spike_times(self, cluster_id=None, load_all=None):
        spike_ids = self._get_spike_ids(cluster_id, load_all=load_all)
        return Bunch(data=self.model.spike_times[spike_ids],
                     spike_ids=spike_ids,
                     lim=(0., self.model.duration))

    def _get_features(self, cluster_id=None, channel_ids=None, load_all=None):
        spike_ids = self._get_spike_ids(cluster_id, load_all=load_all)
        # Use the best channels only if a cluster is specified and
        # channels are not specified.
        if cluster_id is not None and channel_ids is None:
            channel_ids = self.get_best_channels(cluster_id)
        data = self.model.get_features(spike_ids, channel_ids)
        assert data.shape[:2] == (len(spike_ids), len(channel_ids))
        # Remove rows with at least one nan value.
        nan = np.unique(np.nonzero(np.isnan(data))[0])
        nonnan = np.setdiff1d(np.arange(len(spike_ids)), nan)
        data = data[nonnan, ...]
        spike_ids = spike_ids[nonnan]
        assert data.shape[:2] == (len(spike_ids), len(channel_ids))
        assert np.isnan(data).sum() == 0
        return Bunch(data=data,
                     spike_ids=spike_ids,
                     channel_ids=channel_ids,
                     )

    def add_feature_view(self, gui):
        v = FeatureView(features=self._get_features,
                        attributes={'time': self._get_spike_times}
                        )
        return self._add_view(gui, v)

    # Template features
    # -------------------------------------------------------------------------

    def _get_template_features(self, cluster_ids):
        assert len(cluster_ids) == 2
        clu0, clu1 = cluster_ids

        s0 = self._get_spike_ids(clu0)
        s1 = self._get_spike_ids(clu1)

        n0 = self.get_template_counts(clu0)
        n1 = self.get_template_counts(clu1)

        t0 = self.model.get_template_features(s0)
        t1 = self.model.get_template_features(s1)

        x0 = np.average(t0, weights=n0, axis=1)
        y0 = np.average(t0, weights=n1, axis=1)

        x1 = np.average(t1, weights=n0, axis=1)
        y1 = np.average(t1, weights=n1, axis=1)

        return Bunch(x0=x0, y0=y0, x1=x1, y1=y1,
                     data_bounds=(min(x0.min(), x1.min()),
                                  min(y0.min(), y1.min()),
                                  max(y0.max(), y1.max()),
                                  max(y0.max(), y1.max()),
                                  ),
                     )

    def add_template_feature_view(self, gui):
        v = TemplateFeatureView(coords=self._get_template_features,
                                )
        return self._add_view(gui, v)

    # Traces
    # -------------------------------------------------------------------------

    def _get_traces(self, interval):
        """Get traces and spike waveforms."""
        k = self.model.n_samples_templates
        m = self.model

        traces_interval = select_traces(m.traces, interval,
                                        sample_rate=m.sample_rate)
        # Reorder vertically.
        out = Bunch(data=traces_interval)
        out.waveforms = []

        def gbc(cluster_id):
            return self.get_best_channels(cluster_id)

        for b in _iter_spike_waveforms(interval=interval,
                                       traces_interval=traces_interval,
                                       model=self.model,
                                       supervisor=self.supervisor,
                                       color_selector=self.color_selector,
                                       n_samples_waveforms=k,
                                       get_best_channels=gbc,
                                       show_all_spikes=self._show_all_spikes,
                                       ):
            i = b.spike_id
            # Compute the residual: waveform - amplitude * template.
            residual = b.copy()
            template_id = m.spike_templates[i]
            template = m.get_template(template_id).template
            amplitude = m.amplitudes[i]
            residual.data = residual.data - amplitude * template
            out.waveforms.extend([b, residual])
        return out

    def _jump_to_spike(self, view, delta=+1):
        """Jump to next or previous spike from the selected clusters."""
        m = self.model
        cluster_ids = self.supervisor.selected
        if len(cluster_ids) == 0:
            return
        spc = self.supervisor.clustering.spikes_per_cluster
        spike_ids = spc[cluster_ids[0]]
        spike_times = m.spike_times[spike_ids]
        ind = np.searchsorted(spike_times, view.time)
        n = len(spike_times)
        view.go_to(spike_times[(ind + delta) % n])

    def add_trace_view(self, gui):
        m = self.model
        v = TraceView(traces=self._get_traces,
                      n_channels=m.n_channels,
                      sample_rate=m.sample_rate,
                      duration=m.duration,
                      channel_vertical_order=m.channel_vertical_order,
                      )
        self._add_view(gui, v)

        v.actions.separator()

        @v.actions.add(shortcut='alt+pgdown')
        def go_to_next_spike():
            """Jump to the next spike from the first selected cluster."""
            self._jump_to_spike(v, +1)

        @v.actions.add(shortcut='alt+pgup')
        def go_to_previous_spike():
            """Jump to the previous spike from the first selected cluster."""
            self._jump_to_spike(v, -1)

        v.actions.separator()

        @v.actions.add(shortcut='alt+s')
        def toggle_highlighted_spikes():
            """Toggle between showing all spikes or selected spikes."""
            self._show_all_spikes = not self._show_all_spikes
            v.set_interval(force_update=True)

        @gui.connect_
        def on_spike_click(channel_id=None, spike_id=None, cluster_id=None):
            # Select the corresponding cluster.
            self.supervisor.select([cluster_id])
            # Update the trace view.
            v.on_select([cluster_id], force_update=True)

        return v

    # Correlograms
    # -------------------------------------------------------------------------

    def _get_correlograms(self, cluster_ids, bin_size, window_size):
        spike_ids = self.selector.select_spikes(cluster_ids,
                                                self.n_spikes_correlograms,
                                                subset='random',
                                                )
        st = self.model.spike_times[spike_ids]
        sc = self.supervisor.clustering.spike_clusters[spike_ids]
        return correlograms(st,
                            sc,
                            sample_rate=self.model.sample_rate,
                            cluster_ids=cluster_ids,
                            bin_size=bin_size,
                            window_size=window_size,
                            )

    def add_correlogram_view(self, gui):
        m = self.model
        v = CorrelogramView(correlograms=self._get_correlograms,
                            sample_rate=m.sample_rate,
                            )
        return self._add_view(gui, v)

    # Amplitudes
    # -------------------------------------------------------------------------

    def _get_amplitudes(self, cluster_id):
        n = self.n_spikes_amplitudes
        m = self.model
        spike_ids = self.selector.select_spikes([cluster_id], n)
        x = m.spike_times[spike_ids]
        y = m.amplitudes[spike_ids]
        return Bunch(x=x, y=y, data_bounds=(0., 0., m.duration, y.max()))

    def add_amplitude_view(self, gui):
        v = AmplitudeView(coords=self._get_amplitudes,
                          )
        return self._add_view(gui, v)

    # Probe view
    # -------------------------------------------------------------------------

    def add_probe_view(self, gui):
        v = ProbeView(positions=self.model.channel_positions,
                      best_channels=self.get_best_channels,
                      )
        v.attach(gui)
        return v

    # GUI
    # -------------------------------------------------------------------------

    def create_gui(self, **kwargs):
        gui = GUI(name=self.gui_name,
                  subtitle=self.model.dat_path,
                  config_dir=self.config_dir,
                  **kwargs)

        self.supervisor.attach(gui)

        self.add_waveform_view(gui)
        if self.model.traces is not None:
            self.add_trace_view(gui)
        if self.model.features is not None:
            self.add_feature_view(gui)
        if self.model.template_features is not None:
            self.add_template_feature_view(gui)
        self.add_correlogram_view(gui)
        if self.model.amplitudes is not None:
            self.add_amplitude_view(gui)
        self.add_probe_view(gui)

        # Save the memcache when closing the GUI.
        @gui.connect_
        def on_close():
            self.context.save_memcache()

        self.emit('gui_ready', gui)

        return gui
Example #5
0
class KwikController(EventEmitter):
    gui_name = 'KwikGUI'

    n_spikes_waveforms = 100
    batch_size_waveforms = 10

    n_spikes_features = 10000
    n_spikes_amplitudes = 10000

    n_spikes_close_clusters = 100
    n_closest_channels = 16

    def __init__(self, kwik_path, config_dir=None, **kwargs):
        super(KwikController, self).__init__()
        kwik_path = op.realpath(kwik_path)
        _backup(kwik_path)
        self.model = KwikModel(kwik_path, **kwargs)
        m = self.model
        self.channel_vertical_order = np.argsort(m.channel_positions[:, 1])
        self.distance_max = _get_distance_max(self.model.channel_positions)
        self.cache_dir = op.join(op.dirname(kwik_path), '.phy')
        cg = kwargs.get('channel_group', None)
        if cg is not None:
            self.cache_dir = op.join(self.cache_dir, str(cg))
        self.context = Context(self.cache_dir)
        self.config_dir = config_dir

        self._set_cache()
        self.supervisor = self._set_supervisor()
        self.selector = self._set_selector()
        self.color_selector = ColorSelector()

        self._show_all_spikes = False

        attach_plugins(self,
                       plugins=kwargs.get('plugins', None),
                       config_dir=config_dir)

    # Internal methods
    # -------------------------------------------------------------------------

    def _set_cache(self):
        memcached = (
            'get_best_channels',
            'get_probe_depth',
            '_get_mean_masks',
            '_get_mean_waveforms',
        )
        cached = (
            '_get_waveforms',
            '_get_features',
            '_get_masks',
        )
        _cache_methods(self, memcached, cached)

    def _set_supervisor(self):
        # Load the new cluster id.
        new_cluster_id = self.context.load('new_cluster_id'). \
            get('new_cluster_id', None)
        cluster_groups = self.model.cluster_groups
        supervisor = Supervisor(
            self.model.spike_clusters,
            similarity=self.similarity,
            cluster_groups=cluster_groups,
            new_cluster_id=new_cluster_id,
            context=self.context,
        )

        @supervisor.connect
        def on_create_cluster_views():

            supervisor.add_column(self.get_best_channel, name='channel')
            supervisor.add_column(self.get_probe_depth, name='depth')

            @supervisor.actions.add
            def recluster():
                """Relaunch KlustaKwik on the selected clusters."""
                # Selected clusters.
                cluster_ids = supervisor.selected
                spike_ids = self.selector.select_spikes(cluster_ids)
                logger.info("Running KlustaKwik on %d spikes.", len(spike_ids))

                # Run KK2 in a temporary directory to avoid side effects.
                n = 10
                with TemporaryDirectory() as tempdir:
                    spike_clusters, metadata = cluster(
                        self.model,
                        spike_ids,
                        num_starting_clusters=n,
                        tempdir=tempdir,
                    )
                self.supervisor.split(spike_ids, spike_clusters)

        # Save.
        @supervisor.connect
        def on_request_save(spike_clusters, groups, *labels):
            """Save the modified data."""
            groups = {c: g.title() for c, g in groups.items()}
            self.model.save(spike_clusters, groups)

        return supervisor

    def _set_selector(self):
        def spikes_per_cluster(cluster_id):
            return self.supervisor.clustering.spikes_per_cluster[cluster_id]

        return Selector(spikes_per_cluster)

    def _add_view(self, gui, view):
        view.attach(gui)
        self.emit('add_view', gui, view)
        return view

    # Model methods
    # -------------------------------------------------------------------------

    def get_best_channel(self, cluster_id):
        return self.get_best_channels(cluster_id)[0]

    def get_best_channels(self, cluster_id):
        """Only used in the trace view."""
        mm = self._get_mean_masks(cluster_id)
        channel_ids = np.argsort(mm)[::-1]
        ind = mm[channel_ids] > .1
        if np.sum(ind) > 0:
            channel_ids = channel_ids[ind]
        else:
            channel_ids = channel_ids[:4]
        return channel_ids

    def get_cluster_position(self, cluster_id):
        channel_id = self.get_best_channel(cluster_id)
        return self.model.channel_positions[channel_id]

    def get_probe_depth(self, cluster_id):
        return self.get_cluster_position(cluster_id)[1]

    def similarity(self, cluster_id):
        """Return the list of similar clusters to a given cluster."""

        pos_i = self.get_cluster_position(cluster_id)
        assert len(pos_i) == 2

        def _sim_ij(cj):
            """Distance between channel position of clusters i and j."""
            pos_j = self.get_cluster_position(cj)
            assert len(pos_j) == 2
            d = np.sqrt(np.sum((pos_j - pos_i)**2))
            return self.distance_max - d

        out = [(cj, _sim_ij(cj))
               for cj in self.supervisor.clustering.cluster_ids]
        return sorted(out, key=itemgetter(1), reverse=True)

    # Waveforms
    # -------------------------------------------------------------------------

    def _get_masks(self, cluster_id):
        spike_ids = self.selector.select_spikes(
            [cluster_id],
            self.n_spikes_waveforms,
            self.batch_size_waveforms,
        )
        return self.model.all_masks[spike_ids]

    def _get_mean_masks(self, cluster_id):
        return np.mean(self._get_masks(cluster_id), axis=0)

    def _get_waveforms(self, cluster_id):
        """Return a selection of waveforms for a cluster."""
        pos = self.model.channel_positions
        spike_ids = self.selector.select_spikes(
            [cluster_id],
            self.n_spikes_waveforms,
            self.batch_size_waveforms,
        )
        data = self.model.all_waveforms[spike_ids]
        mm = self._get_mean_masks(cluster_id)
        mw = np.mean(data, axis=0)
        amp = get_waveform_amplitude(mm, mw)
        masks = self._get_masks(cluster_id)
        # Find the best channels.
        channel_ids = np.argsort(amp)[::-1]
        return Bunch(
            data=data[..., channel_ids],
            channel_ids=channel_ids,
            channel_positions=pos[channel_ids],
            masks=masks[:, channel_ids],
        )

    def _get_mean_waveforms(self, cluster_id):
        b = self._get_waveforms(cluster_id).copy()
        b.data = np.mean(b.data, axis=0)[np.newaxis, ...]
        b.masks = np.mean(b.masks, axis=0)[np.newaxis, ...]**.1
        b['alpha'] = 1.
        return b

    def add_waveform_view(self, gui):
        v = WaveformView(waveforms=self._get_waveforms, )
        v = self._add_view(gui, v)

        v.actions.separator()

        @v.actions.add(shortcut='m')
        def toggle_mean_waveforms():
            f, g = self._get_waveforms, self._get_mean_waveforms
            v.waveforms = f if v.waveforms == g else g
            v.on_select()

        return v

    # Features
    # -------------------------------------------------------------------------

    def _get_spike_ids(self, cluster_id=None, load_all=None):
        nsf = self.n_spikes_features
        if cluster_id is None:
            # Background points.
            ns = self.model.n_spikes
            return np.arange(0, ns, max(1, ns // nsf))
        else:
            # Load all spikes from the cluster if load_all is True.
            n = nsf if not load_all else None
            return self.selector.select_spikes([cluster_id], n)

    def _get_spike_times(self, cluster_id=None, load_all=None):
        spike_ids = self._get_spike_ids(cluster_id, load_all=load_all)
        return Bunch(data=self.model.spike_times[spike_ids],
                     lim=(0., self.model.duration))

    def _get_features(self, cluster_id=None, channel_ids=None, load_all=None):
        spike_ids = self._get_spike_ids(cluster_id, load_all=load_all)
        # Use the best channels only if a cluster is specified and
        # channels are not specified.
        if cluster_id is not None and channel_ids is None:
            channel_ids = self.get_best_channels(cluster_id)
        f = self.model.all_features[spike_ids][:, channel_ids]
        m = self.model.all_masks[spike_ids][:, channel_ids]
        return Bunch(
            data=f,
            masks=m,
            spike_ids=spike_ids,
            channel_ids=channel_ids,
        )

    def add_feature_view(self, gui):
        v = FeatureView(features=self._get_features,
                        attributes={'time': self._get_spike_times})
        return self._add_view(gui, v)

    # Traces
    # -------------------------------------------------------------------------

    def _get_traces(self, interval):
        """Get traces and spike waveforms."""
        ns = self.model.n_samples_waveforms
        m = self.model
        c = self.channel_vertical_order

        traces_interval = select_traces(m.traces,
                                        interval,
                                        sample_rate=m.sample_rate)
        # Reorder vertically.
        traces_interval = traces_interval[:, c]

        def gbc(cluster_id):
            ch = self.get_best_channels(cluster_id)
            return ch

        out = Bunch(data=traces_interval)
        out.waveforms = []
        for b in _iter_spike_waveforms(
                interval=interval,
                traces_interval=traces_interval,
                model=self.model,
                supervisor=self.supervisor,
                color_selector=self.color_selector,
                n_samples_waveforms=ns,
                get_best_channels=gbc,
                show_all_spikes=self._show_all_spikes,
        ):
            b.channel_labels = m.channel_order[b.channel_ids]
            out.waveforms.append(b)
        return out

    def _jump_to_spike(self, view, delta=+1):
        """Jump to next or previous spike from the selected clusters."""
        m = self.model
        cluster_ids = self.supervisor.selected
        if len(cluster_ids) == 0:
            return
        spc = self.supervisor.clustering.spikes_per_cluster
        spike_ids = spc[cluster_ids[0]]
        spike_times = m.spike_times[spike_ids]
        ind = np.searchsorted(spike_times, view.time)
        n = len(spike_times)
        view.go_to(spike_times[(ind + delta) % n])

    def add_trace_view(self, gui):
        m = self.model
        v = TraceView(
            traces=self._get_traces,
            n_channels=m.n_channels,
            sample_rate=m.sample_rate,
            duration=m.duration,
            channel_vertical_order=self.channel_vertical_order,
        )
        self._add_view(gui, v)

        v.actions.separator()

        @v.actions.add(shortcut='alt+pgdown')
        def go_to_next_spike():
            """Jump to the next spike from the first selected cluster."""
            self._jump_to_spike(v, +1)

        @v.actions.add(shortcut='alt+pgup')
        def go_to_previous_spike():
            """Jump to the previous spike from the first selected cluster."""
            self._jump_to_spike(v, -1)

        v.actions.separator()

        @v.actions.add(shortcut='alt+s')
        def toggle_highlighted_spikes():
            """Toggle between showing all spikes or selected spikes."""
            self._show_all_spikes = not self._show_all_spikes
            v.set_interval(force_update=True)

        @gui.connect_
        def on_spike_click(channel_id=None, spike_id=None, cluster_id=None):
            # Select the corresponding cluster.
            self.supervisor.select([cluster_id])
            # Update the trace view.
            v.on_select([cluster_id], force_update=True)

        return v

    # Correlograms
    # -------------------------------------------------------------------------

    def _get_correlograms(self, cluster_ids, bin_size, window_size):
        spike_ids = self.selector.select_spikes(cluster_ids, 100000)
        st = self.model.spike_times[spike_ids]
        sc = self.supervisor.clustering.spike_clusters[spike_ids]
        return correlograms(
            st,
            sc,
            sample_rate=self.model.sample_rate,
            cluster_ids=cluster_ids,
            bin_size=bin_size,
            window_size=window_size,
        )

    def add_correlogram_view(self, gui):
        m = self.model
        v = CorrelogramView(
            correlograms=self._get_correlograms,
            sample_rate=m.sample_rate,
        )
        return self._add_view(gui, v)

    # GUI
    # -------------------------------------------------------------------------

    def create_gui(self, **kwargs):
        gui = GUI(name=self.gui_name,
                  subtitle=self.model.kwik_path,
                  config_dir=self.config_dir,
                  **kwargs)

        self.supervisor.attach(gui)

        self.add_waveform_view(gui)
        if self.model.traces is not None:
            self.add_trace_view(gui)
        self.add_feature_view(gui)
        self.add_correlogram_view(gui)

        # Save the memcache when closing the GUI.
        @gui.connect_
        def on_close():
            self.context.save_memcache()

        self.emit('gui_ready', gui)

        return gui
Example #6
0
class NeoController(EventEmitter):
    gui_name = 'NeoGUI'

    n_spikes_waveforms = 200
    batch_size_waveforms = 200

    n_spikes_features = 10000
    n_spikes_amplitudes = 10000
    n_spikes_correlograms = 100000

    def __init__(self, data_path, config_dir=None, **kwargs):
        super(NeoController, self).__init__()
        self.model = NeoModel(data_path, **kwargs)
        self.distance_max = _get_distance_max(self.model.channel_positions)
        self.cache_dir = op.join(self.model.output_dir, '.phy')
        cg = kwargs.get('channel_group', None)
        cg = cg or 0
        self.cache_dir = op.join(self.cache_dir, 'channel_group_' + str(cg))
        self.context = Context(self.cache_dir)
        self.config_dir = config_dir
        self._set_cache()
        self.supervisor = self._set_supervisor()
        self.selector = self._set_selector()
        self.color_selector = ColorSelector()

        self._show_all_spikes = False

        attach_plugins(self,
                       plugins=kwargs.get('plugins', None),
                       config_dir=config_dir)

    # Internal methods
    # -------------------------------------------------------------------------

    def _set_cache(self):
        memcached = (
            'get_best_channels',
            #'get_probe_depth',
            '_get_mean_waveforms',
        )
        cached = (
            '_get_waveforms',
            '_get_features',
        )
        _cache_methods(self, memcached, cached)

    def _set_supervisor(self):
        # Load the new cluster id.
        new_cluster_id = self.context.load('new_cluster_id'). \
            get('new_cluster_id', None)
        cluster_groups = self.model.cluster_groups
        supervisor = Supervisor(
            self.model.spike_clusters,
            similarity=self.similarity,
            cluster_groups=cluster_groups,
            new_cluster_id=new_cluster_id,
            context=self.context,
        )

        @supervisor.connect
        def on_create_cluster_views():

            supervisor.add_column(self.get_best_channel, name='channel')
            supervisor.add_column(self.get_probe_depth, name='depth')

            @supervisor.actions.add
            def recluster():
                """Relaunch KlustaKwik on the selected clusters."""
                # Selected clusters.
                cluster_ids = supervisor.selected  # TODO can you have multiselect here?
                spike_ids = self.selector.select_spikes(cluster_ids)
                logger.info("Running KlustaKwik on %d spikes.", len(spike_ids))
                print('***********************')
                print('Fix this wierd fix, cant send in list (cluster_ids)')
                channel_ids = self.get_best_channels(
                    cluster_ids[0]
                )  # TODO sending several cluster_ids to get best channels ?
                spike_clusters = self.model.cluster(spike_ids, channel_ids)
                self.supervisor.split(spike_ids, spike_clusters)

        # Save.
        @supervisor.connect
        def on_request_save(spike_clusters, groups, *labels):
            """Save the modified data."""
            # Save the clusters.
            groups = {c: g.title() for c, g in groups.items()}
            self.model.save(spike_clusters, groups, *labels)

        return supervisor

    def _set_selector(self):
        def spikes_per_cluster(cluster_id):
            return self.supervisor.clustering.spikes_per_cluster[cluster_id]

        return Selector(spikes_per_cluster)

    def _add_view(self, gui, view):
        view.attach(gui)
        self.emit('add_view', gui, view)
        return view

    # Model methods
    # -------------------------------------------------------------------------

    def get_best_channel(self, cluster_id):
        channel_ids = self.get_best_channels(cluster_id)
        amps = self._get_mean_waveforms(cluster_id).data[0].min(axis=0)
        channel_id = channel_ids[np.argmin(amps)]
        return channel_id

    def get_best_channels(self, cluster_ids):  # TODO
        return np.arange(self.model.n_chans)

    def get_cluster_position(self, cluster_id):
        channel_id = self.get_best_channel(cluster_id)
        return self.model.channel_positions[channel_id]

    def get_probe_depth(self, cluster_id):
        channel_id = self.get_best_channel(cluster_id)
        return self.model.channel_positions[channel_id][1]

    def similarity(self, cluster_id):
        """Return the list of similar clusters to a given cluster."""

        pos_i = self.get_cluster_position(cluster_id)

        def _sim_ij(cj):
            """Distance between channel position of clusters i and j."""
            pos_j = self.get_cluster_position(cj)
            d = np.sqrt(np.sum((pos_j - pos_i)**2))
            return self.distance_max - d

        out = [(cj, _sim_ij(cj))
               for cj in self.supervisor.clustering.cluster_ids]
        return sorted(out, key=itemgetter(1), reverse=True)

    # Waveforms
    # -------------------------------------------------------------------------

    def _get_waveforms(self, cluster_id):
        """Return a selection of waveforms for a cluster."""
        pos = self.model.channel_positions
        spike_ids = self.selector.select_spikes(
            [cluster_id],
            self.n_spikes_waveforms,
            self.batch_size_waveforms,
        )
        channel_ids = self.get_best_channels(cluster_id)
        data = self.model.get_waveforms(spike_ids, channel_ids)
        return Bunch(data=data,
                     channel_ids=channel_ids,
                     channel_positions=pos[channel_ids],
                     alpha=0.25)

    def _get_mean_waveforms(self, cluster_id):
        b = self._get_waveforms(cluster_id)
        b.data = b.data.mean(axis=0)[np.newaxis, ...]
        b['alpha'] = 1.
        return b

    def add_waveform_view(self, gui):
        v = WaveformView(waveforms=self._get_waveforms, )
        v = self._add_view(gui, v)

        v.actions.separator()

        @v.actions.add(shortcut='m')
        def toggle_mean_waveforms():
            f, g = self._get_waveforms, self._get_mean_waveforms
            v.waveforms = f if v.waveforms == g else g
            v.on_select()

        return v

    # Features
    # -------------------------------------------------------------------------

    def _get_spike_ids(self, cluster_id=None, load_all=None):
        nsf = self.n_spikes_features
        if cluster_id is None:
            # Background points.
            ns = self.model.n_spikes
            return np.arange(0, ns, max(1, ns // nsf))
        else:
            # Load all spikes from the cluster if load_all is True.
            n = nsf if not load_all else None
            return self.selector.select_spikes([cluster_id], n)

    def _get_spike_times(self, cluster_id=None, load_all=None):
        spike_ids = self._get_spike_ids(cluster_id)
        return Bunch(data=self.model.spike_times[spike_ids],
                     lim=(0., self.model.duration))

    def _get_features(self, cluster_id=None, channel_ids=None, load_all=None):
        spike_ids = self._get_spike_ids(cluster_id, load_all=load_all)
        # Use the best channels only if a cluster is specified and
        # channels are not specified.
        if cluster_id is not None and channel_ids is None:
            channel_ids = self.get_best_channels(cluster_id)
        f = self.model.features[spike_ids][:, channel_ids]
        m = self.model.masks[spike_ids][:, channel_ids]
        return Bunch(
            data=f,
            masks=m,
            spike_ids=spike_ids,
            channel_ids=channel_ids,
        )

    def add_feature_view(self, gui):
        v = FeatureView(features=self._get_features,
                        attributes={'time': self._get_spike_times})
        return self._add_view(gui, v)

    # Correlograms
    # -------------------------------------------------------------------------

    def _get_correlograms(self, cluster_ids, bin_size, window_size):
        spike_ids = self.selector.select_spikes(
            cluster_ids,
            self.n_spikes_correlograms,
            subset='random',
        )
        st = self.model.spike_times[spike_ids]
        sc = self.supervisor.clustering.spike_clusters[spike_ids]
        return correlograms(
            st,
            sc,
            sample_rate=self.model.sample_rate,
            cluster_ids=cluster_ids,
            bin_size=bin_size,
            window_size=window_size,
        )

    def add_correlogram_view(self, gui):
        m = self.model
        v = CorrelogramView(
            correlograms=self._get_correlograms,
            sample_rate=m.sample_rate,
        )
        return self._add_view(gui, v)

    # Amplitudes
    # -------------------------------------------------------------------------

    def _get_amplitudes(self, cluster_id):
        n = self.n_spikes_amplitudes
        m = self.model
        spike_ids = self.selector.select_spikes([cluster_id], n)
        channel_id = self.get_best_channel(cluster_id)
        x = m.spike_times[spike_ids]
        y = m.amplitudes[spike_ids, channel_id]
        return Bunch(x=x, y=y, data_bounds=(0., y.min(), m.duration, y.max()))

    def add_amplitude_view(self, gui):
        v = AmplitudeView(coords=self._get_amplitudes, )
        return self._add_view(gui, v)

    # GUI
    # -------------------------------------------------------------------------

    def create_gui(self, **kwargs):
        gui = GUI(name=self.gui_name,
                  subtitle=self.model.data_path,
                  config_dir=self.config_dir,
                  **kwargs)

        self.supervisor.attach(gui)

        self.add_waveform_view(gui)
        if self.model.features is not None:
            self.add_feature_view(gui)
        self.add_correlogram_view(gui)
        if self.model.amplitudes is not None:
            self.add_amplitude_view(gui)

        # Save the memcache when closing the GUI.
        @gui.connect_
        def on_close():
            self.context.save_memcache()

        self.emit('gui_ready', gui)

        return gui