import logging import yass from yass import preprocess # configure logging module to get useful information logging.basicConfig(level=logging.INFO) # set yass configuration parameters yass.set_config('config_sample.yaml') # run preprocessor (standarized_path, standarized_params, channel_index, whiten_filter) = preprocess.run(if_file_exists='skip')
import logging import yass from yass import preprocess from yass import process from yass import deconvolute # configure logging module to get useful information logging.basicConfig(level=logging.INFO) # set yass configuration parameters yass.set_config('tests/config_nnet.yaml') # run preprocessor score, clr_idx, spt = preprocess.run() # run processor spike_train, spikes_left, templates = process.run(score, clr_idx, spt) # run deconvolution spikes = deconvolute.run(spike_train, spikes_left, templates)
def test_splitting_in_batches_does_not_affect_result(path_to_tests): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() PATH_TO_DATA = path.join(path_to_tests, 'data/standarized.bin') data = RecordingsReader(PATH_TO_DATA, loader='array').data with open(path.join(path_to_tests, 'data/standarized.yaml')) as f: PARAMS = yaml.load(f) channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom) whiten_filter = np.tile( np.eye(channel_index.shape[1], dtype='float32')[np.newaxis, :, :], [channel_index.shape[0], 1, 1]) detection_th = CONFIG.detect.neural_network_detector.threshold_spike triage_th = CONFIG.detect.neural_network_triage.threshold_collision detection_fname = CONFIG.detect.neural_network_detector.filename ae_fname = CONFIG.detect.neural_network_autoencoder.filename triage_fname = CONFIG.detect.neural_network_triage.filename (x_tf, output_tf, NND, NNAE, NNT) = neuralnetwork.prepare_nn( channel_index, whiten_filter, detection_th, triage_th, detection_fname, ae_fname, triage_fname, ) # run all at once with tf.Session() as sess: # get values of above tensors NND.saver.restore(sess, NND.path_to_detector_model) NNAE.saver_ae.restore(sess, NNAE.path_to_ae_model) NNT.saver.restore(sess, NNT.path_to_triage_model) rot = NNAE.load_rotation() neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) (scores, clear, collision) = neuralnetwork.run_detect_triage_featurize( data, sess, x_tf, output_tf, neighbors, rot) # run in batches - buffer size makes sure we can detect spikes if they # appear at the end of any batch bp = BatchProcessor(PATH_TO_DATA, PARAMS['dtype'], PARAMS['n_channels'], PARAMS['data_order'], '100KB', buffer_size=CONFIG.spike_size) with tf.Session() as sess: # get values of above tensors NND.saver.restore(sess, NND.path_to_detector_model) NNAE.saver_ae.restore(sess, NNAE.path_to_ae_model) NNT.saver.restore(sess, NNT.path_to_triage_model) rot = NNAE.load_rotation() neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) res = bp.multi_channel_apply( neuralnetwork.run_detect_triage_featurize, mode='memory', cleanup_function=neuralnetwork.fix_indexes, sess=sess, x_tf=x_tf, output_tf=output_tf, rot=rot, neighbors=neighbors) scores_batch = np.concatenate([element[0] for element in res], axis=0) clear_batch = np.concatenate([element[1] for element in res], axis=0) collision_batch = np.concatenate([element[2] for element in res], axis=0) np.testing.assert_array_equal(clear_batch, clear) np.testing.assert_array_equal(collision_batch, collision) np.testing.assert_array_equal(scores_batch, scores)
import logging import numpy as np import yass from yass import preprocess from yass import detect from yass import cluster from yass import templates from yass import deconvolute np.random.seed(0) # configure logging module to get useful information logging.basicConfig(level=logging.INFO) # set yass configuration parameters yass.set_config('config_sample.yaml', 'deconv-example') standardized_path, standardized_params, whiten_filter = preprocess.run() (spike_index_clear, spike_index_all) = detect.run(standardized_path, standardized_params, whiten_filter) spike_train_clear, tmp_loc, vbParam = cluster.run(spike_index_clear) (templates_, spike_train, groups, idx_good_templates) = templates.run(spike_train_clear, tmp_loc) spike_train = deconvolute.run(spike_index_all, templates_)
def test_splitting_in_batches_does_not_affect(path_to_tests, path_to_standarized_data, path_to_sample_pipeline_folder): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() PATH_TO_DATA = path_to_standarized_data data = RecordingsReader(PATH_TO_DATA, loader='array').data with open( path.join(path_to_sample_pipeline_folder, 'preprocess', 'standarized.yaml')) as f: PARAMS = yaml.load(f) channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom) detection_th = CONFIG.detect.neural_network_detector.threshold_spike triage_th = CONFIG.detect.neural_network_triage.threshold_collision detection_fname = CONFIG.detect.neural_network_detector.filename ae_fname = CONFIG.detect.neural_network_autoencoder.filename triage_fname = CONFIG.detect.neural_network_triage.filename # instantiate neural networks NND = NeuralNetDetector.load(detection_fname, detection_th, channel_index) NNT = NeuralNetTriage.load(triage_fname, triage_th, input_tensor=NND.waveform_tf) NNAE = AutoEncoder(ae_fname, input_tensor=NND.waveform_tf) output_tf = (NNAE.score_tf, NND.spike_index_tf, NNT.idx_clean) # run all at once with tf.Session() as sess: # get values of above tensors NND.restore(sess) NNAE.restore(sess) NNT.restore(sess) rot = NNAE.load_rotation() neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) (scores, clear, collision) = neuralnetwork.run_detect_triage_featurize( data, sess, NND.x_tf, output_tf, neighbors, rot) # run in batches - buffer size makes sure we can detect spikes if they # appear at the end of any batch bp = BatchProcessor(PATH_TO_DATA, PARAMS['dtype'], PARAMS['n_channels'], PARAMS['data_order'], '100KB', buffer_size=CONFIG.spike_size) with tf.Session() as sess: # get values of above tensors NND.restore(sess) NNAE.restore(sess) NNT.restore(sess) rot = NNAE.load_rotation() neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) res = bp.multi_channel_apply( neuralnetwork.run_detect_triage_featurize, mode='memory', cleanup_function=neuralnetwork.fix_indexes, sess=sess, x_tf=NND.x_tf, output_tf=output_tf, rot=rot, neighbors=neighbors) scores_batch = np.concatenate([element[0] for element in res], axis=0) clear_batch = np.concatenate([element[1] for element in res], axis=0) collision_batch = np.concatenate([element[2] for element in res], axis=0) np.testing.assert_array_equal(clear_batch, clear) np.testing.assert_array_equal(collision_batch, collision) np.testing.assert_array_equal(scores_batch, scores)
def test_splitting_in_batches_does_not_affect(path_to_config, path_to_sample_pipeline_folder, make_tmp_folder, path_to_standarized_data): yass.set_config(path_to_config, make_tmp_folder) CONFIG = yass.read_config() PATH_TO_DATA = path_to_standarized_data with open(path.join(path_to_sample_pipeline_folder, 'preprocess', 'standarized.yaml')) as f: PARAMS = yaml.load(f) channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom) detection_th = CONFIG.detect.neural_network_detector.threshold_spike triage_th = CONFIG.detect.neural_network_triage.threshold_collision detection_fname = CONFIG.detect.neural_network_detector.filename ae_fname = CONFIG.detect.neural_network_autoencoder.filename triage_fname = CONFIG.detect.neural_network_triage.filename # instantiate neural networks NND = NeuralNetDetector.load(detection_fname, detection_th, channel_index) triage = KerasModel(triage_fname, allow_longer_waveform_length=True, allow_more_channels=True) NNAE = AutoEncoder.load(ae_fname, input_tensor=NND.waveform_tf) bp = BatchProcessor(PATH_TO_DATA, PARAMS['dtype'], PARAMS['n_channels'], PARAMS['data_order'], '100KB', buffer_size=CONFIG.spike_size) out = ('spike_index', 'waveform') fn = neuralnetwork.apply.fix_indexes_spike_index # detector with tf.Session() as sess: # get values of above tensors NND.restore(sess) res = bp.multi_channel_apply(NND.predict_recording, mode='memory', sess=sess, output_names=out, cleanup_function=fn) spike_index_new = np.concatenate([element[0] for element in res], axis=0) wfs = np.concatenate([element[1] for element in res], axis=0) idx_clean = triage.predict_with_threshold(wfs, triage_th) score = NNAE.predict(wfs) rot = NNAE.load_rotation() neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) (score_clear_new, spike_index_clear_new) = post_processing(score, spike_index_new, idx_clean, rot, neighbors) with tf.Session() as sess: # get values of above tensors NND.restore(sess) res = bp.multi_channel_apply(NND.predict_recording, mode='memory', sess=sess, output_names=('spike_index', 'waveform'), cleanup_function=fn) spike_index_batch, wfs = zip(*res) spike_index_batch = np.concatenate(spike_index_batch, axis=0) wfs = np.concatenate(wfs, axis=0) idx_clean = triage.predict_with_threshold(x=wfs, threshold=triage_th) score = NNAE.predict(wfs) rot = NNAE.load_rotation() neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) (score_clear_batch, spike_index_clear_batch) = post_processing(score, spike_index_batch, idx_clean, rot, neighbors)
def test_decovnolute_new_pipeline(path_to_config): yass.set_config('tests/config_nnet.yaml') score, clr_idx, spt = preprocess.run() spike_train, spikes_left, templates = process.run(score, clr_idx, spt) deconvolute.run(spike_train, spikes_left, templates)
""" Detecting spikes """ import logging import yass from yass import preprocess from yass import detect # configure logging module to get useful information logging.basicConfig(level=logging.INFO) # set yass configuration parameters yass.set_config('config.yaml', 'example-detect') # run preprocessor standardized_path, standardized_params, whiten_filter = preprocess.run() # run detection clear, collision = detect.run(standardized_path, standardized_params, whiten_filter, if_file_exists='overwrite')
def run(config, logger_level='INFO', clean=False, output_dir='tmp/', complete=False, calculate_rf=False, visualize=False, set_zero_seed=False): """Run YASS built-in pipeline Parameters ---------- config: str or mapping (such as dictionary) Path to YASS configuration file or mapping object logger_level: str Logger level clean: bool, optional Delete CONFIG.data.root_folder/output_dir/ before running output_dir: str, optional Output directory (if relative, it makes it relative to CONFIG.data.root_folder) to store the output data, defaults to tmp/. If absolute, it leaves it as it is. complete: bool, optional Generates extra files (needed to generate phy files) Notes ----- Running the preprocessor will generate the followiing files in CONFIG.data.root_folder/output_directory/: * ``config.yaml`` - Copy of the configuration file * ``metadata.yaml`` - Experiment metadata * ``filtered.bin`` - Filtered recordings (from preprocess) * ``filtered.yaml`` - Filtered recordings metadata (from preprocess) * ``standardized.bin`` - Standarized recordings (from preprocess) * ``standardized.yaml`` - Standarized recordings metadata (from preprocess) * ``whitening.npy`` - Whitening filter (from preprocess) Returns ------- numpy.ndarray Spike train """ # load yass configuration parameters set_config(config, output_dir) CONFIG = read_config() TMP_FOLDER = CONFIG.path_to_output_directory # remove tmp folder if needed if os.path.exists(TMP_FOLDER) and clean: shutil.rmtree(TMP_FOLDER) # create TMP_FOLDER if needed if not os.path.exists(TMP_FOLDER): os.makedirs(TMP_FOLDER) # load logging config file logging_config = load_logging_config_file() logging_config['handlers']['file']['filename'] = os.path.join( TMP_FOLDER,'yass.log') logging_config['root']['level'] = logger_level # configure logging logging.config.dictConfig(logging_config) # instantiate logger logger = logging.getLogger(__name__) # print yass version logger.info('YASS version: %s', yass.__version__) ''' ********************************************** ******** SET ENVIRONMENT VARIABLES *********** ********************************************** ''' os.environ["OPENBLAS_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" os.environ["GIO_EXTRA_MODULES"] = "/usr/lib/x86_64-linux-gnu/gio/modules/" ''' ********************************************** ************** PREPROCESS ******************** ********************************************** ''' # preprocess start = time.time() (standardized_path, standardized_dtype) = preprocess.run( os.path.join(TMP_FOLDER, 'preprocess')) #### Block 1: Detection, Clustering, Postprocess #print ("CLUSTERING DEFAULT LENGTH: ", CONFIG.rec_len, " current set to 300 sec") (fname_templates, fname_spike_train) = initial_block( os.path.join(TMP_FOLDER, 'block_1'), standardized_path, standardized_dtype, run_chunk_sec = CONFIG.clustering_chunk) #run_chunk_sec = [0, 600*20000]) #run_chunk_sec = [0, 300]) print (" inpput to block2: ", fname_templates) #### Block 2: Deconv, Merge, Residuals, Clustering, Postprocess n_iterations = 1 for it in range(n_iterations): (fname_templates, fname_spike_train) = iterative_block( os.path.join(TMP_FOLDER, 'block_{}'.format(it+2)), standardized_path, standardized_dtype, fname_templates, run_chunk_sec = CONFIG.clustering_chunk) ### Pre-final deconv: Deconvolve, Residual, Merge, kill low fr units (fname_templates, fname_spike_train)= pre_final_deconv( os.path.join(TMP_FOLDER, 'pre_final_deconv'), standardized_path, standardized_dtype, fname_templates, run_chunk_sec = CONFIG.clustering_chunk) ### Final deconv: Deconvolve, Residual, soft assignment (fname_templates, fname_spike_train, fname_soft_assignment)= final_deconv( os.path.join(TMP_FOLDER, 'final_deconv'), standardized_path, standardized_dtype, fname_templates, update_templates = True, run_chunk_sec = CONFIG.final_deconv_chunk) ## save the final templates and spike train fname_templates_final = os.path.join( TMP_FOLDER, 'templates.npy') fname_spike_train_final = os.path.join( TMP_FOLDER, 'spike_train.npy') fname_soft_assignment_final = os.path.join( TMP_FOLDER, 'soft_assignment.npy') # tranpose axes templates = np.load(fname_templates).transpose(1,2,0) # align spike time to the beginning spike_train = np.load(fname_spike_train) #spike_train[:,0] -= CONFIG.spike_size//2 soft_assignment = np.load(fname_soft_assignment) np.save(fname_templates_final, templates) np.save(fname_spike_train_final, spike_train) np.save(fname_soft_assignment_final, soft_assignment) total_time = time.time() - start ''' ********************************************** ************** RF / VISUALIZE **************** ********************************************** ''' if calculate_rf: rf.run() if visualize: visual.run() logger.info('Finished YASS execution. Total time: {}'.format( human_readable_time(total_time))) logger.info('Final Templates Location: '+fname_templates_final) logger.info('Final Spike Train Location: '+fname_spike_train_final)