def test_can_use_neural_network_detector(path_to_tests): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() data = RecordingsReader(path.join(path_to_tests, 'data/standarized.bin'), loader='array').data channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom) whiten_filter = np.tile( np.eye(channel_index.shape[1], dtype='float32')[np.newaxis, :, :], [channel_index.shape[0], 1, 1]) detection_th = CONFIG.detect.neural_network_detector.threshold_spike triage_th = CONFIG.detect.neural_network_triage.threshold_collision detection_fname = CONFIG.detect.neural_network_detector.filename ae_fname = CONFIG.detect.neural_network_autoencoder.filename triage_fname = CONFIG.detect.neural_network_triage.filename (x_tf, output_tf, NND, NNAE, NNT) = neuralnetwork.prepare_nn(channel_index, whiten_filter, detection_th, triage_th, detection_fname, ae_fname, triage_fname) with tf.Session() as sess: # get values of above tensors NND.saver.restore(sess, NND.path_to_detector_model) NNAE.saver_ae.restore(sess, NNAE.path_to_ae_model) NNT.saver.restore(sess, NNT.path_to_triage_model) rot = NNAE.load_rotation() neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) neuralnetwork.run_detect_triage_featurize(data, sess, x_tf, output_tf, neighbors, rot)
def templates_uncropped(path_to_config, make_tmp_folder, path_to_sample_pipeline_folder, path_to_standarized_data): spike_train = np.array([100, 0, 150, 0, 200, 1, 250, 1, 300, 2, 350, 2]).reshape(-1, 2) yass.set_config(path_to_config, make_tmp_folder) CONFIG = yass.read_config() spike_train = np.load( path.join(path_to_sample_pipeline_folder, 'spike_train_post_deconv_post_merge.npy')) n_spikes, _ = spike_train.shape weighted_spike_train = np.hstack( (spike_train, np.ones((n_spikes, 1), 'int32'))) templates_uncropped, _ = get_templates(weighted_spike_train, path_to_standarized_data, CONFIG.resources.max_memory, 4 * CONFIG.spike_size) templates_uncropped_ = np.transpose(templates_uncropped, (2, 1, 0)) return templates_uncropped_
def test_can_make_noise(path_to_tests, path_to_standarized_data): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() n_spikes, _ = spike_train.shape weighted_spike_train = np.hstack( (spike_train, np.ones((n_spikes, 1), 'int32'))) templates_uncropped, _ = get_templates(weighted_spike_train, path_to_standarized_data, CONFIG.resources.max_memory, 4 * CONFIG.spike_size) templates_uncropped = np.transpose(templates_uncropped, (2, 1, 0)) templates = crop_and_align_templates(templates_uncropped, CONFIG.spike_size, CONFIG.neigh_channels, CONFIG.geom) spatial_SIG, temporal_SIG = noise_cov(path_to_standarized_data, CONFIG.neigh_channels, CONFIG.geom, templates.shape[1]) x_clean = make_clean(templates, min_amp=2, max_amp=10, nk=100) make_noise(x_clean, noise_ratio=10, templates=templates, spatial_SIG=spatial_SIG, temporal_SIG=temporal_SIG)
def test_can_reload_detector(path_to_tests, path_to_sample_pipeline_folder, tmp_folder): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() (x_detect, y_detect, x_triage, y_triage, x_ae, y_ae) = make_training_data(CONFIG, spike_train, chosen_templates, min_amplitude, n_spikes, path_to_sample_pipeline_folder) _, waveform_length, n_neighbors = x_detect.shape path_to_model = path.join(tmp_folder, 'detect-net.ckpt') detector = NeuralNetDetector(path_to_model, filters, waveform_length, n_neighbors, threshold=0.5, channel_index=CONFIG.channel_index, n_iter=10) detector.fit(x_detect, y_detect) NeuralNetDetector.load(path_to_model, threshold=0.5, channel_index=CONFIG.channel_index)
def test_can_use_neural_network_detector(path_to_tests, path_to_standarized_data): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() data = RecordingsReader(path_to_standarized_data, loader='array').data channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom) detection_th = CONFIG.detect.neural_network_detector.threshold_spike triage_th = CONFIG.detect.neural_network_triage.threshold_collision detection_fname = CONFIG.detect.neural_network_detector.filename ae_fname = CONFIG.detect.neural_network_autoencoder.filename triage_fname = CONFIG.detect.neural_network_triage.filename # instantiate neural networks NND = NeuralNetDetector.load(detection_fname, detection_th, channel_index) NNT = NeuralNetTriage.load(triage_fname, triage_th, input_tensor=NND.waveform_tf) NNAE = AutoEncoder(ae_fname, input_tensor=NND.waveform_tf) output_tf = (NNAE.score_tf, NND.spike_index_tf, NNT.idx_clean) with tf.Session() as sess: NND.restore(sess) NNAE.restore(sess) NNT.restore(sess) rot = NNAE.load_rotation() neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) neuralnetwork.run_detect_triage_featurize(data, sess, NND.x_tf, output_tf, neighbors, rot)
def test_can_make_misaligned(path_to_tests, path_to_standarized_data): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() n_spikes, _ = spike_train.shape weighted_spike_train = np.hstack( (spike_train, np.ones((n_spikes, 1), 'int32'))) templates_uncropped, _ = get_templates(weighted_spike_train, path_to_standarized_data, CONFIG.resources.max_memory, 4 * CONFIG.spike_size) templates_uncropped = np.transpose(templates_uncropped, (2, 1, 0)) x_clean = make_clean(templates_uncropped, min_amp=2, max_amp=10, nk=100) make_misaligned(x_clean, templates_uncropped, max_shift=2 * CONFIG.spike_size, misalign_ratio=1, misalign_ratio2=1, multi=True, nneigh=templates_uncropped.shape[2])
def run_neural_network(standardized_path, standardized_dtype, output_directory, run_chunk_sec='full'): """Run neural network detection """ logger = logging.getLogger(__name__) CONFIG = read_config() # load NN detector detector = Detect(CONFIG.neuralnetwork.detect.n_filters, CONFIG.spike_size_nn, CONFIG.channel_index) detector.load(CONFIG.neuralnetwork.detect.filename) # load NN denoiser denoiser = Denoise(CONFIG.neuralnetwork.denoise.n_filters, CONFIG.neuralnetwork.denoise.filter_sizes, CONFIG.spike_size_nn) denoiser.load(CONFIG.neuralnetwork.denoise.filename) # get data reader batch_length = CONFIG.resources.n_sec_chunk * CONFIG.resources.n_processors n_sec_chunk = CONFIG.resources.n_sec_chunk_gpu_detect print(" batch length to (sec): ", batch_length, " (longer increase speed a bit)") print(" length of each seg (sec): ", n_sec_chunk) buffer = CONFIG.spike_size_nn if run_chunk_sec == 'full': chunk_sec = None else: chunk_sec = run_chunk_sec reader = READER(standardized_path, standardized_dtype, CONFIG, batch_length, buffer, chunk_sec) # neighboring channels channel_index_dedup = make_channel_index(CONFIG.neigh_channels, CONFIG.geom, steps=2) # threshold for neuralnet detection detect_threshold = CONFIG.detect.threshold # loop over each chunk batch_ids = np.arange(reader.n_batches) batch_ids_split = np.split(batch_ids, len(CONFIG.torch_devices)) processes = [] for ii, device in enumerate(CONFIG.torch_devices): p = mp.Process(target=run_nn_detction_batch, args=(batch_ids_split[ii], output_directory, reader, n_sec_chunk, detector, denoiser, channel_index_dedup, detect_threshold, device)) p.start() processes.append(p) for p in processes: p.join()
def test_can_make_training_data(path_to_tests, path_to_sample_pipeline_folder): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() make_training_data(CONFIG, spike_train, chosen_templates, min_amplitude, n_spikes_to_make, data_folder=path_to_sample_pipeline_folder)
def run_template_update(output_directory, fname_templates, fname_spike_train, fname_shifts, fname_scales, fname_residual, residual_dtype, residual_offset=0, update_weight=50, units_to_update=None): fname_templates_out = os.path.join(output_directory, 'templates.npy') if not os.path.exists(fname_templates_out): print('updating templates') CONFIG = read_config() # output folder if not os.path.exists(output_directory): os.makedirs(output_directory) # reader if CONFIG.deconvolution.deconv_gpu: n_sec_chunk = CONFIG.resources.n_sec_chunk_gpu_deconv else: n_sec_chunk = CONFIG.resources.n_sec_chunk reader_residual = READER(fname_residual, residual_dtype, CONFIG, n_sec_chunk, offset=residual_offset) # residual obj that can shift templates in gpu residual_comp = RESIDUAL_GPU2(None, CONFIG, None, None, None, None, None, None, None, None, True) residual_comp.load_templates(fname_templates) residual_comp.make_bsplines_parallel() avg_min_max_vals, weights = get_avg_min_max_vals( fname_templates, fname_spike_train, fname_shifts, fname_scales, reader_residual, residual_comp, units_to_update) templates_updated = update_templates(fname_templates, weights, avg_min_max_vals, update_weight, units_to_update) np.save(fname_templates_out, templates_updated) return fname_templates_out
def run(fname_recording, recording_dtype, fname_spike_train, output_directory): """ """ logger = logging.getLogger(__name__) CONFIG = read_config() # make output directory if not exist if not os.path.exists(output_directory): os.mkdir(output_directory) # get reader reader = READER(fname_recording, recording_dtype, CONFIG) reader.spike_size = CONFIG.spike_size_nn # get noise covariance logger.info('Compute Noise Covaraince') save_dir = os.path.join(output_directory, 'noise_cov') chunk = [0, np.min((5*60*reader.sampling_rate, reader.end))] fname_spatial_sig, fname_temporal_sig = get_noise_covariance( reader, save_dir, CONFIG, chunk) # get processed templates logger.info('Crop Templates') save_dir = os.path.join(output_directory, 'templates') fname_templates_snippets = get_templates_on_local_channels( reader, save_dir, fname_spike_train, CONFIG) # denoise templates fname_templates_denoised = denoise_templates( fname_templates_snippets, save_dir) # make training data logger.info('Make Training Data') DetectTD = Detection_Training_Data( fname_templates_denoised, fname_spatial_sig, fname_temporal_sig) DenoTD = Denoising_Training_Data( fname_templates_denoised, fname_spatial_sig, fname_temporal_sig) return DetectTD, DenoTD
def run_voltage_treshold(standardized_path, standardized_dtype, output_directory, run_chunk_sec='full'): """Run detection that thresholds on amplitude """ logger = logging.getLogger(__name__) CONFIG = read_config() # get data reader #n_sec_chunk = CONFIG.resources.n_sec_chunk*CONFIG.resources.n_processors batch_length = CONFIG.resources.n_sec_chunk n_sec_chunk = 0.5 print(" batch length to (sec): ", batch_length, " (longer increase speed a bit)") print(" length of each seg (sec): ", n_sec_chunk) buffer = CONFIG.spike_size if run_chunk_sec == 'full': chunk_sec = None else: chunk_sec = run_chunk_sec reader = READER(standardized_path, standardized_dtype, CONFIG, batch_length, buffer, chunk_sec) # number of processed chunks n_mini_per_big_batch = int(np.ceil(batch_length / n_sec_chunk)) total_processing = int(reader.n_batches * n_mini_per_big_batch) # neighboring channels channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom, steps=2) if CONFIG.resources.multi_processing: parmap.starmap(run_voltage_threshold_parallel, list(zip(np.arange(reader.n_batches))), reader, n_sec_chunk, CONFIG.detect.threshold, channel_index, output_directory, processes=CONFIG.resources.n_processors, pm_pbar=True) else: for batch_id in range(reader.n_batches): run_voltage_threshold_parallel(batch_id, reader, n_sec_chunk, CONFIG.detect.threshold, channel_index, output_directory)
def test_can_use_detect_and_triage_after_reload(path_to_tests, path_to_sample_pipeline_folder, tmp_folder, path_to_standarized_data): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() (x_detect, y_detect, x_triage, y_triage, x_ae, y_ae) = make_training_data(CONFIG, spike_train, chosen_templates, min_amplitude, n_spikes, path_to_sample_pipeline_folder) _, waveform_length, n_neighbors = x_detect.shape path_to_model = path.join(tmp_folder, 'detect-net.ckpt') detector = NeuralNetDetector(path_to_model, filters, waveform_length, n_neighbors, threshold=0.5, channel_index=CONFIG.channel_index, n_iter=10) detector.fit(x_detect, y_detect) detector = NeuralNetDetector.load(path_to_model, threshold=0.5, channel_index=CONFIG.channel_index) triage = NeuralNetTriage(path_to_model, filters, waveform_length, n_neighbors, threshold=0.5, n_iter=10) triage.fit(x_detect, y_detect) triage = NeuralNetTriage.load(path_to_model, threshold=0.5) data = RecordingExplorer(path_to_standarized_data).reader.data output_names = ('spike_index', 'waveform', 'probability') (spike_index, waveform, proba) = detector.predict(data, output_names=output_names) triage.predict(waveform[:, :, :n_neighbors])
def run(): """RF computation """ CONFIG = read_config() stim_movie_file = os.path.join(CONFIG.data.root_folder, CONFIG.data.stimulus) triggers_fname = os.path.join(CONFIG.data.root_folder, CONFIG.data.triggers) spike_train_fname = os.path.join(CONFIG.path_to_output_directory, 'spike_train.npy') saving_dir = os.path.join(CONFIG.path_to_output_directory, 'rf') rf = RF(stim_movie_file, triggers_fname, spike_train_fname, saving_dir) rf.calculate_STA() rf.detect_multi_rf() rf.classification()
def covariance(recordings, temporal_size, neigbor_steps): """Compute noise spatial and temporal covariance Parameters ---------- recordings: matrix Multi-cannel recordings (n observations x n channels) temporal_size: neigbor_steps: int Number of steps from the multi-channel geometry to consider two channels as neighors """ CONFIG = read_config() # get the neighbor channels at a max "neigbor_steps" steps neigh_channels = n_steps_neigh_channels(CONFIG.neighChannels, neigbor_steps) # sum neighor flags for every channel, this gives the number of neighbors # per channel, then find the channel with the most neighbors # TODO: why are we selecting this one? channel = np.argmax(np.sum(neigh_channels, 0)) # get the neighbor channels for "channel" (neighbords_idx, ) = np.where(neigh_channels[channel]) # order neighbors by distance neighbords_idx, temp = order_channels_by_distance(channel, neighbords_idx, CONFIG.geom) # from the multi-channel recordings, get the neighbor channels # (this includes the channel with the most neighbors itself) rec = recordings[:, neighbords_idx] # filter recording if CONFIG.preprocess.filter == 1: rec = butterworth(rec, CONFIG.filter.low_pass_freq, CONFIG.filter.high_factor, CONFIG.filter.order, CONFIG.recordings.sampling_rate) # standardize recording sd_ = standarize.sd(rec, CONFIG.recordings.sampling_rate) rec = standarize.standarize(rec, sd_) # compute and return spatial and temporal covariance return util.covariance(rec, temporal_size, neigbor_steps, CONFIG.spikeSize)
def test_can_make_clean(path_to_tests, path_to_standarized_data): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() n_spikes, _ = spike_train.shape weighted_spike_train = np.hstack( (spike_train, np.ones((n_spikes, 1), 'int32'))) templates_uncropped, _ = get_templates(weighted_spike_train, path_to_standarized_data, CONFIG.resources.max_memory, 4 * CONFIG.spike_size) templates_uncropped = np.transpose(templates_uncropped, (2, 1, 0)) make_clean(templates_uncropped, min_amp=2, max_amp=10, nk=100)
def test_can_use_detector_after_fit(path_to_config, path_to_sample_pipeline_folder, make_tmp_folder, path_to_standardized_data): yass.set_config(path_to_config, make_tmp_folder) CONFIG = yass.read_config() spike_train = np.load(path.join(path_to_sample_pipeline_folder, 'spike_train.npy')) chosen_templates = np.unique(spike_train[:, 1]) min_amplitude = 4 max_amplitude = 60 n_spikes_to_make = 100 templates = make.load_templates(path_to_sample_pipeline_folder, spike_train, CONFIG, chosen_templates) path_to_standardized = path.join(path_to_sample_pipeline_folder, 'preprocess', 'standarized.bin') (x_detect, y_detect, x_triage, y_triage, x_ae, y_ae) = make.training_data(CONFIG, templates, min_amplitude, max_amplitude, n_spikes_to_make, path_to_standardized) _, waveform_length, n_neighbors = x_detect.shape path_to_model = path.join(make_tmp_folder, 'detect-net.ckpt') detector = NeuralNetDetector(path_to_model, [8, 4], waveform_length, n_neighbors, threshold=0.5, channel_index=CONFIG.channel_index, n_iter=10) detector.fit(x_detect, y_detect) data = RecordingExplorer(path_to_standardized_data).reader.data output_names = ('spike_index', 'waveform', 'probability') (spike_index, waveform, proba) = detector.predict_recording(data, output_names=output_names) detector.predict(x_detect)
def test_can_crop_and_align_templates(path_to_tests, path_to_standarized_data): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() n_spikes, _ = spike_train.shape weighted_spike_train = np.hstack( (spike_train, np.ones((n_spikes, 1), 'int32'))) templates_uncropped, _ = get_templates(weighted_spike_train, path_to_standarized_data, CONFIG.resources.max_memory, 4 * CONFIG.spike_size) templates_uncropped = np.transpose(templates_uncropped, (2, 1, 0)) crop_and_align_templates(templates_uncropped, CONFIG.spike_size, CONFIG.neigh_channels, CONFIG.geom)
def test_deconvolution(patch_triage_network, path_to_config, make_tmp_folder): yass.set_config(path_to_config, make_tmp_folder) (standarized_path, standarized_params, whiten_filter) = preprocess.run() spike_index_all = detect.run(standarized_path, standarized_params, whiten_filter) cluster.run(None, spike_index_all) CONFIG = read_config() TMP_FOLDER = CONFIG.path_to_output_directory path_to_spike_train_cluster = path.join(TMP_FOLDER, 'spike_train_cluster.npy') spike_train_cluster = np.load(path_to_spike_train_cluster) templates_cluster = np.load(path.join(TMP_FOLDER, 'templates_cluster.npy')) spike_train, postdeconv_templates = deconvolve.run(spike_train_cluster, templates_cluster)
def test_can_reload_detector(path_to_config, path_to_sample_pipeline_folder, make_tmp_folder): yass.set_config(path_to_config, make_tmp_folder) CONFIG = yass.read_config() spike_train = np.load(path.join(path_to_sample_pipeline_folder, 'spike_train.npy')) chosen_templates = np.unique(spike_train[:, 1]) min_amplitude = 4 max_amplitude = 60 n_spikes_to_make = 100 templates = make.load_templates(path_to_sample_pipeline_folder, spike_train, CONFIG, chosen_templates) path_to_standarized = path.join(path_to_sample_pipeline_folder, 'preprocess', 'standarized.bin') (x_detect, y_detect, x_triage, y_triage, x_ae, y_ae) = make.training_data(CONFIG, templates, min_amplitude, max_amplitude, n_spikes_to_make, path_to_standarized) _, waveform_length, n_neighbors = x_detect.shape path_to_model = path.join(make_tmp_folder, 'detect-net.ckpt') detector = NeuralNetDetector(path_to_model, [8, 4], waveform_length, n_neighbors, threshold=0.5, channel_index=CONFIG.channel_index, n_iter=10) detector.fit(x_detect, y_detect) NeuralNetDetector.load(path_to_model, threshold=0.5, channel_index=CONFIG.channel_index)
def get_o_layer(standarized_path, standarized_params, output_directory='tmp/', output_dtype='float32', output_filename='o_layer.bin', if_file_exists='skip', save_partial_results=False): """Get the output of NN detector instead of outputting the spike index """ CONFIG = read_config() channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom, 1) x_tf = tf.placeholder("float", [None, None]) # load Neural Net's detection_fname = CONFIG.detect.neural_network_detector.filename detection_th = CONFIG.detect.neural_network_detector.threshold_spike NND = NeuralNetDetector(detection_fname) o_layer_tf = NND.make_o_layer_tf_tensors(x_tf, channel_index, detection_th) bp = BatchProcessor(standarized_path, standarized_params['dtype'], standarized_params['n_channels'], standarized_params['data_format'], CONFIG.resources.max_memory, buffer_size=CONFIG.spike_size) TMP = os.path.join(CONFIG.data.root_folder, output_directory) _output_path = os.path.join(TMP, output_filename) (o_path, o_params) = bp.multi_channel_apply(_get_o_layer, mode='disk', cleanup_function=fix_indexes, output_path=_output_path, cast_dtype=output_dtype, x_tf=x_tf, o_layer_tf=o_layer_tf, NND=NND) return o_path, o_params
def run_deduplication(batch_files_dir, output_directory): CONFIG = read_config() neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) w = 5 batch_ids = list(np.arange(len(os.listdir(batch_files_dir)))) if CONFIG.resources.multi_processing: #if False: parmap.map(run_deduplication_batch_simple, batch_ids, batch_files_dir, output_directory, neighbors, w, processes=CONFIG.resources.n_processors, pm_pbar=True) else: for batch_id in batch_ids: run_deduplication_batch_simple(batch_id, batch_files_dir, output_directory, neighbors, w)
def run(template_fname, spike_train_fname, shifts_fname, output_directory, residual_fname, residual_dtype): logger = logging.getLogger(__name__) CONFIG = read_config() # fname_out = os.path.join(output_directory, 'soft_assignment.npy') if os.path.exists(fname_out): return fname_out # output folder if not os.path.exists(output_directory): os.makedirs(output_directory) # reader for residual reader_resid = READER(residual_fname, residual_dtype, CONFIG, CONFIG.resources.n_sec_chunk_gpu_deconv) # load NN detector detector = Detect(CONFIG.neuralnetwork.detect.n_filters, CONFIG.spike_size_nn, CONFIG.channel_index) detector.load(CONFIG.neuralnetwork.detect.filename) detector = detector.cuda() # initialize soft assignment calculator threshold = CONFIG.deconvolution.threshold / 0.1 sna = SOFTNOISEASSIGNMENT(spike_train_fname, template_fname, shifts_fname, reader_resid, detector, CONFIG.channel_index, threshold) # compuate soft assignment probs = sna.compute_soft_assignment() np.save(fname_out, probs) return fname_out
from pathlib import Path import logging from datetime import datetime from memory_profiler import profile import yass from yass import deconvolute import settings if __name__ == '__main__': settings.run() start = datetime.now() logger = logging.getLogger(__name__) CONFIG = yass.read_config() logger.info('Deconvolution started at second: %.2f', (datetime.now() - start).total_seconds()) DIRECTORY = Path(CONFIG.data.root_folder, 'profiling') spike_index_all = str(DIRECTORY / 'spike_index_all.npy') templates = str(DIRECTORY / 'templates.npy') profile(deconvolute.run)(spike_index_all, templates, output_directory='profiling') logger.info('Deconvolution finished at second: %.2f', (datetime.now() - start).total_seconds())
def run(config, logger_level='INFO', clean=False, output_dir='tmp/', complete=False, set_zero_seed=False): """Run YASS built-in pipeline Parameters ---------- config: str or mapping (such as dictionary) Path to YASS configuration file or mapping object logger_level: str Logger level clean: bool, optional Delete CONFIG.data.root_folder/output_dir/ before running output_dir: str, optional Output directory (if relative, it makes it relative to CONFIG.data.root_folder) to store the output data, defaults to tmp/. If absolute, it leaves it as it is. complete: bool, optional Generates extra files (needed to generate phy files) Notes ----- Running the preprocessor will generate the followiing files in CONFIG.data.root_folder/output_directory/: * ``config.yaml`` - Copy of the configuration file * ``metadata.yaml`` - Experiment metadata * ``filtered.bin`` - Filtered recordings (from preprocess) * ``filtered.yaml`` - Filtered recordings metadata (from preprocess) * ``standardized.bin`` - Standarized recordings (from preprocess) * ``standardized.yaml`` - Standarized recordings metadata (from preprocess) * ``whitening.npy`` - Whitening filter (from preprocess) Returns ------- numpy.ndarray Spike train """ # load yass configuration parameters set_config(config, output_dir) CONFIG = read_config() TMP_FOLDER = CONFIG.path_to_output_directory # remove tmp folder if needed if os.path.exists(TMP_FOLDER) and clean: shutil.rmtree(TMP_FOLDER) # create TMP_FOLDER if needed if not os.path.exists(TMP_FOLDER): os.makedirs(TMP_FOLDER) # load logging config file logging_config = load_logging_config_file() logging_config['handlers']['file']['filename'] = path.join( TMP_FOLDER, 'yass.log') logging_config['root']['level'] = logger_level # configure logging logging.config.dictConfig(logging_config) # instantiate logger logger = logging.getLogger(__name__) # print yass version logger.info('YASS version: %s', yass.__version__) ''' ********************************************** ******** SET ENVIRONMENT VARIABLES *********** ********************************************** ''' os.environ["OPENBLAS_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" os.environ["GIO_EXTRA_MODULES"] = "/usr/lib/x86_64-linux-gnu/gio/modules/" ''' ********************************************** ************** PREPROCESS ******************** ********************************************** ''' # preprocess start = time.time() (standardized_path, standardized_params, whiten_filter) = (preprocess.run( if_file_exists=CONFIG.preprocess.if_file_exists)) time_preprocess = time.time() - start ''' ********************************************** ************** DETECT EVENTS ***************** ********************************************** ''' # detect # Cat: This code now runs with open tensorflow calls start = time.time() (spike_index_all) = detect.run(standardized_path, standardized_params, whiten_filter, if_file_exists=CONFIG.detect.if_file_exists, save_results=CONFIG.detect.save_results) spike_index_clear = None time_detect = time.time() - start ''' ********************************************** ***************** CLUSTER ******************** ********************************************** ''' # cluster start = time.time() path_to_spike_train_cluster = path.join(TMP_FOLDER, 'spike_train_cluster.npy') if os.path.exists(path_to_spike_train_cluster) == False: cluster.run(spike_index_clear, spike_index_all) else: print("\nClustering completed previously...\n\n") spike_train_cluster = np.load(path_to_spike_train_cluster) templates_cluster = np.load( os.path.join(TMP_FOLDER, 'templates_cluster.npy')) time_cluster = time.time() - start #print ("Spike train clustered: ", spike_index_cluster.shape, "spike train clear: ", # spike_train_clear.shape, " templates: ", templates.shape) ''' ********************************************** ************** DECONVOLUTION ***************** ********************************************** ''' # run deconvolution start = time.time() spike_train, postdeconv_templates = deconvolve.run(spike_train_cluster, templates_cluster) time_deconvolution = time.time() - start # save spike train path_to_spike_train = path.join(TMP_FOLDER, 'spike_train_post_deconv_post_merge.npy') np.save(path_to_spike_train, spike_train) logger.info('Spike train saved in: {}'.format(path_to_spike_train)) # save template path_to_templates = path.join(TMP_FOLDER, 'templates_post_deconv_post_merge.npy') np.save(path_to_templates, postdeconv_templates) logger.info('Templates saved in: {}'.format(path_to_templates)) ''' ********************************************** ************** POST PROCESSING**************** ********************************************** ''' # save metadata in tmp path_to_metadata = path.join(TMP_FOLDER, 'metadata.yaml') logging.info('Saving metadata in {}'.format(path_to_metadata)) save_metadata(path_to_metadata) # save metadata in tmp path_to_metadata = path.join(TMP_FOLDER, 'metadata.yaml') logging.info('Saving metadata in {}'.format(path_to_metadata)) save_metadata(path_to_metadata) # save config.yaml copy in tmp/ path_to_config_copy = path.join(TMP_FOLDER, 'config.yaml') if isinstance(config, Mapping): with open(path_to_config_copy, 'w') as f: yaml.dump(config, f, default_flow_style=False) else: shutil.copy2(config, path_to_config_copy) logging.info('Saving copy of config: {} in {}'.format( config, path_to_config_copy)) # this part loads waveforms for all spikes in the spike train and scores # them, this data is needed to later generate phy files if complete: STANDARIZED_PATH = path.join(TMP_FOLDER, 'standardized.bin') PARAMS = load_yaml(path.join(TMP_FOLDER, 'standardized.yaml')) # load waveforms for all spikes in the spike train logger.info('Loading waveforms from all spikes in the spike train...') explorer = RecordingExplorer(STANDARIZED_PATH, spike_size=CONFIG.spike_size, dtype=PARAMS['dtype'], n_channels=PARAMS['n_channels'], data_order=PARAMS['data_order']) waveforms = explorer.read_waveforms(spike_train[:, 0]) path_to_waveforms = path.join(TMP_FOLDER, 'spike_train_waveforms.npy') np.save(path_to_waveforms, waveforms) logger.info('Saved all waveforms from the spike train in {}...'.format( path_to_waveforms)) # score all waveforms logger.info('Scoring waveforms from all spikes in the spike train...') path_to_rotation = path.join(TMP_FOLDER, 'rotation.npy') rotation = np.load(path_to_rotation) main_channels = explorer.main_channel_for_waveforms(waveforms) path_to_main_channels = path.join(TMP_FOLDER, 'waveforms_main_channel.npy') np.save(path_to_main_channels, main_channels) logger.info('Saved all waveforms main channels in {}...'.format( path_to_waveforms)) waveforms_score = dim_red.score(waveforms, rotation, main_channels, CONFIG.neigh_channels, CONFIG.geom) path_to_waveforms_score = path.join(TMP_FOLDER, 'waveforms_score.npy') np.save(path_to_waveforms_score, waveforms_score) logger.info('Saved all scores in {}...'.format(path_to_waveforms)) # score templates # TODO: templates should be returned in the right shape to avoid .T templates_ = templates.T main_channels_tmpls = explorer.main_channel_for_waveforms(templates_) path_to_templates_main_c = path.join(TMP_FOLDER, 'templates_main_channel.npy') np.save(path_to_templates_main_c, main_channels_tmpls) logger.info('Saved all templates main channels in {}...'.format( path_to_templates_main_c)) templates_score = dim_red.score(templates_, rotation, main_channels_tmpls, CONFIG.neigh_channels, CONFIG.geom) path_to_templates_score = path.join(TMP_FOLDER, 'templates_score.npy') np.save(path_to_templates_score, templates_score) logger.info( 'Saved all templates scores in {}...'.format(path_to_waveforms)) logger.info('Finished YASS execution. Timing summary:') total = (time_preprocess + time_detect + time_cluster + time_deconvolution) logger.info('\t Preprocess: %s (%.2f %%)', human_readable_time(time_preprocess), time_preprocess / total * 100) logger.info('\t Detection: %s (%.2f %%)', human_readable_time(time_detect), time_detect / total * 100) logger.info('\t Clustering: %s (%.2f %%)', human_readable_time(time_cluster), time_cluster / total * 100) logger.info('\t Deconvolution: %s (%.2f %%)', human_readable_time(time_deconvolution), time_deconvolution / total * 100) return spike_train
def run(config, logger_level='INFO', clean=False, output_dir='tmp/'): """Run YASS built-in pipeline Parameters ---------- config: str or mapping (such as dictionary) Path to YASS configuration file or mapping object logger_level: str Logger level clean: bool, optional Delete CONFIG.data.root_folder/output_dir/ before running output_dir: str, optional Output directory (if relative, it makes it relative to CONFIG.data.root_folder) to store the output data, defaults to tmp/. If absolute, it leaves it as it is. complete: bool, optional Generates extra files (needed to generate phy files) Notes ----- Running the preprocessor will generate the followiing files in CONFIG.data.root_folder/output_directory/: * ``config.yaml`` - Copy of the configuration file * ``metadata.yaml`` - Experiment metadata * ``filtered.bin`` - Filtered recordings (from preprocess) * ``filtered.yaml`` - Filtered recordings metadata (from preprocess) * ``standardized.bin`` - Standarized recordings (from preprocess) * ``standardized.yaml`` - Standarized recordings metadata (from preprocess) * ``whitening.npy`` - Whitening filter (from preprocess) Returns ------- numpy.ndarray Spike train """ # load yass configuration parameters CONFIG = Config.from_yaml(config) CONFIG._data['cluster']['min_fr'] = 1 CONFIG._data['clean_up']['mad']['min_var_gap'] = 1.5 CONFIG._data['clean_up']['mad']['max_violations'] = 5 CONFIG._data['neuralnetwork']['apply_nn'] = False CONFIG._data['detect']['threshold'] = 4 set_config(CONFIG._data, output_dir) CONFIG = read_config() TMP_FOLDER = CONFIG.path_to_output_directory # remove tmp folder if needed if os.path.exists(TMP_FOLDER) and clean: shutil.rmtree(TMP_FOLDER) # create TMP_FOLDER if needed if not os.path.exists(TMP_FOLDER): os.makedirs(TMP_FOLDER) # load logging config file logging_config = load_logging_config_file() logging_config['handlers']['file']['filename'] = os.path.join( TMP_FOLDER, 'yass.log') logging_config['root']['level'] = logger_level # configure logging logging.config.dictConfig(logging_config) # instantiate logger logger = logging.getLogger(__name__) # print yass version logger.info('YASS version: %s', yass.__version__) ''' ********************************************** ******** SET ENVIRONMENT VARIABLES *********** ********************************************** ''' os.environ["OPENBLAS_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" os.environ["GIO_EXTRA_MODULES"] = "/usr/lib/x86_64-linux-gnu/gio/modules/" # TODO: if input spike train is None, run yass with threshold detector #if fname_spike_train is None: # logger.info('Not available yet. You must input spike train') # return ''' ********************************************** ************** PREPROCESS ******************** ********************************************** ''' # preprocess start = time.time() (standardized_path, standardized_dtype) = preprocess.run( os.path.join(TMP_FOLDER, 'preprocess')) TMP_FOLDER = os.path.join(TMP_FOLDER, 'nn_train') if not os.path.exists(TMP_FOLDER): os.makedirs(TMP_FOLDER) if CONFIG.neuralnetwork.training.input_spike_train_filname is None: # run on 10 minutes of data rec_len = np.min( (CONFIG.rec_len, CONFIG.recordings.sampling_rate * 10 * 60)) # detect logger.info('DETECTION') spike_index_path = detect.run(standardized_path, standardized_dtype, os.path.join(TMP_FOLDER, 'detect'), run_chunk_sec=[0, rec_len]) logger.info('CLUSTERING') # cluster raw_data = True full_run = False fname_templates, fname_spike_train = cluster.run( os.path.join(TMP_FOLDER, 'cluster'), standardized_path, standardized_dtype, fname_spike_index=spike_index_path, raw_data=True, full_run=True) methods = [ 'off_center', 'low_ptp', 'high_mad', 'duplicate', 'duplicate_l2' ] fname_templates, fname_spike_train = postprocess.run( methods, os.path.join(TMP_FOLDER, 'cluster_post_process'), standardized_path, standardized_dtype, fname_templates, fname_spike_train) else: # if there is an input spike train, use it fname_spike_train = CONFIG.neuralnetwork.training.input_spike_train_filname # Get training data maker DetectTD, DenoTD = augment.run(standardized_path, standardized_dtype, fname_spike_train, os.path.join(TMP_FOLDER, 'augment')) # Train Detector detector = Detect(CONFIG.neuralnetwork.detect.n_filters, CONFIG.spike_size_nn, CONFIG.channel_index).cuda() detector.train(os.path.join(TMP_FOLDER, 'detect.pt'), DetectTD) # Train Denoiser denoiser = Denoise(CONFIG.neuralnetwork.denoise.n_filters, CONFIG.neuralnetwork.denoise.filter_sizes, CONFIG.spike_size_nn).cuda() denoiser.train(os.path.join(TMP_FOLDER, 'denoise.pt'), DenoTD)
def run(output_directory, fname_spike_train, fname_shifts, fname_scales, fname_templates, fname_soft_assignment, fname_residual, residual_dtype): logger = logging.getLogger(__name__) CONFIG = read_config() # output folder if not os.path.exists(output_directory): os.makedirs(output_directory) fname_spike_train_out = os.path.join(output_directory, 'spike_train.npy') fname_templates_out = os.path.join(output_directory, 'templates.npy') fname_soft_assignment_out = os.path.join(output_directory, 'soft_assignment.npy') fname_shifts_out = os.path.join(output_directory, 'shifts.npy') fname_scales_out = os.path.join(output_directory, 'scales.npy') if os.path.exists(fname_spike_train_out) and os.path.exists( fname_templates_out): return (fname_templates_out, fname_spike_train_out, fname_shifts_out, fname_scales_out, fname_soft_assignment_out) reader_residual = READER(fname_residual, residual_dtype, CONFIG) # get whitening filters fname_spatial_cov = os.path.join(output_directory, 'spatial_cov.npy') fname_temporal_cov = os.path.join(output_directory, 'temporal_cov.npy') if not (os.path.exists(fname_spatial_cov) and os.path.exists(fname_temporal_cov)): spatial_cov, temporal_cov = get_noise_covariance( reader_residual, CONFIG) np.save(fname_spatial_cov, spatial_cov) np.save(fname_temporal_cov, temporal_cov) else: spatial_cov = np.load(fname_spatial_cov) temporal_cov = np.load(fname_temporal_cov) # initialize merge: find candidates logger.info("finding merge candidates") tm = TemplateMerge(output_directory, reader_residual, fname_templates, fname_spike_train, fname_shifts, fname_scales, fname_soft_assignment, fname_spatial_cov, fname_temporal_cov, CONFIG.geom, CONFIG.resources.multi_processing, CONFIG.resources.n_processors) # find merge pairs logger.info("merging pairs") tm.get_merge_pairs() # update templates adn spike train accordingly logger.info("udpating templates and spike train") (templates_new, spike_train_new, shifts_new, scales_new, soft_assignment_new, merge_array) = tm.merge_units() # save results fname_merge_array = os.path.join(output_directory, 'merge_array.npy') np.save(fname_merge_array, merge_array) np.save(fname_spike_train_out, spike_train_new) np.save(fname_templates_out, templates_new) np.save(fname_shifts_out, shifts_new) np.save(fname_scales_out, scales_new) np.save(fname_soft_assignment_out, soft_assignment_new) logger.info('Number of units after merge: {}'.format( templates_new.shape[0])) return (fname_templates_out, fname_spike_train_out, fname_shifts_out, fname_scales_out, fname_soft_assignment_out)
def _run_pipeline(config, output_file, logger_level='INFO', clean=True, output_dir='tmp/', complete=False): """ Run the entire pipeline given a path to a config file and output path """ # load yass configuration parameters set_config(config) CONFIG = read_config() ROOT_FOLDER = CONFIG.data.root_folder TMP_FOLDER = path.join(ROOT_FOLDER, output_dir) # remove tmp folder if needed if os.path.exists(TMP_FOLDER) and clean: shutil.rmtree(TMP_FOLDER) # create TMP_FOLDER if needed if not os.path.exists(TMP_FOLDER): os.makedirs(TMP_FOLDER) # load logging config file logging_config = load_logging_config_file() logging_config['handlers']['file']['filename'] = path.join( TMP_FOLDER, 'yass.log') logging_config['root']['level'] = logger_level # configure logging logging.config.dictConfig(logging_config) # instantiate logger logger = logging.getLogger(__name__) # run preprocessor (score, spike_index_clear, spike_index_collision) = preprocess.run(output_directory=output_dir) # run processor (spike_train_clear, templates, spike_index_collision) = process.run(score, spike_index_clear, spike_index_collision, output_directory=output_dir) # run deconvolution spike_train = deconvolute.run(spike_train_clear, templates, spike_index_collision, output_directory=output_dir) # save metadata in tmp path_to_metadata = path.join(TMP_FOLDER, 'metadata.yaml') logging.info('Saving metadata in {}'.format(path_to_metadata)) save_metadata(path_to_metadata) # save config.yaml copy in tmp/ path_to_config_copy = path.join(TMP_FOLDER, 'config.yaml') shutil.copy2(config, path_to_config_copy) logging.info('Saving copy of config: {} in {}'.format( config, path_to_config_copy)) # save templates path_to_templates = path.join(TMP_FOLDER, 'templates.npy') logging.info('Saving templates in {}'.format(path_to_templates)) np.save(path_to_templates, templates) path_to_spike_train = path.join(TMP_FOLDER, output_file) np.save(path_to_spike_train, spike_train) logger.info('Spike train saved in: {}'.format(path_to_spike_train)) # this part loads waveforms for all spikes in the spike train and scores # them, this data is needed to later generate phy files if complete: STANDARIZED_PATH = path.join(TMP_FOLDER, 'standarized.bin') PARAMS = load_yaml(path.join(TMP_FOLDER, 'standarized.yaml')) # load waveforms for all spikes in the spike train logger.info('Loading waveforms from all spikes in the spike train...') explorer = RecordingExplorer(STANDARIZED_PATH, spike_size=CONFIG.spikeSize, dtype=PARAMS['dtype'], n_channels=PARAMS['n_channels'], data_format=PARAMS['data_format']) waveforms = explorer.read_waveforms(spike_train[:, 0]) path_to_waveforms = path.join(TMP_FOLDER, 'spike_train_waveforms.npy') np.save(path_to_waveforms, waveforms) logger.info('Saved all waveforms from the spike train in {}...'.format( path_to_waveforms)) # score all waveforms logger.info('Scoring waveforms from all spikes in the spike train...') path_to_rotation = path.join(TMP_FOLDER, 'rotation.npy') rotation = np.load(path_to_rotation) main_channels = explorer.main_channel_for_waveforms(waveforms) path_to_main_channels = path.join(TMP_FOLDER, 'waveforms_main_channel.npy') np.save(path_to_main_channels, main_channels) logger.info('Saved all waveforms main channels in {}...'.format( path_to_waveforms)) waveforms_score = dim_red.score(waveforms, rotation, main_channels, CONFIG.neighChannels, CONFIG.geom) path_to_waveforms_score = path.join(TMP_FOLDER, 'waveforms_score.npy') np.save(path_to_waveforms_score, waveforms_score) logger.info('Saved all scores in {}...'.format(path_to_waveforms)) # score templates # TODO: templates should be returned in the right shape to avoid .T templates_ = templates.T main_channels_tmpls = explorer.main_channel_for_waveforms(templates_) path_to_templates_main_c = path.join(TMP_FOLDER, 'templates_main_channel.npy') np.save(path_to_templates_main_c, main_channels_tmpls) logger.info('Saved all templates main channels in {}...'.format( path_to_templates_main_c)) templates_score = dim_red.score(templates_, rotation, main_channels_tmpls, CONFIG.neighChannels, CONFIG.geom) path_to_templates_score = path.join(TMP_FOLDER, 'templates_score.npy') np.save(path_to_templates_score, templates_score) logger.info( 'Saved all templates scores in {}...'.format(path_to_waveforms))
def _neural_network_detection(standarized_path, standarized_params, n_observations, output_directory): """Run neural network detection and autoencoder dimensionality reduction """ logger = logging.getLogger(__name__) CONFIG = read_config() OUTPUT_DTYPE = CONFIG.preprocess.dtype TMP_FOLDER = os.path.join(CONFIG.data.root_folder, output_directory) # detect spikes bp = BatchProcessor(standarized_path, standarized_params['dtype'], standarized_params['n_channels'], standarized_params['data_format'], CONFIG.resources.max_memory, buffer_size=0) # check if all scores, clear and collision spikes exist.. path_to_score = os.path.join(TMP_FOLDER, 'score_clear.npy') path_to_spike_index_clear = os.path.join(TMP_FOLDER, 'spike_index_clear.npy') path_to_spike_index_collision = os.path.join(TMP_FOLDER, 'spike_index_collision.npy') if all([ os.path.exists(path_to_score), os.path.exists(path_to_spike_index_clear), os.path.exists(path_to_spike_index_collision) ]): logger.info('Loading "{}", "{}" and "{}"'.format( path_to_score, path_to_spike_index_clear, path_to_spike_index_collision)) scores = np.load(path_to_score) clear = np.load(path_to_spike_index_clear) collision = np.load(path_to_spike_index_collision) else: logger.info('One or more of "{}", "{}" or "{}" files were missing, ' 'computing...'.format(path_to_score, path_to_spike_index_clear, path_to_spike_index_collision)) # apply threshold detector on standarized data autoencoder_filename = CONFIG.neural_network_autoencoder.filename mc = bp.multi_channel_apply res = mc( neuralnetwork.nn_detection, mode='memory', cleanup_function=neuralnetwork.fix_indexes, neighbors=CONFIG.neighChannels, geom=CONFIG.geom, temporal_features=CONFIG.spikes.temporal_features, # FIXME: what is this? temporal_window=3, th_detect=CONFIG.neural_network_detector.threshold_spike, th_triage=CONFIG.neural_network_triage.threshold_collision, detector_filename=CONFIG.neural_network_detector.filename, autoencoder_filename=autoencoder_filename, triage_filename=CONFIG.neural_network_triage.filename) # save clear spikes clear = np.concatenate([element[1] for element in res], axis=0) logger.info('Removing clear indexes outside the allowed range to ' 'draw a complete waveform...') clear, idx = detect.remove_incomplete_waveforms( clear, CONFIG.spikeSize + CONFIG.templatesMaxShift, n_observations) np.save(path_to_spike_index_clear, clear) logger.info('Saved spike index clear in {}...'.format( path_to_spike_index_clear)) # save collided spikes collision = np.concatenate([element[2] for element in res], axis=0) logger.info('Removing collision indexes outside the allowed range to ' 'draw a complete waveform...') collision, _ = detect.remove_incomplete_waveforms( collision, CONFIG.spikeSize + CONFIG.templatesMaxShift, n_observations) np.save(path_to_spike_index_collision, collision) logger.info('Saved spike index collision in {}...'.format( path_to_spike_index_collision)) if CONFIG.clustering.clustering_method == 'location': ####################### # Waveform extraction # ####################### # TODO: what should the behaviour be for spike indexes that are # when starting/ending the recordings and it is not possible to # draw a complete waveform? logger.info('Computing whitening matrix...') bp = BatchProcessor(standarized_path, standarized_params['dtype'], standarized_params['n_channels'], standarized_params['data_format'], CONFIG.resources.max_memory) batches = bp.multi_channel() first_batch, _, _ = next(batches) Q = whiten.matrix(first_batch, CONFIG.neighChannels, CONFIG.spikeSize) path_to_whitening_matrix = os.path.join(TMP_FOLDER, 'whitening.npy') np.save(path_to_whitening_matrix, Q) logger.info('Saved whitening matrix in {}'.format( path_to_whitening_matrix)) # apply whitening to every batch (whitened_path, whitened_params) = bp.multi_channel_apply( np.matmul, mode='disk', output_path=os.path.join(TMP_FOLDER, 'whitened.bin'), if_file_exists='skip', cast_dtype=OUTPUT_DTYPE, b=Q) main_channel = clear[:, 1] # load and dump waveforms from clear spikes path_to_waveforms_clear = os.path.join(TMP_FOLDER, 'waveforms_clear.npy') if os.path.exists(path_to_waveforms_clear): logger.info( 'Found clear waveforms in {}, loading them...'.format( path_to_waveforms_clear)) waveforms_clear = np.load(path_to_waveforms_clear) else: logger.info( 'Did not find clear waveforms in {}, reading them from {}'. format(path_to_waveforms_clear, whitened_path)) explorer = RecordingExplorer(whitened_path, spike_size=CONFIG.spikeSize) waveforms_clear = explorer.read_waveforms(clear[:, 0], 'all') np.save(path_to_waveforms_clear, waveforms_clear) logger.info('Saved waveform from clear spikes in: {}'.format( path_to_waveforms_clear)) main_channel = clear[:, 1] # save rotation detector_filename = CONFIG.neural_network_detector.filename autoencoder_filename = CONFIG.neural_network_autoencoder.filename rotation = neuralnetwork.load_rotation(detector_filename, autoencoder_filename) path_to_rotation = os.path.join(TMP_FOLDER, 'rotation.npy') logger.info("rotation_matrix_shape = {}".format(rotation.shape)) np.save(path_to_rotation, rotation) logger.info( 'Saved rotation matrix in {}...'.format(path_to_rotation)) logger.info('Denoising...') path_to_denoised_waveforms = os.path.join( TMP_FOLDER, 'denoised_waveforms.npy') if os.path.exists(path_to_denoised_waveforms): logger.info( 'Found denoised waveforms in {}, loading them...'.format( path_to_denoised_waveforms)) denoised_waveforms = np.load(path_to_denoised_waveforms) else: logger.info( 'Did not find denoised waveforms in {}, evaluating them' 'from {}'.format(path_to_denoised_waveforms, path_to_waveforms_clear)) waveforms_clear = np.load(path_to_waveforms_clear) denoised_waveforms = dim_red.denoise(waveforms_clear, rotation, CONFIG) logger.info('Saving denoised waveforms to {}'.format( path_to_denoised_waveforms)) np.save(path_to_denoised_waveforms, denoised_waveforms) isolated_index, x, y = get_isolated_spikes_and_locations( denoised_waveforms, main_channel, CONFIG) x = (x - np.mean(x)) / np.std(x) y = (y - np.mean(y)) / np.std(y) corrupted_index = np.logical_not( np.in1d(np.arange(clear.shape[0]), isolated_index)) collision = np.concatenate([collision, clear[corrupted_index]], axis=0) clear = clear[isolated_index] waveforms_clear = waveforms_clear[isolated_index] ################################################# # Dimensionality reduction (Isolated Waveforms) # ################################################# scores = dim_red.main_channel_scores(waveforms_clear, rotation, clear, CONFIG) scores = (scores - np.mean(scores, axis=0)) / np.std(scores) scores = np.concatenate([ x[:, np.newaxis, np.newaxis], y[:, np.newaxis, np.newaxis], scores[:, :, np.newaxis] ], axis=1) else: # save scores scores = np.concatenate([element[0] for element in res], axis=0) logger.info( 'Removing scores for indexes outside the allowed range to ' 'draw a complete waveform...') scores = scores[idx] # compute Q for whitening logger.info('Computing whitening matrix...') bp = BatchProcessor(standarized_path, standarized_params['dtype'], standarized_params['n_channels'], standarized_params['data_format'], CONFIG.resources.max_memory) batches = bp.multi_channel() first_batch, _, _ = next(batches) Q = whiten.matrix_localized(first_batch, CONFIG.neighChannels, CONFIG.geom, CONFIG.spikeSize) path_to_whitening_matrix = os.path.join(TMP_FOLDER, 'whitening.npy') np.save(path_to_whitening_matrix, Q) logger.info('Saved whitening matrix in {}'.format( path_to_whitening_matrix)) scores = whiten.score(scores, clear[:, 1], Q) np.save(path_to_score, scores) logger.info('Saved spike scores in {}...'.format(path_to_score)) # save rotation detector_filename = CONFIG.neural_network_detector.filename autoencoder_filename = CONFIG.neural_network_autoencoder.filename rotation = neuralnetwork.load_rotation(detector_filename, autoencoder_filename) path_to_rotation = os.path.join(TMP_FOLDER, 'rotation.npy') np.save(path_to_rotation, rotation) logger.info( 'Saved rotation matrix in {}...'.format(path_to_rotation)) np.save(path_to_score, scores) logger.info('Saved spike scores in {}...'.format(path_to_score)) return scores, clear, collision
def run(output_directory='tmp/'): """Execute preprocessing pipeline Parameters ---------- output_directory: str, optional Location to store partial results, relative to CONFIG.data.root_folder, defaults to tmp/ Returns ------- clear_scores: numpy.ndarray (n_spikes, n_features, n_channels) 3D array with the scores for the clear spikes, first simension is the number of spikes, second is the nymber of features and third the number of channels spike_index_clear: numpy.ndarray (n_clear_spikes, 2) 2D array with indexes for clear spikes, first column contains the spike location in the recording and the second the main channel (channel whose amplitude is maximum) spike_index_collision: numpy.ndarray (n_collided_spikes, 2) 2D array with indexes for collided spikes, first column contains the spike location in the recording and the second the main channel (channel whose amplitude is maximum) Notes ----- Running the preprocessor will generate the followiing files in CONFIG.data.root_folder/output_directory/: * ``config.yaml`` - Copy of the configuration file * ``metadata.yaml`` - Experiment metadata * ``filtered.bin`` - Filtered recordings * ``filtered.yaml`` - Filtered recordings metadata * ``standarized.bin`` - Standarized recordings * ``standarized.yaml`` - Standarized recordings metadata * ``whitened.bin`` - Whitened recordings * ``whitened.yaml`` - Whitened recordings metadata * ``rotation.npy`` - Rotation matrix for dimensionality reduction * ``spike_index_clear.npy`` - Same as spike_index_clear returned * ``spike_index_collision.npy`` - Same as spike_index_collision returned * ``score_clear.npy`` - Scores for clear spikes * ``waveforms_clear.npy`` - Waveforms for clear spikes Examples -------- .. literalinclude:: ../examples/preprocess.py """ logger = logging.getLogger(__name__) CONFIG = read_config() OUTPUT_DTYPE = CONFIG.preprocess.dtype logger.info( 'Output dtype for transformed data will be {}'.format(OUTPUT_DTYPE)) TMP = os.path.join(CONFIG.data.root_folder, output_directory) if not os.path.exists(TMP): logger.info('Creating temporary folder: {}'.format(TMP)) os.makedirs(TMP) else: logger.info('Temporary folder {} already exists, output will be ' 'stored there'.format(TMP)) path = os.path.join(CONFIG.data.root_folder, CONFIG.data.recordings) dtype = CONFIG.recordings.dtype # initialize pipeline object, one batch per channel pipeline = BatchPipeline(path, dtype, CONFIG.recordings.n_channels, CONFIG.recordings.format, CONFIG.resources.max_memory, TMP) # add filter transformation if necessary if CONFIG.preprocess.filter: filter_op = Transform(butterworth, 'filtered.bin', mode='single_channel_one_batch', keep=True, if_file_exists='skip', cast_dtype=OUTPUT_DTYPE, low_freq=CONFIG.filter.low_pass_freq, high_factor=CONFIG.filter.high_factor, order=CONFIG.filter.order, sampling_freq=CONFIG.recordings.sampling_rate) pipeline.add([filter_op]) (filtered_path, ), (filtered_params, ) = pipeline.run() # standarize bp = BatchProcessor(filtered_path, filtered_params['dtype'], filtered_params['n_channels'], filtered_params['data_format'], CONFIG.resources.max_memory) batches = bp.multi_channel() first_batch, _, _ = next(batches) sd = standard_deviation(first_batch, CONFIG.recordings.sampling_rate) (standarized_path, standarized_params) = bp.multi_channel_apply( standarize, mode='disk', output_path=os.path.join(TMP, 'standarized.bin'), if_file_exists='skip', cast_dtype=OUTPUT_DTYPE, sd=sd) standarized = RecordingsReader(standarized_path) n_observations = standarized.observations if CONFIG.spikes.detection == 'threshold': return _threshold_detection(standarized_path, standarized_params, n_observations, output_directory) elif CONFIG.spikes.detection == 'nn': return _neural_network_detection(standarized_path, standarized_params, n_observations, output_directory)
def _threshold_detection(standarized_path, standarized_params, n_observations, output_directory): """Run threshold detector and dimensionality reduction using PCA """ logger = logging.getLogger(__name__) CONFIG = read_config() OUTPUT_DTYPE = CONFIG.preprocess.dtype TMP_FOLDER = os.path.join(CONFIG.data.root_folder, output_directory) ############### # Whiten data # ############### # compute Q for whitening logger.info('Computing whitening matrix...') bp = BatchProcessor(standarized_path, standarized_params['dtype'], standarized_params['n_channels'], standarized_params['data_format'], CONFIG.resources.max_memory) batches = bp.multi_channel() first_batch, _, _ = next(batches) Q = whiten.matrix(first_batch, CONFIG.neighChannels, CONFIG.spikeSize) path_to_whitening_matrix = os.path.join(TMP_FOLDER, 'whitening.npy') np.save(path_to_whitening_matrix, Q) logger.info( 'Saved whitening matrix in {}'.format(path_to_whitening_matrix)) # apply whitening to every batch (whitened_path, whitened_params) = bp.multi_channel_apply( np.matmul, mode='disk', output_path=os.path.join(TMP_FOLDER, 'whitened.bin'), if_file_exists='skip', cast_dtype=OUTPUT_DTYPE, b=Q) ################### # Spike detection # ################### path_to_spike_index_clear = os.path.join(TMP_FOLDER, 'spike_index_clear.npy') bp = BatchProcessor(standarized_path, standarized_params['dtype'], standarized_params['n_channels'], standarized_params['data_format'], CONFIG.resources.max_memory, buffer_size=0) # clear spikes if os.path.exists(path_to_spike_index_clear): # if it exists, load it... logger.info('Found file in {}, loading it...'.format( path_to_spike_index_clear)) spike_index_clear = np.load(path_to_spike_index_clear) else: # if it doesn't, detect spikes... logger.info('Did not find file in {}, finding spikes using threshold' ' detector...'.format(path_to_spike_index_clear)) # apply threshold detector on standarized data spikes = bp.multi_channel_apply(detect.threshold, mode='memory', cleanup_function=detect.fix_indexes, neighbors=CONFIG.neighChannels, spike_size=CONFIG.spikeSize, std_factor=CONFIG.stdFactor) spike_index_clear = np.vstack(spikes) logger.info('Removing clear indexes outside the allowed range to ' 'draw a complete waveform...') spike_index_clear, _ = (detect.remove_incomplete_waveforms( spike_index_clear, CONFIG.spikeSize + CONFIG.templatesMaxShift, n_observations)) logger.info('Saving spikes in {}...'.format(path_to_spike_index_clear)) np.save(path_to_spike_index_clear, spike_index_clear) path_to_spike_index_collision = os.path.join(TMP_FOLDER, 'spike_index_collision.npy') # collided spikes if os.path.exists(path_to_spike_index_collision): # if it exists, load it... logger.info('Found collided spikes in {}, loading them...'.format( path_to_spike_index_collision)) spike_index_collision = np.load(path_to_spike_index_collision) if spike_index_collision.shape[0] != 0: raise ValueError('Found non-empty collision spike index in {}, ' 'but threshold detector is selected, collision ' 'detection is not implemented for threshold ' 'detector so array must have dimensios (0, 2) ' 'but had ({}, {})'.format( path_to_spike_index_collision, *spike_index_collision.shape)) else: # triage is not implemented on threshold detector, return empty array logger.info('Creating empty array for' ' collided spikes (collision detection is not implemented' ' with threshold detector. Saving them in {}'.format( path_to_spike_index_collision)) spike_index_collision = np.zeros((0, 2), 'int32') np.save(path_to_spike_index_collision, spike_index_collision) ####################### # Waveform extraction # ####################### # load and dump waveforms from clear spikes path_to_waveforms_clear = os.path.join(TMP_FOLDER, 'waveforms_clear.npy') if os.path.exists(path_to_waveforms_clear): logger.info('Found clear waveforms in {}, loading them...'.format( path_to_waveforms_clear)) waveforms_clear = np.load(path_to_waveforms_clear) else: logger.info( 'Did not find clear waveforms in {}, reading them from {}'.format( path_to_waveforms_clear, standarized_path)) explorer = RecordingExplorer(standarized_path, spike_size=CONFIG.spikeSize) waveforms_clear = explorer.read_waveforms(spike_index_clear[:, 0]) np.save(path_to_waveforms_clear, waveforms_clear) logger.info('Saved waveform from clear spikes in: {}'.format( path_to_waveforms_clear)) ######################### # PCA - rotation matrix # ######################### # compute per-batch sufficient statistics for PCA on standarized data logger.info('Computing PCA sufficient statistics...') stats = bp.multi_channel_apply(dim_red.suff_stat, mode='memory', spike_index=spike_index_clear, spike_size=CONFIG.spikeSize) suff_stats = reduce(lambda x, y: np.add(x, y), [e[0] for e in stats]) spikes_per_channel = reduce(lambda x, y: np.add(x, y), [e[1] for e in stats]) # compute rotation matrix logger.info('Computing PCA projection matrix...') rotation = dim_red.project(suff_stats, spikes_per_channel, CONFIG.spikes.temporal_features, CONFIG.neighChannels) path_to_rotation = os.path.join(TMP_FOLDER, 'rotation.npy') np.save(path_to_rotation, rotation) logger.info('Saved rotation matrix in {}...'.format(path_to_rotation)) main_channel = spike_index_clear[:, 1] ########################################### # PCA - waveform dimensionality reduction # ########################################### if CONFIG.clustering.clustering_method == 'location': logger.info('Denoising...') path_to_denoised_waveforms = os.path.join(TMP_FOLDER, 'denoised_waveforms.npy') if os.path.exists(path_to_denoised_waveforms): logger.info( 'Found denoised waveforms in {}, loading them...'.format( path_to_denoised_waveforms)) denoised_waveforms = np.load(path_to_denoised_waveforms) else: logger.info( 'Did not find denoised waveforms in {}, evaluating them' 'from {}'.format(path_to_denoised_waveforms, path_to_waveforms_clear)) waveforms_clear = np.load(path_to_waveforms_clear) denoised_waveforms = dim_red.denoise(waveforms_clear, rotation, CONFIG) logger.info('Saving denoised waveforms to {}'.format( path_to_denoised_waveforms)) np.save(path_to_denoised_waveforms, denoised_waveforms) isolated_index, x, y = get_isolated_spikes_and_locations( denoised_waveforms, main_channel, CONFIG) x = (x - np.mean(x)) / np.std(x) y = (y - np.mean(y)) / np.std(y) corrupted_index = np.logical_not( np.in1d(np.arange(spike_index_clear.shape[0]), isolated_index)) spike_index_collision = np.concatenate( [spike_index_collision, spike_index_clear[corrupted_index]], axis=0) spike_index_clear = spike_index_clear[isolated_index] waveforms_clear = waveforms_clear[isolated_index] ################################################# # Dimensionality reduction (Isolated Waveforms) # ################################################# scores = dim_red.main_channel_scores(waveforms_clear, rotation, spike_index_clear, CONFIG) scores = (scores - np.mean(scores, axis=0)) / np.std(scores) scores = np.concatenate([ x[:, np.newaxis, np.newaxis], y[:, np.newaxis, np.newaxis], scores[:, :, np.newaxis] ], axis=1) else: logger.info('Reducing spikes dimensionality with PCA matrix...') scores = dim_red.score(waveforms_clear, rotation, spike_index_clear[:, 1], CONFIG.neighChannels, CONFIG.geom) # save scores path_to_score = os.path.join(TMP_FOLDER, 'score_clear.npy') np.save(path_to_score, scores) logger.info('Saved spike scores in {}...'.format(path_to_score)) return scores, spike_index_clear, spike_index_collision