def multi_channel_apply(self, function, mode, cleanup_function=None, output_path=None, from_time=None, to_time=None, channels='all', if_file_exists='overwrite', cast_dtype=None, **kwargs): """ Apply a function where each batch has observations from more than one channel Parameters ---------- function: callable Function to be applied, must accept a 2D numpy array in 'long' format as its first parameter (number of observations, number of channels) mode: str 'disk' or 'memory', if 'disk', a binary file is created at the beginning of the operation and each partial result is saved (ussing numpy.ndarray.tofile function), at the end of the operation two files are generated: the binary file and a yaml file with some file parameters (useful if you want to later use RecordingsReader to read the file). If 'memory', partial results are kept in memory and returned as a list cleanup_function: callable, optional A function to be executed after `function` and before adding the partial result to the list of results (if `memory` mode) or to the biinary file (if in `disk mode`) output_path: str, optional Where to save the output, required if 'disk' mode force_complete_channel_batch: bool, optional If True, every index generated will correspond to all the observations in a single channel, hence n_batches = n_selected_channels, defaults to True. If True from_time and to_time must be None from_time: int, optional Starting time, defaults to None to_time: int, optional Ending time, defaults to None channels: int, tuple or str, optional A tuple with the channel indexes or 'all' to traverse all channels, defaults to 'all' if_file_exists: str, optional One of 'overwrite', 'abort', 'skip'. If 'overwrite' it replaces the file if it exists, if 'abort' if raise a ValueError exception if the file exists, if 'skip' if skips the operation if the file exists. Only valid when mode = 'disk' cast_dtype: str, optional Output dtype, defaults to None which means no cast is done **kwargs kwargs to pass to function Returns ------- output_path Path to output binary file params Binary file params Examples -------- .. literalinclude:: ../../examples/batch/multi_channel_apply.py Notes ----- Applying functions will incur in memory overhead, which depends on the function implementation, this is an important thing to consider if the transformation changes the data's dtype (e.g. converts int16 to float64), which means that a chunk of 1MB in int16 will have a size of 4MB in float64. Take that into account when setting max_memory For performance reasons, outputs data in 'long' format. """ if mode not in ['disk', 'memory']: raise ValueError('Mode should be disk or memory, received: {}' .format(mode)) if mode == 'disk' and output_path is None: raise ValueError('output_path is required in "disk" mode') if (mode == 'disk' and if_file_exists == 'abort' and os.path.exists(output_path)): raise ValueError('{} already exists'.format(output_path)) self.logger.info('Applying function {}...' .format(function_path(function))) if (mode == 'disk' and if_file_exists == 'skip' and os.path.exists(output_path)): # load params... path_to_yaml = output_path.replace('.bin', '.yaml') if not os.path.exists(path_to_yaml): raise ValueError("if_file_exists = 'skip', but {}" " is missing, aborting..." .format(path_to_yaml)) with open(path_to_yaml) as f: params = yaml.load(f) self.logger.info('{} exists, skiping...'.format(output_path)) return output_path, params if mode == 'disk': fn = self._multi_channel_apply_disk start = time.time() res = fn(function, cleanup_function, output_path, from_time, to_time, channels, cast_dtype, **kwargs) elapsed = time.time() - start self.logger.info('{} took {}' .format(function_path(function), human_readable_time(elapsed))) return res else: fn = self._multi_channel_apply_memory start = time.time() res = fn(function, cleanup_function, from_time, to_time, channels, cast_dtype, **kwargs) elapsed = time.time() - start self.logger.info('{} took {}' .format(function_path(function), human_readable_time(elapsed))) return res
def run(config, logger_level='INFO', clean=False, output_dir='tmp/', complete=False, set_zero_seed=False): """Run YASS built-in pipeline Parameters ---------- config: str or mapping (such as dictionary) Path to YASS configuration file or mapping object logger_level: str Logger level clean: bool, optional Delete CONFIG.data.root_folder/output_dir/ before running output_dir: str, optional Output directory (if relative, it makes it relative to CONFIG.data.root_folder) to store the output data, defaults to tmp/. If absolute, it leaves it as it is. complete: bool, optional Generates extra files (needed to generate phy files) Notes ----- Running the preprocessor will generate the followiing files in CONFIG.data.root_folder/output_directory/: * ``config.yaml`` - Copy of the configuration file * ``metadata.yaml`` - Experiment metadata * ``filtered.bin`` - Filtered recordings (from preprocess) * ``filtered.yaml`` - Filtered recordings metadata (from preprocess) * ``standardized.bin`` - Standarized recordings (from preprocess) * ``standardized.yaml`` - Standarized recordings metadata (from preprocess) * ``whitening.npy`` - Whitening filter (from preprocess) Returns ------- numpy.ndarray Spike train """ # load yass configuration parameters set_config(config, output_dir) CONFIG = read_config() TMP_FOLDER = CONFIG.path_to_output_directory # remove tmp folder if needed if os.path.exists(TMP_FOLDER) and clean: shutil.rmtree(TMP_FOLDER) # create TMP_FOLDER if needed if not os.path.exists(TMP_FOLDER): os.makedirs(TMP_FOLDER) # load logging config file logging_config = load_logging_config_file() logging_config['handlers']['file']['filename'] = path.join( TMP_FOLDER, 'yass.log') logging_config['root']['level'] = logger_level # configure logging logging.config.dictConfig(logging_config) # instantiate logger logger = logging.getLogger(__name__) # print yass version logger.info('YASS version: %s', yass.__version__) ''' ********************************************** ******** SET ENVIRONMENT VARIABLES *********** ********************************************** ''' os.environ["OPENBLAS_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" os.environ["GIO_EXTRA_MODULES"] = "/usr/lib/x86_64-linux-gnu/gio/modules/" ''' ********************************************** ************** PREPROCESS ******************** ********************************************** ''' # preprocess start = time.time() (standardized_path, standardized_params, whiten_filter) = (preprocess.run( if_file_exists=CONFIG.preprocess.if_file_exists)) time_preprocess = time.time() - start ''' ********************************************** ************** DETECT EVENTS ***************** ********************************************** ''' # detect # Cat: This code now runs with open tensorflow calls start = time.time() (spike_index_all) = detect.run(standardized_path, standardized_params, whiten_filter, if_file_exists=CONFIG.detect.if_file_exists, save_results=CONFIG.detect.save_results) spike_index_clear = None time_detect = time.time() - start ''' ********************************************** ***************** CLUSTER ******************** ********************************************** ''' # cluster start = time.time() path_to_spike_train_cluster = path.join(TMP_FOLDER, 'spike_train_cluster.npy') if os.path.exists(path_to_spike_train_cluster) == False: cluster.run(spike_index_clear, spike_index_all) else: print("\nClustering completed previously...\n\n") spike_train_cluster = np.load(path_to_spike_train_cluster) templates_cluster = np.load( os.path.join(TMP_FOLDER, 'templates_cluster.npy')) time_cluster = time.time() - start #print ("Spike train clustered: ", spike_index_cluster.shape, "spike train clear: ", # spike_train_clear.shape, " templates: ", templates.shape) ''' ********************************************** ************** DECONVOLUTION ***************** ********************************************** ''' # run deconvolution start = time.time() spike_train, postdeconv_templates = deconvolve.run(spike_train_cluster, templates_cluster) time_deconvolution = time.time() - start # save spike train path_to_spike_train = path.join(TMP_FOLDER, 'spike_train_post_deconv_post_merge.npy') np.save(path_to_spike_train, spike_train) logger.info('Spike train saved in: {}'.format(path_to_spike_train)) # save template path_to_templates = path.join(TMP_FOLDER, 'templates_post_deconv_post_merge.npy') np.save(path_to_templates, postdeconv_templates) logger.info('Templates saved in: {}'.format(path_to_templates)) ''' ********************************************** ************** POST PROCESSING**************** ********************************************** ''' # save metadata in tmp path_to_metadata = path.join(TMP_FOLDER, 'metadata.yaml') logging.info('Saving metadata in {}'.format(path_to_metadata)) save_metadata(path_to_metadata) # save metadata in tmp path_to_metadata = path.join(TMP_FOLDER, 'metadata.yaml') logging.info('Saving metadata in {}'.format(path_to_metadata)) save_metadata(path_to_metadata) # save config.yaml copy in tmp/ path_to_config_copy = path.join(TMP_FOLDER, 'config.yaml') if isinstance(config, Mapping): with open(path_to_config_copy, 'w') as f: yaml.dump(config, f, default_flow_style=False) else: shutil.copy2(config, path_to_config_copy) logging.info('Saving copy of config: {} in {}'.format( config, path_to_config_copy)) # this part loads waveforms for all spikes in the spike train and scores # them, this data is needed to later generate phy files if complete: STANDARIZED_PATH = path.join(TMP_FOLDER, 'standardized.bin') PARAMS = load_yaml(path.join(TMP_FOLDER, 'standardized.yaml')) # load waveforms for all spikes in the spike train logger.info('Loading waveforms from all spikes in the spike train...') explorer = RecordingExplorer(STANDARIZED_PATH, spike_size=CONFIG.spike_size, dtype=PARAMS['dtype'], n_channels=PARAMS['n_channels'], data_order=PARAMS['data_order']) waveforms = explorer.read_waveforms(spike_train[:, 0]) path_to_waveforms = path.join(TMP_FOLDER, 'spike_train_waveforms.npy') np.save(path_to_waveforms, waveforms) logger.info('Saved all waveforms from the spike train in {}...'.format( path_to_waveforms)) # score all waveforms logger.info('Scoring waveforms from all spikes in the spike train...') path_to_rotation = path.join(TMP_FOLDER, 'rotation.npy') rotation = np.load(path_to_rotation) main_channels = explorer.main_channel_for_waveforms(waveforms) path_to_main_channels = path.join(TMP_FOLDER, 'waveforms_main_channel.npy') np.save(path_to_main_channels, main_channels) logger.info('Saved all waveforms main channels in {}...'.format( path_to_waveforms)) waveforms_score = dim_red.score(waveforms, rotation, main_channels, CONFIG.neigh_channels, CONFIG.geom) path_to_waveforms_score = path.join(TMP_FOLDER, 'waveforms_score.npy') np.save(path_to_waveforms_score, waveforms_score) logger.info('Saved all scores in {}...'.format(path_to_waveforms)) # score templates # TODO: templates should be returned in the right shape to avoid .T templates_ = templates.T main_channels_tmpls = explorer.main_channel_for_waveforms(templates_) path_to_templates_main_c = path.join(TMP_FOLDER, 'templates_main_channel.npy') np.save(path_to_templates_main_c, main_channels_tmpls) logger.info('Saved all templates main channels in {}...'.format( path_to_templates_main_c)) templates_score = dim_red.score(templates_, rotation, main_channels_tmpls, CONFIG.neigh_channels, CONFIG.geom) path_to_templates_score = path.join(TMP_FOLDER, 'templates_score.npy') np.save(path_to_templates_score, templates_score) logger.info( 'Saved all templates scores in {}...'.format(path_to_waveforms)) logger.info('Finished YASS execution. Timing summary:') total = (time_preprocess + time_detect + time_cluster + time_deconvolution) logger.info('\t Preprocess: %s (%.2f %%)', human_readable_time(time_preprocess), time_preprocess / total * 100) logger.info('\t Detection: %s (%.2f %%)', human_readable_time(time_detect), time_detect / total * 100) logger.info('\t Clustering: %s (%.2f %%)', human_readable_time(time_cluster), time_cluster / total * 100) logger.info('\t Deconvolution: %s (%.2f %%)', human_readable_time(time_deconvolution), time_deconvolution / total * 100) return spike_train
def run(config, logger_level='INFO', clean=False, output_dir='tmp/', complete=False, calculate_rf=False, visualize=False, set_zero_seed=False): """Run YASS built-in pipeline Parameters ---------- config: str or mapping (such as dictionary) Path to YASS configuration file or mapping object logger_level: str Logger level clean: bool, optional Delete CONFIG.data.root_folder/output_dir/ before running output_dir: str, optional Output directory (if relative, it makes it relative to CONFIG.data.root_folder) to store the output data, defaults to tmp/. If absolute, it leaves it as it is. complete: bool, optional Generates extra files (needed to generate phy files) Notes ----- Running the preprocessor will generate the followiing files in CONFIG.data.root_folder/output_directory/: * ``config.yaml`` - Copy of the configuration file * ``metadata.yaml`` - Experiment metadata * ``filtered.bin`` - Filtered recordings (from preprocess) * ``filtered.yaml`` - Filtered recordings metadata (from preprocess) * ``standardized.bin`` - Standarized recordings (from preprocess) * ``standardized.yaml`` - Standarized recordings metadata (from preprocess) * ``whitening.npy`` - Whitening filter (from preprocess) Returns ------- numpy.ndarray Spike train """ # load yass configuration parameters set_config(config, output_dir) CONFIG = read_config() TMP_FOLDER = CONFIG.path_to_output_directory # remove tmp folder if needed if os.path.exists(TMP_FOLDER) and clean: shutil.rmtree(TMP_FOLDER) # create TMP_FOLDER if needed if not os.path.exists(TMP_FOLDER): os.makedirs(TMP_FOLDER) # load logging config file logging_config = load_logging_config_file() logging_config['handlers']['file']['filename'] = os.path.join( TMP_FOLDER, 'yass.log') logging_config['root']['level'] = logger_level # configure logging logging.config.dictConfig(logging_config) # instantiate logger logger = logging.getLogger(__name__) # print yass version logger.info('YASS version: %s', yass.__version__) ''' ********************************************** ******** SET ENVIRONMENT VARIABLES *********** ********************************************** ''' os.environ["OPENBLAS_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" os.environ["GIO_EXTRA_MODULES"] = "/usr/lib/x86_64-linux-gnu/gio/modules/" ''' ********************************************** ************** PREPROCESS ******************** ********************************************** ''' # run preprocess and spike detect # preprocess start = time.time() (standardized_path, standardized_params) = preprocess.run( os.path.join(TMP_FOLDER, 'preprocess')) # run entire pipeline multipole times on multiple chunks: run_blocks = [] for k in range(0, 600, 60): run_blocks.append([k, k + 60]) for blk_ctr, run_block in enumerate(run_blocks): BLK_FOLDER = TMP_FOLDER + "/" + str(run_block[0]) + "_" + str( run_block[1]) if not os.path.exists(BLK_FOLDER): os.makedirs(BLK_FOLDER) print(" RUNNING ON: ", run_block) #### Block 1: Detection, Clustering, Postprocess (fname_templates, fname_spike_train) = initial_block( os.path.join(BLK_FOLDER, 'block_1'), standardized_path, standardized_params, run_chunk_sec=run_block) print(" inpput to next block: ", fname_templates) ### Block 3: Deconvolve, Residual, Merge (fname_templates, fname_spike_train, fname_templates_up, fname_spike_train_up, fname_residual, residual_dtype) = single_block( os.path.join(BLK_FOLDER, 'final_deconv'), standardized_path, standardized_params, fname_templates, run_block) ## save the final templates and spike train fname_templates_final = os.path.join(BLK_FOLDER, 'templates.npy') fname_spike_train_final = os.path.join(BLK_FOLDER, 'spike_train.npy') # tranpose axes templates = np.load(fname_templates).transpose(1, 2, 0) # align spike time to the beginning spike_train = np.load(fname_spike_train) spike_train[:, 0] -= CONFIG.spike_size // 2 np.save(fname_templates_final, templates) np.save(fname_spike_train_final, spike_train) total_time = time.time() - start logger.info('Finished YASS execution. Total time: {}'.format( human_readable_time(total_time))) logger.info('Final Templates Location: ' + fname_templates_final) logger.info('Final Spike Train Location: ' + fname_spike_train_final)
def run(config, logger_level='INFO', clean=False, output_dir='tmp/', complete=False, set_zero_seed=False): """Run YASS built-in pipeline Parameters ---------- config: str or mapping (such as dictionary) Path to YASS configuration file or mapping object logger_level: str Logger level clean: bool, optional Delete CONFIG.data.root_folder/output_dir/ before running output_dir: str, optional Output directory (relative to CONFIG.data.root_folder to store the output data, defaults to tmp/ complete: bool, optional Generates extra files (needed to generate phy files) Notes ----- Running the preprocessor will generate the followiing files in CONFIG.data.root_folder/output_directory/: * ``config.yaml`` - Copy of the configuration file * ``metadata.yaml`` - Experiment metadata * ``filtered.bin`` - Filtered recordings (from preprocess) * ``filtered.yaml`` - Filtered recordings metadata (from preprocess) * ``standarized.bin`` - Standarized recordings (from preprocess) * ``standarized.yaml`` - Standarized recordings metadata (from preprocess) * ``whitening.npy`` - Whitening filter (from preprocess) Returns ------- numpy.ndarray Spike train """ # load yass configuration parameters set_config(config) CONFIG = read_config() ROOT_FOLDER = CONFIG.data.root_folder TMP_FOLDER = path.join(ROOT_FOLDER, output_dir) # remove tmp folder if needed if os.path.exists(TMP_FOLDER) and clean: shutil.rmtree(TMP_FOLDER) # create TMP_FOLDER if needed if not os.path.exists(TMP_FOLDER): os.makedirs(TMP_FOLDER) # load logging config file logging_config = load_logging_config_file() logging_config['handlers']['file']['filename'] = path.join(TMP_FOLDER, 'yass.log') logging_config['root']['level'] = logger_level # configure logging logging.config.dictConfig(logging_config) # instantiate logger and start coloredlogs logger = logging.getLogger(__name__) coloredlogs.install(logger=logger) if set_zero_seed: logger.warning('Set numpy seed to zero') np.random.seed(0) # print yass version logger.info('YASS version: %s', yass.__version__) # preprocess start = time.time() (standarized_path, standarized_params, channel_index, whiten_filter) = (preprocess .run(output_directory=output_dir, if_file_exists=CONFIG.preprocess.if_file_exists)) time_preprocess = time.time() - start # detect start = time.time() (score, spike_index_clear, spike_index_all) = detect.run(standarized_path, standarized_params, channel_index, whiten_filter, output_directory=output_dir, if_file_exists=CONFIG.detect.if_file_exists, save_results=CONFIG.detect.save_results) time_detect = time.time() - start # cluster start = time.time() spike_train_clear, tmp_loc, vbParam = cluster.run( score, spike_index_clear, output_directory=output_dir, if_file_exists=CONFIG.cluster.if_file_exists, save_results=CONFIG.cluster.save_results) time_cluster = time.time() - start # get templates start = time.time() (templates, spike_train_clear_after_templates, groups, idx_good_templates) = get_templates.run( spike_train_clear, tmp_loc, output_directory=output_dir, if_file_exists=CONFIG.templates.if_file_exists, save_results=CONFIG.templates.save_results) time_templates = time.time() - start # run deconvolution start = time.time() spike_train = deconvolute.run(spike_index_all, templates, output_directory=output_dir) time_deconvolution = time.time() - start # save metadata in tmp path_to_metadata = path.join(TMP_FOLDER, 'metadata.yaml') logging.info('Saving metadata in {}'.format(path_to_metadata)) save_metadata(path_to_metadata) # save config.yaml copy in tmp/ path_to_config_copy = path.join(TMP_FOLDER, 'config.yaml') if isinstance(config, Mapping): with open(path_to_config_copy, 'w') as f: yaml.dump(config, f, default_flow_style=False) else: shutil.copy2(config, path_to_config_copy) logging.info('Saving copy of config: {} in {}'.format(config, path_to_config_copy)) # TODO: complete flag saves other files needed for integrating phy # with yass, the integration hasn't been completed yet # this part loads waveforms for all spikes in the spike train and scores # them, this data is needed to later generate phy files if complete: STANDARIZED_PATH = path.join(TMP_FOLDER, 'standarized.bin') PARAMS = load_yaml(path.join(TMP_FOLDER, 'standarized.yaml')) # load waveforms for all spikes in the spike train logger.info('Loading waveforms from all spikes in the spike train...') explorer = RecordingExplorer(STANDARIZED_PATH, spike_size=CONFIG.spike_size, dtype=PARAMS['dtype'], n_channels=PARAMS['n_channels'], data_order=PARAMS['data_order']) waveforms = explorer.read_waveforms(spike_train[:, 0]) path_to_waveforms = path.join(TMP_FOLDER, 'spike_train_waveforms.npy') np.save(path_to_waveforms, waveforms) logger.info('Saved all waveforms from the spike train in {}...' .format(path_to_waveforms)) # score all waveforms logger.info('Scoring waveforms from all spikes in the spike train...') path_to_rotation = path.join(TMP_FOLDER, 'rotation.npy') rotation = np.load(path_to_rotation) main_channels = explorer.main_channel_for_waveforms(waveforms) path_to_main_channels = path.join(TMP_FOLDER, 'waveforms_main_channel.npy') np.save(path_to_main_channels, main_channels) logger.info('Saved all waveforms main channels in {}...' .format(path_to_waveforms)) waveforms_score = dim_red.score(waveforms, rotation, main_channels, CONFIG.neigh_channels, CONFIG.geom) path_to_waveforms_score = path.join(TMP_FOLDER, 'waveforms_score.npy') np.save(path_to_waveforms_score, waveforms_score) logger.info('Saved all scores in {}...'.format(path_to_waveforms)) # score templates # TODO: templates should be returned in the right shape to avoid .T templates_ = templates.T main_channels_tmpls = explorer.main_channel_for_waveforms(templates_) path_to_templates_main_c = path.join(TMP_FOLDER, 'templates_main_channel.npy') np.save(path_to_templates_main_c, main_channels_tmpls) logger.info('Saved all templates main channels in {}...' .format(path_to_templates_main_c)) templates_score = dim_red.score(templates_, rotation, main_channels_tmpls, CONFIG.neigh_channels, CONFIG.geom) path_to_templates_score = path.join(TMP_FOLDER, 'templates_score.npy') np.save(path_to_templates_score, templates_score) logger.info('Saved all templates scores in {}...' .format(path_to_waveforms)) logger.info('Finished YASS execution. Timing summary:') total = (time_preprocess + time_detect + time_cluster + time_templates + time_deconvolution) logger.info('\t Preprocess: %s (%.2f %%)', human_readable_time(time_preprocess), time_preprocess/total*100) logger.info('\t Detection: %s (%.2f %%)', human_readable_time(time_detect), time_detect/total*100) logger.info('\t Clustering: %s (%.2f %%)', human_readable_time(time_cluster), time_cluster/total*100) logger.info('\t Templates: %s (%.2f %%)', human_readable_time(time_templates), time_templates/total*100) logger.info('\t Deconvolution: %s (%.2f %%)', human_readable_time(time_deconvolution), time_deconvolution/total*100) return spike_train
def multi_channel_apply(self, function, mode, cleanup_function=None, output_path=None, from_time=None, to_time=None, channels='all', if_file_exists='overwrite', cast_dtype=None, pass_batch_info=False, pass_batch_results=False, processes=1, **kwargs): """ Apply a function where each batch has observations from more than one channel Parameters ---------- function: callable Function to be applied, first parameter passed will be a 2D numpy array in 'long' shape (number of observations, number of channels). If pass_batch_info is True, another two keyword parameters will be passed to function: 'idx_local' is the slice object with the limits of the data in [observations, channels] format (excluding the buffer), 'idx' is the absolute index of the data again in [observations, channels] format mode: str 'disk' or 'memory', if 'disk', a binary file is created at the beginning of the operation and each partial result is saved (ussing numpy.ndarray.tofile function), at the end of the operation two files are generated: the binary file and a yaml file with some file parameters (useful if you want to later use RecordingsReader to read the file). If 'memory', partial results are kept in memory and returned as a list cleanup_function: callable, optional A function to be executed after `function` and before adding the partial result to the list of results (if `memory` mode) or to the biinary file (if in `disk mode`). `cleanup_function` will be called with the following parameters (in that order): result from applying `function` to the batch, slice object with the idx where the data is located (exludes buffer), slice object with the absolute location of the data and buffer size output_path: str, optional Where to save the output, required if 'disk' mode force_complete_channel_batch: bool, optional If True, every index generated will correspond to all the observations in a single channel, hence n_batches = n_selected_channels, defaults to True. If True from_time and to_time must be None from_time: int, optional Starting time, defaults to None to_time: int, optional Ending time, defaults to None channels: int, tuple or str, optional A tuple with the channel indexes or 'all' to traverse all channels, defaults to 'all' if_file_exists: str, optional One of 'overwrite', 'abort', 'skip'. If 'overwrite' it replaces the file if it exists, if 'abort' if raise a ValueError exception if the file exists, if 'skip' if skips the operation if the file exists. Only valid when mode = 'disk' cast_dtype: str, optional Output dtype, defaults to None which means no cast is done pass_batch_info: bool, optional Whether to call the function with batch info or just call it with the batch data (see description in the function) parameter pass_batch_results: bool, optional Whether to pass results from the previous batch to the next one, defaults to False. Only relevant when mode='memory'. If True, function will be called with the keyword parameter 'previous_batch' which contains the computation for the last batch, it is set to None in the first batch **kwargs kwargs to pass to function Returns ------- output_path, params (when mode is 'disk') Path to output binary file, Binary file params list (when mode is 'memory' and pass_batch_results is False) List where every element is the result of applying the function to one batch. When pass_batch_results is True, it returns the output of the function for the last batch Examples -------- .. literalinclude:: ../../examples/batch/multi_channel_apply_disk.py .. literalinclude:: ../../examples/batch/multi_channel_apply_memory.py Notes ----- Applying functions will incur in memory overhead, which depends on the function implementation, this is an important thing to consider if the transformation changes the data's dtype (e.g. converts int16 to float64), which means that a chunk of 1MB in int16 will have a size of 4MB in float64. Take that into account when setting max_memory For performance reasons, outputs data in 'samples' order. """ if mode not in ['disk', 'memory']: raise ValueError( 'Mode should be disk or memory, received: {}'.format(mode)) if mode == 'disk' and output_path is None: raise ValueError('output_path is required in "disk" mode') if (mode == 'disk' and if_file_exists == 'abort' and os.path.exists(output_path)): raise ValueError('{} already exists'.format(output_path)) self.logger.info('Applying function {}...'.format( function_path(function))) if (mode == 'disk' and if_file_exists == 'skip' and os.path.exists(output_path)): # load params... path_to_yaml = output_path.replace('.bin', '.yaml') if not os.path.exists(path_to_yaml): raise ValueError( "if_file_exists = 'skip', but {}" " is missing, aborting...".format(path_to_yaml)) with open(path_to_yaml) as f: params = yaml.load(f) self.logger.info('{} exists, skiping...'.format(output_path)) return output_path, params if mode == 'disk': if processes == 1: fn = self._multi_channel_apply_disk else: fn = partial(self._multi_channel_apply_disk_parallel, processes=processes) start = time.time() res = fn(function, cleanup_function, output_path, from_time, to_time, channels, cast_dtype, pass_batch_info, pass_batch_results, **kwargs) elapsed = time.time() - start self.logger.info('{} took {}'.format(function_path(function), human_readable_time(elapsed))) return res else: fn = self._multi_channel_apply_memory start = time.time() res = fn(function, cleanup_function, from_time, to_time, channels, cast_dtype, pass_batch_info, pass_batch_results, **kwargs) elapsed = time.time() - start self.logger.info('{} took {}'.format(function_path(function), human_readable_time(elapsed))) return res
def run(config, logger_level='INFO', clean=False, output_dir='tmp/', complete=False, calculate_rf=False, visualize=False, set_zero_seed=False): """Run YASS built-in pipeline Parameters ---------- config: str or mapping (such as dictionary) Path to YASS configuration file or mapping object logger_level: str Logger level clean: bool, optional Delete CONFIG.data.root_folder/output_dir/ before running output_dir: str, optional Output directory (if relative, it makes it relative to CONFIG.data.root_folder) to store the output data, defaults to tmp/. If absolute, it leaves it as it is. complete: bool, optional Generates extra files (needed to generate phy files) Notes ----- Running the preprocessor will generate the followiing files in CONFIG.data.root_folder/output_directory/: * ``config.yaml`` - Copy of the configuration file * ``metadata.yaml`` - Experiment metadata * ``filtered.bin`` - Filtered recordings (from preprocess) * ``filtered.yaml`` - Filtered recordings metadata (from preprocess) * ``standardized.bin`` - Standarized recordings (from preprocess) * ``standardized.yaml`` - Standarized recordings metadata (from preprocess) * ``whitening.npy`` - Whitening filter (from preprocess) Returns ------- numpy.ndarray Spike train """ # load yass configuration parameters set_config(config, output_dir) CONFIG = read_config() TMP_FOLDER = CONFIG.path_to_output_directory # remove tmp folder if needed if os.path.exists(TMP_FOLDER) and clean: shutil.rmtree(TMP_FOLDER) # create TMP_FOLDER if needed if not os.path.exists(TMP_FOLDER): os.makedirs(TMP_FOLDER) # load logging config file logging_config = load_logging_config_file() logging_config['handlers']['file']['filename'] = os.path.join( TMP_FOLDER,'yass.log') logging_config['root']['level'] = logger_level # configure logging logging.config.dictConfig(logging_config) # instantiate logger logger = logging.getLogger(__name__) # print yass version logger.info('YASS version: %s', yass.__version__) ''' ********************************************** ******** SET ENVIRONMENT VARIABLES *********** ********************************************** ''' os.environ["OPENBLAS_NUM_THREADS"] = "1" os.environ["MKL_NUM_THREADS"] = "1" os.environ["GIO_EXTRA_MODULES"] = "/usr/lib/x86_64-linux-gnu/gio/modules/" ''' ********************************************** ************** PREPROCESS ******************** ********************************************** ''' # preprocess start = time.time() (standardized_path, standardized_dtype) = preprocess.run( os.path.join(TMP_FOLDER, 'preprocess')) #### Block 1: Detection, Clustering, Postprocess #print ("CLUSTERING DEFAULT LENGTH: ", CONFIG.rec_len, " current set to 300 sec") (fname_templates, fname_spike_train) = initial_block( os.path.join(TMP_FOLDER, 'block_1'), standardized_path, standardized_dtype, run_chunk_sec = CONFIG.clustering_chunk) #run_chunk_sec = [0, 600*20000]) #run_chunk_sec = [0, 300]) print (" inpput to block2: ", fname_templates) #### Block 2: Deconv, Merge, Residuals, Clustering, Postprocess n_iterations = 1 for it in range(n_iterations): (fname_templates, fname_spike_train) = iterative_block( os.path.join(TMP_FOLDER, 'block_{}'.format(it+2)), standardized_path, standardized_dtype, fname_templates, run_chunk_sec = CONFIG.clustering_chunk) ### Pre-final deconv: Deconvolve, Residual, Merge, kill low fr units (fname_templates, fname_spike_train)= pre_final_deconv( os.path.join(TMP_FOLDER, 'pre_final_deconv'), standardized_path, standardized_dtype, fname_templates, run_chunk_sec = CONFIG.clustering_chunk) ### Final deconv: Deconvolve, Residual, soft assignment (fname_templates, fname_spike_train, fname_soft_assignment)= final_deconv( os.path.join(TMP_FOLDER, 'final_deconv'), standardized_path, standardized_dtype, fname_templates, update_templates = True, run_chunk_sec = CONFIG.final_deconv_chunk) ## save the final templates and spike train fname_templates_final = os.path.join( TMP_FOLDER, 'templates.npy') fname_spike_train_final = os.path.join( TMP_FOLDER, 'spike_train.npy') fname_soft_assignment_final = os.path.join( TMP_FOLDER, 'soft_assignment.npy') # tranpose axes templates = np.load(fname_templates).transpose(1,2,0) # align spike time to the beginning spike_train = np.load(fname_spike_train) #spike_train[:,0] -= CONFIG.spike_size//2 soft_assignment = np.load(fname_soft_assignment) np.save(fname_templates_final, templates) np.save(fname_spike_train_final, spike_train) np.save(fname_soft_assignment_final, soft_assignment) total_time = time.time() - start ''' ********************************************** ************** RF / VISUALIZE **************** ********************************************** ''' if calculate_rf: rf.run() if visualize: visual.run() logger.info('Finished YASS execution. Total time: {}'.format( human_readable_time(total_time))) logger.info('Final Templates Location: '+fname_templates_final) logger.info('Final Spike Train Location: '+fname_spike_train_final)