Example #1
0
    def __init__(self, **kwargs):
        """Initialization"""

        self.nb_samples = 1024
        self.dtype = 'float32'
        self.sampling_rate = 20000
        self.probe_path = 'probe.prb'
        self.probe = load_probe(self.probe_path)
        self.nb_channels = self.probe.nb_channels
        self.export_peaks = True

        self._params_pipe = Pipe()
        self._number_pipe = Pipe()
        self._data_pipe = Pipe()
        self._mads_pipe = Pipe()
        self._peaks_pipe = Pipe()
        self._qt_process = GUIProcess(self._params_pipe,
                                      self._number_pipe,
                                      self._data_pipe,
                                      self._mads_pipe,
                                      self._peaks_pipe,
                                      probe_path=self.probe_path)

        self._is_mad_reception_blocking = False
        self._is_peak_reception_blocking = False
        self._qt_process.start()
        self.number = 0

        self._params_pipe[1].send({
            'nb_samples': self.nb_samples,
            'sampling_rate': self.sampling_rate,
        })

        return
    def __init__(self, **kwargs):
        """Initialization of the object.

        Parameters:
            data_path: string
            dtype: string
            nb_channels: integer
            nb_samples: integer
            sampling_rate: float
            is_realistic: boolean
            speed_factor: float
            nb_replay: integer
            probe_path: string

        See also:
            circusort.block.Block
        """

        Block.__init__(self, **kwargs)
        self.add_output('data', structure='dict')

        # Lines useful to remove PyCharm warnings.
        self.data_path = self.data_path
        self.dtype = self.dtype
        self.nb_channels = self.nb_channels
        self.nb_samples = self.nb_samples
        self.sampling_rate = self.sampling_rate
        self.is_realistic = self.is_realistic
        self.speed_factor = self.speed_factor
        self.nb_replay = self.nb_replay
        self.offset = self.offset
        self.probe_path = self.probe_path
        self.zero_channels = self.zero_channels
        self._output_dtype = 'float32'
        self._quantum_size = self.gain
        self._quantum_offset = float(np.iinfo('int16').min)
        self._buffer_rate = float(self.nb_samples) / self.sampling_rate

        self._absolute_start_time = None
        self._absolute_end_time = None

        if self.probe_path is not None:
            self.probe = load_probe(self.probe_path, logger=self.log)
            # Log info message.
            string = "{} reads the probe layout"
            message = string.format(self.name)
            self.log.info(message)
        else:
            self.probe = None

        if self.zero_channels is not None:
            if np.iterable(self.zero_channels):
                self.zero_channels = np.array(self.zero_channels)
            else:
                self.zero_channels = np.array([self.zero_channels])
    def __init__(self, probe_path=None, params=None):
        app.Canvas.__init__(self, title="Vispy canvas2")

        self.probe = load_probe(probe_path)
        # self.channels = params['channels']
        self.nb_channels = self.probe.nb_channels
        self.init_time = 0

        box_corner_positions = np.array(
            [[+0.9, +0.9], [-0.9, +0.9], [-0.9, -0.9], [+0.9, -0.9],
             [+0.9, +0.9]],
            dtype=np.float32)

        self._box_program = gloo.Program(vert=BOX_VERT_SHADER,
                                         frag=BOX_FRAG_SHADER)
        self._box_program['a_position'] = box_corner_positions

        # Rates Shaders
        self.nb_cells, self.nb_cells_selected, self.time_max = 0, 0, 0
        self.rate_vector = np.zeros(100).astype(np.float32)
        self.color_rates, self.index_bar, self.index_time = 0, 0, 0
        self.rate_mat = np.zeros((self.nb_cells, 30), dtype=np.float32)
        self.scale_x = 20
        self.nb_cells, self.time_max = 0, 0
        self.index_cell, self.index_cell_selec = 0, 0
        self.list_unselected_cells = list(range(self.nb_cells))
        self.list_selected_cells = []
        self.initialized = True

        self.rates_program = gloo.Program(vert=RATES_VERT_SHADER,
                                          frag=RATES_FRAG_SHADER)
        self.rates_program['a_rate_value'] = self.rate_vector
        # self.rates_program['a_rate_value'] = y_data
        # self.rates_program['a_color'] = color_bar
        # self.rates_program['a_index_time'] = index_time
        # self.rates_program['a_index_cell'] = index_cell
        # self.rates_program['a_index_x'] = index_x

        # self.rates_program['u_time_max'] = self.time_max
        # self.rates_program['u_nb_points'] = self.nb_points

        gloo.set_viewport(0, 0, *self.physical_size)
        gloo.set_state(clear_color='black',
                       blend=True,
                       blend_func=('src_alpha', 'one_minus_src_alpha'))

        # Final details.

        gloo.set_viewport(0, 0, *self.physical_size)
        gloo.set_state(clear_color='black',
                       blend=True,
                       blend_func=('src_alpha', 'one_minus_src_alpha'))
Example #4
0
    def __init__(self, **kwargs):

        Block.__init__(self, **kwargs)
        if self.probe == None:
            self.log.error(
                '{n}: the probe file must be specified!'.format(n=self.name))
        else:
            self.probe = load_probe(self.probe, radius=None, logger=self.log)
            self.log.info('{n} reads the probe layout'.format(n=self.name))

        self.positions = self.probe.positions
        self.add_input('peaks')
        self.mpl_display = True
    def __init__(self, **kwargs):

        Block.__init__(self, **kwargs)

        if self.probe is None:
            # TODO improve the following line.
            raise NotImplementedError()
        else:
            self.probe = load_probe(self.probe, radius=None, logger=self.log)

        self.add_input('peaks')

        # Flag to call plot from the main thread (c.f. circusort.cli.process).
        self.mpl_display = True
    def __init__(self, **kwargs):
        """Initialize template updater.

        Arguments:
            probe_path: string
            radius: none | float (optional)
            cc_merge: float (optional)
            cc_mixture: none | float (optional)
            templates_path: none | string (optional)
            overlaps_path: none | string (optional)
            precomputed_template_paths: none | list (optional)
            sampling_rate: float (optional)
            nb_samples: integer (optional)
        """

        Block.__init__(self, **kwargs)

        # The following lines are useful to avoid some PyCharm's warnings.
        self.probe_path = self.probe_path
        self.radius = self.radius
        self.cc_merge = self.cc_merge
        self.cc_mixture = self.cc_mixture
        self.templates_path = self.templates_path
        self.overlaps_path = self.overlaps_path
        self.precomputed_template_paths = self.precomputed_template_paths
        self.sampling_rate = self.sampling_rate
        self.nb_samples = self.nb_samples
        self.skip_overlaps = self.skip_overlaps

        # Initialize private attributes.
        if self.probe_path is None:
            self.probe = None
            # Log error message.
            string = "{}: the probe file must be specified!"
            message = string.format(self.name)
            self.log.error(message)
        else:
            self.probe = load_probe(self.probe_path, radius=self.radius, logger=self.log)
            # Log info message.
            string = "{} reads the probe layout"
            message = string.format(self.name)
            self.log.info(message)
        self._template_store = None
        self._template_dictionary = None
        self._overlap_store = None
        self._two_components = None

        self.add_input('templates', structure='dict')
        self.add_output('updater', structure='dict')
Example #7
0
    def __init__(self, probe_path=None, params=None):
        app.Canvas.__init__(self, title="Rate view")

        self.probe = load_probe(probe_path)
        # self.channels = params['channels']
        self.nb_channels = self.probe.nb_channels
        self.init_time = 0

        box_corner_positions = np.array(
            [[+0.9, +0.9], [-0.9, +0.9], [-0.9, -0.9], [+0.9, -0.9],
             [+0.9, +0.9]],
            dtype=np.float32)

        self._box_program = gloo.Program(vert=BOX_VERT_SHADER,
                                         frag=BOX_FRAG_SHADER)
        self._box_program['a_position'] = box_corner_positions

        # Rates Shaders
        self.x_value = 0
        self.nb_cells = 0
        self.rate_mat = np.zeros((self.nb_cells, 30), dtype=np.float32)
        self.rate_vector = np.zeros(100).astype(np.float32)
        self.index_time, self.index_cell = 0, 0
        self.color_rates = np.array([[1, 1, 1]])

        self.time_window = 50
        self.time_window_from_start = True
        self.list_selected_cells = []
        self.selected_cells_vector = 0
        self.rate_mat_cum, self.rate_vector_cum = 0, 0
        self.u_scale = np.array([[1.0, 1.0]]).astype(np.float32)
        self.initialized = False
        self.cum_plots = False

        self.rates_program = gloo.Program(vert=RATES_VERT_SHADER,
                                          frag=RATES_FRAG_SHADER)
        self.rates_program['a_rate_value'] = self.rate_vector

        # Final details.

        gloo.set_viewport(0, 0, *self.physical_size)
        gloo.set_state(clear_color='black',
                       blend=True,
                       blend_func=('src_alpha', 'one_minus_src_alpha'))
    def __init__(self,
                 data_path,
                 probe_path,
                 sampling_rate=20e+3,
                 dtype='int16',
                 gain=0.1042):

        self._data_path = data_path
        self._probe_path = probe_path
        self._sampling_rate = sampling_rate
        self._dtype = dtype
        self._gain = gain

        self._probe = load_probe(self._probe_path)

        self._data = load_datafile(self._data_path,
                                   self._sampling_rate,
                                   self._nb_channels,
                                   self._dtype,
                                   gain=gain)
Example #9
0
    def __init__(self, probe_path=None, params=None):
        app.Canvas.__init__(self, title="ISI view")

        self.probe = load_probe(probe_path)
        # self.channels = params['channels']
        self.nb_channels = self.probe.nb_channels
        self.init_time = 0

        box_corner_positions = np.array(
            [[+0.9, +0.9], [-0.9, +0.9], [-0.9, -0.9], [+0.9, -0.9],
             [+0.9, +0.9]],
            dtype=np.float32)

        self._box_program = gloo.Program(vert=BOX_VERT_SHADER,
                                         frag=BOX_FRAG_SHADER)
        self._box_program['a_position'] = box_corner_positions

        self.initialized = False
        self.nb_points, self.nb_cells = 0, 0
        self.list_selected_isi, self.selected_isi_vector = [], 0
        self.list_isi, self.isi_vector = 0, 5
        self.index_x, self.index_cell, self.color_isi = 1, 0, 0
        self.isi_mat, self.isi_smooth = 0, 0
        self.u_scale = [1.0, 1.0]

        self.isi_program = gloo.Program(vert=ISI_VERT_SHADER,
                                        frag=ISI_FRAG_SHADER)
        self.isi_program['a_isi_value'] = self.isi_vector
        self.isi_program['a_index_x'] = self.index_x
        self.isi_program['u_scale'] = self.u_scale

        # Final details.
        gloo.set_viewport(0, 0, *self.physical_size)
        gloo.set_state(clear_color='black',
                       blend=True,
                       blend_func=('src_alpha', 'one_minus_src_alpha'))
    def __init__(self, probe_path=None, params=None):

        app.Canvas.__init__(self, title="Vispy canvas", keys="interactive")

        probe = load_probe(probe_path)
        nb_buffers_per_signal = int(
            np.ceil((params['time']['max'] * 1e-3) * params['sampling_rate'] /
                    float(params['nb_samples'])))
        self._time_max = (float(nb_buffers_per_signal * params['nb_samples']) /
                          params['sampling_rate']) * 1e+3
        self._time_min = params['time']['min']

        # Signals.

        # Number of signals.
        nb_signals = probe.nb_channels
        # Number of samples per buffer.
        self._nb_samples_per_buffer = params['nb_samples']
        # Number of samples per signal.
        nb_samples_per_signal = nb_buffers_per_signal * self._nb_samples_per_buffer
        # Generate the signal values.
        self._signal_values = np.zeros((nb_signals, nb_samples_per_signal),
                                       dtype=np.float32)
        # Color of each vertex.
        # TODO: make it more efficient by using a GLSL-based color map and the index.
        signal_colors = 0.75 * np.ones((nb_signals, 3), dtype=np.float32)
        signal_colors = np.repeat(signal_colors,
                                  repeats=nb_samples_per_signal,
                                  axis=0)
        signal_indices = np.repeat(np.arange(0, nb_signals, dtype=np.float32),
                                   repeats=nb_samples_per_signal)
        signal_positions = np.c_[np.repeat(probe.x.astype(np.float32),
                                           repeats=nb_samples_per_signal),
                                 np.repeat(probe.y.astype(np.float32),
                                           repeats=nb_samples_per_signal), ]
        sample_indices = np.tile(np.arange(0,
                                           nb_samples_per_signal,
                                           dtype=np.float32),
                                 reps=nb_signals)
        # Define GLSL program.
        self._signal_program = gloo.Program(vert=SIGNAL_VERT_SHADER,
                                            frag=SIGNAL_FRAG_SHADER)
        self._signal_program['a_signal_index'] = signal_indices
        self._signal_program['a_signal_position'] = signal_positions
        self._signal_program['a_signal_value'] = self._signal_values.reshape(
            -1, 1)
        self._signal_program['a_signal_color'] = signal_colors
        self._signal_program['a_sample_index'] = sample_indices
        self._signal_program['u_nb_samples_per_signal'] = nb_samples_per_signal
        self._signal_program['u_x_min'] = probe.x_limits[0]
        self._signal_program['u_x_max'] = probe.x_limits[1]
        self._signal_program['u_y_min'] = probe.y_limits[0]
        self._signal_program['u_y_max'] = probe.y_limits[1]
        self._signal_program[
            'u_d_scale'] = probe.minimum_interelectrode_distance
        self._signal_program[
            'u_t_scale'] = self._time_max / params['time']['init']
        self._signal_program['u_v_scale'] = params['voltage']['init']

        # MADs.

        # Generate the MADs values.
        mads_indices = np.arange(0, nb_signals, dtype=np.float32)
        mads_indices = np.repeat(mads_indices,
                                 repeats=2 * (nb_buffers_per_signal + 1))
        mads_positions = np.c_[np.repeat(probe.x.astype(np.float32),
                                         repeats=2 *
                                         (nb_buffers_per_signal + 1)),
                               np.repeat(probe.y.astype(np.float32),
                                         repeats=2 *
                                         (nb_buffers_per_signal + 1)), ]
        self._mads_values = np.zeros(
            (nb_signals, 2 * (nb_buffers_per_signal + 1)), dtype=np.float32)
        mads_colors = np.array([0.75, 0.0, 0.0], dtype=np.float32)
        mads_colors = np.tile(mads_colors, reps=(nb_signals, 1))
        mads_colors = np.repeat(mads_colors,
                                repeats=2 * (nb_buffers_per_signal + 1),
                                axis=0)
        sample_indices = np.arange(0,
                                   nb_buffers_per_signal + 1,
                                   dtype=np.float32)
        sample_indices = np.repeat(sample_indices, repeats=2)
        sample_indices = self._nb_samples_per_buffer * sample_indices
        sample_indices = np.tile(sample_indices, reps=nb_signals)
        # Define GLSL program.
        self._mads_program = gloo.Program(vert=MADS_VERT_SHADER,
                                          frag=MADS_FRAG_SHADER)
        self._mads_program['a_mads_index'] = mads_indices
        self._mads_program['a_mads_position'] = mads_positions
        self._mads_program['a_mads_value'] = self._mads_values.reshape(-1, 1)
        self._mads_program['a_mads_color'] = mads_colors
        self._mads_program['a_sample_index'] = sample_indices
        self._mads_program['u_nb_samples_per_signal'] = nb_samples_per_signal
        self._mads_program['u_x_min'] = probe.x_limits[0]
        self._mads_program['u_x_max'] = probe.x_limits[1]
        self._mads_program['u_y_min'] = probe.y_limits[0]
        self._mads_program['u_y_max'] = probe.y_limits[1]
        self._mads_program['u_d_scale'] = probe.minimum_interelectrode_distance
        self._mads_program[
            'u_t_scale'] = self._time_max / params['time']['init']
        self._mads_program['u_v_scale'] = params['voltage']['init']

        # Peaks.

        # # Define GLSL program.
        # self._peaks_program = gloo.Program(vert=PEAKS_VERT_SHADER, frag=PEAKS_FRAG_SHADER)
        # self._peaks_program['a_peaks_value'] = self._peaks_value(-1, 1)
        # # TODO complete.

        # Boxes.

        box_indices = np.repeat(np.arange(0, nb_signals, dtype=np.float32),
                                repeats=5)
        box_positions = np.c_[
            np.repeat(probe.x.astype(np.float32), repeats=5),
            np.repeat(probe.y.astype(np.float32), repeats=5), ]
        corner_positions = np.c_[
            np.tile(np.array([+1.0, -1.0, -1.0, +1.0, +1.0], dtype=np.float32),
                    reps=nb_signals),
            np.tile(np.array([+1.0, +1.0, -1.0, -1.0, +1.0], dtype=np.float32),
                    reps=nb_signals), ]
        # Define GLSL program.
        self._box_program = gloo.Program(vert=BOX_VERT_SHADER,
                                         frag=BOX_FRAG_SHADER)
        self._box_program['a_box_index'] = box_indices
        self._box_program['a_box_position'] = box_positions
        self._box_program['a_corner_position'] = corner_positions
        self._box_program['u_x_min'] = probe.x_limits[0]
        self._box_program['u_x_max'] = probe.x_limits[1]
        self._box_program['u_y_min'] = probe.y_limits[0]
        self._box_program['u_y_max'] = probe.y_limits[1]
        self._box_program['u_d_scale'] = probe.minimum_interelectrode_distance

        # Final details.

        gloo.set_viewport(0, 0, *self.physical_size)

        gloo.set_state(clear_color='black',
                       blend=True,
                       blend_func=('src_alpha', 'one_minus_src_alpha'))
Example #11
0
    def __init__(self, **kwargs):

        Block.__init__(self, **kwargs)

        # The following lines are useful to avoid some PyCharm's warnings.
        self.alignment = self.alignment
        self.sampling_rate = self.sampling_rate
        self.spike_width = self.spike_width
        self.spike_jitter = self.spike_jitter
        self.spike_sigma = self.spike_sigma
        self.nb_waveforms = self.nb_waveforms
        self.nb_waveforms_tracking = self.nb_waveforms_tracking
        self.channels = self.channels
        self.probe_path = self.probe_path
        self.radius = self.radius
        self.m_ratio = self.m_ratio
        self.noise_thr = self.noise_thr
        self.dispersion = self.dispersion
        self.two_components = self.two_components
        self.decay_factor = self.decay_factor / float(self.sampling_rate)
        self.mu = self.mu
        self.sub_dim = self.sub_dim
        self.epsilon = self.epsilon
        self.theta = self.theta
        self.smoothing_factor = int(0.2e-3 * self.sampling_rate)
        if np.mod(self.smoothing_factor, 2) == 0:
            self.smoothing_factor += 1
        self.tracking = self.tracking
        self.safety_time = self.safety_time
        self.compression = self.compression
        self.local_merges = self.local_merges
        self.debug_plots = self.debug_plots
        self.debug_ground_truth_templates = self.debug_ground_truth_templates
        self.debug_file_format = self.debug_file_format
        self.debug_data = self.debug_data
        self.smart_select = self.smart_select
        self.hanning_filtering = self.hanning_filtering

        if self.probe_path is None:
            # Log error message.
            string = "{}: the probe file must be specified!"
            message = string.format(self.name)
            self.log.error(message)
        else:
            self.probe = load_probe(self.probe_path,
                                    radius=self.radius,
                                    logger=self.log)
            # Log info message.
            string = "{} reads the probe layout"
            message = string.format(self.name)
            self.log.info(message)

        for directory in [self.debug_plots, self.debug_data]:
            if directory is not None:
                if os.path.exists(directory):
                    shutil.rmtree(directory)
                os.makedirs(directory)
                # Log info message.
                string = "{} creates directory {}"
                message = string.format(self.name, directory)
                self.log.info(message)

        self.add_input('data', structure='dict')
        self.add_input('pcs', structure='dict')
        self.add_input('peaks', structure='dict')
        self.add_output('templates', structure='dict')

        self.thresholds = None

        self._dtype = None
        self._nb_channels = None
        self._nb_samples = None

        self.inodes = np.zeros(self.probe.total_nb_channels, dtype=np.int32)
        self.inodes[self.probe.nodes] = np.argsort(self.probe.nodes)
    def __init__(self, params_pipe, number_pipe, data_pipe, mads_pipe, peaks_pipe,
                 probe_path=None, screen_resolution=None):

        QMainWindow.__init__(self)

        # Receive parameters.
        params = params_pipe[0].recv()
        self.probe = load_probe(probe_path)
        self._nb_samples = params['nb_samples']
        self._sampling_rate = params['sampling_rate']
        self._display_list = list(range(self.probe.nb_channels))

        self._params = {
            'nb_samples': self._nb_samples,
            'sampling_rate': self._sampling_rate,
            'time': {
                'min': 10.0,  # ms
                'max': 1000.0,  # ms
                'init': 100.0,  # ms
            },
            'voltage': {
                'min': 10.0,  # µV
                'max': 10e+3,  # µV
                'init': 20.0,  # µV
            },
            'mads': {
                'min': 0.0,  # µV
                'max': 100,  # µV
                'init': 3,  # µV
            },
            'channels': self._display_list
        }

        self._canvas = TraceCanvas(probe_path=probe_path, params=self._params)

        central_widget = self._canvas.native

        # Create controls widgets.
        label_time = QLabel()
        label_time.setText(u"time")
        label_time_unit = QLabel()
        label_time_unit.setText(u"ms")

        self._dsp_time = QDoubleSpinBox()
        self._dsp_time.setMinimum(self._params['time']['min'])
        self._dsp_time.setMaximum(self._params['time']['max'])
        self._dsp_time.setValue(self._params['time']['init'])
        self._dsp_time.valueChanged.connect(self._on_time_changed)

        label_display_mads = QLabel()
        label_display_mads.setText(u"Display Mads")
        self._display_mads = QCheckBox()
        self._display_mads.stateChanged.connect(self._on_mads_display)

        label_display_peaks = QLabel()
        label_display_peaks.setText(u"Display Peaks")
        self._display_peaks = QCheckBox()
        self._display_peaks.stateChanged.connect(self._on_peaks_display)

        label_mads = QLabel()
        label_mads.setText(u"Mads")
        label_mads_unit = QLabel()
        label_mads_unit.setText(u"unit")
        self._dsp_mads = QDoubleSpinBox()
        self._dsp_mads.setMinimum(self._params['mads']['min'])
        self._dsp_mads.setMaximum(self._params['mads']['max'])
        self._dsp_mads.setValue(self._params['mads']['init'])
        self._dsp_mads.valueChanged.connect(self._on_mads_changed)

        label_voltage = QLabel()
        label_voltage.setText(u"voltage")
        label_voltage_unit = QLabel()
        label_voltage_unit.setText(u"µV")
        self._dsp_voltage = QDoubleSpinBox()
        self._dsp_voltage.setMinimum(self._params['voltage']['min'])
        self._dsp_voltage.setMaximum(self._params['voltage']['max'])
        self._dsp_voltage.setValue(self._params['voltage']['init'])
        self._dsp_voltage.valueChanged.connect(self._on_voltage_changed)

        # Color spikes
        self._color_spikes = QCheckBox()
        self._color_spikes.setText('See Spikes color')
        self._color_spikes.setCheckState(Qt.Checked)
        self._color_spikes.stateChanged.connect(self.display_spikes_color)



        # self._selection_channels.setGeometry(QtCore.QRect(10, 10, 211, 291))

        spacer = QSpacerItem(20, 40, QSizePolicy.Minimum, QSizePolicy.Expanding)

        # Create controls grid.
        grid = QGridLayout()
        # # Add time row.
        grid.addWidget(label_time, 0, 0)
        grid.addWidget(self._dsp_time, 0, 1)
        grid.addWidget(label_time_unit, 0, 2)
        # # Add voltage row.
        grid.addWidget(label_voltage, 1, 0)
        grid.addWidget(self._dsp_voltage, 1, 1)
        grid.addWidget(label_voltage_unit, 1, 2)
        # # Add Mads widgets

        grid.addWidget(label_display_mads, 3, 0)
        grid.addWidget(self._display_mads, 3, 1)

        grid.addWidget(label_mads, 4, 0)
        grid.addWidget(self._dsp_mads, 4, 1)
        grid.addWidget(label_mads_unit, 4, 2)

        grid.addWidget(self._color_spikes, 5, 0)

        # # Add spacer.
        grid.addItem(spacer)

        # # Create info group.
        controls_group = QGroupBox()
        controls_group.setLayout(grid)


        self._selection_channels = QListWidget()
        self._selection_channels.setSelectionMode(
            QAbstractItemView.ExtendedSelection
        )

        for i in range(self.probe.nb_channels):
            item = QListWidgetItem("Channel %i" % i)
            self._selection_channels.addItem(item)
            self._selection_channels.item(i).setSelected(True)

        def add_channel():
            items = self._selection_channels.selectedItems()
            self._display_list = []
            for i in range(len(items)):
                self._display_list.append(i)
            self._on_channels_changed()

        # self._selection_channels.itemClicked.connect(add_channel)

        nb_channel = self.probe.nb_channels
        self._selection_channels.itemSelectionChanged.connect(lambda: self.selected_channels(nb_channel))

        # Create info grid.
        channels_grid = QGridLayout()
        # # Add Channel selection
        # grid.addWidget(label_selection, 3, 0)
        channels_grid.addWidget(self._selection_channels, 0, 1)

        # # Add spacer.
        channels_grid.addItem(spacer)

        # Create controls group.
        channels_group = QGroupBox()
        channels_group.setLayout(channels_grid)

        # # Create controls dock.
        channels_dock = QDockWidget()
        channels_dock.setWidget(channels_group)
        channels_dock.setWindowTitle("Channels selection")

        # # Create controls dock.
        control_dock = QDockWidget()
        control_dock.setWidget(controls_group)
        control_dock.setWindowTitle("Controls")

        # Create info widgets.
        label_time = QLabel()
        label_time.setText(u"time")
        self._label_time_value = QLineEdit()
        self._label_time_value.setText(u"0")
        self._label_time_value.setReadOnly(True)
        self._label_time_value.setAlignment(Qt.AlignRight)
        label_time_unit = QLabel()
        label_time_unit.setText(u"s")
        info_buffer_label = QLabel()
        info_buffer_label.setText(u"buffer")
        self._info_buffer_value_label = QLineEdit()
        self._info_buffer_value_label.setText(u"0")
        self._info_buffer_value_label.setReadOnly(True)
        self._info_buffer_value_label.setAlignment(Qt.AlignRight)
        info_buffer_unit_label = QLabel()
        info_buffer_unit_label.setText(u"")
        info_probe_label = QLabel()
        info_probe_label.setText(u"probe")
        info_probe_value_label = QLineEdit()
        info_probe_value_label.setText(u"{}".format(probe_path))
        info_probe_value_label.setReadOnly(True)
        # TODO place the following info in another grid?
        info_probe_unit_label = QLabel()
        info_probe_unit_label.setText(u"")

        info_spacer = QSpacerItem(20, 40, QSizePolicy.Minimum, QSizePolicy.Expanding)

        # Create info grid.
        info_grid = QGridLayout()
        # # Time row.
        info_grid.addWidget(label_time, 0, 0)
        info_grid.addWidget(self._label_time_value, 0, 1)
        info_grid.addWidget(label_time_unit, 0, 2)
        # # Buffer row.
        info_grid.addWidget(info_buffer_label, 1, 0)
        info_grid.addWidget(self._info_buffer_value_label, 1, 1)
        info_grid.addWidget(info_buffer_unit_label, 1, 2)
        # # Probe row.
        info_grid.addWidget(info_probe_label, 2, 0)
        info_grid.addWidget(info_probe_value_label, 2, 1)
        info_grid.addWidget(info_probe_unit_label, 2, 2)
        # # Spacer.
        info_grid.addItem(info_spacer)

        # Create info group.
        info_group = QGroupBox()
        info_group.setLayout(info_grid)

        # Create info dock.
        info_dock = QDockWidget()
        info_dock.setWidget(info_group)
        info_dock.setWindowTitle("Info")

        # Create thread.
        thread = Thread(number_pipe, data_pipe, mads_pipe, peaks_pipe)
        thread.number_signal.connect(self._number_callback)
        thread.reception_signal.connect(self._reception_callback)
        thread.start()

        # Add dockable windows.
        self.addDockWidget(Qt.LeftDockWidgetArea, control_dock)
        self.addDockWidget(Qt.LeftDockWidgetArea, info_dock)
        self.addDockWidget(Qt.LeftDockWidgetArea, channels_dock)
        # Set central widget.
        self.setCentralWidget(central_widget)
        # Set window size.
        if screen_resolution is not None:
            screen_width = screen_resolution.width()
            screen_height = screen_resolution.height()
            self.resize(screen_width, screen_height)
        # Set window title.
        self.setWindowTitle("SpyKING Circus ORT - Read 'n' Qt display")

        print(" ")  # TODO remove?
Example #13
0
    def __init__(self, params_pipe, number_pipe, templates_pipe, spikes_pipe,
                 probe_path=None, screen_resolution=None):

        QMainWindow.__init__(self)

        # Receive parameters.
        params = params_pipe[0].recv()
        self.probe = load_probe(probe_path)
        self._nb_samples = params['nb_samples']
        self._sampling_rate = params['sampling_rate']
        self._display_list = []

        self._params = {
            'nb_samples': self._nb_samples,
            'sampling_rate': self._sampling_rate,
            'time': {
                'min': 10.0,  # ms
                'max': 100.0,  # ms
                'init': 100.0,  # ms
            },
            'voltage': {
                'min': -200,  # µV
                'max': 20e+1,  # µV
                'init': 50.0,  # µV
            },
            'templates': self._display_list
        }

        self._canvas_mea = MEACanvas(probe_path=probe_path, params=self._params)
        self._canvas_template = TemplateCanvas(probe_path=probe_path, params=self._params)
        self._canvas_rate = RateCanvas(probe_path=probe_path, params=self._params)
        self._canvas_isi = ISICanvas(probe_path=probe_path, params=self._params)

        self.cells = Cells({})
        self._nb_buffer = 0

        # TODO ISI
        self.isi_bin_width, self.isi_x_max = 2, 25.0

        canvas_template_widget = self._canvas_template.native
        canvas_mea = self._canvas_mea.native
        canvas_rate = self._canvas_rate.native
        canvas_isi = self._canvas_isi.native

        # Create controls widgets.
        label_time = QLabel()
        label_time.setText(u"time")
        label_time_unit = QLabel()
        label_time_unit.setText(u"ms")

        self._dsp_time = QDoubleSpinBox()
        self._dsp_time.setMinimum(self._params['time']['min'])
        self._dsp_time.setMaximum(self._params['time']['max'])
        self._dsp_time.setValue(self._params['time']['init'])
        self._dsp_time.valueChanged.connect(self._on_time_changed)

        label_voltage = QLabel()
        label_voltage.setText(u"voltage")
        label_voltage_unit = QLabel()
        label_voltage_unit.setText(u"µV")
        self._dsp_voltage = QDoubleSpinBox()
        self._dsp_voltage.setMinimum(self._params['voltage']['min'])
        self._dsp_voltage.setMaximum(self._params['voltage']['max'])
        self._dsp_voltage.setValue(self._params['voltage']['init'])
        self._dsp_voltage.valueChanged.connect(self._on_voltage_changed)

        label_binsize = QLabel()
        label_binsize.setText(u"Bin size")
        label_binsize_unit = QLabel()
        label_binsize_unit.setText(u"second")
        self._dsp_binsize = QDoubleSpinBox()
        self._dsp_binsize.setRange(0.1, 10)
        self._dsp_binsize.setSingleStep(0.1)
        self.bin_size = 1
        self._dsp_binsize.setValue(self.bin_size)
        self._dsp_binsize.valueChanged.connect(self._on_binsize_changed)

        label_zoomrates = QLabel()
        label_zoomrates.setText(u'Zoom rates')
        self._zoom_rates = QDoubleSpinBox()
        self._zoom_rates.setRange(1, 50)
        self._zoom_rates.setSingleStep(0.1)
        self._zoom_rates.setValue(1)
        self._zoom_rates.valueChanged.connect(self._on_zoomrates_changed)

        label_time_window = QLabel()
        label_time_window.setText(u'Time window rates')
        label_time_window_unit = QLabel()
        label_time_window_unit.setText(u'second')
        self._dsp_tw_rate = QDoubleSpinBox()
        self._dsp_tw_rate.setRange(1, 50)
        self._dsp_tw_rate.setSingleStep(self.bin_size)
        self._dsp_tw_rate.setValue(50 * self.bin_size)
        self._dsp_tw_rate.valueChanged.connect(self._on_time_window_changed)

        label_tw_from_start = QLabel()
        label_tw_from_start.setText('Time scale from start')
        self._tw_from_start = QCheckBox()
        self._tw_from_start.setChecked(True)

        self._selection_templates = QTableWidget()
        self._selection_templates.setSelectionMode(
            QAbstractItemView.ExtendedSelection
        )
        self._selection_templates.setColumnCount(3)
        self._selection_templates.setVerticalHeaderLabels(['Nb template', 'Channel', 'Amplitude'])
        self._selection_templates.insertRow(0)
        self._selection_templates.setItem(0, 0, QTableWidgetItem('Nb template'))
        self._selection_templates.setItem(0, 1, QTableWidgetItem('Channel'))
        self._selection_templates.setItem(0, 2, QTableWidgetItem('Amplitude'))

        # self._selection_channels.setGeometry(QtCore.QRect(10, 10, 211, 291))
        # for i in range(self.nb_templates):
        #     numRows = self.tableWidget.rowCount()
        #     self.tableWidget.insertRow(numRows)

        #     item = QTableWidgetItem("Template %i" % i)
        #     self._selection_templates.addItem(item)
        #     self._selection_templates.item(i).setSelected(False)

        spacer = QSpacerItem(20, 40, QSizePolicy.Minimum, QSizePolicy.Expanding)

        # Create controls grid.
        grid = QGridLayout()
        # # Add time row.
        grid.addWidget(label_time, 0, 0)
        grid.addWidget(self._dsp_time, 0, 1)
        grid.addWidget(label_time_unit, 0, 2)
        # # Add voltage row.
        grid.addWidget(label_voltage, 1, 0)
        grid.addWidget(self._dsp_voltage, 1, 1)
        grid.addWidget(label_voltage_unit, 1, 2)

        # # Add binsize row.
        grid.addWidget(label_binsize, 2, 0)
        grid.addWidget(self._dsp_binsize, 2, 1)
        grid.addWidget(label_binsize_unit, 2, 2)

        # # Add zoom rate
        grid.addWidget(label_zoomrates, 3, 0)
        grid.addWidget(self._zoom_rates, 3, 1)

        # Add a double checkbox for time window
        grid.addWidget(label_time_window, 4, 0)
        grid.addWidget(self._dsp_tw_rate, 4, 1)
        grid.addWidget(label_time_window_unit, 4, 2)

        ## Add checkbox to display the rates from start
        grid.addWidget(label_tw_from_start, 5, 0)
        grid.addWidget(self._tw_from_start, 5, 1)

        # # Add spacer.
        grid.addItem(spacer)

        # # Create info group.
        controls_group = QGroupBox()
        controls_group.setLayout(grid)

        # Create info grid.
        templates_grid = QGridLayout()
        # # Add Channel selection
        # grid.addWidget(label_selection, 3, 0)
        templates_grid.addWidget(self._selection_templates, 0, 1)

        def add_template():
            items = self._selection_templates.selectedItems()
            self._display_list = []
            for i in range(len(items)):
                self._display_list.append(i)
            self._on_templates_changed()

        # self._selection_templates.itemClicked.connect(add_template)

        # Template selection signals
        self._selection_templates.itemSelectionChanged.connect(lambda: self.selected_templates(
            self.nb_templates))

        # Checkbox to display all the rates
        self._tw_from_start.stateChanged.connect(self.time_window_rate_full)
        # self._selection_templates.itemPressed(0, 1).connect(self.sort_template())

        # # Add spacer.
        templates_grid.addItem(spacer)

        # Create controls group.
        templates_group = QGroupBox()
        templates_group.setLayout(templates_grid)

        # # Create controls dock.
        templates_dock = QDockWidget()
        templates_dock.setWidget(templates_group)
        templates_dock.setWindowTitle("Channels selection")

        # # Create controls dock.
        control_dock = QDockWidget()
        control_dock.setWidget(controls_group)
        control_dock.setWindowTitle("Controls")

        # Create info widgets.
        label_time = QLabel()
        label_time.setText(u"time")
        self._label_time_value = QLineEdit()
        self._label_time_value.setText(u"0")
        self._label_time_value.setReadOnly(True)
        self._label_time_value.setAlignment(Qt.AlignRight)
        label_time_unit = QLabel()
        label_time_unit.setText(u"s")
        info_buffer_label = QLabel()
        info_buffer_label.setText(u"buffer")
        self._info_buffer_value_label = QLineEdit()
        self._info_buffer_value_label.setText(u"0")
        self._info_buffer_value_label.setReadOnly(True)
        self._info_buffer_value_label.setAlignment(Qt.AlignRight)
        info_buffer_unit_label = QLabel()
        info_buffer_unit_label.setText(u"")
        info_probe_label = QLabel()
        info_probe_label.setText(u"probe")
        info_probe_value_label = QLineEdit()
        info_probe_value_label.setText(u"{}".format(probe_path))
        info_probe_value_label.setReadOnly(True)
        # TODO place the following info in another grid?
        info_probe_unit_label = QLabel()
        info_probe_unit_label.setText(u"")

        info_spacer = QSpacerItem(20, 40, QSizePolicy.Minimum, QSizePolicy.Expanding)

        # Create info grid.
        info_grid = QGridLayout()
        # # Time row.
        info_grid.addWidget(label_time, 0, 0)
        info_grid.addWidget(self._label_time_value, 0, 1)
        info_grid.addWidget(label_time_unit, 0, 2)
        # # Buffer row.
        info_grid.addWidget(info_buffer_label, 1, 0)
        info_grid.addWidget(self._info_buffer_value_label, 1, 1)
        info_grid.addWidget(info_buffer_unit_label, 1, 2)
        # # Probe row.
        info_grid.addWidget(info_probe_label, 2, 0)
        info_grid.addWidget(info_probe_value_label, 2, 1)
        info_grid.addWidget(info_probe_unit_label, 2, 2)
        # # Spacer.
        info_grid.addItem(info_spacer)

        # Create info group.
        info_group = QGroupBox()
        info_group.setLayout(info_grid)

        # Create info dock.
        info_dock = QDockWidget()
        info_dock.setWidget(info_group)
        info_dock.setWindowTitle("Info")

        # Create thread.
        thread = Thread(number_pipe, templates_pipe, spikes_pipe)
        thread.number_signal.connect(self._number_callback)
        thread.reception_signal.connect(self._reception_callback)
        thread.start()

        # Add dockable windows.
        self.addDockWidget(Qt.LeftDockWidgetArea, control_dock)
        self.addDockWidget(Qt.LeftDockWidgetArea, info_dock)
        self.addDockWidget(Qt.LeftDockWidgetArea, templates_dock)

        # Add Grid Layout for canvas
        canvas_grid = QGridLayout()

        group_canv_temp = QDockWidget()
        group_canv_temp.setWidget(canvas_template_widget)
        group_canv_mea = QDockWidget()
        group_canv_mea.setWidget(canvas_mea)
        group_canv_rate = QDockWidget()
        group_canv_rate.setWidget(canvas_rate)
        group_canv_isi = QDockWidget()
        group_canv_isi.setWidget(canvas_isi)

        canvas_grid.addWidget(group_canv_temp, 0, 0)
        canvas_grid.addWidget(group_canv_mea, 0, 1)
        canvas_grid.addWidget(group_canv_rate, 1, 1)
        canvas_grid.addWidget(group_canv_isi, 1, 0)
        canvas_group = QGroupBox()
        canvas_group.setLayout(canvas_grid)

        # Set central widget.
        self.setCentralWidget(canvas_group)
        # Set window size.
        if screen_resolution is not None:
            screen_width = screen_resolution.width()
            screen_height = screen_resolution.height()
            self.resize(screen_width, screen_height)
        # Set window title.
        self.setWindowTitle("SpyKING Circus ORT - Read 'n' Qt display")

        print(" ")  # TODO remove?
    def __init__(self, probe_path=None, params=None):

        app.Canvas.__init__(self, title="Vispy canvas")

        self.probe = load_probe(probe_path)

        nb_buffers_per_signal = int(
            np.ceil((params['time']['max'] * 1e-3) * params['sampling_rate'] /
                    float(params['nb_samples'])))
        self.nb_buffers_per_signal = nb_buffers_per_signal
        self._time_max = (float(nb_buffers_per_signal * params['nb_samples']) /
                          params['sampling_rate']) * 1e+3
        self._time_min = params['time']['min']
        # self.templates = params['templates']
        self.initialized = False

        self.cells = None

        # Reception
        self.nb_templates = 0
        self.nb_samples_per_template = 0
        self.nb_channels = self.probe.nb_channels
        self.template_values = np.zeros((1, 1), dtype=np.float32)

        self.nb_electrode, self.nb_samples_per_template = 0, 0

        self.templates = np.zeros(
            shape=(self.nb_channels * self.nb_samples_per_template *
                   self.nb_templates, ),
            dtype=np.float32)

        self.templates_index = np.repeat(
            (np.arange(0, self.nb_templates, dtype=np.float32)),
            repeats=self.nb_channels * self.nb_samples_per_template)
        self.electrode_index = np.tile(np.repeat(
            np.arange(0, self.nb_channels, dtype=np.float32),
            repeats=self.nb_samples_per_template),
                                       reps=self.nb_templates)
        self.template_sample_index = np.tile(
            np.arange(0, self.nb_samples_per_template, dtype=np.float32),
            reps=self.nb_templates * self.nb_channels)

        self.template_selected = np.ones(self.nb_channels * self.nb_templates *
                                         self.nb_samples_per_template,
                                         dtype=np.float32)

        # Signals.

        # Number of signals.
        self.nb_signals = self.probe.nb_channels
        # Number of samples per buffer.
        self._nb_samples_per_buffer = params['nb_samples']
        # Number of samples per signal.
        nb_samples_per_signal = nb_buffers_per_signal * self._nb_samples_per_buffer
        self._nb_samples_per_signal = nb_samples_per_signal
        # Generate the signal values.
        self._template_values = np.zeros(
            (self.nb_signals, nb_samples_per_signal), dtype=np.float32)

        # Color of each vertex.
        # TODO: make it more efficient by using a GLSL-based color map and the index.
        template_colors = 0.75 * np.ones(
            (self.nb_signals, 3), dtype=np.float32)
        template_colors = np.repeat(template_colors,
                                    repeats=nb_samples_per_signal,
                                    axis=0)
        template_indices = np.repeat(np.arange(0,
                                               self.nb_signals,
                                               dtype=np.float32),
                                     repeats=nb_samples_per_signal)
        template_positions = np.c_[
            np.repeat(self.probe.x.astype(np.float32),
                      repeats=self.nb_samples_per_template),
            np.repeat(self.probe.y.astype(np.float32),
                      repeats=self.nb_samples_per_template), ]
        sample_indices = np.tile(np.arange(0,
                                           nb_samples_per_signal,
                                           dtype=np.float32),
                                 reps=self.nb_signals)

        self.template_position = np.tile(template_positions,
                                         (self.nb_templates, 1))
        np.random.seed(12)
        self.template_colors = np.repeat(
            np.random.uniform(size=(self.nb_templates, 3), low=.3, high=.9),
            self.nb_channels * self.nb_samples_per_template,
            axis=0).astype(np.float32)
        self.list_selected_templates = []

        # Define GLSL program.
        self._template_program = gloo.Program(vert=TEMPLATE_VERT_SHADER,
                                              frag=TEMPLATE_FRAG_SHADER)
        self._template_program['a_template_index'] = self.electrode_index
        self._template_program['a_template_position'] = self.template_position
        self._template_program['a_template_value'] = self.templates
        self._template_program['a_template_color'] = self.template_colors
        self._template_program['a_sample_index'] = self.template_sample_index
        self._template_program['a_template_selected'] = self.template_selected
        self._template_program[
            'u_nb_samples_per_signal'] = self.nb_samples_per_template
        self._template_program['u_x_min'] = self.probe.x_limits[0]
        self._template_program['u_x_max'] = self.probe.x_limits[1]
        self._template_program['u_y_min'] = self.probe.y_limits[0]
        self._template_program['u_y_max'] = self.probe.y_limits[1]
        self._template_program[
            'u_d_scale'] = self.probe.minimum_interelectrode_distance
        self._template_program[
            'u_t_scale'] = self._time_max / params['time']['init']
        self._template_program['u_v_scale'] = params['voltage']['init']

        # Boxes.

        box_indices = np.repeat(np.arange(0,
                                          self.nb_channels,
                                          dtype=np.float32),
                                repeats=5)
        box_positions = np.c_[
            np.repeat(self.probe.x.astype(np.float32), repeats=5),
            np.repeat(self.probe.y.astype(np.float32), repeats=5), ]
        corner_positions = np.c_[
            np.tile(np.array([+1.0, -1.0, -1.0, +1.0, +1.0], dtype=np.float32),
                    reps=self.nb_channels),
            np.tile(np.array([+1.0, +1.0, -1.0, -1.0, +1.0], dtype=np.float32),
                    reps=self.nb_channels), ]
        # Define GLSL program.
        self._box_program = gloo.Program(vert=BOX_VERT_SHADER,
                                         frag=BOX_FRAG_SHADER)
        self._box_program['a_box_index'] = box_indices
        self._box_program['a_box_position'] = box_positions
        self._box_program['a_corner_position'] = corner_positions
        self._box_program['u_x_min'] = self.probe.x_limits[0]
        self._box_program['u_x_max'] = self.probe.x_limits[1]
        self._box_program['u_y_min'] = self.probe.y_limits[0]
        self._box_program['u_y_max'] = self.probe.y_limits[1]
        self._box_program[
            'u_d_scale'] = self.probe.minimum_interelectrode_distance

        # Final details.

        gloo.set_viewport(0, 0, *self.physical_size)

        gloo.set_state(clear_color='black',
                       blend=True,
                       blend_func=('src_alpha', 'one_minus_src_alpha'))
Example #15
0
import os
# import numpy as np

from circusort.io.spikes import load_spikes, spikes2cells
from circusort.io.cells import load_cells
from circusort.io.template_store import load_template_store
# from circusort.utils.validation import get_fp_fn_rate
from circusort.io.datafile import load_datafile
from circusort.plt.cells import *
from circusort.plt.template import *
from circusort.io.probe import load_probe

data_path = "rates_manipulation"

p = load_probe(
    '/home/pierre/.spyking-circus-ort/benchmarks/%s/generation/probe.prb' %
    data_path)

generation_directory = os.path.join("~", ".spyking-circus-ort", "benchmarks",
                                    data_path)
probe_path = os.path.join(generation_directory, "probe.prb")

similarity_thresh = 0.9

print('Loading data...')
injected_cells = load_cells(os.path.join(generation_directory, 'generation'))
fitted_spikes = load_spikes(
    os.path.join(os.path.join(generation_directory, 'sorting'), 'spikes.h5'))
found_templates = load_template_store(
    os.path.join(os.path.join(generation_directory, 'sorting'),
                 'templates.h5'))
    def __init__(self, probe_path=None, params=None):
        app.Canvas.__init__(self, title="Probe view")

        self.probe = load_probe(probe_path)
        # self.channels = params['channels']
        self.nb_channels = self.probe.nb_channels
        self.initialized = False

        # TODO Add method to probe file to extract minimum coordinates without the interelectrode dist
        x_min, x_max = self.probe.x_limits[0] + self.probe.minimum_interelectrode_distance,\
                       self.probe.x_limits[1] - self.probe.minimum_interelectrode_distance
        y_min, y_max = self.probe.y_limits[0] + self.probe.minimum_interelectrode_distance, \
                       self.probe.y_limits[1] - self.probe.minimum_interelectrode_distance

        probe_corner = np.array(
            [[x_max, y_max], [x_min, y_max], [x_min, y_min], [x_max, y_min],
             [x_max, y_max]],
            dtype=np.float32)

        corner_bound_positions = np.array(
            [[+1.0, +1.0], [-1.0, +1.0], [-1.0, -1.0], [+1.0, -1.0],
             [+1.0, +1.0]],
            dtype=np.float32)

        # Define GLSL program.
        self._boundary_program = gloo.Program(vert=BOUNDARY_VERT_SHADER,
                                              frag=BOUNDARY_FRAG_SHADER)
        self._boundary_program['a_pos_probe'] = probe_corner
        self._boundary_program['a_corner_position'] = corner_bound_positions
        self._boundary_program['u_x_min'] = self.probe.x_limits[0]
        self._boundary_program['u_x_max'] = self.probe.x_limits[1]
        self._boundary_program['u_y_min'] = self.probe.y_limits[0]
        self._boundary_program['u_y_max'] = self.probe.y_limits[1]
        self._boundary_program[
            'u_d_scale'] = self.probe.minimum_interelectrode_distance
        self._boundary_program['u_scale'] = (1.0, 1.0)
        self._boundary_program['u_pan'] = (0.0, 0.0)

        # Probe
        channel_pos = np.c_[
            np.repeat(self.probe.x.astype(np.float32), repeats=1),
            np.repeat(self.probe.y.astype(np.float32), repeats=1), ]
        selected_channels = np.ones(self.nb_channels, dtype=np.float32)

        self._channel_program = gloo.Program(vert=CHANNELS_VERT_SHADER,
                                             frag=CHANNELS_FRAG_SHADER)
        self._channel_program['a_channel_position'] = channel_pos
        self._channel_program['a_selected_channel'] = selected_channels
        self._channel_program['radius'] = 10
        self._channel_program['u_x_min'] = self.probe.x_limits[0]
        self._channel_program['u_x_max'] = self.probe.x_limits[1]
        self._channel_program['u_y_min'] = self.probe.y_limits[0]
        self._channel_program['u_y_max'] = self.probe.y_limits[1]
        self._channel_program['u_scale'] = (1.0, 1.0)
        self._channel_program['u_pan'] = (0.0, 0.0)
        self._channel_program[
            'u_d_scale'] = self.probe.minimum_interelectrode_distance
        #self._channel_program['u_d_scale'] = self.probe.minimum_interelectrode_distance

        #Barycenters
        self.nb_templates = 0
        self.selected_bary = 0
        barycenter_position = np.zeros((self.nb_templates, 2),
                                       dtype=np.float32)
        temp_selected = np.ones(self.nb_templates, dtype=np.float32)
        self.barycenter = np.zeros((self.nb_templates, 2), dtype=np.float32)
        self.list_selected_templates = []
        np.random.seed(12)
        self.bary_color = np.random.uniform(size=(self.nb_templates, 3),
                                            low=.5,
                                            high=.9).astype(np.float32)

        self._barycenter_program = gloo.Program(vert=BARYCENTER_VERT_SHADER,
                                                frag=BARYCENTER_FRAG_SHADER)
        self._barycenter_program['a_barycenter_position'] = self.barycenter
        self._barycenter_program['a_selected_template'] = temp_selected
        self._barycenter_program['a_color'] = self.bary_color
        self._barycenter_program['radius'] = 5
        self._barycenter_program['u_x_min'] = self.probe.x_limits[0]
        self._barycenter_program['u_x_max'] = self.probe.x_limits[1]
        self._barycenter_program['u_y_min'] = self.probe.y_limits[0]
        self._barycenter_program['u_y_max'] = self.probe.y_limits[1]
        self._barycenter_program['u_scale'] = (1.0, 1.0)
        self._barycenter_program['u_pan'] = (0.0, 0.0)
        self._barycenter_program[
            'u_d_scale'] = self.probe.minimum_interelectrode_distance

        # Final details.
        gloo.set_viewport(0, 0, *self.physical_size)
        gloo.set_state(clear_color='black',
                       blend=True,
                       blend_func=('src_alpha', 'one_minus_src_alpha'))
Example #17
0
    def __init__(self,
                 file_name,
                 probe_file=None,
                 mode='r+',
                 compression=None):
        # TODO use compression='gzip' instead.
        # TODO i.e. fix the ValueError: Compression filter "gzip" is unavailable.

        self.file_name = os.path.expanduser(os.path.abspath(file_name))
        self.probe_file = probe_file
        self.mode = mode
        self._index = -1
        self.mappings = {}
        self._2_components = False
        self._temporal_width = None
        self.h5_file = None
        self._first_creation = None
        self._last_creation = None
        self._channels = None
        self.compression = compression
        self._similarities = {}

        self._open(self.mode)
        from circusort.io.probe import load_probe

        if self.mode in ['w']:

            assert probe_file is not None
            self.probe = load_probe(self.probe_file)
            for channel, indices in self.probe.edges.items():
                indices = self.probe.edges[channel]
                self.h5_file.create_dataset('mapping/%d' % channel,
                                            data=indices,
                                            chunks=True,
                                            maxshape=(None, ),
                                            compression=self.compression)
                self.mappings[channel] = indices
            self.h5_file.create_dataset('indices',
                                        data=np.zeros(0, dtype=np.int32),
                                        chunks=True,
                                        maxshape=(None, ),
                                        compression=self.compression)
            self.h5_file.create_dataset('times',
                                        data=np.zeros(0, dtype=np.int32),
                                        chunks=True,
                                        maxshape=(None, ),
                                        compression=self.compression)
            self.h5_file.create_dataset('channels',
                                        data=np.zeros(0, dtype=np.int32),
                                        chunks=True,
                                        maxshape=(None, ),
                                        compression=self.compression)
            self.h5_file.attrs['probe_file'] = os.path.abspath(
                os.path.expanduser(probe_file))

        elif self.mode in ['r', 'r+']:

            indices = self.h5_file['indices'][:]
            if len(indices) > 0:
                self._index = indices.max()

            self.mappings = {}
            for key, value in self.h5_file['mapping'].items():
                self.mappings[int(key)] = value[:]

            if self._index >= 0:
                self._2_components = '2' in self.h5_file['waveforms/%d' %
                                                         self._index]

            if self.probe_file is None:
                self.probe_file = self.h5_file.attrs['probe_file']
            self.probe = load_probe(self.probe_file)

        self.nb_channels = len(self.mappings)
        self._close()