Пример #1
0
def waveform_loader(model, filter_wave):
    """Create a waveform loader."""
    n_samples = (model._metadata['extract_s_before'],
                 model._metadata['extract_s_after'])
    order = model._metadata['filter_butter_order']
    rate = model._metadata['sample_rate']
    low = model._metadata['filter_low']
    high = model._metadata['filter_high_factor'] * rate
    b_filter = bandpass_filter(rate=rate,
                               low=low,
                               high=high,
                               order=order)

    if (filter_wave == True):
        def filter(x):
            return apply_filter(x, b_filter)

        filter_margin = order * 3
    else:
        filter = None
        filter_margin = 0

    dc_offset = model._metadata.get('waveform_dc_offset', None)
    scale_factor = model._metadata.get('waveform_scale_factor', None)
    return WaveformLoader(n_samples=n_samples,
                          filter=filter,
                          filter_margin=filter_margin,
                          dc_offset=dc_offset,
                          scale_factor=scale_factor,
                          )
Пример #2
0
    def __init__(
        self,
        traces=None,
        sample_rate=None,
        spike_samples=None,
        masks=None,
        mask_threshold=None,
        filter_order=None,
        n_samples_waveforms=None,
    ):

        # Traces.
        if traces is not None:
            self.traces = traces
            self.n_samples_trace, self.n_channels = traces.shape
        else:
            self._traces = None
            self.n_samples_trace = self.n_channels = 0

        assert spike_samples is not None
        self._spike_samples = spike_samples
        self.n_spikes = len(spike_samples)

        self._masks = masks
        if masks is not None:
            assert self._masks.shape == (self.n_spikes, self.n_channels)
        self._mask_threshold = mask_threshold

        # Define filter.
        if filter_order:
            filter_margin = filter_order * 3
            b_filter = bandpass_filter(
                rate=sample_rate,
                low=500.,
                high=sample_rate * .475,
                order=filter_order,
            )
            self._filter = lambda x, axis=0: apply_filter(
                x, b_filter, axis=axis)
        else:
            filter_margin = 0
            self._filter = lambda x, axis=0: x

        # Number of samples to return, can be an int or a
        # tuple (before, after).
        assert n_samples_waveforms is not None
        self.n_samples_before_after = _before_after(n_samples_waveforms)
        self.n_samples_waveforms = sum(self.n_samples_before_after)
        # Number of additional samples to use for filtering.
        self._filter_margin = _before_after(filter_margin)
        # Number of samples in the extracted raw data chunk.
        self._n_samples_extract = (self.n_samples_waveforms +
                                   sum(self._filter_margin))

        self.dtype = np.float32
        self.shape = (self.n_spikes, self._n_samples_extract, self.n_channels)
        self.ndim = 3
Пример #3
0
    def __init__(self,
                 traces=None,
                 sample_rate=None,
                 spike_samples=None,
                 filter_order=None,
                 n_samples_waveforms=None,
                 ):

        # Traces.
        if traces is not None:
            self.traces = traces
            self.n_samples_trace, self.n_channels = traces.shape
        else:
            self._traces = None
            self.n_samples_trace = self.n_channels = 0

        assert spike_samples is not None
        self._spike_samples = spike_samples
        self.n_spikes = len(spike_samples)

        # Define filter.
        if filter_order:
            filter_margin = filter_order * 3
            b_filter = bandpass_filter(rate=sample_rate,
                                       low=500.,
                                       high=sample_rate * .475,
                                       order=filter_order,
                                       )
            self._filter = lambda x, axis=0: apply_filter(x, b_filter,
                                                          axis=axis)
        else:
            filter_margin = 0
            self._filter = lambda x, axis=0: x

        # Number of samples to return, can be an int or a
        # tuple (before, after).
        assert n_samples_waveforms is not None
        self.n_samples_before_after = _before_after(n_samples_waveforms)
        self.n_samples_waveforms = sum(self.n_samples_before_after)
        # Number of additional samples to use for filtering.
        self._filter_margin = _before_after(filter_margin)
        # Number of samples in the extracted raw data chunk.
        self._n_samples_extract = (self.n_samples_waveforms +
                                   sum(self._filter_margin))

        self.dtype = np.float32
        self.shape = (self.n_spikes, self._n_samples_extract, self.n_channels)
        self.ndim = 3
Пример #4
0
    def _init_data(self):
        if op.exists(self.dat_path):
            logger.debug("Loading traces at `%s`.", self.dat_path)
            traces = _dat_to_traces(self.dat_path,
                                    n_channels=self.n_channels_dat,
                                    dtype=self.dtype or np.int16,
                                    offset=self.offset,
                                    )
            n_samples_t, _ = traces.shape
            assert _ == self.n_channels_dat
        else:
            traces = None
            n_samples_t = 0

        logger.debug("Loading amplitudes.")
        amplitudes = read_array('amplitudes').squeeze()
        n_spikes, = amplitudes.shape
        self.n_spikes = n_spikes

        # Create spike_clusters if the file doesn't exist.
        if not op.exists(filenames['spike_clusters']):
            shutil.copy(filenames['spike_templates'],
                        filenames['spike_clusters'])
        logger.debug("Loading %d spike clusters.", self.n_spikes)
        spike_clusters = read_array('spike_clusters').squeeze()
        spike_clusters = spike_clusters.astype(np.int32)
        assert spike_clusters.shape == (n_spikes,)
        self.spike_clusters = spike_clusters

        logger.debug("Loading spike templates.")
        spike_templates = read_array('spike_templates').squeeze()
        spike_templates = spike_templates.astype(np.int32)
        assert spike_templates.shape == (n_spikes,)
        self.spike_templates = spike_templates

        logger.debug("Loading spike samples.")
        spike_samples = read_array('spike_samples').squeeze()
        assert spike_samples.shape == (n_spikes,)

        logger.debug("Loading templates.")
        templates = read_array('templates')
        templates[np.isnan(templates)] = 0
        # templates = np.transpose(templates, (2, 1, 0))

        # Unwhiten the templates.
        logger.debug("Loading the whitening matrix.")
        self.whitening_matrix = read_array('whitening_matrix')

        if op.exists(filenames['templates_unw']):
            logger.debug("Loading unwhitened templates.")
            templates_unw = read_array('templates_unw')
            templates_unw[np.isnan(templates_unw)] = 0
        else:
            logger.debug("Couldn't find unwhitened templates, computing them.")
            logger.debug("Inversing the whitening matrix %s.",
                         self.whitening_matrix.shape)
            wmi = np.linalg.inv(self.whitening_matrix)
            logger.debug("Unwhitening the templates %s.",
                         templates.shape)
            templates_unw = np.dot(np.ascontiguousarray(templates),
                                   np.ascontiguousarray(wmi))
            # Save the unwhitened templates.
            write_array('templates_unw.npy', templates_unw)

        n_templates, n_samples_templates, n_channels = templates.shape
        self.n_templates = n_templates

        logger.debug("Loading similar templates.")
        self.similar_templates = read_array('similar_templates')
        assert self.similar_templates.shape == (self.n_templates,
                                                self.n_templates)

        logger.debug("Loading channel mapping.")
        channel_mapping = read_array('channel_mapping').squeeze()
        channel_mapping = channel_mapping.astype(np.int32)
        assert channel_mapping.shape == (n_channels,)
        # Ensure that the mappings maps to valid columns in the dat file.
        assert np.all(channel_mapping <= self.n_channels_dat - 1)

        logger.debug("Loading channel positions.")
        channel_positions = read_array('channel_positions')
        assert channel_positions.shape == (n_channels, 2)

        if op.exists(filenames['features']):
            logger.debug("Loading features.")
            all_features = np.load(filenames['features'], mmap_mode='r')
            features_ind = read_array('features_ind').astype(np.int32)
            # Feature subset.
            if op.exists(filenames['features_spike_ids']):
                features_spike_ids = read_array('features_spike_ids') \
                    .astype(np.int32)
                assert len(features_spike_ids) == len(all_features)
                self.features_spike_ids = features_spike_ids
                ns = len(features_spike_ids)
            else:
                ns = self.n_spikes
                self.features_spike_ids = None

            assert all_features.ndim == 3
            n_loc_chan = all_features.shape[2]
            self.n_features_per_channel = all_features.shape[1]
            assert all_features.shape == (ns,
                                          self.n_features_per_channel,
                                          n_loc_chan,
                                          )
            # Check sparse features arrays shapes.
            assert features_ind.shape == (self.n_templates, n_loc_chan)
        else:
            all_features = None
            features_ind = None

        self.all_features = all_features
        self.features_ind = features_ind

        if op.exists(filenames['template_features']):
            logger.debug("Loading template features.")
            template_features = np.load(filenames['template_features'],
                                        mmap_mode='r')
            template_features_ind = read_array('template_features_ind'). \
                astype(np.int32)
            template_features_ind = template_features_ind.copy()
            n_sim_tem = template_features.shape[1]
            assert template_features.shape == (n_spikes, n_sim_tem)
            assert template_features_ind.shape == (n_templates, n_sim_tem)
        else:
            template_features = None
            template_features_ind = None

        self.template_features_ind = template_features_ind
        self.template_features = template_features

        self.n_channels = n_channels
        # Take dead channels into account.
        if traces is not None:
            # Find the scaling factor for the traces.
            scaling = 1. / self._data_lim(traces[:10000])
            traces = _concatenate_virtual_arrays([traces],
                                                 channel_mapping,
                                                 scaling=scaling,
                                                 )
        else:
            scaling = 1.

        # Amplitudes
        self.all_amplitudes = amplitudes
        self.amplitudes_lim = self.all_amplitudes.max()

        # Templates
        self.templates = templates
        self.templates_unw = templates_unw
        assert self.templates.shape == self.templates_unw.shape
        self.n_samples_templates = n_samples_templates
        self.n_samples_waveforms = n_samples_templates
        self.template_lim = np.max(np.abs(self.templates))

        self.duration = n_samples_t / float(self.sample_rate)

        self.spike_times = spike_samples / float(self.sample_rate)
        assert np.all(np.diff(self.spike_times) >= 0)

        self.cluster_ids = np.unique(self.spike_clusters)
        # n_clusters = len(self.cluster_ids)

        self.channel_positions = channel_positions
        self.all_traces = traces

        # Filter the waveforms.
        order = 3
        filter_margin = order * 3
        b_filter = bandpass_filter(rate=self.sample_rate,
                                   low=500.,
                                   high=self.sample_rate * .475,
                                   order=order)

        # Only filter the data for the waveforms if the traces
        # are not already filtered.
        if not getattr(self, 'hp_filtered', False):
            logger.debug("HP filtering the data for waveforms")

            def the_filter(x, axis=0):
                return apply_filter(x, b_filter, axis=axis)
        else:
            the_filter = None

        # Fetch waveforms from traces.
        nsw = self.n_samples_waveforms
        if traces is not None:
            waveforms = WaveformLoader(traces=traces,
                                       n_samples_waveforms=nsw,
                                       filter=the_filter,
                                       filter_margin=filter_margin,
                                       )
            waveforms = SpikeLoader(waveforms, spike_samples)
        else:
            waveforms = None
        self.all_waveforms = waveforms

        self.template_masks = get_masks(self.templates)
        self.all_masks = MaskLoader(self.template_masks, self.spike_templates)

        # Read the cluster groups.
        logger.debug("Loading the cluster groups.")
        self.cluster_groups = {}
        if op.exists(filenames['cluster_groups']):
            with open(filenames['cluster_groups'], 'r') as f:
                reader = csv.reader(f, delimiter='\t')
                # Skip the header.
                for row in reader:
                    break
                for row in reader:
                    cluster, group = row
                    cluster = int(cluster)
                    self.cluster_groups[cluster] = group
        for cluster_id in self.cluster_ids:
            if cluster_id not in self.cluster_groups:
                self.cluster_groups[cluster_id] = None
Пример #5
0
    def _init_data(self):
        if op.exists(self.dat_path):
            logger.debug("Loading traces at `%s`.", self.dat_path)
            traces = _dat_to_traces(
                self.dat_path,
                n_channels=self.n_channels_dat,
                dtype=self.dtype or np.int16,
                offset=self.offset,
            )
            n_samples_t, _ = traces.shape
            assert _ == self.n_channels_dat
        else:
            traces = None
            n_samples_t = 0

        logger.debug("Loading amplitudes.")
        amplitudes = read_array('amplitudes').squeeze()
        n_spikes, = amplitudes.shape
        self.n_spikes = n_spikes

        # Create spike_clusters if the file doesn't exist.
        if not op.exists(filenames['spike_clusters']):
            shutil.copy(filenames['spike_templates'],
                        filenames['spike_clusters'])
        logger.debug("Loading spike clusters.")
        spike_clusters = read_array('spike_clusters').squeeze()
        spike_clusters = spike_clusters.astype(np.int32)
        assert spike_clusters.shape == (n_spikes, )
        self.spike_clusters = spike_clusters

        logger.debug("Loading spike templates.")
        spike_templates = read_array('spike_templates').squeeze()
        spike_templates = spike_templates.astype(np.int32)
        assert spike_templates.shape == (n_spikes, )
        self.spike_templates = spike_templates

        logger.debug("Loading spike samples.")
        spike_samples = read_array('spike_samples').squeeze()
        assert spike_samples.shape == (n_spikes, )

        logger.debug("Loading templates.")
        templates = read_array('templates')
        templates[np.isnan(templates)] = 0
        # templates = np.transpose(templates, (2, 1, 0))

        # Unwhiten the templates.
        logger.debug("Loading the whitening matrix.")
        self.whitening_matrix = read_array('whitening_matrix')

        if op.exists(filenames['templates_unw']):
            logger.debug("Loading unwhitened templates.")
            templates_unw = read_array('templates_unw')
            templates_unw[np.isnan(templates_unw)] = 0
        else:
            logger.debug("Couldn't find unwhitened templates, computing them.")
            logger.debug("Inversing the whitening matrix %s.",
                         self.whitening_matrix.shape)
            wmi = np.linalg.inv(self.whitening_matrix)
            logger.debug("Unwhitening the templates %s.", templates.shape)
            templates_unw = np.dot(templates, wmi)

        n_templates, n_samples_templates, n_channels = templates.shape
        self.n_templates = n_templates

        logger.debug("Loading similar templates.")
        self.similar_templates = read_array('similar_templates')
        assert self.similar_templates.shape == (self.n_templates,
                                                self.n_templates)

        logger.debug("Loading channel mapping.")
        channel_mapping = read_array('channel_mapping').squeeze()
        channel_mapping = channel_mapping.astype(np.int32)
        assert channel_mapping.shape == (n_channels, )
        # Ensure that the mappings maps to valid columns in the dat file.
        assert np.all(channel_mapping <= self.n_channels_dat - 1)

        logger.debug("Loading channel positions.")
        channel_positions = read_array('channel_positions')
        assert channel_positions.shape == (n_channels, 2)

        if op.exists(filenames['features']):
            logger.debug("Loading features.")
            all_features = np.load(filenames['features'], mmap_mode='r')
            features_ind = read_array('features_ind').astype(np.int32)
            # Feature subset.
            if op.exists(filenames['features_spike_ids']):
                features_spike_ids = read_array('features_spike_ids') \
                    .astype(np.int32)
                assert len(features_spike_ids) == len(all_features)
                self.features_spike_ids = features_spike_ids
                ns = len(features_spike_ids)
            else:
                ns = self.n_spikes
                self.features_spike_ids = None

            assert all_features.ndim == 3
            n_loc_chan = all_features.shape[2]
            self.n_features_per_channel = all_features.shape[1]
            assert all_features.shape == (
                ns,
                self.n_features_per_channel,
                n_loc_chan,
            )
            # Check sparse features arrays shapes.
            assert features_ind.shape == (self.n_templates, n_loc_chan)
        else:
            all_features = None
            features_ind = None

        self.all_features = all_features
        self.features_ind = features_ind

        if op.exists(filenames['template_features']):
            logger.debug("Loading template features.")
            template_features = np.load(filenames['template_features'],
                                        mmap_mode='r')
            template_features_ind = read_array('template_features_ind'). \
                astype(np.int32)
            template_features_ind = template_features_ind.copy()
            n_sim_tem = template_features.shape[1]
            assert template_features.shape == (n_spikes, n_sim_tem)
            assert template_features_ind.shape == (n_templates, n_sim_tem)
        else:
            template_features = None
            template_features_ind = None

        self.template_features_ind = template_features_ind
        self.template_features = template_features

        self.n_channels = n_channels
        # Take dead channels into account.
        if traces is not None:
            # Find the scaling factor for the traces.
            scaling = 1. / self._data_lim(traces[:10000])
            traces = _concatenate_virtual_arrays(
                [traces],
                channel_mapping,
                scaling=scaling,
            )
        else:
            scaling = 1.

        # Amplitudes
        self.all_amplitudes = amplitudes
        self.amplitudes_lim = self.all_amplitudes.max()

        # Templates
        self.templates = templates
        self.templates_unw = templates_unw
        assert self.templates.shape == self.templates_unw.shape
        self.n_samples_templates = n_samples_templates
        self.n_samples_waveforms = n_samples_templates
        self.template_lim = np.max(np.abs(self.templates))

        self.duration = n_samples_t / float(self.sample_rate)

        self.spike_times = spike_samples / float(self.sample_rate)
        assert np.all(np.diff(self.spike_times) >= 0)

        self.cluster_ids = np.unique(self.spike_clusters)
        # n_clusters = len(self.cluster_ids)

        self.channel_positions = channel_positions
        self.all_traces = traces

        # Filter the waveforms.
        order = 3
        filter_margin = order * 3
        b_filter = bandpass_filter(rate=self.sample_rate,
                                   low=500.,
                                   high=self.sample_rate * .475,
                                   order=order)

        # Only filter the data for the waveforms if the traces
        # are not already filtered.
        if not getattr(self, 'hp_filtered', False):
            logger.debug("HP filtering the data for waveforms")

            def the_filter(x, axis=0):
                return apply_filter(x, b_filter, axis=axis)
        else:
            the_filter = None

        # Fetch waveforms from traces.
        nsw = self.n_samples_waveforms
        if traces is not None:
            waveforms = WaveformLoader(
                traces=traces,
                n_samples_waveforms=nsw,
                filter=the_filter,
                filter_margin=filter_margin,
            )
            waveforms = SpikeLoader(waveforms, spike_samples)
        else:
            waveforms = None
        self.all_waveforms = waveforms

        self.template_masks = get_masks(self.templates)
        self.all_masks = MaskLoader(self.template_masks, self.spike_templates)

        # Read the cluster groups.
        logger.debug("Loading the cluster groups.")
        self.cluster_groups = {}
        if op.exists(filenames['cluster_groups']):
            with open(filenames['cluster_groups'], 'r') as f:
                reader = csv.reader(f, delimiter='\t')
                # Skip the header.
                for row in reader:
                    break
                for row in reader:
                    cluster, group = row
                    cluster = int(cluster)
                    self.cluster_groups[cluster] = group
        for cluster_id in self.cluster_ids:
            if cluster_id not in self.cluster_groups:
                self.cluster_groups[cluster_id] = None
Пример #6
0
    def __init__(self,
                 basename,
                 filename,
                 N_t,
                 n_channels=None,
                 n_total_channels=None,
                 prb_file=None,
                 dtype=None,
                 sample_rate=None,
                 offset=0,
                 gain=0.01
                 ):

        self.n_features_per_channel = 3
        self.n_total_channels = n_total_channels
        self.extract_s_after = self.extract_s_before = extract_s_before = extract_s_after = int(N_t - 1)//2

        # set to True if your data is already pre-filtered (much quicker)
        filtered_datfile = False

        # Filtering parameters for PCA (these are ignored if filtered_datfile == True)
        filter_low = 500.
        filter_high = 0.95 * .5 * sample_rate
        filter_butter_order = 3

        self.basename = basename
        self.kwik_path = basename + '.kwik'
        self.dtype = dtype
        self.prb_file = prb_file
        self.probe = load_probe(prb_file)

        self.sample_rate = sample_rate
        self.filtered_datfile = filtered_datfile

        self._sd = SpikeDetekt(probe=self.probe,
                               n_features_per_channel=
                               self.n_features_per_channel,
                               pca_n_waveforms_max=10000,
                               extract_s_before=extract_s_before,
                               extract_s_after=extract_s_after,
                               sample_rate=sample_rate,
                               )
        self.n_samples_w = extract_s_before + extract_s_after

        # A xxx.filtered.trunc file may be created if needed.
        self.file, self.traces_f = _read_filtered(filename,
                                       n_channels=n_total_channels,
                                       dtype=dtype,
                                       )
        self.n_samples, self.n_total_channels = self.traces_f.shape
        self.n_channels = n_channels
        assert n_total_channels == self.n_total_channels
        info("Loaded traces: {}.".format(self.traces_f.shape))

        # Load spikes.
        self.spike_samples, self.spike_clusters = _read_spikes(basename)
        self.n_spikes = len(self.spike_samples)
        assert len(self.spike_clusters) == self.n_spikes
        info("Loaded {} spikes.".format(self.n_spikes))

        # Chunks when computing features.
        self.chunk_size = 2500
        self.n_chunks   = int(np.ceil(self.n_spikes / self.chunk_size))

        # Load templates and masks.
        self.templates, self.template_masks = _read_templates(basename, self.probe, self.n_total_channels, self.n_channels)
        self.n_templates = len(self.templates)
        info("Loaded templates: {}.".format(self.templates.shape))

        # Load amplitudes.
        self.amplitudes = _read_amplitudes(basename, self.n_templates, self.n_spikes, self.spike_clusters)

        if extract_features:
            # The WaveformLoader fetches and filters waveforms from the raw traces dynamically.
            n_samples = (extract_s_before, extract_s_after)
            b_filter = bandpass_filter(rate=self.sample_rate,
                                       low=filter_low,
                                       high=filter_high,
                                       order=filter_butter_order)

            def filter(x):
              return apply_filter(x, b_filter)

            filter_margin = filter_butter_order * 3

            nodes            = []
            for key in self.probe['channel_groups'].keys():
              nodes += self.probe['channel_groups'][key]['channels']
            nodes    = np.array(nodes, dtype=np.int32)

            if filtered_datfile:
              self._wl = WaveformLoader(traces=self.traces_f,
                                        n_samples=self.n_samples_w,
                                        dc_offset=offset,
                                        scale_factor=gain,
                                        channels=nodes
                                        )
            else:
              self._wl = WaveformLoader(traces=self.traces_f,
                                        n_samples=self.n_samples_w,
                                        filter=filter,
                                        filter_margin=filter_margin,
                                        dc_offset=offset,
                                        scale_factor=gain,
                                        channels=nodes
                                        )

            # A virtual (n_spikes, n_samples, n_channels) array that is
            # memmapped to the filtered data file.
            self.waveforms = SpikeLoader(self._wl, self.spike_samples)

            assert self.waveforms.shape == (self.n_spikes,
                                            self.n_samples_w,
                                            self.n_channels)
            assert self.template_masks.shape == (self.n_templates, self.n_channels)