예제 #1
0
    def on_request_split_FTT(self):
        """Return the spikes enclosed by the lasso."""

        if (self.lasso.count < 3
                or not len(self.cluster_ids)):  # pragma: no cover
            return np.array([], dtype=np.int64)
        #assert len(self.channel_ids)

        id_str = [str(id) for id in self.cluster_ids]

        dlg = QtGui.QInputDialog(None)
        [cl,
         ok] = QtGui.QInputDialog.getItem(dlg,
                                          "Which cluster id should be split:",
                                          "Which cluster id should be split:",
                                          id_str, 0, False)
        if not ok:
            return np.array([], dtype=np.int64)
        cluster_index = id_str.index(cl)

        bunchs = self._get_data(self.cluster_ids)
        pos = np.vstack((bunchs[cluster_index]['x'],
                         bunchs[cluster_index]['y'])).transpose()
        # Normalize the points.
        ra = Range(self.data_bounds)
        pos = ra.apply(pos)
        # Find lassoed spikes.
        ind = self.lasso.in_polygon(pos)
        ids = bunchs[cluster_index]['spike_ids'][ind]
        self.lasso.clear()
        return np.unique(ids)
예제 #2
0
    def on_request_split(self, sender=None):
        """Return the spikes enclosed by the lasso."""
        if (self.canvas.lasso.count < 3
                or not len(self.cluster_ids)):  # pragma: no cover
            return np.array([], dtype=np.int64)
        assert len(self.channel_ids)

        # Get the dimensions of the lassoed subplot.
        i, j = self.canvas.layout.active_box
        dim = self.grid_dim[i][j]
        dim_x, dim_y = dim.split(',')

        # Get all points from all clusters.
        pos = []
        spike_ids = []

        for cluster_id in self.cluster_ids:
            # Load all spikes.
            bunch = self.features(cluster_id,
                                  channel_ids=self.channel_ids,
                                  load_all=True)
            px = self._get_axis_data(bunch,
                                     dim_x,
                                     cluster_id=cluster_id,
                                     load_all=True)
            py = self._get_axis_data(bunch,
                                     dim_y,
                                     cluster_id=cluster_id,
                                     load_all=True)
            points = np.c_[px.data, py.data]

            # Normalize the points.
            xmin, xmax = self._get_axis_bounds(dim_x, px)
            ymin, ymax = self._get_axis_bounds(dim_y, py)
            r = Range((xmin, ymin, xmax, ymax))
            points = r.apply(points)

            pos.append(points)
            spike_ids.append(bunch.spike_ids)
        pos = np.vstack(pos)
        spike_ids = np.concatenate(spike_ids)

        # Find lassoed spikes.
        ind = self.canvas.lasso.in_polygon(pos)
        self.canvas.lasso.clear()
        return np.unique(spike_ids[ind])
예제 #3
0
파일: feature.py 프로젝트: kwikteam/phy
    def on_request_split(self):
        """Return the spikes enclosed by the lasso."""
        if (self.lasso.count < 3 or
                not len(self.cluster_ids)):  # pragma: no cover
            return np.array([], dtype=np.int64)
        assert len(self.channel_ids)

        # Get the dimensions of the lassoed subplot.
        i, j = self.lasso.box
        dim = self.grid_dim[i][j]
        dim_x, dim_y = dim.split(',')

        # Get all points from all clusters.
        pos = []
        spike_ids = []

        for cluster_id in self.cluster_ids:
            # Load all spikes.
            bunch = self.features(cluster_id,
                                  channel_ids=self.channel_ids,
                                  load_all=True)
            px = self._get_axis_data(bunch, dim_x, cluster_id=cluster_id,
                                     load_all=True)
            py = self._get_axis_data(bunch, dim_y, cluster_id=cluster_id,
                                     load_all=True)
            points = np.c_[px.data, py.data]

            # Normalize the points.
            xmin, xmax = self._get_axis_bounds(dim_x, px)
            ymin, ymax = self._get_axis_bounds(dim_y, py)
            r = Range((xmin, ymin, xmax, ymax))
            points = r.apply(points)

            pos.append(points)
            spike_ids.append(bunch.spike_ids)
        pos = np.vstack(pos)
        spike_ids = np.concatenate(spike_ids)

        # Find lassoed spikes.
        ind = self.lasso.in_polygon(pos)
        self.lasso.clear()
        return np.unique(spike_ids[ind])