def waveform_loader(model, filter_wave): """Create a waveform loader.""" n_samples = (model._metadata['extract_s_before'], model._metadata['extract_s_after']) order = model._metadata['filter_butter_order'] rate = model._metadata['sample_rate'] low = model._metadata['filter_low'] high = model._metadata['filter_high_factor'] * rate b_filter = bandpass_filter(rate=rate, low=low, high=high, order=order) if (filter_wave == True): def filter(x): return apply_filter(x, b_filter) filter_margin = order * 3 else: filter = None filter_margin = 0 dc_offset = model._metadata.get('waveform_dc_offset', None) scale_factor = model._metadata.get('waveform_scale_factor', None) return WaveformLoader(n_samples=n_samples, filter=filter, filter_margin=filter_margin, dc_offset=dc_offset, scale_factor=scale_factor, )
def __init__( self, traces=None, sample_rate=None, spike_samples=None, masks=None, mask_threshold=None, filter_order=None, n_samples_waveforms=None, ): # Traces. if traces is not None: self.traces = traces self.n_samples_trace, self.n_channels = traces.shape else: self._traces = None self.n_samples_trace = self.n_channels = 0 assert spike_samples is not None self._spike_samples = spike_samples self.n_spikes = len(spike_samples) self._masks = masks if masks is not None: assert self._masks.shape == (self.n_spikes, self.n_channels) self._mask_threshold = mask_threshold # Define filter. if filter_order: filter_margin = filter_order * 3 b_filter = bandpass_filter( rate=sample_rate, low=500., high=sample_rate * .475, order=filter_order, ) self._filter = lambda x, axis=0: apply_filter( x, b_filter, axis=axis) else: filter_margin = 0 self._filter = lambda x, axis=0: x # Number of samples to return, can be an int or a # tuple (before, after). assert n_samples_waveforms is not None self.n_samples_before_after = _before_after(n_samples_waveforms) self.n_samples_waveforms = sum(self.n_samples_before_after) # Number of additional samples to use for filtering. self._filter_margin = _before_after(filter_margin) # Number of samples in the extracted raw data chunk. self._n_samples_extract = (self.n_samples_waveforms + sum(self._filter_margin)) self.dtype = np.float32 self.shape = (self.n_spikes, self._n_samples_extract, self.n_channels) self.ndim = 3
def __init__(self, traces=None, sample_rate=None, spike_samples=None, filter_order=None, n_samples_waveforms=None, ): # Traces. if traces is not None: self.traces = traces self.n_samples_trace, self.n_channels = traces.shape else: self._traces = None self.n_samples_trace = self.n_channels = 0 assert spike_samples is not None self._spike_samples = spike_samples self.n_spikes = len(spike_samples) # Define filter. if filter_order: filter_margin = filter_order * 3 b_filter = bandpass_filter(rate=sample_rate, low=500., high=sample_rate * .475, order=filter_order, ) self._filter = lambda x, axis=0: apply_filter(x, b_filter, axis=axis) else: filter_margin = 0 self._filter = lambda x, axis=0: x # Number of samples to return, can be an int or a # tuple (before, after). assert n_samples_waveforms is not None self.n_samples_before_after = _before_after(n_samples_waveforms) self.n_samples_waveforms = sum(self.n_samples_before_after) # Number of additional samples to use for filtering. self._filter_margin = _before_after(filter_margin) # Number of samples in the extracted raw data chunk. self._n_samples_extract = (self.n_samples_waveforms + sum(self._filter_margin)) self.dtype = np.float32 self.shape = (self.n_spikes, self._n_samples_extract, self.n_channels) self.ndim = 3
def _init_data(self): if op.exists(self.dat_path): logger.debug("Loading traces at `%s`.", self.dat_path) traces = _dat_to_traces(self.dat_path, n_channels=self.n_channels_dat, dtype=self.dtype or np.int16, offset=self.offset, ) n_samples_t, _ = traces.shape assert _ == self.n_channels_dat else: traces = None n_samples_t = 0 logger.debug("Loading amplitudes.") amplitudes = read_array('amplitudes').squeeze() n_spikes, = amplitudes.shape self.n_spikes = n_spikes # Create spike_clusters if the file doesn't exist. if not op.exists(filenames['spike_clusters']): shutil.copy(filenames['spike_templates'], filenames['spike_clusters']) logger.debug("Loading %d spike clusters.", self.n_spikes) spike_clusters = read_array('spike_clusters').squeeze() spike_clusters = spike_clusters.astype(np.int32) assert spike_clusters.shape == (n_spikes,) self.spike_clusters = spike_clusters logger.debug("Loading spike templates.") spike_templates = read_array('spike_templates').squeeze() spike_templates = spike_templates.astype(np.int32) assert spike_templates.shape == (n_spikes,) self.spike_templates = spike_templates logger.debug("Loading spike samples.") spike_samples = read_array('spike_samples').squeeze() assert spike_samples.shape == (n_spikes,) logger.debug("Loading templates.") templates = read_array('templates') templates[np.isnan(templates)] = 0 # templates = np.transpose(templates, (2, 1, 0)) # Unwhiten the templates. logger.debug("Loading the whitening matrix.") self.whitening_matrix = read_array('whitening_matrix') if op.exists(filenames['templates_unw']): logger.debug("Loading unwhitened templates.") templates_unw = read_array('templates_unw') templates_unw[np.isnan(templates_unw)] = 0 else: logger.debug("Couldn't find unwhitened templates, computing them.") logger.debug("Inversing the whitening matrix %s.", self.whitening_matrix.shape) wmi = np.linalg.inv(self.whitening_matrix) logger.debug("Unwhitening the templates %s.", templates.shape) templates_unw = np.dot(np.ascontiguousarray(templates), np.ascontiguousarray(wmi)) # Save the unwhitened templates. write_array('templates_unw.npy', templates_unw) n_templates, n_samples_templates, n_channels = templates.shape self.n_templates = n_templates logger.debug("Loading similar templates.") self.similar_templates = read_array('similar_templates') assert self.similar_templates.shape == (self.n_templates, self.n_templates) logger.debug("Loading channel mapping.") channel_mapping = read_array('channel_mapping').squeeze() channel_mapping = channel_mapping.astype(np.int32) assert channel_mapping.shape == (n_channels,) # Ensure that the mappings maps to valid columns in the dat file. assert np.all(channel_mapping <= self.n_channels_dat - 1) logger.debug("Loading channel positions.") channel_positions = read_array('channel_positions') assert channel_positions.shape == (n_channels, 2) if op.exists(filenames['features']): logger.debug("Loading features.") all_features = np.load(filenames['features'], mmap_mode='r') features_ind = read_array('features_ind').astype(np.int32) # Feature subset. if op.exists(filenames['features_spike_ids']): features_spike_ids = read_array('features_spike_ids') \ .astype(np.int32) assert len(features_spike_ids) == len(all_features) self.features_spike_ids = features_spike_ids ns = len(features_spike_ids) else: ns = self.n_spikes self.features_spike_ids = None assert all_features.ndim == 3 n_loc_chan = all_features.shape[2] self.n_features_per_channel = all_features.shape[1] assert all_features.shape == (ns, self.n_features_per_channel, n_loc_chan, ) # Check sparse features arrays shapes. assert features_ind.shape == (self.n_templates, n_loc_chan) else: all_features = None features_ind = None self.all_features = all_features self.features_ind = features_ind if op.exists(filenames['template_features']): logger.debug("Loading template features.") template_features = np.load(filenames['template_features'], mmap_mode='r') template_features_ind = read_array('template_features_ind'). \ astype(np.int32) template_features_ind = template_features_ind.copy() n_sim_tem = template_features.shape[1] assert template_features.shape == (n_spikes, n_sim_tem) assert template_features_ind.shape == (n_templates, n_sim_tem) else: template_features = None template_features_ind = None self.template_features_ind = template_features_ind self.template_features = template_features self.n_channels = n_channels # Take dead channels into account. if traces is not None: # Find the scaling factor for the traces. scaling = 1. / self._data_lim(traces[:10000]) traces = _concatenate_virtual_arrays([traces], channel_mapping, scaling=scaling, ) else: scaling = 1. # Amplitudes self.all_amplitudes = amplitudes self.amplitudes_lim = self.all_amplitudes.max() # Templates self.templates = templates self.templates_unw = templates_unw assert self.templates.shape == self.templates_unw.shape self.n_samples_templates = n_samples_templates self.n_samples_waveforms = n_samples_templates self.template_lim = np.max(np.abs(self.templates)) self.duration = n_samples_t / float(self.sample_rate) self.spike_times = spike_samples / float(self.sample_rate) assert np.all(np.diff(self.spike_times) >= 0) self.cluster_ids = np.unique(self.spike_clusters) # n_clusters = len(self.cluster_ids) self.channel_positions = channel_positions self.all_traces = traces # Filter the waveforms. order = 3 filter_margin = order * 3 b_filter = bandpass_filter(rate=self.sample_rate, low=500., high=self.sample_rate * .475, order=order) # Only filter the data for the waveforms if the traces # are not already filtered. if not getattr(self, 'hp_filtered', False): logger.debug("HP filtering the data for waveforms") def the_filter(x, axis=0): return apply_filter(x, b_filter, axis=axis) else: the_filter = None # Fetch waveforms from traces. nsw = self.n_samples_waveforms if traces is not None: waveforms = WaveformLoader(traces=traces, n_samples_waveforms=nsw, filter=the_filter, filter_margin=filter_margin, ) waveforms = SpikeLoader(waveforms, spike_samples) else: waveforms = None self.all_waveforms = waveforms self.template_masks = get_masks(self.templates) self.all_masks = MaskLoader(self.template_masks, self.spike_templates) # Read the cluster groups. logger.debug("Loading the cluster groups.") self.cluster_groups = {} if op.exists(filenames['cluster_groups']): with open(filenames['cluster_groups'], 'r') as f: reader = csv.reader(f, delimiter='\t') # Skip the header. for row in reader: break for row in reader: cluster, group = row cluster = int(cluster) self.cluster_groups[cluster] = group for cluster_id in self.cluster_ids: if cluster_id not in self.cluster_groups: self.cluster_groups[cluster_id] = None
def _init_data(self): if op.exists(self.dat_path): logger.debug("Loading traces at `%s`.", self.dat_path) traces = _dat_to_traces( self.dat_path, n_channels=self.n_channels_dat, dtype=self.dtype or np.int16, offset=self.offset, ) n_samples_t, _ = traces.shape assert _ == self.n_channels_dat else: traces = None n_samples_t = 0 logger.debug("Loading amplitudes.") amplitudes = read_array('amplitudes').squeeze() n_spikes, = amplitudes.shape self.n_spikes = n_spikes # Create spike_clusters if the file doesn't exist. if not op.exists(filenames['spike_clusters']): shutil.copy(filenames['spike_templates'], filenames['spike_clusters']) logger.debug("Loading spike clusters.") spike_clusters = read_array('spike_clusters').squeeze() spike_clusters = spike_clusters.astype(np.int32) assert spike_clusters.shape == (n_spikes, ) self.spike_clusters = spike_clusters logger.debug("Loading spike templates.") spike_templates = read_array('spike_templates').squeeze() spike_templates = spike_templates.astype(np.int32) assert spike_templates.shape == (n_spikes, ) self.spike_templates = spike_templates logger.debug("Loading spike samples.") spike_samples = read_array('spike_samples').squeeze() assert spike_samples.shape == (n_spikes, ) logger.debug("Loading templates.") templates = read_array('templates') templates[np.isnan(templates)] = 0 # templates = np.transpose(templates, (2, 1, 0)) # Unwhiten the templates. logger.debug("Loading the whitening matrix.") self.whitening_matrix = read_array('whitening_matrix') if op.exists(filenames['templates_unw']): logger.debug("Loading unwhitened templates.") templates_unw = read_array('templates_unw') templates_unw[np.isnan(templates_unw)] = 0 else: logger.debug("Couldn't find unwhitened templates, computing them.") logger.debug("Inversing the whitening matrix %s.", self.whitening_matrix.shape) wmi = np.linalg.inv(self.whitening_matrix) logger.debug("Unwhitening the templates %s.", templates.shape) templates_unw = np.dot(templates, wmi) n_templates, n_samples_templates, n_channels = templates.shape self.n_templates = n_templates logger.debug("Loading similar templates.") self.similar_templates = read_array('similar_templates') assert self.similar_templates.shape == (self.n_templates, self.n_templates) logger.debug("Loading channel mapping.") channel_mapping = read_array('channel_mapping').squeeze() channel_mapping = channel_mapping.astype(np.int32) assert channel_mapping.shape == (n_channels, ) # Ensure that the mappings maps to valid columns in the dat file. assert np.all(channel_mapping <= self.n_channels_dat - 1) logger.debug("Loading channel positions.") channel_positions = read_array('channel_positions') assert channel_positions.shape == (n_channels, 2) if op.exists(filenames['features']): logger.debug("Loading features.") all_features = np.load(filenames['features'], mmap_mode='r') features_ind = read_array('features_ind').astype(np.int32) # Feature subset. if op.exists(filenames['features_spike_ids']): features_spike_ids = read_array('features_spike_ids') \ .astype(np.int32) assert len(features_spike_ids) == len(all_features) self.features_spike_ids = features_spike_ids ns = len(features_spike_ids) else: ns = self.n_spikes self.features_spike_ids = None assert all_features.ndim == 3 n_loc_chan = all_features.shape[2] self.n_features_per_channel = all_features.shape[1] assert all_features.shape == ( ns, self.n_features_per_channel, n_loc_chan, ) # Check sparse features arrays shapes. assert features_ind.shape == (self.n_templates, n_loc_chan) else: all_features = None features_ind = None self.all_features = all_features self.features_ind = features_ind if op.exists(filenames['template_features']): logger.debug("Loading template features.") template_features = np.load(filenames['template_features'], mmap_mode='r') template_features_ind = read_array('template_features_ind'). \ astype(np.int32) template_features_ind = template_features_ind.copy() n_sim_tem = template_features.shape[1] assert template_features.shape == (n_spikes, n_sim_tem) assert template_features_ind.shape == (n_templates, n_sim_tem) else: template_features = None template_features_ind = None self.template_features_ind = template_features_ind self.template_features = template_features self.n_channels = n_channels # Take dead channels into account. if traces is not None: # Find the scaling factor for the traces. scaling = 1. / self._data_lim(traces[:10000]) traces = _concatenate_virtual_arrays( [traces], channel_mapping, scaling=scaling, ) else: scaling = 1. # Amplitudes self.all_amplitudes = amplitudes self.amplitudes_lim = self.all_amplitudes.max() # Templates self.templates = templates self.templates_unw = templates_unw assert self.templates.shape == self.templates_unw.shape self.n_samples_templates = n_samples_templates self.n_samples_waveforms = n_samples_templates self.template_lim = np.max(np.abs(self.templates)) self.duration = n_samples_t / float(self.sample_rate) self.spike_times = spike_samples / float(self.sample_rate) assert np.all(np.diff(self.spike_times) >= 0) self.cluster_ids = np.unique(self.spike_clusters) # n_clusters = len(self.cluster_ids) self.channel_positions = channel_positions self.all_traces = traces # Filter the waveforms. order = 3 filter_margin = order * 3 b_filter = bandpass_filter(rate=self.sample_rate, low=500., high=self.sample_rate * .475, order=order) # Only filter the data for the waveforms if the traces # are not already filtered. if not getattr(self, 'hp_filtered', False): logger.debug("HP filtering the data for waveforms") def the_filter(x, axis=0): return apply_filter(x, b_filter, axis=axis) else: the_filter = None # Fetch waveforms from traces. nsw = self.n_samples_waveforms if traces is not None: waveforms = WaveformLoader( traces=traces, n_samples_waveforms=nsw, filter=the_filter, filter_margin=filter_margin, ) waveforms = SpikeLoader(waveforms, spike_samples) else: waveforms = None self.all_waveforms = waveforms self.template_masks = get_masks(self.templates) self.all_masks = MaskLoader(self.template_masks, self.spike_templates) # Read the cluster groups. logger.debug("Loading the cluster groups.") self.cluster_groups = {} if op.exists(filenames['cluster_groups']): with open(filenames['cluster_groups'], 'r') as f: reader = csv.reader(f, delimiter='\t') # Skip the header. for row in reader: break for row in reader: cluster, group = row cluster = int(cluster) self.cluster_groups[cluster] = group for cluster_id in self.cluster_ids: if cluster_id not in self.cluster_groups: self.cluster_groups[cluster_id] = None
def __init__(self, basename, filename, N_t, n_channels=None, n_total_channels=None, prb_file=None, dtype=None, sample_rate=None, offset=0, gain=0.01 ): self.n_features_per_channel = 3 self.n_total_channels = n_total_channels self.extract_s_after = self.extract_s_before = extract_s_before = extract_s_after = int(N_t - 1)//2 # set to True if your data is already pre-filtered (much quicker) filtered_datfile = False # Filtering parameters for PCA (these are ignored if filtered_datfile == True) filter_low = 500. filter_high = 0.95 * .5 * sample_rate filter_butter_order = 3 self.basename = basename self.kwik_path = basename + '.kwik' self.dtype = dtype self.prb_file = prb_file self.probe = load_probe(prb_file) self.sample_rate = sample_rate self.filtered_datfile = filtered_datfile self._sd = SpikeDetekt(probe=self.probe, n_features_per_channel= self.n_features_per_channel, pca_n_waveforms_max=10000, extract_s_before=extract_s_before, extract_s_after=extract_s_after, sample_rate=sample_rate, ) self.n_samples_w = extract_s_before + extract_s_after # A xxx.filtered.trunc file may be created if needed. self.file, self.traces_f = _read_filtered(filename, n_channels=n_total_channels, dtype=dtype, ) self.n_samples, self.n_total_channels = self.traces_f.shape self.n_channels = n_channels assert n_total_channels == self.n_total_channels info("Loaded traces: {}.".format(self.traces_f.shape)) # Load spikes. self.spike_samples, self.spike_clusters = _read_spikes(basename) self.n_spikes = len(self.spike_samples) assert len(self.spike_clusters) == self.n_spikes info("Loaded {} spikes.".format(self.n_spikes)) # Chunks when computing features. self.chunk_size = 2500 self.n_chunks = int(np.ceil(self.n_spikes / self.chunk_size)) # Load templates and masks. self.templates, self.template_masks = _read_templates(basename, self.probe, self.n_total_channels, self.n_channels) self.n_templates = len(self.templates) info("Loaded templates: {}.".format(self.templates.shape)) # Load amplitudes. self.amplitudes = _read_amplitudes(basename, self.n_templates, self.n_spikes, self.spike_clusters) if extract_features: # The WaveformLoader fetches and filters waveforms from the raw traces dynamically. n_samples = (extract_s_before, extract_s_after) b_filter = bandpass_filter(rate=self.sample_rate, low=filter_low, high=filter_high, order=filter_butter_order) def filter(x): return apply_filter(x, b_filter) filter_margin = filter_butter_order * 3 nodes = [] for key in self.probe['channel_groups'].keys(): nodes += self.probe['channel_groups'][key]['channels'] nodes = np.array(nodes, dtype=np.int32) if filtered_datfile: self._wl = WaveformLoader(traces=self.traces_f, n_samples=self.n_samples_w, dc_offset=offset, scale_factor=gain, channels=nodes ) else: self._wl = WaveformLoader(traces=self.traces_f, n_samples=self.n_samples_w, filter=filter, filter_margin=filter_margin, dc_offset=offset, scale_factor=gain, channels=nodes ) # A virtual (n_spikes, n_samples, n_channels) array that is # memmapped to the filtered data file. self.waveforms = SpikeLoader(self._wl, self.spike_samples) assert self.waveforms.shape == (self.n_spikes, self.n_samples_w, self.n_channels) assert self.template_masks.shape == (self.n_templates, self.n_channels)