Ejemplo n.º 1
0
def _truncate(fn, extension='.dat', offset=None, n_channels=None, itemsize=None, dtype=None, chunk_size=50000):
    """Eventually truncate a file at the end to ensure it has a correct shape.
    """
    data = np.memmap(fn, dtype=dtype, offset=offset)
    N    = data.shape[0]

    if np.mod(N, n_channels) != 0:

        fn_copy   = fn + extension
        N         = int(N/n_channels)

        if op.exists(fn_copy):
            return fn_copy, (N, n_channels)

        # Create the end-truncated file.
        info("Truncating...")
        f = open(fn_copy, 'w')
        chunk_len = n_channels*chunk_size
        n_samples = N/chunk_len
        for i in range(n_samples):
            data = np.memmap(fn, dtype=dtype, offset=offset)
            f.write(data[i*chunk_len:(i+1)*chunk_len])
        f.close()
    else:
        fn_copy   = fn
        N         = int(N/n_channels)
    return fn_copy, (N, n_channels)
Ejemplo n.º 2
0
def _read_filtered(filename, n_channels=None, dtype=None):
    fn = filename
    with open(fn, 'rb') as f:
        data = f.read(4096)
    data = data.decode('ascii',  'ignore')
    try:
        i = data.index('EOH')
        # OFFSET = HEADER + EOH (3 bytes) + 2 uint16 samples (4 bytes)
        offset = i + 3 + 2 * 2
    except Exception:
        offset = 0
    info("Header: {} bytes.".format(offset))
    dtype = np.dtype(dtype)
    filename, shape = _truncate(fn,
                      offset=offset,
                      n_channels=n_channels,
                      itemsize=dtype.itemsize,
                      dtype=dtype)
    return filename, np.memmap(filename, dtype=dtype, offset=0, shape=shape)
Ejemplo n.º 3
0
        def on_key_press(e):
            if e.key == 'space':
                self._n += 1 if ('Shift' not in e.modifiers) else -1
                if name == 'templates':
                    info("Template {}.".format(self._n))
                    w.set_data(waveforms=templates[self._n],
                               masks=masks[self._n],
                               )
                elif name == 'waveforms':
                    sample = self.spike_samples[self._n]
                    cluster = self.spike_clusters[self._n]
                    info("Waveform {}, template={}, sample={}.".format(self._n,
                         cluster, sample))

                    wav = np.vstack((templates[self._n],
                                     self.templates[cluster][:-1][None, ...]))

                    m = np.vstack((masks[self._n],
                                   self.template_masks[cluster][None, ...]))
                    w.set_data(waveforms=wav,
                               masks=m,
                               spike_clusters=[0, 1],
                               )
Ejemplo n.º 4
0
    def create_kwik(self):
        # Create an empty Kwik file.
        info("Starting the conversion to Kwik...")
        create_kwik(kwik_path=self.kwik_path,
                    raw_data_files=[self.file],
                    prb_file=self.prb_file,
                    n_channels=self.n_total_channels,
                    sample_rate=self.sample_rate,
                    dtype=self.dtype,
                    nfeatures_per_channel=self.n_features_per_channel,
                    extract_s_after = self.extract_s_after,
                    extract_s_before = self.extract_s_before,
                    overwrite=True,
                    )

        # Compute PCs and features.
        if extract_features:
            info("Computing PCs...")
            self.compute_pcs()

            info("Computing features of all spikes...")
            # WARNING: watch out RAM usage here. We cannot use a generator because
            # the KwiKCreator only accepts lists at the moment.
            features = (f for f in self.compute_features())
            masks    = (m for m in self.compute_masks())
        else:
            info("Skipping PCA...")
            features = None
            masks = None
            self.n_features_per_channel = None

        # Add clusters.
        creator = KwikCreator(self.kwik_path)

        info("Adding the clusters in the kwik file.")
        creator.add_clustering(group=1,
                               name='main',
                               spike_clusters=self.spike_clusters,
                               template_waveforms=self.templates,
                               template_masks=self.template_masks,
                               template_amplitudes=self.amplitudes,
                               )

        # Add spikes.
        info("Adding the spikes in the kwik file.")
        creator.add_spikes(group=1,
                           spike_samples=self.spike_samples,
                           masks=masks,
                           features=features,
                           n_channels = self.n_channels,
                           n_features = self.n_features_per_channel
                           )

        # Add template amplitudes. We add these to the .kwik file, not the
        # .kwx, since they're lightweight enough that you can delete them
        # afterwards!


        info("Kwik file successfully created!")
Ejemplo n.º 5
0
    def __init__(self,
                 basename,
                 filename,
                 N_t,
                 n_channels=None,
                 n_total_channels=None,
                 prb_file=None,
                 dtype=None,
                 sample_rate=None,
                 offset=0,
                 gain=0.01
                 ):

        self.n_features_per_channel = 3
        self.n_total_channels = n_total_channels
        self.extract_s_after = self.extract_s_before = extract_s_before = extract_s_after = int(N_t - 1)//2

        # set to True if your data is already pre-filtered (much quicker)
        filtered_datfile = False

        # Filtering parameters for PCA (these are ignored if filtered_datfile == True)
        filter_low = 500.
        filter_high = 0.95 * .5 * sample_rate
        filter_butter_order = 3

        self.basename = basename
        self.kwik_path = basename + '.kwik'
        self.dtype = dtype
        self.prb_file = prb_file
        self.probe = load_probe(prb_file)

        self.sample_rate = sample_rate
        self.filtered_datfile = filtered_datfile

        self._sd = SpikeDetekt(probe=self.probe,
                               n_features_per_channel=
                               self.n_features_per_channel,
                               pca_n_waveforms_max=10000,
                               extract_s_before=extract_s_before,
                               extract_s_after=extract_s_after,
                               sample_rate=sample_rate,
                               )
        self.n_samples_w = extract_s_before + extract_s_after

        # A xxx.filtered.trunc file may be created if needed.
        self.file, self.traces_f = _read_filtered(filename,
                                       n_channels=n_total_channels,
                                       dtype=dtype,
                                       )
        self.n_samples, self.n_total_channels = self.traces_f.shape
        self.n_channels = n_channels
        assert n_total_channels == self.n_total_channels
        info("Loaded traces: {}.".format(self.traces_f.shape))

        # Load spikes.
        self.spike_samples, self.spike_clusters = _read_spikes(basename)
        self.n_spikes = len(self.spike_samples)
        assert len(self.spike_clusters) == self.n_spikes
        info("Loaded {} spikes.".format(self.n_spikes))

        # Chunks when computing features.
        self.chunk_size = 2500
        self.n_chunks   = int(np.ceil(self.n_spikes / self.chunk_size))

        # Load templates and masks.
        self.templates, self.template_masks = _read_templates(basename, self.probe, self.n_total_channels, self.n_channels)
        self.n_templates = len(self.templates)
        info("Loaded templates: {}.".format(self.templates.shape))

        # Load amplitudes.
        self.amplitudes = _read_amplitudes(basename, self.n_templates, self.n_spikes, self.spike_clusters)

        if extract_features:
            # The WaveformLoader fetches and filters waveforms from the raw traces dynamically.
            n_samples = (extract_s_before, extract_s_after)
            b_filter = bandpass_filter(rate=self.sample_rate,
                                       low=filter_low,
                                       high=filter_high,
                                       order=filter_butter_order)

            def filter(x):
              return apply_filter(x, b_filter)

            filter_margin = filter_butter_order * 3

            nodes            = []
            for key in self.probe['channel_groups'].keys():
              nodes += self.probe['channel_groups'][key]['channels']
            nodes    = np.array(nodes, dtype=np.int32)

            if filtered_datfile:
              self._wl = WaveformLoader(traces=self.traces_f,
                                        n_samples=self.n_samples_w,
                                        dc_offset=offset,
                                        scale_factor=gain,
                                        channels=nodes
                                        )
            else:
              self._wl = WaveformLoader(traces=self.traces_f,
                                        n_samples=self.n_samples_w,
                                        filter=filter,
                                        filter_margin=filter_margin,
                                        dc_offset=offset,
                                        scale_factor=gain,
                                        channels=nodes
                                        )

            # A virtual (n_spikes, n_samples, n_channels) array that is
            # memmapped to the filtered data file.
            self.waveforms = SpikeLoader(self._wl, self.spike_samples)

            assert self.waveforms.shape == (self.n_spikes,
                                            self.n_samples_w,
                                            self.n_channels)
            assert self.template_masks.shape == (self.n_templates, self.n_channels)