def test_decimate(): x = np.random.randn(16000, 3) y = decimate(x)
def run(raw_data=None, experiment=None, prm=None, probe=None, _debug=False, convert_only=False): """This main function takes raw data (either as a RawReader, or a path to a filename, or an array) and executes the main algorithm (filtering, spike detection, extraction...).""" assert experiment is not None, ("An Experiment instance needs to be " "provided in order to write the output.") # Create file logger for the experiment. LOGGER_FILE = create_file_logger(experiment.gen_filename('log')) # Get parameters from the PRM dictionary. chunk_size = prm.get('chunk_size', None) chunk_overlap = prm.get('chunk_overlap', 0) nchannels = prm.get('nchannels', None) # Ensure a RawDataReader is instantiated. if raw_data is not None: if not isinstance(raw_data, BaseRawDataReader): raw_data = read_raw(raw_data, nchannels=nchannels) else: raw_data = read_raw(experiment) # Log. if convert_only: info("Starting file conversion only. Klusta version {1:s}, on {0:s}".format((str(raw_data)), spikedetekt2.__version__)) info("Running spike detection on a single chunk of spikes only, so as to have some information") first_chunk_detected = False # horrible hack - detects spikes on one chunk only so KV doesn't complain else: info("Starting SpikeDetekt version {1:s} on {0:s}".format((str(raw_data)), spikedetekt2.__version__)) debug("Parameters: \n" + (display_params(prm))) # Get the bandpass filter. filter = bandpass_filter(**prm) if not (convert_only and first_chunk_detected): # Compute the strong threshold across excerpts uniformly scattered across the # whole recording. threshold = get_threshold(raw_data, filter=filter, channels=probe.channels, **prm) assert not np.isnan(threshold.weak).any() assert not np.isnan(threshold.strong).any() debug("Threshold: " + str(threshold)) # Debug module. diagnostics_path = prm.get('diagnostics_path', None) if diagnostics_path: diagnostics_mod = _import_module(diagnostics_path) if not hasattr(diagnostics_mod, 'diagnostics'): raise ValueError("The diagnostics module must implement a " "'diagnostics()' function.") diagnostics_fun = diagnostics_mod.diagnostics else: diagnostics_fun = None # Progress bar. progress_bar = ProgressReporter(period=30.) nspikes = 0 # Loop through all chunks with overlap. for chunk in raw_data.chunks(chunk_size=chunk_size, chunk_overlap=chunk_overlap,): # Log. debug("Processing chunk {0:s}...".format(chunk)) nsamples = chunk.nsamples rec = chunk.recording nrecs = chunk.nrecordings s_end = chunk.s_end # Filter the (full) chunk. chunk_raw = chunk.data_chunk_full # shape: (nsamples, nchannels) chunk_fil = apply_filter(chunk_raw, filter=filter) i = chunk.keep_start - chunk.s_start j = chunk.keep_end - chunk.s_start # Add the data to the KWD files. if prm.get('save_raw', False): # Do not append the raw data to the .kwd file if we're already reading # from the .kwd file. if not isinstance(raw_data, (KwdRawDataReader, ExperimentRawDataReader)): # Save raw data. experiment.recordings[chunk.recording].raw.append(convert_dtype(chunk.data_chunk_keep, np.int16)) if prm.get('save_high', False): # Save high-pass filtered data: need to remove the overlapping # sections. chunk_fil_keep = chunk_fil[i:j,:] experiment.recordings[chunk.recording].high.append(convert_dtype(chunk_fil_keep, np.int16)) if prm.get('save_low', True): # Save LFP. chunk_low = decimate(chunk_raw) chunk_low_keep = chunk_low[i//16:j//16,:] experiment.recordings[chunk.recording].low.append(convert_dtype(chunk_low_keep, np.int16)) if not (convert_only and first_chunk_detected): # Apply thresholds. chunk_detect, chunk_threshold = apply_threshold(chunk_fil, threshold=threshold, **prm) # Remove dead channels. dead = np.setdiff1d(np.arange(nchannels), probe.channels) chunk_detect[:,dead] = 0 chunk_threshold.strong[:,dead] = 0 chunk_threshold.weak[:,dead] = 0 # Find connected component (strong threshold). Return list of # Component instances. components = connected_components( chunk_strong=chunk_threshold.strong, chunk_weak=chunk_threshold.weak, probe_adjacency_list=probe.adjacency_list, chunk=chunk, **prm) # Now we extract the spike in each component. waveforms = extract_waveforms(chunk_detect=chunk_detect, threshold=threshold, chunk_fil=chunk_fil, chunk_raw=chunk_raw, probe=probe, components=components, **prm) # DEBUG module. # Execute the debug script. if diagnostics_fun: try: diagnostics_fun(**locals()) except Exception as e: warn("The diagnostics module failed: " + e.message) # Log number of spikes in the chunk. nspikes += len(waveforms) # We sort waveforms by increasing order of fractional time. [add_waveform(experiment, waveform) for waveform in sorted(waveforms)] first_chunk_detected = True # Update the progress bar. progress_bar.update(rec/float(nrecs) + (float(s_end) / (nsamples*nrecs)), '%d spikes found.' % (nspikes)) # DEBUG: keep only the first shank. if _debug: break # Feature extraction. save_features(experiment, **prm) close_file_logger(LOGGER_FILE) progress_bar.finish()