예제 #1
0
    def on_request_split_FTT(self):
        """Return the spikes enclosed by the lasso."""

        if (self.lasso.count < 3
                or not len(self.cluster_ids)):  # pragma: no cover
            return np.array([], dtype=np.int64)
        #assert len(self.channel_ids)

        id_str = [str(id) for id in self.cluster_ids]

        dlg = QtGui.QInputDialog(None)
        [cl,
         ok] = QtGui.QInputDialog.getItem(dlg,
                                          "Which cluster id should be split:",
                                          "Which cluster id should be split:",
                                          id_str, 0, False)
        if not ok:
            return np.array([], dtype=np.int64)
        cluster_index = id_str.index(cl)

        bunchs = self._get_data(self.cluster_ids)
        pos = np.vstack((bunchs[cluster_index]['x'],
                         bunchs[cluster_index]['y'])).transpose()
        # Normalize the points.
        ra = Range(self.data_bounds)
        pos = ra.apply(pos)
        # Find lassoed spikes.
        ind = self.lasso.in_polygon(pos)
        ids = bunchs[cluster_index]['spike_ids'][ind]
        self.lasso.clear()
        return np.unique(ids)
예제 #2
0
 def on_mouse_press(self, e):
     key = self._key_pressed
     if 'Control' in e.modifiers or key in map(str, range(10)):
         key = int(key.name) if key in map(str, range(10)) else None
         # Get mouse position in NDC.
         mouse_pos = self.panzoom.get_mouse_pos(e.pos)
         box_id = self.stacked.get_closest_box(mouse_pos)
         channel_id = self._permute_channels(box_id, inv=True)
         # Find the spike and cluster closest to the mouse.
         db = self._data_bounds
         # Get the information about the displayed spikes.
         if not self._waveform_times:
             return
         # Get the time coordinate of the mouse position.
         mouse_time = Range(NDC, db).apply(mouse_pos)[0][0]
         # Get the closest spike id.
         times, spike_ids, spike_clusters, channel_ids = \
             zip(*(_ for _ in self._waveform_times if channel_id in _[3]))
         i = np.argmin(np.abs(np.array(times) - mouse_time))
         # Raise the spike_click event.
         spike_id = spike_ids[i]
         cluster_id = spike_clusters[i]
         self.events.spike_click(channel_id=channel_id,
                                 spike_id=spike_id,
                                 cluster_id=cluster_id)
예제 #3
0
    def __init__(self, amplitudes=None, amplitude_name=None, duration=None):
        super(AmplitudeView, self).__init__()
        self.state_attrs += ('amplitude_name',)

        self.canvas.enable_axes()
        self.canvas.enable_lasso()
        # Ensure amplitudes is a dictionary, even if there is a single amplitude.
        if not isinstance(amplitudes, dict):
            amplitudes = {'amplitude': amplitudes}
        assert amplitudes
        self.amplitudes = amplitudes
        self.amplitude_names = list(amplitudes.keys())
        # Current amplitude type.
        self.amplitude_name = amplitude_name or self.amplitude_names[0]
        assert self.amplitude_name in amplitudes

        self.cluster_ids = ()
        self.duration = duration or 1

        # Histogram visual.
        self.hist_visual = HistogramVisual()
        self.hist_visual.transforms.add_on_gpu([
            Range(NDC, (-1, -1, 1, -1 + 2 * self.histogram_scale)), Rotate('ccw')])
        self.canvas.add_visual(self.hist_visual)

        # Scatter plot.
        self.visual = ScatterVisual()
        self.canvas.add_visual(self.visual)

        # Amplitude name.
        self.text_visual = TextVisual()
        self.canvas.add_visual(self.text_visual, exclude_origins=(self.canvas.panzoom,))
예제 #4
0
파일: trace.py 프로젝트: zsong30/phy
    def on_mouse_click(self, e):
        """Select a cluster by clicking on a spike."""
        if 'Control' in e.modifiers:
            # Get mouse position in NDC.
            box_id, _ = self.canvas.stacked.box_map(e.pos)
            channel_id = np.nonzero(self.channel_y_ranks == box_id)[0]
            # Find the spike and cluster closest to the mouse.
            db = self.data_bounds
            # Get the information about the displayed spikes.
            wt = [(t, s, c, ch) for t, s, c, ch in self._waveform_times if channel_id in ch]
            if not wt:
                return
            # Get the time coordinate of the mouse position.
            mouse_pos = self.canvas.panzoom.window_to_ndc(e.pos)
            mouse_time = Range(NDC, db).apply(mouse_pos)[0][0]
            # Get the closest spike id.
            times, spike_ids, spike_clusters, channel_ids = zip(*wt)
            i = np.argmin(np.abs(np.array(times) - mouse_time))
            # Raise the select_spike event.
            spike_id = spike_ids[i]
            cluster_id = spike_clusters[i]
            emit('select_spike', self, channel_id=channel_id,
                 spike_id=spike_id, cluster_id=cluster_id)

        if 'Shift' in e.modifiers:
            # Get mouse position in NDC.
            box_id, _ = self.canvas.stacked.box_map(e.pos)
            channel_id = int(np.nonzero(self.channel_y_ranks == box_id)[0][0])
            emit('select_channel', self, channel_id=channel_id, button=e.button)
예제 #5
0
    def on_request_split(self, sender=None):
        """Return the spikes enclosed by the lasso."""
        if (self.canvas.lasso.count < 3
                or not len(self.cluster_ids)):  # pragma: no cover
            return np.array([], dtype=np.int64)
        assert len(self.channel_ids)

        # Get the dimensions of the lassoed subplot.
        i, j = self.canvas.layout.active_box
        dim = self.grid_dim[i][j]
        dim_x, dim_y = dim.split(',')

        # Get all points from all clusters.
        pos = []
        spike_ids = []

        for cluster_id in self.cluster_ids:
            # Load all spikes.
            bunch = self.features(cluster_id,
                                  channel_ids=self.channel_ids,
                                  load_all=True)
            px = self._get_axis_data(bunch,
                                     dim_x,
                                     cluster_id=cluster_id,
                                     load_all=True)
            py = self._get_axis_data(bunch,
                                     dim_y,
                                     cluster_id=cluster_id,
                                     load_all=True)
            points = np.c_[px.data, py.data]

            # Normalize the points.
            xmin, xmax = self._get_axis_bounds(dim_x, px)
            ymin, ymax = self._get_axis_bounds(dim_y, py)
            r = Range((xmin, ymin, xmax, ymax))
            points = r.apply(points)

            pos.append(points)
            spike_ids.append(bunch.spike_ids)
        pos = np.vstack(pos)
        spike_ids = np.concatenate(spike_ids)

        # Find lassoed spikes.
        ind = self.canvas.lasso.in_polygon(pos)
        self.canvas.lasso.clear()
        return np.unique(spike_ids[ind])
예제 #6
0
파일: amplitude.py 프로젝트: LBHB/phy
 def on_mouse_click(self, e):
     """Select a time from the amplitude view to display in the trace view."""
    # from pdb import set_trace
    # set_trace()  
     if 'Shift' in e.modifiers:
         mouse_pos = self.canvas.panzoom.window_to_ndc(e.pos)
         time = Range(NDC, self.data_bounds).apply(mouse_pos)[0][0]
         emit('select_time', self, time)
예제 #7
0
파일: feature.py 프로젝트: kwikteam/phy
    def on_request_split(self):
        """Return the spikes enclosed by the lasso."""
        if (self.lasso.count < 3 or
                not len(self.cluster_ids)):  # pragma: no cover
            return np.array([], dtype=np.int64)
        assert len(self.channel_ids)

        # Get the dimensions of the lassoed subplot.
        i, j = self.lasso.box
        dim = self.grid_dim[i][j]
        dim_x, dim_y = dim.split(',')

        # Get all points from all clusters.
        pos = []
        spike_ids = []

        for cluster_id in self.cluster_ids:
            # Load all spikes.
            bunch = self.features(cluster_id,
                                  channel_ids=self.channel_ids,
                                  load_all=True)
            px = self._get_axis_data(bunch, dim_x, cluster_id=cluster_id,
                                     load_all=True)
            py = self._get_axis_data(bunch, dim_y, cluster_id=cluster_id,
                                     load_all=True)
            points = np.c_[px.data, py.data]

            # Normalize the points.
            xmin, xmax = self._get_axis_bounds(dim_x, px)
            ymin, ymax = self._get_axis_bounds(dim_y, py)
            r = Range((xmin, ymin, xmax, ymax))
            points = r.apply(points)

            pos.append(points)
            spike_ids.append(bunch.spike_ids)
        pos = np.vstack(pos)
        spike_ids = np.concatenate(spike_ids)

        # Find lassoed spikes.
        ind = self.lasso.in_polygon(pos)
        self.lasso.clear()
        return np.unique(spike_ids[ind])
예제 #8
0
def _iter_channel(positions):
    size = 100
    margin = 5
    boxes = _get_boxes(positions, keep_aspect_ratio=False)
    xmin, ymin = boxes[:, :2].min(axis=0)
    xmax, ymax = boxes[:, 2:].max(axis=0)
    x = boxes[:, [0, 2]].mean(axis=1)
    y = - boxes[:, [1, 3]].mean(axis=1)
    positions = np.c_[x, y]
    tr = [margin, margin, size - margin, size - margin]
    positions = Range(NDC, tr).apply(positions)
    for x, y in positions:
        yield x, y
예제 #9
0
    def __init__(self, amplitudes=None, amplitudes_type=None, duration=None):
        super(AmplitudeView, self).__init__()
        self.state_attrs += ('amplitudes_type', )

        self.canvas.enable_axes()
        self.canvas.enable_lasso()

        # Ensure amplitudes is a dictionary, even if there is a single amplitude.
        if not isinstance(amplitudes, dict):
            amplitudes = {'amplitude': amplitudes}
        assert amplitudes
        self.amplitudes = amplitudes

        # Rotating property amplitudes types.
        self.amplitudes_types = RotatingProperty()
        for name, value in self.amplitudes.items():
            self.amplitudes_types.add(name, value)
        # Current amplitudes type.
        self.amplitudes_types.set(amplitudes_type)
        assert self.amplitudes_type in self.amplitudes

        self.cluster_ids = ()
        self.duration = duration or 1.

        # Histogram visual.
        self.hist_visual = HistogramVisual()
        self.hist_visual.transforms.add([
            Range(NDC, (-1, -1, 1, -1 + 2 * self.histogram_scale)),
            Rotate('cw'),
            Scale((1, -1)),
            Translate((2.05, 0)),
        ])
        self.canvas.add_visual(self.hist_visual)
        self.canvas.panzoom.zoom = self.canvas.panzoom._default_zoom = (.75, 1)
        self.canvas.panzoom.pan = self.canvas.panzoom._default_pan = (-.25, 0)

        # Yellow vertical bar showing the selected time interval.
        self.patch_visual = PatchVisual(primitive_type='triangle_fan')
        self.patch_visual.inserter.insert_vert(
            '''
            const float MIN_INTERVAL_SIZE = 0.01;
            uniform float u_interval_size;
        ''', 'header')
        self.patch_visual.inserter.insert_vert(
            '''
            gl_Position.y = pos_orig.y;

            // The following is used to ensure that (1) the bar width increases with the zoom level
            // but also (2) there is a minimum absolute width so that the bar remains visible
            // at low zoom levels.
            float w = max(MIN_INTERVAL_SIZE, u_interval_size * u_zoom.x);
            // HACK: the z coordinate is used to store 0 or 1, depending on whether the current
            // vertex is on the left or right edge of the bar.
            gl_Position.x += w * (-1 + 2 * int(a_position.z == 0));

        ''', 'after_transforms')
        self.canvas.add_visual(self.patch_visual)

        # Scatter plot.
        self.visual = ScatterVisual()
        self.canvas.add_visual(self.visual)
        self.canvas.panzoom.set_constrain_bounds((-2, -2, +2, +2))