Esempio n. 1
0
File: views.py Progetto: mspacek/phy
    def on_select(self, cluster_ids=None):
        super(ScatterView, self).on_select(cluster_ids)
        cluster_ids = self.cluster_ids
        n_clusters = len(cluster_ids)
        if n_clusters == 0:
            return

        # Get the x and y coordinates.
        data = self.coords(cluster_ids)
        if data is None:
            self.clear()
            return
        assert isinstance(data, list)

        # Plot the points.
        with self.building():
            for i, cl in enumerate(cluster_ids):
                # Skip non-existing clusters.
                if i >= len(data):  # pragma: no cover
                    continue
                d = data[i]
                spike_ids = d.spike_ids
                x = d.x
                y = d.y
                data_bounds = d.get('data_bounds', 'auto')
                n_spikes = len(spike_ids)
                assert n_spikes > 0
                assert x.shape == (n_spikes,)
                assert y.shape == (n_spikes,)

                self.scatter(x=x, y=y,
                             color=tuple(_colormap(i)) + (.5,),
                             size=self._default_marker_size,
                             data_bounds=data_bounds,
                             )
Esempio n. 2
0
def _get_point_color(clu_idx=None):
    if clu_idx is not None:
        color = tuple(_colormap(clu_idx)) + (.5, )
    else:
        color = (.5, ) * 4
    assert len(color) == 4
    return color
Esempio n. 3
0
def _get_point_color(clu_idx=None):
    if clu_idx is not None:
        color = tuple(_colormap(clu_idx)) + (.5,)
    else:
        color = (.5,) * 4
    assert len(color) == 4
    return color
Esempio n. 4
0
    def _plot_features(self,
                       i,
                       j,
                       x_dim,
                       y_dim,
                       x,
                       y,
                       masks=None,
                       clu_idx=None):
        """Plot the features in a subplot."""
        assert x.shape == y.shape
        n_spikes = x.shape[0]

        if clu_idx is not None:
            color = tuple(_colormap(clu_idx)) + (.5, )
        else:
            color = (1., 1., 1., .5)
        assert len(color) == 4

        # Find the masks for the given subplot channel.
        if isinstance(x_dim[i, j], tuple):
            ch, fet = x_dim[i, j]
            # NOTE: we add the cluster relative index for the computation
            # of the depth on the GPU.
            m = masks[:, ch] * .999 + (clu_idx or 0)
        else:
            m = np.ones(n_spikes) * .999 + (clu_idx or 0)

        # Marker size, smaller for background features.
        size = self._default_marker_size if clu_idx is not None else 1.

        self[i, j].scatter(
            x=x,
            y=y,
            color=color,
            masks=m,
            size=size,
            data_bounds=None,
            uniform=True,
        )
        if i == 0:
            # HACK: call this when i=0 (first line) but plot the text
            # in the last subplot line. This is because we skip i > j
            # in the subplot loop.
            i0 = (self.n_cols - 1)
            dim = x_dim[i0, j] if j < (self.n_cols - 1) else x_dim[i0, 0]
            self[i0, j].text(
                pos=[0., -1.],
                text=str(dim),
                anchor=[0., -1.04],
                data_bounds=None,
            )
        if j == 0:
            self[i, j].text(
                pos=[-1., 0.],
                text=str(y_dim[i, j]),
                anchor=[-1.03, 0.],
                data_bounds=None,
            )
Esempio n. 5
0
    def _plot_points(self, bunchs, data_bounds):
        for i, d in enumerate(bunchs):
            x, y = d.x, d.y
            assert x.ndim == y.ndim == 1
            assert x.shape == y.shape

            self.scatter(x=x, y=y,
                         color=tuple(_colormap(i)) + (.5,),
                         size=self._default_marker_size,
                         data_bounds=data_bounds,
                         )
Esempio n. 6
0
    def _plot_waveforms(self, bunchs, bunchs_set, channel_ids, cluster_ids):
        # Initialize the box scaling the first time.
        if self.box_scaling[1] == 1.:
            M = np.max([np.max(np.abs(b.data)) for b in bunchs])
            self.box_scaling[1] = 1. / M
            self._update_boxes()
        clu_offsets = _get_clu_offsets(bunchs)
        max_clu_offsets = max(clu_offsets) + 1
        for i, d in enumerate(bunchs):
            wave = d.data
            alpha = d.get('alpha', .5)
            channel_ids_loc = d.channel_ids

            n_channels = len(channel_ids_loc)
            masks = d.get('masks', np.ones((wave.shape[0], n_channels)))
            # By default, this is 0, 1, 2 for the first 3 clusters.
            # But it can be customized when displaying several sets
            # of waveforms per cluster.
            # i = cluster_ids.index(d.cluster_id)  # 0, 1, 2, ...

            n_spikes_clu, n_samples = wave.shape[:2]
            assert wave.shape[2] == n_channels
            assert masks.shape == (n_spikes_clu, n_channels)

            # Find the x coordinates.
            t = _get_linear_x(n_spikes_clu * n_channels, n_samples)
            if not self.overlap:

                # Determine the cluster offset.
                offset = clu_offsets[i]
                t = t + 2.5 * (offset - (max_clu_offsets - 1) / 2.)
                # The total width should not depend on the number of
                # clusters.
                t /= max_clu_offsets

            # Get the spike masks.
            m = masks
            # HACK: on the GPU, we get the actual masks with fract(masks)
            # since we add the relative cluster index. We need to ensure
            # that the masks is never 1.0, otherwise it is interpreted as
            # 0.
            m *= .99999
            # NOTE: we add the cluster index which is used for the
            # computation of the depth on the GPU.
            m += i

            color = tuple(_colormap(i)) + (alpha,)
            assert len(color) == 4

            # Generate the box index (one number per channel).
            box_index = _index_of(channel_ids_loc, channel_ids)
            box_index = np.repeat(box_index, n_samples)
            box_index = np.tile(box_index, n_spikes_clu)
            assert box_index.shape == (n_spikes_clu *
                                       n_channels *
                                       n_samples,)

            # Generate the waveform array.
            wave = np.transpose(wave, (0, 2, 1))
            wave = wave.reshape((n_spikes_clu * n_channels, n_samples))

            self.uplot(x=t,
                       y=wave,
                       color=color,
                       masks=m,
                       box_index=box_index,
                       data_bounds=None,
                       )
            
        for i, d in enumerate(bunchs_set):
            wave = d.data ##### Equivalent to the data of module CellTypes.py
            clIds = str(cluster_ids).replace(' ', '')
            color = tuple(_colormap(i)) + (alpha,)
            color=color[:3]+(0.3,)
            #np.save('/home/ms047/Desktop/waveform385%s.npy'%clIds, wave)
            # Generate the waveform array.
            plot_wvf_feat_vispy(self, wave, color, i)
Esempio n. 7
0
File: views.py Progetto: mspacek/phy
    def on_select(self, cluster_ids=None):
        super(WaveformView, self).on_select(cluster_ids)
        cluster_ids = self.cluster_ids
        n_clusters = len(cluster_ids)
        if n_clusters == 0:
            return

        # Load the waveform subset.
        data = self.waveforms(cluster_ids)

        # Plot all waveforms.
        with self.building():
            already_shown = set()
            for i, d in enumerate(data):
                if (self.filtered_tags and
                        d.get('tag') not in self.filtered_tags):
                    continue  # pragma: no cover
                alpha = d.get('alpha', .5)
                wave = d.data
                masks = d.masks
                # By default, this is 0, 1, 2 for the first 3 clusters.
                # But it can be customized when displaying several sets
                # of waveforms per cluster.
                pos_idx = cluster_ids.index(d.cluster_id)  # 0, 1, 2, ...

                n_spikes_clu, n_samples, n_unmasked = wave.shape
                assert masks.shape[0] == n_spikes_clu

                # Find the unmasked channels for those spikes.
                unmasked = d.get('channels', np.arange(self.n_channels))
                assert n_unmasked == len(unmasked)
                assert n_unmasked > 0

                # Find the x coordinates.
                t = _get_linear_x(n_spikes_clu * n_unmasked, n_samples)
                if not self.overlap:
                    t = t + 2.5 * (pos_idx - (n_clusters - 1) / 2.)
                    # The total width should not depend on the number of
                    # clusters.
                    t /= n_clusters

                # Get the spike masks.
                m = masks[:, unmasked].reshape((-1, 1))
                # HACK: on the GPU, we get the actual masks with fract(masks)
                # since we add the relative cluster index. We need to ensure
                # that the masks is never 1.0, otherwise it is interpreted as
                # 0.
                m *= .999
                # NOTE: we add the cluster index which is used for the
                # computation of the depth on the GPU.
                m += pos_idx

                color = tuple(_colormap(pos_idx)) + (alpha,)
                assert len(color) == 4

                # Generate the box index (one number per channel).
                box_index = unmasked
                box_index = np.repeat(box_index, n_samples)
                box_index = np.tile(box_index, n_spikes_clu)
                assert box_index.shape == (n_spikes_clu * n_unmasked *
                                           n_samples,)

                # Generate the waveform array.
                wave = np.transpose(wave, (0, 2, 1))
                wave = wave.reshape((n_spikes_clu * n_unmasked, n_samples))

                self.plot(x=t,
                          y=wave,
                          color=color,
                          masks=m,
                          box_index=box_index,
                          data_bounds=None,
                          uniform=True,
                          )
                # Add channel labels.
                if self.do_show_labels:
                    for ch in unmasked:
                        # Skip labels that have already been shown.
                        if ch in already_shown:
                            continue
                        already_shown.add(ch)
                        ch_label = '%d' % self.channel_order[ch]
                        self[ch].text(pos=[t[0, 0], 0.],
                                      text=ch_label,
                                      anchor=[-1.01, -.25],
                                      data_bounds=None,
                                      )

        # Zoom on the best channels when selecting clusters.
        channels = self.best_channels(cluster_ids)
        if channels is not None and self.do_zoom_on_channels:
            self.zoom_on_channels(channels)
Esempio n. 8
0
    def _make_plots(self, clu_selection_idx, cluster_id):
        """
        Generate all plots for one cluster
        :param clu_selection_idx: index of cluster in the current selection
        :param cluster_id: ID of cluster
        :return:
        """
        if clu_selection_idx is not None:
            color = tuple(_colormap(clu_selection_idx))
            color_transp = color + (0.5, )
            color_solid = color + (1.0, )
        else:
            return
        assert len(color) == 3

        pos = self.tracking_data_shifted
        x = pos[self.valid_tracking, 1]
        y = pos[self.valid_tracking, 2]

        # Get indices of all spikes for the current cluster
        spikes_in_clu = np.isin(self.spike_clusters, cluster_id)
        spike_samples = self.spike_samples[spikes_in_clu & self.valid_spikes]

        inds_spike_tracking = _binary_search(pos[:, 0], spike_samples)
        valid_spikes = inds_spike_tracking >= 0
        inds_spike_tracking = inds_spike_tracking[valid_spikes]
        hd_tuning_curve, spike_hd = self._hd_tuning_curve(inds_spike_tracking)

        # Find range of X/Y tracking data
        min_x = np.min(x)
        min_y = np.min(y)
        max_x = np.max(x)
        max_y = np.max(y)
        mid_x = (min_x + max_x) / 2
        mid_y = (min_y + max_y) / 2
        rng_x = max_x - min_x
        rng_y = max_y - min_y
        max_rng = max(rng_x, rng_y)
        hw = max_rng / 2
        data_bounds = (mid_x - hw, mid_y - hw, mid_x + hw, mid_y + hw)

        # Plot path first time only
        if clu_selection_idx == 0:
            # Both normal and hd-coded
            for i in [0, 1]:
                self[0, i].uplot(x=pos[self.valid_tracking_time, 1].reshape(
                    (1, -1)),
                                 y=pos[self.valid_tracking_time, 2].reshape(
                                     (1, -1)),
                                 color=(1, 1, 1, 0.2),
                                 data_bounds=data_bounds)

        # Spike locations
        spikes_pos = pos[inds_spike_tracking, :]

        self[0, 0].scatter(x=spikes_pos[:, 1],
                           y=spikes_pos[:, 2],
                           color=color_transp,
                           size=self.spike_dot_size,
                           data_bounds=data_bounds)

        # Spike locations (HD-color-coded)
        spike_colors = _vector_to_rgb(spike_hd)
        self[0, 1].scatter(x=spikes_pos[:, 1],
                           y=spikes_pos[:, 2],
                           color=spike_colors,
                           size=self.spike_dot_size,
                           data_bounds=data_bounds)

        # Rate map contour
        contours = self._2d_rate_map_contours(inds_spike_tracking)
        v = np.linspace(0, 1, len(contours))
        for i, contour in enumerate(contours):
            contour_color = tuple(color) + (v[i], )
            for line in contour:
                self[1, 0].plot(x=line[:, 0],
                                y=line[:, 1],
                                color=contour_color,
                                data_bounds=data_bounds)

        # HD plot
        rho = 1.1
        (pol_x, pol_y) = _pol2cart(rho, np.linspace(0, 2 * math.pi, 1000))
        min_x = np.min(pol_x)
        min_y = np.min(pol_y)
        max_x = np.max(pol_x)
        max_y = np.max(pol_y)
        data_bounds = (min_x, min_y, max_x, max_y)

        # Plot axes first time only
        if clu_selection_idx == 0:
            self[1, 1].uplot(x=pol_x,
                             y=pol_y,
                             color=(1, 1, 1, 0.5),
                             data_bounds=data_bounds)

            self[1, 1].uplot(x=np.array([min_x, max_x]),
                             y=np.array([0, 0]) + (min_y + max_y) / 2,
                             color=(1, 1, 1, 0.3),
                             data_bounds=data_bounds)

            self[1, 1].uplot(y=np.array([min_y, max_y]),
                             x=np.array([0, 0]) + (min_x + max_x) / 2,
                             color=(1, 1, 1, 0.3),
                             data_bounds=data_bounds)

        bin_size_hd = self.bins['hd'][1] - self.bins['hd'][0]
        bin_centres_hd = self.bins['hd'][:-1] + bin_size_hd / 2
        (x, y) = _pol2cart(hd_tuning_curve, bin_centres_hd)

        # HD tuning curve
        self[1, 1].plot(x=x, y=y, color=color_solid, data_bounds=data_bounds)
Esempio n. 9
0
    def on_select(self, cluster_ids=None):
        super(WaveformView, self).on_select(cluster_ids)
        cluster_ids = self.cluster_ids
        n_clusters = len(cluster_ids)
        if n_clusters == 0:
            return

        # Load the waveform subset.
        data = self.waveforms(cluster_ids)
        # Take one element in the list.
        data = data[self.data_index % len(data)]
        alpha = data.get('alpha', .5)
        spike_ids = data.spike_ids
        spike_clusters = data.spike_clusters
        w = data.data
        masks = data.masks
        n_spikes = len(spike_ids)
        assert w.ndim == 3
        n_samples = w.shape[1]
        assert w.shape == (n_spikes, n_samples, self.n_channels)
        assert masks.shape == (n_spikes, self.n_channels)

        # Plot all waveforms.
        # OPTIM: avoid the loop.
        with self.building():
            for i, cl in enumerate(cluster_ids):

                # Select the spikes corresponding to a given cluster.
                idx = spike_clusters == cl
                n_spikes_clu = idx.sum()  # number of spikes in the cluster.

                # Find the x coordinates.
                t = _get_linear_x(n_spikes_clu * self.n_channels, n_samples)
                if not self.overlap:
                    t = t + 2.5 * (i - (n_clusters - 1) / 2.)
                    # The total width should not depend on the number of
                    # clusters.
                    t /= n_clusters

                # Get the spike masks.
                m = masks[idx, :].reshape((n_spikes_clu * self.n_channels, 1))
                # HACK: on the GPU, we get the actual masks with fract(masks)
                # since we add the relative cluster index. We need to ensure
                # that the masks is never 1.0, otherwise it is interpreted as
                # 0.
                m *= .999
                # NOTE: we add the cluster index which is used for the
                # computation of the depth on the GPU.
                m += i

                color = tuple(_colormap(i)) + (alpha, )
                assert len(color) == 4

                # Generate the box index (one number per channel).
                box_index = np.arange(self.n_channels)
                box_index = np.repeat(box_index, n_samples)
                box_index = np.tile(box_index, n_spikes_clu)
                assert box_index.shape == (n_spikes_clu * self.n_channels *
                                           n_samples)

                # Generate the waveform array.
                wave = w[idx, :, :]
                wave = np.transpose(wave, (0, 2, 1))
                wave = wave.reshape(
                    (n_spikes_clu * self.n_channels, n_samples))

                self.plot(
                    x=t,
                    y=wave,
                    color=color,
                    masks=m,
                    box_index=box_index,
                    data_bounds=None,
                    uniform=True,
                )
                # Add channel labels.
                if self.do_show_labels and i == 0:
                    for ch in range(self.n_channels):
                        self[ch].text(
                            pos=[t[0, 0], 0.],
                            # TODO: use real channel labels.
                            text=str(ch),
                            anchor=[-1.01, -.25],
                            data_bounds=None,
                        )

        # Zoom on the best channels when selecting clusters.
        channels = self.best_channels(cluster_ids)
        if channels is not None and self.do_zoom_on_channels:
            self.zoom_on_channels(channels)
Esempio n. 10
0
    def _plot_waveforms(self, bunchs, channel_ids):
        # Initialize the box scaling the first time.
        if self.box_scaling[1] == 1.:
            M = np.max([np.max(np.abs(b.data)) for b in bunchs])
            self.box_scaling[1] = 1. / M
            self._update_boxes()
        clu_offsets = _get_clu_offsets(bunchs)
        max_clu_offsets = max(clu_offsets) + 1
        for i, d in enumerate(bunchs):
            wave = d.data
            alpha = d.get('alpha', .5)
            channel_ids_loc = d.channel_ids

            n_channels = len(channel_ids_loc)
            masks = d.get('masks', np.ones((wave.shape[0], n_channels)))
            # By default, this is 0, 1, 2 for the first 3 clusters.
            # But it can be customized when displaying several sets
            # of waveforms per cluster.
            # i = cluster_ids.index(d.cluster_id)  # 0, 1, 2, ...

            n_spikes_clu, n_samples = wave.shape[:2]
            assert wave.shape[2] == n_channels
            assert masks.shape == (n_spikes_clu, n_channels)

            # Find the x coordinates.
            t = _get_linear_x(n_spikes_clu * n_channels, n_samples)
            if not self.overlap:

                # Determine the cluster offset.
                offset = clu_offsets[i]
                t = t + 2.5 * (offset - (max_clu_offsets - 1) / 2.)
                # The total width should not depend on the number of
                # clusters.
                t /= max_clu_offsets

            # Get the spike masks.
            m = masks
            # HACK: on the GPU, we get the actual masks with fract(masks)
            # since we add the relative cluster index. We need to ensure
            # that the masks is never 1.0, otherwise it is interpreted as
            # 0.
            m *= .99999
            # NOTE: we add the cluster index which is used for the
            # computation of the depth on the GPU.
            m += i

            color = tuple(_colormap(i)) + (alpha,)
            assert len(color) == 4

            # Generate the box index (one number per channel).
            box_index = _index_of(channel_ids_loc, channel_ids)
            box_index = np.repeat(box_index, n_samples)
            box_index = np.tile(box_index, n_spikes_clu)
            assert box_index.shape == (n_spikes_clu *
                                       n_channels *
                                       n_samples,)

            # Generate the waveform array.
            wave = np.transpose(wave, (0, 2, 1))
            wave = wave.reshape((n_spikes_clu * n_channels, n_samples))

            self.uplot(x=t,
                       y=wave,
                       color=color,
                       masks=m,
                       box_index=box_index,
                       data_bounds=None,
                       )