Example #1
0
def test_spikecache_1():
    nspikes = 100000
    nclusters = 100
    nchannels = 8
    spike_clusters = np.random.randint(size=nspikes, low=0, high=nclusters)
    
    sc = SpikeCache(spike_clusters=spike_clusters,
                    cache_fraction=.1,
                    features_masks=np.zeros((nspikes, 3*nchannels, 2)),
                    waveforms_raw=np.zeros((nspikes, 20, nchannels)),
                    waveforms_filtered=np.zeros((nspikes, 20, nchannels)))
                    
    ind, fm = sc.load_features_masks(.1)
    
    assert len(ind) == nspikes // 100
    assert fm.shape[0] == nspikes // 100
    
    ind, fm = sc.load_features_masks(clusters=[10, 20])
    
    assert len(ind) == fm.shape[0]
    assert np.allclose(ind, np.nonzero(np.in1d(spike_clusters, (10, 20)))[0])
    
    ind, fm = sc.load_features_masks(clusters=[1000])
    assert len(ind) == 0
    assert len(fm) == 0
Example #2
0
 def init_cache(self):
     """Initialize the cache for the features & masks."""
     self._spikecache = SpikeCache(
         # TODO: handle multiple clusterings in the spike cache here
         spike_clusters=self.clusters.main,
         features_masks=self.features_masks,
         waveforms_raw=self.waveforms_raw,
         waveforms_filtered=self.waveforms_filtered,
         # TODO: put this value in the parameters
         cache_fraction=1.,)
Example #3
0
 def init_cache(self):
     """Initialize the cache for the features & masks."""
     self._spikecache = SpikeCache(
         # TODO: handle multiple clusterings in the spike cache here
         spike_clusters=self.clusters.main, 
         features_masks=self.features_masks,
         waveforms_raw=self.waveforms_raw,
         waveforms_filtered=self.waveforms_filtered,
         # TODO: put this value in the parameters
         cache_fraction=1.,)
Example #4
0
def test_spikecache_2():
    nspikes = 100000
    nclusters = 100
    nchannels = 8
    spike_clusters = np.random.randint(size=nspikes, low=0, high=nclusters)
    
    sc = SpikeCache(spike_clusters=spike_clusters,
                    cache_fraction=.1,
                    features_masks=np.zeros((nspikes, 3*nchannels, 2)),
                    waveforms_raw=np.zeros((nspikes, 20, nchannels)),
                    waveforms_filtered=np.zeros((nspikes, 20, nchannels)))
           
    ind, waveforms = sc.load_waveforms(clusters=[10], count=10)
    assert len(ind) == waveforms.shape[0]
    assert len(ind) >= 10
    
    ind, waveforms = sc.load_waveforms(clusters=[10, 20], count=10)
    assert len(ind) == waveforms.shape[0]
    assert len(ind) >= 20
    
    ind, waveforms = sc.load_waveforms(clusters=[1000], count=10)
    assert len(ind) == 0
Example #5
0
class Spikes(Node):
    def __init__(self, files, node=None, root=None):
        super(Spikes, self).__init__(files, node, root=root)

        self.time_samples = self._node.time_samples
        self.time_fractional = self._node.time_fractional
        self.recording = self._node.recording
        self.clusters = Clusters(self._files,
                                 self._node.clusters,
                                 root=self._root)

        # Add concatenated time samples
        self.concatenated_time_samples = self._compute_concatenated_time_samples(
        )

        self.channel_group_id = self._node._v_parent._v_name

        # Get large datasets, that may be in external files.
        # self.features_masks = self._get_child('features_masks')
        # self.waveforms_raw = self._get_child('waveforms_raw')
        # self.waveforms_filtered = self._get_child('waveforms_filtered')

        # Load features masks directly from KWX.
        g = self.channel_group_id
        path = '/channel_groups/{}/features_masks'.format(g)
        if files['kwx']:
            self.features_masks = files['kwx'].getNode(path)
        else:
            self.features_masks = None

        # Load raw data directly from raw data.
        traces = _read_traces(files)

        b = self._root.application_data.spikedetekt._f_getAttr(
            'extract_s_before')
        a = self._root.application_data.spikedetekt._f_getAttr(
            'extract_s_after')

        order = self._root.application_data.spikedetekt._f_getAttr(
            'filter_butter_order')
        rate = self._root.application_data.spikedetekt._f_getAttr(
            'sample_rate')
        low = self._root.application_data.spikedetekt._f_getAttr('filter_low')
        if 'filter_high_factor' in self._root.application_data.spikedetekt._v_attrs:
            high = self._root.application_data.spikedetekt._f_getAttr(
                'filter_high_factor') * rate
        else:
            # NOTE: old format
            high = self._root.application_data.spikedetekt._f_getAttr(
                'filter_high')
        b_filter = bandpass_filter(rate=rate, low=low, high=high, order=order)

        debug("Enable waveform filter.")

        def the_filter(x, axis=0):
            return apply_filter(x, b_filter, axis=axis)

        filter_margin = order * 3

        channels = self._root.channel_groups._f_getChild(
            self.channel_group_id)._f_getAttr('channel_order')
        _waveform_loader = WaveformLoader(
            n_samples=(b + a),
            traces=traces,
            filter=the_filter,
            filter_margin=filter_margin,
            scale_factor=.01,
            channels=channels,
        )
        self.waveforms_raw = SpikeLoader(_waveform_loader,
                                         self.time_samples[:])
        self.waveforms_filtered = self.waveforms_raw

        nspikes = len(self.time_samples)

        if self.waveforms_raw is not None:
            self.nsamples, self.nchannels = self.waveforms_raw.shape[1:]

        if self.features_masks is None:
            self.features_masks = np.zeros((nspikes, 1, 1), dtype=np.float32)

        if len(self.features_masks.shape) == 3:
            self.features = ArrayProxy(self.features_masks, col=0)
            self.masks = ArrayProxy(self.features_masks, col=1)
        elif len(self.features_masks.shape) == 2:
            self.features = self.features_masks
            self.masks = None  #np.ones_like(self.features)
        self.nfeatures = self.features.shape[1]

    def _compute_concatenated_time_samples(self):
        t_rel = self.time_samples[:]
        recordings = self.recording[:]
        if len(recordings) == 0 and len(t_rel) > 0:
            recordings = np.zeros_like(t_rel)
        # Get list of recordings.
        recs = self._root.recordings
        recs = sorted([int(_._v_name) for _ in recs._f_listNodes()])
        # Get their start times.
        if not recs:
            return t_rel
        start_times = np.zeros(max(recs) + 1, dtype=np.uint64)
        for r in recs:
            recgrp = getattr(self._root.recordings, str(r))
            sample_rate = recgrp._f_getAttr('sample_rate')
            start_time = recgrp._f_getAttr('start_time') or 0.
            start_times[r] = int(start_time * sample_rate)
        return t_rel + start_times[recordings]

    def add(self, **kwargs):
        """Add a spike. Only `time_samples` is mandatory."""
        add_spikes(self._files,
                   channel_group_id=self.channel_group_id,
                   **kwargs)

    def init_cache(self):
        """Initialize the cache for the features & masks."""
        self._spikecache = SpikeCache(
            # TODO: handle multiple clusterings in the spike cache here
            spike_clusters=self.clusters.main,
            features_masks=self.features_masks,
            waveforms_raw=self.waveforms_raw,
            waveforms_filtered=self.waveforms_filtered,
            # TODO: put this value in the parameters
            cache_fraction=1.,
        )

    def load_features_masks_bg(self, *args, **kwargs):
        return self._spikecache.load_features_masks_bg(*args, **kwargs)

    def load_features_masks(self, *args, **kwargs):
        return self._spikecache.load_features_masks(*args, **kwargs)

    def load_waveforms(self, *args, **kwargs):
        return self._spikecache.load_waveforms(*args, **kwargs)

    def __getitem__(self, item):
        raise NotImplementedError("""It is not possible to select entire spikes
            yet.""")

    def __len__(self):
        return self.time_samples.shape[0]
Example #6
0
class Spikes(Node):
    def __init__(self, files, node=None, root=None):
        super(Spikes, self).__init__(files, node, root=root)
        
        self.time_samples = self._node.time_samples
        self.time_fractional = self._node.time_fractional
        self.recording = self._node.recording
        self.clusters = Clusters(self._files, self._node.clusters, root=self._root)
        
        # Add concatenated time samples
        self.concatenated_time_samples = self._compute_concatenated_time_samples()
        
        self.channel_group_id = self._node._v_parent._v_name
        
        # Get large datasets, that may be in external files.
        self.features_masks = self._get_child('features_masks')
        self.waveforms_raw = self._get_child('waveforms_raw')
        self.waveforms_filtered = self._get_child('waveforms_filtered')
        
        nspikes = len(self.time_samples)
        
        if self.waveforms_raw is not None:
            self.nsamples, self.nchannels = self.waveforms_raw.shape[1:]
        
        if self.features_masks is None:
            self.features_masks = np.zeros((nspikes, 1, 1), dtype=np.float32)
            
        if len(self.features_masks.shape) == 3:
            self.features = ArrayProxy(self.features_masks, col=0)
            self.masks = ArrayProxy(self.features_masks, col=1)
        elif len(self.features_masks.shape) == 2:
            self.features = self.features_masks
            self.masks = None  #np.ones_like(self.features)
        self.nfeatures = self.features.shape[1]
       
    def _compute_concatenated_time_samples(self):
        t_rel = self.time_samples[:]
        recordings = self.recording[:]
        if len(recordings) == 0 and len(t_rel) > 0:
            recordings = np.zeros_like(t_rel)
        # Get list of recordings.
        recs = self._root.recordings
        recs = sorted([int(_._v_name) for _ in recs._f_listNodes()])
        # Get their start times.
        if not recs:
            return t_rel
        start_times = np.zeros(max(recs)+1, dtype=np.uint64)
        for r in recs:
            recgrp = getattr(self._root.recordings, str(r))
            sample_rate = recgrp._f_getAttr('sample_rate')
            start_time = recgrp._f_getAttr('start_time') or 0.
            start_times[r] = int(start_time * sample_rate)
        return t_rel + start_times[recordings]
       
    def add(self, **kwargs):
        """Add a spike. Only `time_samples` is mandatory."""
        add_spikes(self._files, channel_group_id=self.channel_group_id, **kwargs)
    
    def init_cache(self):
        """Initialize the cache for the features & masks."""
        self._spikecache = SpikeCache(
            # TODO: handle multiple clusterings in the spike cache here
            spike_clusters=self.clusters.main, 
            features_masks=self.features_masks,
            waveforms_raw=self.waveforms_raw,
            waveforms_filtered=self.waveforms_filtered,
            # TODO: put this value in the parameters
            cache_fraction=1.,)
    
    def load_features_masks_bg(self, *args, **kwargs):
        return self._spikecache.load_features_masks_bg(*args, **kwargs)
    
    def load_features_masks(self, *args, **kwargs):
        return self._spikecache.load_features_masks(*args, **kwargs)
    
    def load_waveforms(self, *args, **kwargs):
        return self._spikecache.load_waveforms(*args, **kwargs)
    
    def __getitem__(self, item):
        raise NotImplementedError("""It is not possible to select entire spikes 
            yet.""")
            
    def __len__(self):
        return self.time_samples.shape[0]
Example #7
0
class Spikes(Node):
    def __init__(self, files, node=None, root=None):
        super(Spikes, self).__init__(files, node, root=root)

        self.time_samples = self._node.time_samples
        self.time_fractional = self._node.time_fractional
        self.recording = self._node.recording
        self.clusters = Clusters(self._files, self._node.clusters, root=self._root)

        # Add concatenated time samples
        self.concatenated_time_samples = self._compute_concatenated_time_samples()

        self.channel_group_id = self._node._v_parent._v_name

        # Get large datasets, that may be in external files.
        # self.features_masks = self._get_child('features_masks')
        # self.waveforms_raw = self._get_child('waveforms_raw')
        # self.waveforms_filtered = self._get_child('waveforms_filtered')

        # Load features masks directly from KWX.
        g = self.channel_group_id
        path = '/channel_groups/{}/features_masks'.format(g)
        if files['kwx']:
            self.features_masks = files['kwx'].getNode(path)
        else:
            self.features_masks = None

        # Load raw data directly from raw data.
        traces = _read_traces(files)

        b = self._root.application_data.spikedetekt._f_getAttr('extract_s_before')
        a = self._root.application_data.spikedetekt._f_getAttr('extract_s_after')

        order = self._root.application_data.spikedetekt._f_getAttr('filter_butter_order')
        rate = self._root.application_data.spikedetekt._f_getAttr('sample_rate')
        low = self._root.application_data.spikedetekt._f_getAttr('filter_low')
        if 'filter_high_factor' in self._root.application_data.spikedetekt._v_attrs:
            high = self._root.application_data.spikedetekt._f_getAttr('filter_high_factor') * rate
        else:
            # NOTE: old format
            high = self._root.application_data.spikedetekt._f_getAttr('filter_high')
        b_filter = bandpass_filter(rate=rate,
                                   low=low,
                                   high=high,
                                   order=order)

        debug("Enable waveform filter.")

        def the_filter(x, axis=0):
            return apply_filter(x, b_filter, axis=axis)

        filter_margin = order * 3

        channels = self._root.channel_groups._f_getChild(self.channel_group_id)._f_getAttr('channel_order')
        _waveform_loader = WaveformLoader(n_samples=(b, a),
                                          traces=traces,
                                          filter=the_filter,
                                          filter_margin=filter_margin,
                                          scale_factor=.01,
                                          channels=channels,
                                          )
        self.waveforms_raw = SpikeLoader(_waveform_loader,
                                         self.concatenated_time_samples)
        self.waveforms_filtered = self.waveforms_raw

        nspikes = len(self.time_samples)

        if self.waveforms_raw is not None:
            self.nsamples, self.nchannels = self.waveforms_raw.shape[1:]

        if self.features_masks is None:
            self.features_masks = np.zeros((nspikes, 1, 1), dtype=np.float32)

        if len(self.features_masks.shape) == 3:
            self.features = ArrayProxy(self.features_masks, col=0)
            self.masks = ArrayProxy(self.features_masks, col=1)
        elif len(self.features_masks.shape) == 2:
            self.features = self.features_masks
            self.masks = None  #np.ones_like(self.features)
        self.nfeatures = self.features.shape[1]

    def _compute_concatenated_time_samples(self):
        t_rel = self.time_samples[:]
        recordings = self.recording[:]
        if len(recordings) == 0 and len(t_rel) > 0:
            recordings = np.zeros_like(t_rel)
        # Get list of recordings.
        recs = self._root.recordings
        recs = sorted([int(_._v_name) for _ in recs._f_listNodes()])
        # Get their start times.
        if not recs:
            return t_rel
        start_times = np.zeros(max(recs)+1, dtype=np.uint64)
        for r in recs:
            recgrp = getattr(self._root.recordings, str(r))
            sample_rate = recgrp._f_getAttr('sample_rate')
            start_time = recgrp._f_getAttr('start_time') or 0.
            start_times[r] = int(start_time * sample_rate)
        return t_rel + start_times[recordings]

    def add(self, **kwargs):
        """Add a spike. Only `time_samples` is mandatory."""
        add_spikes(self._files, channel_group_id=self.channel_group_id, **kwargs)

    def init_cache(self):
        """Initialize the cache for the features & masks."""
        self._spikecache = SpikeCache(
            # TODO: handle multiple clusterings in the spike cache here
            spike_clusters=self.clusters.main,
            features_masks=self.features_masks,
            waveforms_raw=self.waveforms_raw,
            waveforms_filtered=self.waveforms_filtered,
            # TODO: put this value in the parameters
            cache_fraction=1.,)

    def load_features_masks_bg(self, *args, **kwargs):
        return self._spikecache.load_features_masks_bg(*args, **kwargs)

    def load_features_masks(self, *args, **kwargs):
        return self._spikecache.load_features_masks(*args, **kwargs)

    def load_waveforms(self, *args, **kwargs):
        return self._spikecache.load_waveforms(*args, **kwargs)

    def __getitem__(self, item):
        raise NotImplementedError("""It is not possible to select entire spikes
            yet.""")

    def __len__(self):
        return self.time_samples.shape[0]
Example #8
0
class Spikes(Node):
    def __init__(self, files, node=None, root=None):
        super(Spikes, self).__init__(files, node, root=root)

        self.time_samples = self._node.time_samples
        self.time_fractional = self._node.time_fractional
        self.recording = self._node.recording
        self.clusters = Clusters(self._files, self._node.clusters, root=self._root)

        # Add concatenated time samples
        self.concatenated_time_samples = self._compute_concatenated_time_samples()

        self.channel_group_id = self._node._v_parent._v_name

        # Get large datasets, that may be in external files.
        # self.features_masks = self._get_child('features_masks')
        # self.waveforms_raw = self._get_child('waveforms_raw')
        # self.waveforms_filtered = self._get_child('waveforms_filtered')

        # Load features masks directly from KWX.
        g = self.channel_group_id
        path = '/channel_groups/{}/features_masks'.format(g)
        self.features_masks = files['kwx'].getNode(path)

        # Load raw data directly from raw data.
        # path = '/recordings/{}/raw/dat_path'.format(0)
        # traces_path = files['kwik'].getNode(path)

        # TODO: include here
        # from phy.traces.waveform import WaveformLoader, SpikeLoader

        # b = self._root.application_data.spikedetekt.extract_s_before
        # a = self._root.application_data.spikedetekt.extract_s_after
        # _waveform_loader = WaveformLoader(n_samples=(b + a),
        #                                   # traces=traces,
        #                                   # filter=the_filter,
        #                                   # filter_margin=filter_margin,
        #                                   # dc_offset=dc_offset,
        #                                   # scale_factor=scale_factor,
        #                                   )
        # self.waveforms_raw = SpikeLoader(_waveform_loader, self.time_samples)
        # TODO
        # self.waveforms_filtered = None



        nspikes = len(self.time_samples)

        if self.waveforms_raw is not None:
            self.nsamples, self.nchannels = self.waveforms_raw.shape[1:]

        if self.features_masks is None:
            self.features_masks = np.zeros((nspikes, 1, 1), dtype=np.float32)

        if len(self.features_masks.shape) == 3:
            self.features = ArrayProxy(self.features_masks, col=0)
            self.masks = ArrayProxy(self.features_masks, col=1)
        elif len(self.features_masks.shape) == 2:
            self.features = self.features_masks
            self.masks = None  #np.ones_like(self.features)
        self.nfeatures = self.features.shape[1]

    def _compute_concatenated_time_samples(self):
        t_rel = self.time_samples[:]
        recordings = self.recording[:]
        if len(recordings) == 0 and len(t_rel) > 0:
            recordings = np.zeros_like(t_rel)
        # Get list of recordings.
        recs = self._root.recordings
        recs = sorted([int(_._v_name) for _ in recs._f_listNodes()])
        # Get their start times.
        if not recs:
            return t_rel
        start_times = np.zeros(max(recs)+1, dtype=np.uint64)
        for r in recs:
            recgrp = getattr(self._root.recordings, str(r))
            sample_rate = recgrp._f_getAttr('sample_rate')
            start_time = recgrp._f_getAttr('start_time') or 0.
            start_times[r] = int(start_time * sample_rate)
        return t_rel + start_times[recordings]

    def add(self, **kwargs):
        """Add a spike. Only `time_samples` is mandatory."""
        add_spikes(self._files, channel_group_id=self.channel_group_id, **kwargs)

    def init_cache(self):
        """Initialize the cache for the features & masks."""
        self._spikecache = SpikeCache(
            # TODO: handle multiple clusterings in the spike cache here
            spike_clusters=self.clusters.main,
            features_masks=self.features_masks,
            waveforms_raw=self.waveforms_raw,
            waveforms_filtered=self.waveforms_filtered,
            # TODO: put this value in the parameters
            cache_fraction=1.,)

    def load_features_masks_bg(self, *args, **kwargs):
        return self._spikecache.load_features_masks_bg(*args, **kwargs)

    def load_features_masks(self, *args, **kwargs):
        return self._spikecache.load_features_masks(*args, **kwargs)

    def load_waveforms(self, *args, **kwargs):
        return self._spikecache.load_waveforms(*args, **kwargs)

    def __getitem__(self, item):
        raise NotImplementedError("""It is not possible to select entire spikes
            yet.""")

    def __len__(self):
        return self.time_samples.shape[0]
Example #9
0
class Spikes(Node):
    def __init__(self, files, node=None, root=None):
        super(Spikes, self).__init__(files, node, root=root)

        self.time_samples = self._node.time_samples
        self.time_fractional = self._node.time_fractional
        self.recording = self._node.recording
        self.clusters = Clusters(self._files,
                                 self._node.clusters,
                                 root=self._root)

        # Add concatenated time samples
        self.concatenated_time_samples = self._compute_concatenated_time_samples(
        )

        self.channel_group_id = self._node._v_parent._v_name

        # Get large datasets, that may be in external files.
        self.features_masks = self._get_child('features_masks')
        self.waveforms_raw = self._get_child('waveforms_raw')
        self.waveforms_filtered = self._get_child('waveforms_filtered')

        nspikes = len(self.time_samples)

        if self.waveforms_raw is not None:
            self.nsamples, self.nchannels = self.waveforms_raw.shape[1:]

        if self.features_masks is None:
            self.features_masks = np.zeros((nspikes, 1, 1), dtype=np.float32)

        if len(self.features_masks.shape) == 3:
            self.features = ArrayProxy(self.features_masks, col=0)
            self.masks = ArrayProxy(self.features_masks, col=1)
        elif len(self.features_masks.shape) == 2:
            self.features = self.features_masks
            self.masks = None  #np.ones_like(self.features)
        self.nfeatures = self.features.shape[1]

    def _compute_concatenated_time_samples(self):
        t_rel = self.time_samples[:]
        recordings = self.recording[:]
        if len(recordings) == 0 and len(t_rel) > 0:
            recordings = np.zeros_like(t_rel)
        # Get list of recordings.
        recs = self._root.recordings
        recs = sorted([int(_._v_name) for _ in recs._f_listNodes()])
        # Get their start times.
        if not recs:
            return t_rel
        start_times = np.zeros(max(recs) + 1, dtype=np.uint64)
        for r in recs:
            recgrp = getattr(self._root.recordings, str(r))
            sample_rate = recgrp._f_getAttr('sample_rate')
            start_time = recgrp._f_getAttr('start_time') or 0.
            start_times[r] = int(start_time * sample_rate)
        return t_rel + start_times[recordings]

    def add(self, **kwargs):
        """Add a spike. Only `time_samples` is mandatory."""
        add_spikes(self._files,
                   channel_group_id=self.channel_group_id,
                   **kwargs)

    def init_cache(self):
        """Initialize the cache for the features & masks."""
        self._spikecache = SpikeCache(
            # TODO: handle multiple clusterings in the spike cache here
            spike_clusters=self.clusters.main,
            features_masks=self.features_masks,
            waveforms_raw=self.waveforms_raw,
            waveforms_filtered=self.waveforms_filtered,
            # TODO: put this value in the parameters
            cache_fraction=1.,
        )

    def load_features_masks_bg(self, *args, **kwargs):
        return self._spikecache.load_features_masks_bg(*args, **kwargs)

    def load_features_masks(self, *args, **kwargs):
        return self._spikecache.load_features_masks(*args, **kwargs)

    def load_waveforms(self, *args, **kwargs):
        return self._spikecache.load_waveforms(*args, **kwargs)

    def __getitem__(self, item):
        raise NotImplementedError(
            """It is not possible to select entire spikes 
            yet.""")

    def __len__(self):
        return self.time_samples.shape[0]
Example #10
0
class Spikes(Node):
    def __init__(self, files, node=None, root=None):
        super(Spikes, self).__init__(files, node, root=root)

        self.time_samples = self._node.time_samples
        self.time_fractional = self._node.time_fractional
        self.recording = self._node.recording
        self.clusters = Clusters(self._files,
                                 self._node.clusters,
                                 root=self._root)

        # Add concatenated time samples
        self.concatenated_time_samples = self._compute_concatenated_time_samples(
        )

        self.channel_group_id = self._node._v_parent._v_name

        # Get large datasets, that may be in external files.
        # self.features_masks = self._get_child('features_masks')
        # self.waveforms_raw = self._get_child('waveforms_raw')
        # self.waveforms_filtered = self._get_child('waveforms_filtered')

        # Load features masks directly from KWX.
        g = self.channel_group_id
        path = '/channel_groups/{}/features_masks'.format(g)
        self.features_masks = files['kwx'].getNode(path)

        # Load raw data directly from raw data.
        # path = '/recordings/{}/raw/dat_path'.format(0)
        # traces_path = files['kwik'].getNode(path)

        # TODO: include here
        # from phy.traces.waveform import WaveformLoader, SpikeLoader

        # b = self._root.application_data.spikedetekt.extract_s_before
        # a = self._root.application_data.spikedetekt.extract_s_after
        # _waveform_loader = WaveformLoader(n_samples=(b + a),
        #                                   # traces=traces,
        #                                   # filter=the_filter,
        #                                   # filter_margin=filter_margin,
        #                                   # dc_offset=dc_offset,
        #                                   # scale_factor=scale_factor,
        #                                   )
        # self.waveforms_raw = SpikeLoader(_waveform_loader, self.time_samples)
        # TODO
        # self.waveforms_filtered = None

        nspikes = len(self.time_samples)

        if self.waveforms_raw is not None:
            self.nsamples, self.nchannels = self.waveforms_raw.shape[1:]

        if self.features_masks is None:
            self.features_masks = np.zeros((nspikes, 1, 1), dtype=np.float32)

        if len(self.features_masks.shape) == 3:
            self.features = ArrayProxy(self.features_masks, col=0)
            self.masks = ArrayProxy(self.features_masks, col=1)
        elif len(self.features_masks.shape) == 2:
            self.features = self.features_masks
            self.masks = None  #np.ones_like(self.features)
        self.nfeatures = self.features.shape[1]

    def _compute_concatenated_time_samples(self):
        t_rel = self.time_samples[:]
        recordings = self.recording[:]
        if len(recordings) == 0 and len(t_rel) > 0:
            recordings = np.zeros_like(t_rel)
        # Get list of recordings.
        recs = self._root.recordings
        recs = sorted([int(_._v_name) for _ in recs._f_listNodes()])
        # Get their start times.
        if not recs:
            return t_rel
        start_times = np.zeros(max(recs) + 1, dtype=np.uint64)
        for r in recs:
            recgrp = getattr(self._root.recordings, str(r))
            sample_rate = recgrp._f_getAttr('sample_rate')
            start_time = recgrp._f_getAttr('start_time') or 0.
            start_times[r] = int(start_time * sample_rate)
        return t_rel + start_times[recordings]

    def add(self, **kwargs):
        """Add a spike. Only `time_samples` is mandatory."""
        add_spikes(self._files,
                   channel_group_id=self.channel_group_id,
                   **kwargs)

    def init_cache(self):
        """Initialize the cache for the features & masks."""
        self._spikecache = SpikeCache(
            # TODO: handle multiple clusterings in the spike cache here
            spike_clusters=self.clusters.main,
            features_masks=self.features_masks,
            waveforms_raw=self.waveforms_raw,
            waveforms_filtered=self.waveforms_filtered,
            # TODO: put this value in the parameters
            cache_fraction=1.,
        )

    def load_features_masks_bg(self, *args, **kwargs):
        return self._spikecache.load_features_masks_bg(*args, **kwargs)

    def load_features_masks(self, *args, **kwargs):
        return self._spikecache.load_features_masks(*args, **kwargs)

    def load_waveforms(self, *args, **kwargs):
        return self._spikecache.load_waveforms(*args, **kwargs)

    def __getitem__(self, item):
        raise NotImplementedError("""It is not possible to select entire spikes
            yet.""")

    def __len__(self):
        return self.time_samples.shape[0]