def test_cluster(path_to_threshold_config): yass.set_config(path_to_threshold_config) (standarized_path, standarized_params, whiten_filter) = preprocess.run() (score, spike_index_clear, spike_index_all) = detect.run(standarized_path, standarized_params, whiten_filter) cluster.run(score, spike_index_clear) clean_tmp()
def test_cluster_nnet(path_to_config, make_tmp_folder): yass.set_config(path_to_config, make_tmp_folder) (standarized_path, standarized_params, whiten_filter) = preprocess.run() spike_index_all = detect.run(standarized_path, standarized_params, whiten_filter) cluster.run(None, spike_index_all)
def test_templates_returns_expected_results(path_to_config, path_to_output_reference, make_tmp_folder): yass.set_config(path_to_config, make_tmp_folder) (standarized_path, standarized_params, whiten_filter) = preprocess.run() (spike_index_clear, spike_index_all) = detect.run(standarized_path, standarized_params, whiten_filter) (spike_train_clear, tmp_loc, vbParam) = cluster.run(spike_index_clear) (templates_, spike_train, groups, idx_good_templates) = templates.run(spike_train_clear, tmp_loc, save_results=True) path_to_templates = path.join(path_to_output_reference, 'templates.npy') ReferenceTesting.assert_array_equal(templates_, path_to_templates)
def test_templates_loads_from_disk_if_files_exist(caplog, path_to_threshold_config): yass.set_config(path_to_threshold_config) (standarized_path, standarized_params, channel_index, whiten_filter) = preprocess.run() (score, spike_index_clear, spike_index_all) = detect.run(standarized_path, standarized_params, channel_index, whiten_filter) spike_train_clear, tmp_loc, vbParam = cluster.run( score, spike_index_clear) # save results templates.run(spike_train_clear, tmp_loc, save_results=True) assert templates.run.executed # next time this should not run and just load from files templates.run(spike_train_clear, tmp_loc, save_results=True) assert not templates.run.executed
def test_templates_returns_expected_results(path_to_threshold_config, path_to_data_folder): np.random.seed(0) yass.set_config(path_to_threshold_config) (standarized_path, standarized_params, channel_index, whiten_filter) = preprocess.run() (score, spike_index_clear, spike_index_all) = detect.run(standarized_path, standarized_params, channel_index, whiten_filter) spike_train_clear, tmp_loc, vbParam = cluster.run(score, spike_index_clear) (templates_, spike_train, groups, idx_good_templates) = templates.run(spike_train_clear, tmp_loc) path_to_templates = path.join(path_to_data_folder, 'output_reference', 'templates.npy') ReferenceTesting.assert_array_equal(templates_, path_to_templates) clean_tmp()
def initial_block(TMP_FOLDER, standardized_path, standardized_params, run_chunk_sec): logger = logging.getLogger(__name__) if not os.path.exists(TMP_FOLDER): os.makedirs(TMP_FOLDER) ''' ********************************************** ************** DETECT EVENTS ***************** ********************************************** ''' # detect logger.info('INITIAL DETECTION') spike_index_path = detect.run(standardized_path, standardized_params, os.path.join(TMP_FOLDER, 'detect'), run_chunk_sec=run_chunk_sec) logger.info('INITIAL CLUSTERING') # cluster raw_data = True full_run = True fname_templates, fname_spike_train = cluster.run( spike_index_path, standardized_path, standardized_params['dtype'], os.path.join(TMP_FOLDER, 'cluster'), raw_data, full_run) methods = ['duplicate', 'high_mad', 'collision'] fname_templates, fname_spike_train = postprocess.run( methods, fname_templates, fname_spike_train, os.path.join(TMP_FOLDER, 'cluster_post_process'), standardized_path, standardized_params['dtype']) return fname_templates, fname_spike_train
def test_cluster_nnet(path_to_config, make_tmp_folder): yass.set_config(path_to_config, make_tmp_folder) (standardized_path, standardized_params) = preprocess.run( os.path.join(make_tmp_folder, 'preprocess')) spike_index_path = detect.run( standardized_path, standardized_params, os.path.join(make_tmp_folder, 'detect')) cluster.run( spike_index_path, standardized_path, standardized_params['dtype'], os.path.join(make_tmp_folder, 'cluster'), True, True)
def test_deconvolution(patch_triage_network, path_to_config, make_tmp_folder): yass.set_config(path_to_config, make_tmp_folder) (standarized_path, standarized_params, whiten_filter) = preprocess.run() spike_index_all = detect.run(standarized_path, standarized_params, whiten_filter) cluster.run(None, spike_index_all) CONFIG = read_config() TMP_FOLDER = CONFIG.path_to_output_directory path_to_spike_train_cluster = path.join(TMP_FOLDER, 'spike_train_cluster.npy') spike_train_cluster = np.load(path_to_spike_train_cluster) templates_cluster = np.load(path.join(TMP_FOLDER, 'templates_cluster.npy')) spike_train, postdeconv_templates = deconvolve.run(spike_train_cluster, templates_cluster)
def test_cluster_loads_from_disk_if_all_files_exist(caplog, path_to_threshold_config): yass.set_config(path_to_threshold_config) (standarized_path, standarized_params, whiten_filter) = preprocess.run() (score, spike_index_clear, spike_index_all) = detect.run(standarized_path, standarized_params, whiten_filter) # save results cluster.run(score, spike_index_clear, save_results=True) assert cluster.run.executed # next time this should not run and just load from files cluster.run(score, spike_index_clear, save_results=True) assert not cluster.run.executed
def main(): settings.run() start = datetime.now() logger = logging.getLogger(__name__) logger.info('Preprocessing started at second: %.2f', (datetime.now() - start).total_seconds()) # preprocessing (standarized_path, standarized_params, channel_index, whiten_filter) = preprocess.run(output_directory='profiling', if_file_exists='overwrite') logger.info('Preprocessing finished and detection started at second: %.2f', (datetime.now() - start).total_seconds()) # detection (score, spike_index_clear, spike_index_all) = detect.run(standarized_path, standarized_params, channel_index, whiten_filter, output_directory='profiling', if_file_exists='overwrite', save_results=True) logger.info('Detection finished and clustering started at second: %.2f', (datetime.now() - start).total_seconds()) # clustering spike_train_clear = cluster.run(score, spike_index_clear, output_directory='profiling', if_file_exists='overwrite', save_results=True) logger.info('Clustering finished and templates started at second: %.2f', (datetime.now() - start).total_seconds()) # templates the_templates = templates.run(spike_train_clear, output_directory='profiling', if_file_exists='overwrite', save_results=True) logger.info('templates finished and deconvolution started at second: %.2f', (datetime.now() - start).total_seconds()) # deconvolution deconvolute.run(spike_index_all, the_templates, output_directory='profiling') logger.info('Deconvolution finished at second: %.2f', (datetime.now() - start).total_seconds())
def test_decovnolution(path_to_threshold_config): yass.set_config('tests/config_nnet.yaml') (standarized_path, standarized_params, whiten_filter) = preprocess.run() (score, spike_index_clear, spike_index_all) = detect.run(standarized_path, standarized_params, whiten_filter) spike_train_clear, tmp_loc, vbParam = cluster.run(score, spike_index_clear) (templates_, spike_train, groups, idx_good_templates) = templates.run(spike_train_clear, tmp_loc) deconvolute.run(spike_index_all, templates_) clean_tmp()
def test_deconvolution(patch_triage_network, path_to_config, make_tmp_folder): yass.set_config(path_to_config, make_tmp_folder) (standardized_path, standardized_params) = preprocess.run( os.path.join(make_tmp_folder, 'preprocess')) spike_index_path = detect.run(standardized_path, standardized_params, os.path.join(make_tmp_folder, 'detect')) fname_templates, fname_spike_train = cluster.run( spike_index_path, standardized_path, standardized_params['dtype'], os.path.join(make_tmp_folder, 'cluster'), True, True) (fname_templates, fname_spike_train, fname_templates_up, fname_spike_train_up) = deconvolve.run( fname_templates, os.path.join(make_tmp_folder, 'deconv'), standardized_path, standardized_params['dtype'])
def test_templates(path_to_threshold_config): yass.set_config(path_to_threshold_config) (standarized_path, standarized_params, channel_index, whiten_filter) = preprocess.run() (score, spike_index_clear, spike_index_all) = detect.run(standarized_path, standarized_params, channel_index, whiten_filter) spike_train_clear, tmp_loc, vbParam = cluster.run( score, spike_index_clear) templates.run(spike_train_clear, tmp_loc) clean_tmp()
def main(): settings.run() start = datetime.now() logger = logging.getLogger(__name__) logger.info('Preprocessing started at second: %.2f', (datetime.now() - start).total_seconds()) # preprocessing (standarized_path, standarized_params, channel_index, whiten_filter) = preprocess.run() logger.info('Preprocessing finished and detection started at second: %.2f', (datetime.now() - start).total_seconds()) # detection (score, spike_index_clear, spike_index_all) = detect.run(standarized_path, standarized_params, channel_index, whiten_filter) logger.info('Detection finished and clustering started at second: %.2f', (datetime.now() - start).total_seconds()) # clustering spike_train_clear = cluster.run(score, spike_index_clear) logger.info('Clustering finished and templates started at second: %.2f', (datetime.now() - start).total_seconds()) # templates the_templates = templates.run(spike_train_clear) logger.info('templates finished and deconvolution started at second: %.2f', (datetime.now() - start).total_seconds()) # deconvolution deconvolute.run(spike_index_all, the_templates) logger.info('Deconvolution finished at second: %.2f', (datetime.now() - start).total_seconds())
def test_cluster_returns_expected_results(path_to_threshold_config, path_to_data_folder): np.random.seed(0) yass.set_config(path_to_threshold_config) (standarized_path, standarized_params, whiten_filter) = preprocess.run() (score, spike_index_clear, spike_index_all) = detect.run(standarized_path, standarized_params, whiten_filter) spike_train, tmp_loc, vbParam = cluster.run(score, spike_index_clear) path_to_spike_train = path.join(path_to_data_folder, 'output_reference', 'cluster_spike_train.npy') path_to_tmp_loc = path.join(path_to_data_folder, 'output_reference', 'cluster_tmp_loc.npy') ReferenceTesting.assert_array_equal(spike_train, path_to_spike_train) ReferenceTesting.assert_array_equal(tmp_loc, path_to_tmp_loc) clean_tmp()
def run(config, logger_level='INFO', clean=False, output_dir='tmp/'): """Run YASS built-in pipeline Parameters ---------- config: str or mapping (such as dictionary) Path to YASS configuration file or mapping object logger_level: str Logger level clean: bool, optional Delete CONFIG.data.root_folder/output_dir/ before running output_dir: str, optional Output directory (if relative, it makes it relative to CONFIG.data.root_folder) to store the output data, defaults to tmp/. If absolute, it leaves it as it is. complete: bool, optional Generates extra files (needed to generate phy files) Notes ----- Running the preprocessor will generate the followiing files in CONFIG.data.root_folder/output_directory/: * ``config.yaml`` - Copy of the configuration file * ``metadata.yaml`` - Experiment metadata * ``filtered.bin`` - Filtered recordings (from preprocess) * ``filtered.yaml`` - Filtered recordings metadata (from preprocess) * ``standardized.bin`` - Standarized recordings (from preprocess) * ``standardized.yaml`` - Standarized recordings metadata (from preprocess) * ``whitening.npy`` - Whitening filter (from preprocess) Returns ------- numpy.ndarray Spike train """ # load yass configuration parameters CONFIG = Config.from_yaml(config) CONFIG._data['cluster']['min_fr'] = 1 CONFIG._data['clean_up']['mad']['min_var_gap'] = 1.5 CONFIG._data['clean_up']['mad']['max_violations'] = 5 CONFIG._data['neuralnetwork']['apply_nn'] = False CONFIG._data['detect']['threshold'] = 4 set_config(CONFIG._data, output_dir) CONFIG = read_config() TMP_FOLDER = CONFIG.path_to_output_directory # remove tmp folder if needed if os.path.exists(TMP_FOLDER) and clean: shutil.rmtree(TMP_FOLDER) # create TMP_FOLDER if needed if not os.path.exists(TMP_FOLDER): os.makedirs(TMP_FOLDER) # load logging config file logging_config = load_logging_config_file() logging_config['handlers']['file']['filename'] = os.path.join( TMP_FOLDER, 'yass.log') logging_config['root']['level'] = logger_level # configure logging logging.config.dictConfig(logging_config) # instantiate logger logger = logging.getLogger(__name__) # print yass version logger.info('YASS version: %s', yass.__version__) ''' ********************************************** ******** SET ENVIRONMENT VARIABLES *********** ********************************************** ''' os.environ["OPENBLAS_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" os.environ["GIO_EXTRA_MODULES"] = "/usr/lib/x86_64-linux-gnu/gio/modules/" # TODO: if input spike train is None, run yass with threshold detector #if fname_spike_train is None: # logger.info('Not available yet. You must input spike train') # return ''' ********************************************** ************** PREPROCESS ******************** ********************************************** ''' # preprocess start = time.time() (standardized_path, standardized_dtype) = preprocess.run( os.path.join(TMP_FOLDER, 'preprocess')) TMP_FOLDER = os.path.join(TMP_FOLDER, 'nn_train') if not os.path.exists(TMP_FOLDER): os.makedirs(TMP_FOLDER) if CONFIG.neuralnetwork.training.input_spike_train_filname is None: # run on 10 minutes of data rec_len = np.min( (CONFIG.rec_len, CONFIG.recordings.sampling_rate * 10 * 60)) # detect logger.info('DETECTION') spike_index_path = detect.run(standardized_path, standardized_dtype, os.path.join(TMP_FOLDER, 'detect'), run_chunk_sec=[0, rec_len]) logger.info('CLUSTERING') # cluster raw_data = True full_run = False fname_templates, fname_spike_train = cluster.run( os.path.join(TMP_FOLDER, 'cluster'), standardized_path, standardized_dtype, fname_spike_index=spike_index_path, raw_data=True, full_run=True) methods = [ 'off_center', 'low_ptp', 'high_mad', 'duplicate', 'duplicate_l2' ] fname_templates, fname_spike_train = postprocess.run( methods, os.path.join(TMP_FOLDER, 'cluster_post_process'), standardized_path, standardized_dtype, fname_templates, fname_spike_train) else: # if there is an input spike train, use it fname_spike_train = CONFIG.neuralnetwork.training.input_spike_train_filname # Get training data maker DetectTD, DenoTD = augment.run(standardized_path, standardized_dtype, fname_spike_train, os.path.join(TMP_FOLDER, 'augment')) # Train Detector detector = Detect(CONFIG.neuralnetwork.detect.n_filters, CONFIG.spike_size_nn, CONFIG.channel_index).cuda() detector.train(os.path.join(TMP_FOLDER, 'detect.pt'), DetectTD) # Train Denoiser denoiser = Denoise(CONFIG.neuralnetwork.denoise.n_filters, CONFIG.neuralnetwork.denoise.filter_sizes, CONFIG.spike_size_nn).cuda() denoiser.train(os.path.join(TMP_FOLDER, 'denoise.pt'), DenoTD)
import numpy as np import logging import yass from yass import preprocess from yass import detect from yass import cluster np.random.seed(0) # configure logging module to get useful information logging.basicConfig(level=logging.INFO) # set yass configuration parameters yass.set_config('config_sample.yaml', 'preprocess-example/') standarized_path, standarized_params, whiten_filter = preprocess.run() (spike_index_clear, spike_index_all) = detect.run(standarized_path, standarized_params, whiten_filter) spike_train_clear, tmp_loc, vbParam = cluster.run(spike_index_clear)
def run(config, logger_level='INFO', clean=False, output_dir='tmp/', complete=False, set_zero_seed=False): """Run YASS built-in pipeline Parameters ---------- config: str or mapping (such as dictionary) Path to YASS configuration file or mapping object logger_level: str Logger level clean: bool, optional Delete CONFIG.data.root_folder/output_dir/ before running output_dir: str, optional Output directory (if relative, it makes it relative to CONFIG.data.root_folder) to store the output data, defaults to tmp/. If absolute, it leaves it as it is. complete: bool, optional Generates extra files (needed to generate phy files) Notes ----- Running the preprocessor will generate the followiing files in CONFIG.data.root_folder/output_directory/: * ``config.yaml`` - Copy of the configuration file * ``metadata.yaml`` - Experiment metadata * ``filtered.bin`` - Filtered recordings (from preprocess) * ``filtered.yaml`` - Filtered recordings metadata (from preprocess) * ``standardized.bin`` - Standarized recordings (from preprocess) * ``standardized.yaml`` - Standarized recordings metadata (from preprocess) * ``whitening.npy`` - Whitening filter (from preprocess) Returns ------- numpy.ndarray Spike train """ # load yass configuration parameters set_config(config, output_dir) CONFIG = read_config() TMP_FOLDER = CONFIG.path_to_output_directory # remove tmp folder if needed if os.path.exists(TMP_FOLDER) and clean: shutil.rmtree(TMP_FOLDER) # create TMP_FOLDER if needed if not os.path.exists(TMP_FOLDER): os.makedirs(TMP_FOLDER) # load logging config file logging_config = load_logging_config_file() logging_config['handlers']['file']['filename'] = path.join( TMP_FOLDER, 'yass.log') logging_config['root']['level'] = logger_level # configure logging logging.config.dictConfig(logging_config) # instantiate logger logger = logging.getLogger(__name__) # print yass version logger.info('YASS version: %s', yass.__version__) ''' ********************************************** ******** SET ENVIRONMENT VARIABLES *********** ********************************************** ''' os.environ["OPENBLAS_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" os.environ["GIO_EXTRA_MODULES"] = "/usr/lib/x86_64-linux-gnu/gio/modules/" ''' ********************************************** ************** PREPROCESS ******************** ********************************************** ''' # preprocess start = time.time() (standardized_path, standardized_params, whiten_filter) = (preprocess.run( if_file_exists=CONFIG.preprocess.if_file_exists)) time_preprocess = time.time() - start ''' ********************************************** ************** DETECT EVENTS ***************** ********************************************** ''' # detect # Cat: This code now runs with open tensorflow calls start = time.time() (spike_index_all) = detect.run(standardized_path, standardized_params, whiten_filter, if_file_exists=CONFIG.detect.if_file_exists, save_results=CONFIG.detect.save_results) spike_index_clear = None time_detect = time.time() - start ''' ********************************************** ***************** CLUSTER ******************** ********************************************** ''' # cluster start = time.time() path_to_spike_train_cluster = path.join(TMP_FOLDER, 'spike_train_cluster.npy') if os.path.exists(path_to_spike_train_cluster) == False: cluster.run(spike_index_clear, spike_index_all) else: print("\nClustering completed previously...\n\n") spike_train_cluster = np.load(path_to_spike_train_cluster) templates_cluster = np.load( os.path.join(TMP_FOLDER, 'templates_cluster.npy')) time_cluster = time.time() - start #print ("Spike train clustered: ", spike_index_cluster.shape, "spike train clear: ", # spike_train_clear.shape, " templates: ", templates.shape) ''' ********************************************** ************** DECONVOLUTION ***************** ********************************************** ''' # run deconvolution start = time.time() spike_train, postdeconv_templates = deconvolve.run(spike_train_cluster, templates_cluster) time_deconvolution = time.time() - start # save spike train path_to_spike_train = path.join(TMP_FOLDER, 'spike_train_post_deconv_post_merge.npy') np.save(path_to_spike_train, spike_train) logger.info('Spike train saved in: {}'.format(path_to_spike_train)) # save template path_to_templates = path.join(TMP_FOLDER, 'templates_post_deconv_post_merge.npy') np.save(path_to_templates, postdeconv_templates) logger.info('Templates saved in: {}'.format(path_to_templates)) ''' ********************************************** ************** POST PROCESSING**************** ********************************************** ''' # save metadata in tmp path_to_metadata = path.join(TMP_FOLDER, 'metadata.yaml') logging.info('Saving metadata in {}'.format(path_to_metadata)) save_metadata(path_to_metadata) # save metadata in tmp path_to_metadata = path.join(TMP_FOLDER, 'metadata.yaml') logging.info('Saving metadata in {}'.format(path_to_metadata)) save_metadata(path_to_metadata) # save config.yaml copy in tmp/ path_to_config_copy = path.join(TMP_FOLDER, 'config.yaml') if isinstance(config, Mapping): with open(path_to_config_copy, 'w') as f: yaml.dump(config, f, default_flow_style=False) else: shutil.copy2(config, path_to_config_copy) logging.info('Saving copy of config: {} in {}'.format( config, path_to_config_copy)) # this part loads waveforms for all spikes in the spike train and scores # them, this data is needed to later generate phy files if complete: STANDARIZED_PATH = path.join(TMP_FOLDER, 'standardized.bin') PARAMS = load_yaml(path.join(TMP_FOLDER, 'standardized.yaml')) # load waveforms for all spikes in the spike train logger.info('Loading waveforms from all spikes in the spike train...') explorer = RecordingExplorer(STANDARIZED_PATH, spike_size=CONFIG.spike_size, dtype=PARAMS['dtype'], n_channels=PARAMS['n_channels'], data_order=PARAMS['data_order']) waveforms = explorer.read_waveforms(spike_train[:, 0]) path_to_waveforms = path.join(TMP_FOLDER, 'spike_train_waveforms.npy') np.save(path_to_waveforms, waveforms) logger.info('Saved all waveforms from the spike train in {}...'.format( path_to_waveforms)) # score all waveforms logger.info('Scoring waveforms from all spikes in the spike train...') path_to_rotation = path.join(TMP_FOLDER, 'rotation.npy') rotation = np.load(path_to_rotation) main_channels = explorer.main_channel_for_waveforms(waveforms) path_to_main_channels = path.join(TMP_FOLDER, 'waveforms_main_channel.npy') np.save(path_to_main_channels, main_channels) logger.info('Saved all waveforms main channels in {}...'.format( path_to_waveforms)) waveforms_score = dim_red.score(waveforms, rotation, main_channels, CONFIG.neigh_channels, CONFIG.geom) path_to_waveforms_score = path.join(TMP_FOLDER, 'waveforms_score.npy') np.save(path_to_waveforms_score, waveforms_score) logger.info('Saved all scores in {}...'.format(path_to_waveforms)) # score templates # TODO: templates should be returned in the right shape to avoid .T templates_ = templates.T main_channels_tmpls = explorer.main_channel_for_waveforms(templates_) path_to_templates_main_c = path.join(TMP_FOLDER, 'templates_main_channel.npy') np.save(path_to_templates_main_c, main_channels_tmpls) logger.info('Saved all templates main channels in {}...'.format( path_to_templates_main_c)) templates_score = dim_red.score(templates_, rotation, main_channels_tmpls, CONFIG.neigh_channels, CONFIG.geom) path_to_templates_score = path.join(TMP_FOLDER, 'templates_score.npy') np.save(path_to_templates_score, templates_score) logger.info( 'Saved all templates scores in {}...'.format(path_to_waveforms)) logger.info('Finished YASS execution. Timing summary:') total = (time_preprocess + time_detect + time_cluster + time_deconvolution) logger.info('\t Preprocess: %s (%.2f %%)', human_readable_time(time_preprocess), time_preprocess / total * 100) logger.info('\t Detection: %s (%.2f %%)', human_readable_time(time_detect), time_detect / total * 100) logger.info('\t Clustering: %s (%.2f %%)', human_readable_time(time_cluster), time_cluster / total * 100) logger.info('\t Deconvolution: %s (%.2f %%)', human_readable_time(time_deconvolution), time_deconvolution / total * 100) return spike_train
def test_new_process_shows_error_if_empty_config(): with pytest.raises(ValueError): cluster.run(None, None)
def iterative_block(TMP_FOLDER, standardized_path, standardized_params, fname_templates, run_chunk_sec): logger = logging.getLogger(__name__) if not os.path.exists(TMP_FOLDER): os.makedirs(TMP_FOLDER) # run deconvolution logger.info('DECONV') (fname_templates, fname_spike_train, fname_templates_up, fname_spike_train_up, fname_shifts) = deconvolve.run( fname_templates, os.path.join(TMP_FOLDER, 'deconv'), standardized_path, standardized_params['dtype'], run_chunk_sec=run_chunk_sec) # compute residual logger.info('RESIDUAL COMPUTATION') fname_residual, residual_dtype = residual.run( fname_shifts, fname_templates_up, fname_spike_train_up, os.path.join(TMP_FOLDER, 'residual'), standardized_path, standardized_params['dtype'], dtype_out='float32', run_chunk_sec=run_chunk_sec) logger.info('BLOCK1 MERGE') fname_templates_up, fname_spike_train_up = merge.run( os.path.join(TMP_FOLDER, 'post_deconv_merge'), False, fname_spike_train, fname_templates, fname_spike_train_up, fname_templates_up, standardized_path, standardized_params['dtype'], fname_residual, residual_dtype) fname_templates = fname_templates_up fname_spike_train = fname_spike_train_up # cluster logger.info('RECLUSTERING') raw_data = False full_run = True fname_templates, fname_spike_train = cluster.run( fname_spike_train, standardized_path, standardized_params['dtype'], os.path.join(TMP_FOLDER, 'cluster'), raw_data, full_run, fname_residual=fname_residual, residual_dtype=residual_dtype, fname_templates_up=fname_templates_up, fname_spike_train_up=fname_spike_train_up) methods = ['duplicate', 'high_mad', 'collision'] fname_templates, fname_spike_train = postprocess.run( methods, fname_templates, fname_spike_train, os.path.join(TMP_FOLDER, 'cluster_post_process'), standardized_path, standardized_params['dtype']) return fname_templates, fname_spike_train
def test_nn_output(path_to_tests): """Test that pipeline using threshold detector returns the same results """ logger = logging.getLogger(__name__) yass.set_config(path.join(path_to_tests, 'config_nn_49.yaml')) CONFIG = read_config() TMP = Path(CONFIG.data.root_folder, 'tmp') logger.info('Removing %s', TMP) shutil.rmtree(str(TMP)) PATH_TO_REF = '/home/Edu/data/nnet' np.random.seed(0) # run preprocess (standarized_path, standarized_params, whiten_filter) = preprocess.run() # load preprocess output path_to_standarized = path.join(PATH_TO_REF, 'preprocess', 'standarized.bin') path_to_whitening = path.join(PATH_TO_REF, 'preprocess', 'whitening.npy') whitening_saved = np.load(path_to_whitening) standarized_saved = RecordingsReader(path_to_standarized, loader='array').data standarized = RecordingsReader(standarized_path, loader='array').data # test preprocess np.testing.assert_array_equal(whitening_saved, whiten_filter) np.testing.assert_array_equal(standarized_saved, standarized) # run detect (score, spike_index_clear, spike_index_all) = detect.run(standarized_path, standarized_params, whiten_filter) # load detect output path_to_scores = path.join(PATH_TO_REF, 'detect', 'scores_clear.npy') path_to_spike_index_clear = path.join(PATH_TO_REF, 'detect', 'spike_index_clear.npy') path_to_spike_index_all = path.join(PATH_TO_REF, 'detect', 'spike_index_all.npy') scores_saved = np.load(path_to_scores) spike_index_clear_saved = np.load(path_to_spike_index_clear) spike_index_all_saved = np.load(path_to_spike_index_all) # test detect output np.testing.assert_array_equal(scores_saved, score) np.testing.assert_array_equal(spike_index_clear_saved, spike_index_clear) np.testing.assert_array_equal(spike_index_all_saved, spike_index_all) # run cluster (spike_train_clear, tmp_loc, vbParam) = cluster.run(score, spike_index_clear) # load cluster output path_to_spike_train_cluster = path.join(PATH_TO_REF, 'cluster', 'spike_train_cluster.npy') spike_train_cluster_saved = np.load(path_to_spike_train_cluster) # test cluster #np.testing.assert_array_equal(spike_train_cluster_saved, spike_train_clear) # run templates (templates_, spike_train, groups, idx_good_templates) = templates.run(spike_train_clear, tmp_loc, save_results=True) # load templates output path_to_templates = path.join(PATH_TO_REF, 'templates', 'templates.npy') templates_saved = np.load(path_to_templates) # test templates np.testing.assert_array_almost_equal(templates_saved, templates_, decimal=4) # run deconvolution spike_train = deconvolute.run(spike_index_all, templates_) # load deconvolution output path_to_spike_train = path.join(PATH_TO_REF, 'spike_train.npy') spike_train_saved = np.load(path_to_spike_train) # test deconvolution np.testing.assert_array_equal(spike_train_saved, spike_train)
def iterative_block(TMP_FOLDER, standardized_path, standardized_dtype, fname_templates, run_chunk_sec): logger = logging.getLogger(__name__) if not os.path.exists(TMP_FOLDER): os.makedirs(TMP_FOLDER) # run deconvolution logger.info('DECONV') (fname_templates, fname_spike_train, fname_templates_up, fname_spike_train_up, fname_shifts) = deconvolve.run(fname_templates, os.path.join(TMP_FOLDER, 'deconv'), standardized_path, standardized_dtype, run_chunk_sec=run_chunk_sec) # compute residual logger.info('RESIDUAL COMPUTATION') fname_residual, residual_dtype = residual.run(fname_shifts, fname_templates, fname_spike_train, os.path.join( TMP_FOLDER, 'residual'), standardized_path, standardized_dtype, dtype_out='float32', run_chunk_sec=run_chunk_sec) #logger.info('KILL NOISE') #fname_spike_train2 = noise.run( # fname_templates, # fname_spike_train, # fname_shifts, # os.path.join(TMP_FOLDER, # 'noise_kill'), # fname_residual, # residual_dtype) if False: logger.info('SOFT NOISE ASSIGNMENT') fname_soft_assignment = noise.run( fname_templates, fname_spike_train, fname_shifts, os.path.join(TMP_FOLDER, 'soft_assignment'), fname_residual, residual_dtype) logger.info('BLOCK1 MERGE') _, _, _ = merge.run(os.path.join(TMP_FOLDER, 'post_deconv_merge'), fname_spike_train, fname_templates, fname_soft_assignment, fname_residual, residual_dtype) # cluster logger.info('RECLUSTERING') fname_templates, fname_spike_train = cluster.run( os.path.join(TMP_FOLDER, 'cluster'), standardized_path, standardized_dtype, fname_residual=fname_residual, residual_dtype=residual_dtype, fname_spike_index=None, fname_templates=fname_templates, fname_spike_train=fname_spike_train, raw_data=False, full_run=True) methods = ['off_center', 'high_mad', 'duplicate_l2', 'duplicate'] fname_templates, fname_spike_train = postprocess.run( methods, fname_templates, fname_spike_train, os.path.join(TMP_FOLDER, 'cluster_post_process'), standardized_path, standardized_dtype) return fname_templates, fname_spike_train
def run(config, logger_level='INFO', clean=False, output_dir='tmp/', complete=False, set_zero_seed=False): """Run YASS built-in pipeline Parameters ---------- config: str or mapping (such as dictionary) Path to YASS configuration file or mapping object logger_level: str Logger level clean: bool, optional Delete CONFIG.data.root_folder/output_dir/ before running output_dir: str, optional Output directory (relative to CONFIG.data.root_folder to store the output data, defaults to tmp/ complete: bool, optional Generates extra files (needed to generate phy files) Notes ----- Running the preprocessor will generate the followiing files in CONFIG.data.root_folder/output_directory/: * ``config.yaml`` - Copy of the configuration file * ``metadata.yaml`` - Experiment metadata * ``filtered.bin`` - Filtered recordings (from preprocess) * ``filtered.yaml`` - Filtered recordings metadata (from preprocess) * ``standarized.bin`` - Standarized recordings (from preprocess) * ``standarized.yaml`` - Standarized recordings metadata (from preprocess) * ``whitening.npy`` - Whitening filter (from preprocess) Returns ------- numpy.ndarray Spike train """ # load yass configuration parameters set_config(config) CONFIG = read_config() ROOT_FOLDER = CONFIG.data.root_folder TMP_FOLDER = path.join(ROOT_FOLDER, output_dir) # remove tmp folder if needed if os.path.exists(TMP_FOLDER) and clean: shutil.rmtree(TMP_FOLDER) # create TMP_FOLDER if needed if not os.path.exists(TMP_FOLDER): os.makedirs(TMP_FOLDER) # load logging config file logging_config = load_logging_config_file() logging_config['handlers']['file']['filename'] = path.join(TMP_FOLDER, 'yass.log') logging_config['root']['level'] = logger_level # configure logging logging.config.dictConfig(logging_config) # instantiate logger and start coloredlogs logger = logging.getLogger(__name__) coloredlogs.install(logger=logger) if set_zero_seed: logger.warning('Set numpy seed to zero') np.random.seed(0) # print yass version logger.info('YASS version: %s', yass.__version__) # preprocess start = time.time() (standarized_path, standarized_params, channel_index, whiten_filter) = (preprocess .run(output_directory=output_dir, if_file_exists=CONFIG.preprocess.if_file_exists)) time_preprocess = time.time() - start # detect start = time.time() (score, spike_index_clear, spike_index_all) = detect.run(standarized_path, standarized_params, channel_index, whiten_filter, output_directory=output_dir, if_file_exists=CONFIG.detect.if_file_exists, save_results=CONFIG.detect.save_results) time_detect = time.time() - start # cluster start = time.time() spike_train_clear, tmp_loc, vbParam = cluster.run( score, spike_index_clear, output_directory=output_dir, if_file_exists=CONFIG.cluster.if_file_exists, save_results=CONFIG.cluster.save_results) time_cluster = time.time() - start # get templates start = time.time() (templates, spike_train_clear_after_templates, groups, idx_good_templates) = get_templates.run( spike_train_clear, tmp_loc, output_directory=output_dir, if_file_exists=CONFIG.templates.if_file_exists, save_results=CONFIG.templates.save_results) time_templates = time.time() - start # run deconvolution start = time.time() spike_train = deconvolute.run(spike_index_all, templates, output_directory=output_dir) time_deconvolution = time.time() - start # save metadata in tmp path_to_metadata = path.join(TMP_FOLDER, 'metadata.yaml') logging.info('Saving metadata in {}'.format(path_to_metadata)) save_metadata(path_to_metadata) # save config.yaml copy in tmp/ path_to_config_copy = path.join(TMP_FOLDER, 'config.yaml') if isinstance(config, Mapping): with open(path_to_config_copy, 'w') as f: yaml.dump(config, f, default_flow_style=False) else: shutil.copy2(config, path_to_config_copy) logging.info('Saving copy of config: {} in {}'.format(config, path_to_config_copy)) # TODO: complete flag saves other files needed for integrating phy # with yass, the integration hasn't been completed yet # this part loads waveforms for all spikes in the spike train and scores # them, this data is needed to later generate phy files if complete: STANDARIZED_PATH = path.join(TMP_FOLDER, 'standarized.bin') PARAMS = load_yaml(path.join(TMP_FOLDER, 'standarized.yaml')) # load waveforms for all spikes in the spike train logger.info('Loading waveforms from all spikes in the spike train...') explorer = RecordingExplorer(STANDARIZED_PATH, spike_size=CONFIG.spike_size, dtype=PARAMS['dtype'], n_channels=PARAMS['n_channels'], data_order=PARAMS['data_order']) waveforms = explorer.read_waveforms(spike_train[:, 0]) path_to_waveforms = path.join(TMP_FOLDER, 'spike_train_waveforms.npy') np.save(path_to_waveforms, waveforms) logger.info('Saved all waveforms from the spike train in {}...' .format(path_to_waveforms)) # score all waveforms logger.info('Scoring waveforms from all spikes in the spike train...') path_to_rotation = path.join(TMP_FOLDER, 'rotation.npy') rotation = np.load(path_to_rotation) main_channels = explorer.main_channel_for_waveforms(waveforms) path_to_main_channels = path.join(TMP_FOLDER, 'waveforms_main_channel.npy') np.save(path_to_main_channels, main_channels) logger.info('Saved all waveforms main channels in {}...' .format(path_to_waveforms)) waveforms_score = dim_red.score(waveforms, rotation, main_channels, CONFIG.neigh_channels, CONFIG.geom) path_to_waveforms_score = path.join(TMP_FOLDER, 'waveforms_score.npy') np.save(path_to_waveforms_score, waveforms_score) logger.info('Saved all scores in {}...'.format(path_to_waveforms)) # score templates # TODO: templates should be returned in the right shape to avoid .T templates_ = templates.T main_channels_tmpls = explorer.main_channel_for_waveforms(templates_) path_to_templates_main_c = path.join(TMP_FOLDER, 'templates_main_channel.npy') np.save(path_to_templates_main_c, main_channels_tmpls) logger.info('Saved all templates main channels in {}...' .format(path_to_templates_main_c)) templates_score = dim_red.score(templates_, rotation, main_channels_tmpls, CONFIG.neigh_channels, CONFIG.geom) path_to_templates_score = path.join(TMP_FOLDER, 'templates_score.npy') np.save(path_to_templates_score, templates_score) logger.info('Saved all templates scores in {}...' .format(path_to_waveforms)) logger.info('Finished YASS execution. Timing summary:') total = (time_preprocess + time_detect + time_cluster + time_templates + time_deconvolution) logger.info('\t Preprocess: %s (%.2f %%)', human_readable_time(time_preprocess), time_preprocess/total*100) logger.info('\t Detection: %s (%.2f %%)', human_readable_time(time_detect), time_detect/total*100) logger.info('\t Clustering: %s (%.2f %%)', human_readable_time(time_cluster), time_cluster/total*100) logger.info('\t Templates: %s (%.2f %%)', human_readable_time(time_templates), time_templates/total*100) logger.info('\t Deconvolution: %s (%.2f %%)', human_readable_time(time_deconvolution), time_deconvolution/total*100) return spike_train