Example #1
0
    def _get_clu_positions(self, cluster_ids):
        """Get the positions of the channels containing selected clusters."""

        # List of channels per cluster.
        cluster_channels = {i: self.best_channels(cl) for i, cl in enumerate(cluster_ids)}

        # List of clusters per channel.
        clusters_per_channel = defaultdict(lambda: [])
        for clu_idx, channels in cluster_channels.items():
            for channel in channels:
                clusters_per_channel[channel].append(clu_idx)

        # Enumerate the discs for each channel.
        w = self.data_bounds[2] - self.data_bounds[0]
        clu_pos = []
        clu_colors = []
        for channel_id, (x, y) in enumerate(self.positions):
            for i, clu_idx in enumerate(clusters_per_channel[channel_id]):
                n = len(clusters_per_channel[channel_id])
                # Translation.
                t = .025 * w * (i - .5 * (n - 1))
                x += t
                alpha = 1.0 if channel_id not in self.dead_channels else self.dead_channel_alpha
                clu_pos.append((x, y))
                clu_colors.append(selected_cluster_color(clu_idx, alpha=alpha))
        return np.array(clu_pos), np.array(clu_colors)
Example #2
0
def _get_point_color(clu_idx=None):
    if clu_idx is not None:
        color = selected_cluster_color(clu_idx, .5)
    else:
        color = (.5,) * 4
    assert len(color) == 4
    return color
Example #3
0
 def get_clusters_data(self, load_all=None):
     """Return a list of Bunch instances, with attributes pos and spike_ids."""
     if not len(self.cluster_ids):
         return
     cluster_ids = list(self.cluster_ids)
     # Don't need the background when splitting.
     if not load_all:
         # Add None cluster which means background spikes.
         cluster_ids = [None] + cluster_ids
     bunchs = self.amplitudes[self.amplitudes_type](cluster_ids,
                                                    load_all=load_all) or ()
     # Add a pos attribute in bunchs in addition to x and y.
     for i, (cluster_id, bunch) in enumerate(zip(cluster_ids, bunchs)):
         spike_ids = _as_array(bunch.spike_ids)
         spike_times = _as_array(bunch.spike_times)
         amplitudes = _as_array(bunch.amplitudes)
         assert spike_ids.shape == spike_times.shape == amplitudes.shape
         # Ensure that bunch.pos exists, as it used by the LassoMixin.
         bunch.pos = np.c_[spike_times, amplitudes]
         assert bunch.pos.ndim == 2
         bunch.cluster_id = cluster_id
         bunch.color = (
             selected_cluster_color(i - 1, self.marker_alpha)
             # Background amplitude color.
             if cluster_id is not None else (.5, .5, .5, .5))
     return bunchs
Example #4
0
    def get_clusters_data(self, load_all=None):
        bunchs = []
        for i, cluster_id in enumerate(self.cluster_ids):
            bunch = self.cluster_stat(cluster_id)
            if not bunch.data.size:
                continue
            bmin, bmax = bunch.data.min(), bunch.data.max()
            # Update self.x_max if it was not set before.
            self.x_min = self.x_min or bunch.get('x_min', None) or bmin
            self.x_max = self.x_max or bunch.get('x_max', None) or bmax
            self.x_min = min(self.x_min, self.x_max)
            assert self.x_min is not None
            assert self.x_max is not None
            assert self.x_min <= self.x_max

            # Compute the histogram.
            bunch.histogram = _compute_histogram(
                bunch.data, x_min=self.x_min, x_max=self.x_max, n_bins=self.n_bins)
            bunch.ylim = bunch.histogram.max()

            bunch.color = selected_cluster_color(i)
            bunch.index = i
            bunch.cluster_id = cluster_id
            bunchs.append(bunch)
        return bunchs
Example #5
0
File: trace.py Project: zsong30/phy
    def _plot_spike(self, bunch):
        # The spike time corresponds to the first sample of the waveform.
        n_samples, n_channels = bunch.data.shape
        assert len(bunch.channel_ids) == n_channels

        # Generate the x coordinates of the waveform.
        t = bunch.start_time + self.dt * np.arange(n_samples)
        t = np.tile(t, (n_channels, 1))  # (n_unmasked_channels, n_samples)

        # Determine the spike color.
        i = bunch.select_index
        c = bunch.spike_cluster
        cs = self.color_schemes.get()
        color = selected_cluster_color(i, alpha=1) if i is not None else cs.get(c, alpha=1)

        # We could tweak the color of each spike waveform depending on the template amplitude
        # on each of its best channels.
        # channel_amps = bunch.get('channel_amps', None)
        # if channel_amps is not None:
        #     color = np.tile(color, (n_channels, 1))
        #     assert color.shape == (n_channels, 4)
        #     color[:, 3] = channel_amps

        # The box index depends on the channel.
        box_index = self.channel_y_ranks[bunch.channel_ids]
        box_index = np.repeat(box_index[:, np.newaxis], n_samples, axis=0)
        self.waveform_visual.add_batch_data(
            box_index=box_index,
            x=t, y=bunch.data.T, color=color,
            data_bounds=self.data_bounds,
        )
Example #6
0
            def on_select(sender, cluster_ids=None, **kwargs):
                view = gui.get_view(ClusterView)

                # Get selected channels
                channels = [sender.get_cluster_info(c)['ch']
                            for c in cluster_ids]
                channels, c_ids = np.unique(channels, return_index=True)
                channels = channels.tolist()

                # Get cluster colors
                colors = [selected_cluster_color(i, alpha=1)
                          for i in range(len(cluster_ids))]
                colors = (np.asarray(colors)[c_ids] * 255).astype(int)
                colors = [('rgba(' + ', '.join(map(str, c[:3])) + ', 0.2)')
                          for c in colors]

                clust = dict()
                for c in sender.clustering.cluster_ids:
                    ch = sender.get_cluster_info(c)['ch']
                    if ch in channels:
                        clust[str(c)] = colors[channels.index(ch)]

                js = """
                    var ll = """ + str(clust) + """;
                    var itms = document.getElementsByTagName("tr");
                    var chng = []
                    for (var i = 0; i < itms.length; i++) {
                        var c_id = itms[i].getAttribute('data-_id');

                        // New clusters do not have this attribute
                        if (!c_id) {
                            c_id = itms[i].getElementsByClassName('id')
                            if (c_id.length) {
                                c_id = c_id[0].innerHTML;
                            } else {
                                continue;
                            }
                        };

                        if (Object.keys(ll).indexOf(c_id) >= 0) {
                            itms[i].style.background = ll[c_id];
                            chng.push(c_id);
                        } else {
                            itms[i].style.background = '';
                        }
                    }

                    // Report highlighted clusters to callback function
                    chng
                """

                def report(obj):
                    logger.debug('Highlighted clusters %s.',
                                 ', '.join(obj) if len(obj) > 0 else 'none')

                view.eval_js(js, callback=report)
Example #7
0
 def _get_split_cluster_data(self, bunchs):
     """Get the data when there is one Bunch per cluster."""
     # Add a pos attribute in bunchs in addition to x and y.
     for i, (cluster_id, bunch) in enumerate(zip(self.cluster_ids, bunchs)):
         bunch.cluster_id = cluster_id
         if 'pos' not in bunch:
             assert bunch.x.ndim == 1
             assert bunch.x.shape == bunch.y.shape
             bunch.pos = np.c_[bunch.x, bunch.y]
         assert bunch.pos.ndim == 2
         assert 'spike_ids' in bunch
         bunch.color = selected_cluster_color(i, .75)
     return bunchs
Example #8
0
 def get_clusters_data(self):
     if self.waveforms_type not in self.waveforms:
         return
     bunchs = [
         self.waveforms_types.get()(cluster_id)
         for cluster_id in self.cluster_ids
     ]
     clu_offsets = _get_clu_offsets(bunchs)
     n_clu = max(clu_offsets) + 1
     # Offset depending on the overlap.
     for i, (bunch, offset) in enumerate(zip(bunchs, clu_offsets)):
         bunch.index = i
         bunch.offset = offset
         bunch.n_clu = n_clu
         bunch.color = selected_cluster_color(i, bunch.get('alpha', .75))
     return bunchs
Example #9
0
 def get_clusters_data(self, load_all=None):
     ccg = self.correlograms(self.cluster_ids, self.bin_size,
                             self.window_size)
     fr = self.firing_rate(self.cluster_ids,
                           self.bin_size) if self.firing_rate else None
     assert ccg.ndim == 3
     n_bins = ccg.shape[2]
     bunchs = []
     m = ccg.max()
     for i, j in self._iter_subplots(len(self.cluster_ids)):
         b = Bunch()
         b.correlogram = ccg[i, j, :]
         if not self.uniform_normalization:
             # Normalization row per row.
             m = ccg[i, j, :].max()
         b.firing_rate = fr[i, j] if fr is not None else None
         b.data_bounds = (0, 0, n_bins, m)
         b.pair_index = i, j
         b.color = selected_cluster_color(i, 1)
         if i != j:
             b.color = add_alpha(_override_hsv(b.color[:3], s=.1, v=1))
         bunchs.append(b)
     return bunchs
Example #10
0
    def _plot_spike(self, bunch):
        # The spike time corresponds to the first sample of the waveform.
        n_samples, n_channels = bunch.data.shape
        assert len(bunch.channel_ids) == n_channels

        # Generate the x coordinates of the waveform.
        t = bunch.start_time + self.dt * np.arange(n_samples)
        t = np.tile(t, (n_channels, 1))  # (n_unmasked_channels, n_samples)

        # Determine the spike color.
        i = bunch.select_index
        c = bunch.spike_cluster
        cs = self.color_schemes.get()
        color = selected_cluster_color(i, alpha=1) if i is not None else cs.get(c, alpha=1)

        # The box index depends on the channel.
        box_index = self.channel_y_ranks[bunch.channel_ids]
        box_index = np.repeat(box_index[:, np.newaxis], n_samples, axis=0)
        self.waveform_visual.add_batch_data(
            box_index=box_index,
            x=t, y=bunch.data.T, color=color,
            data_bounds=self.data_bounds,
        )
Example #11
0
                def on_select(sender, cluster_ids=None, **kwargs):
                    if not cluster_ids:
                        view.ch = []
                        view.ch_colors = []
                        return

                    # Get selected channels
                    view.ch = [
                        controller.supervisor.get_cluster_info(c)['ch']
                        for c in cluster_ids
                    ]
                    view.ch, c_ids = np.unique(view.ch, return_index=True)
                    view.ch = view.ch.tolist()

                    # Get cluster colors
                    view.ch_colors = [
                        selected_cluster_color(i, alpha=1)
                        for i in range(len(cluster_ids))
                    ]
                    view.ch_colors = np.asarray(view.ch_colors)[c_ids]

                    # Update highlighting
                    view._plot_labels(None)
Example #12
0
    def on_select(self, cluster_ids=(), **kwargs):
        """
        The main method to implement in ManualClusteringView is `on_select()`, called whenever
        new clusters are selected.

        *Note*: `cluster_ids` contains the clusters selected in the cluster view, followed
        by clusters selected in the similarity view.

        """
        """
        This method should always start with these few lines of code.
        """
        self.cluster_ids = cluster_ids
        if not cluster_ids:
            return
        """
        We update the number of boxes in the stacked layout, which is the number of
        selected clusters.
        """
        self.canvas.stacked.n_boxes = len(cluster_ids)
        """
        We obtain the template data.
        """
        bunchs = {
            cluster_id: self.templates(cluster_id).data
            for cluster_id in cluster_ids
        }
        """
        For performance reasons, it is best to use as few visuals as possible. In this example,
        we want 1 waveform template per subplot. We will use a single visual covering all
        subplots at once. This is the key to achieve good performance with OpenGL in Python.
        However, this comes with the drawback that the programming interface is more complicated.

        In principle, we would have to concatenate all data (x and y coordinates) of all subplots
        to pass it to `self.visual.set_data()` in order to draw all subplots at once. But this
        is tedious.

        phy uses the notion of **batch**: for each subplot, we set *partial data* for the subplot
        which just prepares the data for concatenation *after* we're done with looping through
        all clusters. The concatenation happens in the special call
        `self.canvas.update_visual(self.visual)`.

        We need to call `visual.reset_batch()` before constructing a batch.

        """
        self.visual.reset_batch()
        """
        We iterate through all selected clusters.
        """
        for idx, cluster_id in enumerate(cluster_ids):
            bunch = bunchs[cluster_id]
            """
            In this example, we just keep the peak channel. Note that `bunch.template` is a
            2D array `(n_samples, n_channels)` where `n_channels` in the number of "best"
            channels for the cluster. The channels are sorted by decreasing template amplitude,
            so the first one is the peak channel. The channel ids can be found in
            `bunch.channel_ids`.
            """
            y = bunch.template[:, 0]
            """
            We decide to use, on the x axis, values ranging from -1 to 1. This is the
            standard viewport in OpenGL and phy.
            """
            x = np.linspace(-1., 1., len(y))
            """
            phy requires you to specify explicitly the x and y range of the plots.
            The `data_bounds` variable is a `(xmin, ymin, xmax, ymax)` tuple representing the
            lower-left and upper-right corners of a rectangle. By default, the data bounds
            of the entire view is (-1, -1, 1, 1), also called normalized device coordinates.
            Eventually, OpenGL uses this coordinate system for display, but phy provides
            a transform system to convert from different coordinate systems, both on the CPU
            and the GPU.

            Here, the x range is (-1, 1), and the y range is (m, M) where m and M are
            respectively the min and max of the template.
            """
            m, M = y.min(), y.max()
            data_bounds = (-1, m, +1, M)
            """
            This function gives the color of the i-th selected cluster. This is a 4-tuple with
            values between 0 and 1 for RGBA: red, green, blue, alpha channel (transparency,
            1 by default).
            """
            color = selected_cluster_color(idx)
            """
            The plot visual takes as input the x and y coordinates of the points, the color,
            and the data bounds.
            There is also a special keyword argument `box_index` which is the subplot index.
            In the stacked layout, this is just an integer identifying the subplot index, from
            top to bottom. Note that in the grid view, the box index is a pair (row, col).
            """
            self.visual.add_batch_data(x=x,
                                       y=y,
                                       color=color,
                                       data_bounds=data_bounds,
                                       box_index=idx)
        """
        After the loop, this special call automatically builds the data to upload to the GPU
        by concatenating the partial data set in `add_batch_data()`.
        """
        self.canvas.update_visual(self.visual)
        """
        After updating the data on the GPU, we need to refresh the canvas.
        """
        self.canvas.update()
Example #13
0
 def plot(self, **kwargs):
     for i in range(len(self.cluster_ids)):
         self.canvas.scatter(pos=.25 * np.random.randn(100, 2),
                             color=selected_cluster_color(i))
    def plotWaveforms(self):

        Nspk,Ntime,Nchan = self.wavefs.shape

        self.visual.reset_batch()
        self.text_visual.reset_batch()

        x = np.tile(np.linspace(-Ntime/2/self.Fs, Ntime/2/self.Fs, Ntime), (Nspk, 1))

        M=np.max(np.abs(self.wavefs[:,:,self.current_channel_idx]))
        #print(M*self.gain)
        if M*self.gain<100:
            M = 10*np.ceil(M*self.gain/10)
        elif M*self.gain<1000:
            M = 100*np.ceil(M*self.gain/100)
        else:
            M = 1000*np.floor(M*self.gain/1000)
        self.data_bounds = (x[0][0], -M, x[0][-1], M)

        colorwavef = selected_cluster_color(0)
        colormedian = selected_cluster_color(3)#(1,156/256,0,.5)#selected_cluster_color(1)
        colorstd = (0,1,0,1)#selected_cluster_color(2)
        colorqtl = (1,1,0,1)

        if Nspk>100:
            medianCl = np.median(self.wavefs[:,:,self.current_channel_idx],axis=0)
            stdCl = np.std(self.wavefs[:,:,self.current_channel_idx],axis=0)
            q1 = np.quantile(self.wavefs[:,:,self.current_channel_idx],.01,axis=0,interpolation='higher')
            q9 = np.quantile(self.wavefs[:,:,self.current_channel_idx],.99,axis=0,interpolation='lower')

        self.visual.add_batch_data(
                x=x, y=self.gain*self.wavefs[:,:,self.current_channel_idx], color=colorwavef, data_bounds=self.data_bounds, box_index=0)

        #stats
        if Nspk>100:
            x1 = x[0]
            self.visual.add_batch_data(
                    x=x1, y=self.gain*medianCl, color=colormedian, data_bounds=self.data_bounds, box_index=0)
            self.visual.add_batch_data(
                    x=x1, y=self.gain*(medianCl+3*stdCl), color=colorstd, data_bounds=self.data_bounds, box_index=0)
            self.visual.add_batch_data(
                    x=x1, y=self.gain*(medianCl-3*stdCl), color=colorstd, data_bounds=self.data_bounds, box_index=0)
            self.visual.add_batch_data(
                    x=x1, y=self.gain*q1, color=colorqtl, data_bounds=self.data_bounds, box_index=0)
            self.visual.add_batch_data(
                    x=x1, y=self.gain*q9, color=colorqtl, data_bounds=self.data_bounds, box_index=0)

        #axes
        self.text_visual.add_batch_data(
                pos=[.9, .98],
                text='[uV]',
                anchor=[-1, -1],
                box_index=0,
            )
        
        self.text_visual.add_batch_data(
                pos=[-1, -.95],
                text='[ms]',
                anchor=[1, 1],
                box_index=0,
            )

        label = 'Ch {a}'.format(a=self.channel_ids[self.current_channel_idx])
        self.text_visual.add_batch_data(
                pos=[-.98, .98],
                text=str(label),
                anchor=[1, -1],
                box_index=0,
            )
        self.canvas.update_visual(self.visual)
        self.canvas.update_visual(self.text_visual)
        self.canvas.axes.reset_data_bounds(self.data_bounds)

        self.canvas.update()