示例#1
0
    def multi_channel_apply(self, function, mode, cleanup_function=None,
                            output_path=None, from_time=None, to_time=None,
                            channels='all', if_file_exists='overwrite',
                            cast_dtype=None, **kwargs):
        """
        Apply a function where each batch has observations from more than
        one channel

        Parameters
        ----------
        function: callable
            Function to be applied, must accept a 2D numpy array in 'long'
            format as its first parameter (number of observations, number of
            channels)
        mode: str
            'disk' or 'memory', if 'disk', a binary file is created at the
            beginning of the operation and each partial result is saved
            (ussing numpy.ndarray.tofile function), at the end of the
            operation two files are generated: the binary file and a yaml
            file with some file parameters (useful if you want to later use
            RecordingsReader to read the file). If 'memory', partial results
            are kept in memory and returned as a list
        cleanup_function: callable, optional
            A function to be executed after `function` and before adding the
            partial result to the list of results (if `memory` mode) or to the
            biinary file (if in `disk mode`)
        output_path: str, optional
            Where to save the output, required if 'disk' mode
        force_complete_channel_batch: bool, optional
            If True, every index generated will correspond to all the
            observations in a single channel, hence
            n_batches = n_selected_channels, defaults to True. If True
            from_time and to_time must be None
        from_time: int, optional
            Starting time, defaults to None
        to_time: int, optional
            Ending time, defaults to None
        channels: int, tuple or str, optional
            A tuple with the channel indexes or 'all' to traverse all channels,
            defaults to 'all'
        if_file_exists: str, optional
            One of 'overwrite', 'abort', 'skip'. If 'overwrite' it replaces the
            file if it exists, if 'abort' if raise a ValueError exception if
            the file exists, if 'skip' if skips the operation if the file
            exists. Only valid when mode = 'disk'
        cast_dtype: str, optional
            Output dtype, defaults to None which means no cast is done
        **kwargs
            kwargs to pass to function

        Returns
        -------
        output_path
            Path to output binary file
        params
            Binary file params

        Examples
        --------

        .. literalinclude:: ../../examples/batch/multi_channel_apply.py

        Notes
        -----
        Applying functions will incur in memory overhead, which depends
        on the function implementation, this is an important thing to consider
        if the transformation changes the data's dtype (e.g. converts int16 to
        float64), which means that a chunk of 1MB in int16 will have a size
        of 4MB in float64. Take that into account when setting max_memory

        For performance reasons, outputs data in 'long' format.
        """
        if mode not in ['disk', 'memory']:
            raise ValueError('Mode should be disk or memory, received: {}'
                             .format(mode))

        if mode == 'disk' and output_path is None:
            raise ValueError('output_path is required in "disk" mode')

        if (mode == 'disk' and if_file_exists == 'abort' and
           os.path.exists(output_path)):
            raise ValueError('{} already exists'.format(output_path))

        self.logger.info('Applying function {}...'
                         .format(function_path(function)))

        if (mode == 'disk' and if_file_exists == 'skip' and
           os.path.exists(output_path)):
            # load params...
            path_to_yaml = output_path.replace('.bin', '.yaml')

            if not os.path.exists(path_to_yaml):
                raise ValueError("if_file_exists = 'skip', but {}"
                                 " is missing, aborting..."
                                 .format(path_to_yaml))

            with open(path_to_yaml) as f:
                params = yaml.load(f)

            self.logger.info('{} exists, skiping...'.format(output_path))

            return output_path, params

        if mode == 'disk':
            fn = self._multi_channel_apply_disk

            start = time.time()
            res = fn(function, cleanup_function, output_path, from_time,
                     to_time, channels, cast_dtype, **kwargs)
            elapsed = time.time() - start
            self.logger.info('{} took {}'
                             .format(function_path(function),
                                     human_readable_time(elapsed)))
            return res
        else:
            fn = self._multi_channel_apply_memory

            start = time.time()
            res = fn(function, cleanup_function, from_time, to_time, channels,
                     cast_dtype, **kwargs)
            elapsed = time.time() - start
            self.logger.info('{} took {}'
                             .format(function_path(function),
                                     human_readable_time(elapsed)))
            return res
示例#2
0
文件: pipeline.py 项目: Nomow/yass
def run(config,
        logger_level='INFO',
        clean=False,
        output_dir='tmp/',
        complete=False,
        set_zero_seed=False):
    """Run YASS built-in pipeline

    Parameters
    ----------
    config: str or mapping (such as dictionary)
        Path to YASS configuration file or mapping object

    logger_level: str
        Logger level

    clean: bool, optional
        Delete CONFIG.data.root_folder/output_dir/ before running

    output_dir: str, optional
        Output directory (if relative, it makes it relative to
        CONFIG.data.root_folder) to store the output data, defaults to tmp/.
        If absolute, it leaves it as it is.

    complete: bool, optional
        Generates extra files (needed to generate phy files)

    Notes
    -----
    Running the preprocessor will generate the followiing files in
    CONFIG.data.root_folder/output_directory/:

    * ``config.yaml`` - Copy of the configuration file
    * ``metadata.yaml`` - Experiment metadata
    * ``filtered.bin`` - Filtered recordings (from preprocess)
    * ``filtered.yaml`` - Filtered recordings metadata (from preprocess)
    * ``standardized.bin`` - Standarized recordings (from preprocess)
    * ``standardized.yaml`` - Standarized recordings metadata (from preprocess)
    * ``whitening.npy`` - Whitening filter (from preprocess)


    Returns
    -------
    numpy.ndarray
        Spike train
    """

    # load yass configuration parameters
    set_config(config, output_dir)
    CONFIG = read_config()
    TMP_FOLDER = CONFIG.path_to_output_directory

    # remove tmp folder if needed
    if os.path.exists(TMP_FOLDER) and clean:
        shutil.rmtree(TMP_FOLDER)

    # create TMP_FOLDER if needed
    if not os.path.exists(TMP_FOLDER):
        os.makedirs(TMP_FOLDER)

    # load logging config file
    logging_config = load_logging_config_file()
    logging_config['handlers']['file']['filename'] = path.join(
        TMP_FOLDER, 'yass.log')
    logging_config['root']['level'] = logger_level

    # configure logging
    logging.config.dictConfig(logging_config)

    # instantiate logger
    logger = logging.getLogger(__name__)

    # print yass version
    logger.info('YASS version: %s', yass.__version__)
    ''' **********************************************
        ******** SET ENVIRONMENT VARIABLES ***********
        **********************************************
    '''
    os.environ["OPENBLAS_NUM_THREADS"] = "1"
    os.environ["MKL_NUM_THREADS"] = "1"
    os.environ["GIO_EXTRA_MODULES"] = "/usr/lib/x86_64-linux-gnu/gio/modules/"
    ''' **********************************************
        ************** PREPROCESS ********************
        **********************************************
    '''
    # preprocess
    start = time.time()
    (standardized_path, standardized_params, whiten_filter) = (preprocess.run(
        if_file_exists=CONFIG.preprocess.if_file_exists))

    time_preprocess = time.time() - start
    ''' **********************************************
        ************** DETECT EVENTS *****************
        **********************************************
    '''
    # detect
    # Cat: This code now runs with open tensorflow calls
    start = time.time()
    (spike_index_all) = detect.run(standardized_path,
                                   standardized_params,
                                   whiten_filter,
                                   if_file_exists=CONFIG.detect.if_file_exists,
                                   save_results=CONFIG.detect.save_results)
    spike_index_clear = None
    time_detect = time.time() - start
    ''' **********************************************
        ***************** CLUSTER ********************
        **********************************************
    '''

    # cluster
    start = time.time()
    path_to_spike_train_cluster = path.join(TMP_FOLDER,
                                            'spike_train_cluster.npy')
    if os.path.exists(path_to_spike_train_cluster) == False:
        cluster.run(spike_index_clear, spike_index_all)
    else:
        print("\nClustering completed previously...\n\n")

    spike_train_cluster = np.load(path_to_spike_train_cluster)
    templates_cluster = np.load(
        os.path.join(TMP_FOLDER, 'templates_cluster.npy'))

    time_cluster = time.time() - start
    #print ("Spike train clustered: ", spike_index_cluster.shape, "spike train clear: ",
    #        spike_train_clear.shape, " templates: ", templates.shape)
    ''' **********************************************
        ************** DECONVOLUTION *****************
        **********************************************
    '''

    # run deconvolution
    start = time.time()
    spike_train, postdeconv_templates = deconvolve.run(spike_train_cluster,
                                                       templates_cluster)
    time_deconvolution = time.time() - start

    # save spike train
    path_to_spike_train = path.join(TMP_FOLDER,
                                    'spike_train_post_deconv_post_merge.npy')
    np.save(path_to_spike_train, spike_train)
    logger.info('Spike train saved in: {}'.format(path_to_spike_train))

    # save template
    path_to_templates = path.join(TMP_FOLDER,
                                  'templates_post_deconv_post_merge.npy')
    np.save(path_to_templates, postdeconv_templates)
    logger.info('Templates saved in: {}'.format(path_to_templates))
    ''' **********************************************
        ************** POST PROCESSING****************
        **********************************************
    '''

    # save metadata in tmp
    path_to_metadata = path.join(TMP_FOLDER, 'metadata.yaml')
    logging.info('Saving metadata in {}'.format(path_to_metadata))
    save_metadata(path_to_metadata)

    # save metadata in tmp
    path_to_metadata = path.join(TMP_FOLDER, 'metadata.yaml')
    logging.info('Saving metadata in {}'.format(path_to_metadata))
    save_metadata(path_to_metadata)

    # save config.yaml copy in tmp/
    path_to_config_copy = path.join(TMP_FOLDER, 'config.yaml')

    if isinstance(config, Mapping):
        with open(path_to_config_copy, 'w') as f:
            yaml.dump(config, f, default_flow_style=False)
    else:
        shutil.copy2(config, path_to_config_copy)

    logging.info('Saving copy of config: {} in {}'.format(
        config, path_to_config_copy))

    # this part loads waveforms for all spikes in the spike train and scores
    # them, this data is needed to later generate phy files
    if complete:
        STANDARIZED_PATH = path.join(TMP_FOLDER, 'standardized.bin')
        PARAMS = load_yaml(path.join(TMP_FOLDER, 'standardized.yaml'))

        # load waveforms for all spikes in the spike train
        logger.info('Loading waveforms from all spikes in the spike train...')
        explorer = RecordingExplorer(STANDARIZED_PATH,
                                     spike_size=CONFIG.spike_size,
                                     dtype=PARAMS['dtype'],
                                     n_channels=PARAMS['n_channels'],
                                     data_order=PARAMS['data_order'])
        waveforms = explorer.read_waveforms(spike_train[:, 0])

        path_to_waveforms = path.join(TMP_FOLDER, 'spike_train_waveforms.npy')
        np.save(path_to_waveforms, waveforms)
        logger.info('Saved all waveforms from the spike train in {}...'.format(
            path_to_waveforms))

        # score all waveforms
        logger.info('Scoring waveforms from all spikes in the spike train...')
        path_to_rotation = path.join(TMP_FOLDER, 'rotation.npy')
        rotation = np.load(path_to_rotation)

        main_channels = explorer.main_channel_for_waveforms(waveforms)
        path_to_main_channels = path.join(TMP_FOLDER,
                                          'waveforms_main_channel.npy')
        np.save(path_to_main_channels, main_channels)
        logger.info('Saved all waveforms main channels in {}...'.format(
            path_to_waveforms))

        waveforms_score = dim_red.score(waveforms, rotation, main_channels,
                                        CONFIG.neigh_channels, CONFIG.geom)
        path_to_waveforms_score = path.join(TMP_FOLDER, 'waveforms_score.npy')
        np.save(path_to_waveforms_score, waveforms_score)
        logger.info('Saved all scores in {}...'.format(path_to_waveforms))

        # score templates
        # TODO: templates should be returned in the right shape to avoid .T
        templates_ = templates.T
        main_channels_tmpls = explorer.main_channel_for_waveforms(templates_)
        path_to_templates_main_c = path.join(TMP_FOLDER,
                                             'templates_main_channel.npy')
        np.save(path_to_templates_main_c, main_channels_tmpls)
        logger.info('Saved all templates main channels in {}...'.format(
            path_to_templates_main_c))

        templates_score = dim_red.score(templates_, rotation,
                                        main_channels_tmpls,
                                        CONFIG.neigh_channels, CONFIG.geom)
        path_to_templates_score = path.join(TMP_FOLDER, 'templates_score.npy')
        np.save(path_to_templates_score, templates_score)
        logger.info(
            'Saved all templates scores in {}...'.format(path_to_waveforms))

    logger.info('Finished YASS execution. Timing summary:')
    total = (time_preprocess + time_detect + time_cluster + time_deconvolution)
    logger.info('\t Preprocess: %s (%.2f %%)',
                human_readable_time(time_preprocess),
                time_preprocess / total * 100)
    logger.info('\t Detection: %s (%.2f %%)', human_readable_time(time_detect),
                time_detect / total * 100)
    logger.info('\t Clustering: %s (%.2f %%)',
                human_readable_time(time_cluster), time_cluster / total * 100)
    logger.info('\t Deconvolution: %s (%.2f %%)',
                human_readable_time(time_deconvolution),
                time_deconvolution / total * 100)

    return spike_train
示例#3
0
def run(config,
        logger_level='INFO',
        clean=False,
        output_dir='tmp/',
        complete=False,
        calculate_rf=False,
        visualize=False,
        set_zero_seed=False):
    """Run YASS built-in pipeline

    Parameters
    ----------
    config: str or mapping (such as dictionary)
        Path to YASS configuration file or mapping object

    logger_level: str
        Logger level

    clean: bool, optional
        Delete CONFIG.data.root_folder/output_dir/ before running

    output_dir: str, optional
        Output directory (if relative, it makes it relative to
        CONFIG.data.root_folder) to store the output data, defaults to tmp/.
        If absolute, it leaves it as it is.

    complete: bool, optional
        Generates extra files (needed to generate phy files)

    Notes
    -----
    Running the preprocessor will generate the followiing files in
    CONFIG.data.root_folder/output_directory/:

    * ``config.yaml`` - Copy of the configuration file
    * ``metadata.yaml`` - Experiment metadata
    * ``filtered.bin`` - Filtered recordings (from preprocess)
    * ``filtered.yaml`` - Filtered recordings metadata (from preprocess)
    * ``standardized.bin`` - Standarized recordings (from preprocess)
    * ``standardized.yaml`` - Standarized recordings metadata (from preprocess)
    * ``whitening.npy`` - Whitening filter (from preprocess)


    Returns
    -------
    numpy.ndarray
        Spike train
    """

    # load yass configuration parameters
    set_config(config, output_dir)
    CONFIG = read_config()
    TMP_FOLDER = CONFIG.path_to_output_directory

    # remove tmp folder if needed
    if os.path.exists(TMP_FOLDER) and clean:
        shutil.rmtree(TMP_FOLDER)

    # create TMP_FOLDER if needed
    if not os.path.exists(TMP_FOLDER):
        os.makedirs(TMP_FOLDER)

    # load logging config file
    logging_config = load_logging_config_file()
    logging_config['handlers']['file']['filename'] = os.path.join(
        TMP_FOLDER, 'yass.log')
    logging_config['root']['level'] = logger_level

    # configure logging
    logging.config.dictConfig(logging_config)

    # instantiate logger
    logger = logging.getLogger(__name__)

    # print yass version
    logger.info('YASS version: %s', yass.__version__)
    ''' **********************************************
        ******** SET ENVIRONMENT VARIABLES ***********
        **********************************************
    '''
    os.environ["OPENBLAS_NUM_THREADS"] = "1"
    os.environ["MKL_NUM_THREADS"] = "1"
    os.environ["GIO_EXTRA_MODULES"] = "/usr/lib/x86_64-linux-gnu/gio/modules/"
    ''' **********************************************
        ************** PREPROCESS ********************
        **********************************************
    '''

    # run preprocess and spike detect
    # preprocess
    start = time.time()
    (standardized_path, standardized_params) = preprocess.run(
        os.path.join(TMP_FOLDER, 'preprocess'))

    # run entire pipeline multipole times on multiple chunks:
    run_blocks = []
    for k in range(0, 600, 60):
        run_blocks.append([k, k + 60])

    for blk_ctr, run_block in enumerate(run_blocks):

        BLK_FOLDER = TMP_FOLDER + "/" + str(run_block[0]) + "_" + str(
            run_block[1])
        if not os.path.exists(BLK_FOLDER):
            os.makedirs(BLK_FOLDER)

        print("  RUNNING ON: ", run_block)
        #### Block 1: Detection, Clustering, Postprocess
        (fname_templates, fname_spike_train) = initial_block(
            os.path.join(BLK_FOLDER, 'block_1'),
            standardized_path,
            standardized_params,
            run_chunk_sec=run_block)

        print(" inpput to next block: ", fname_templates)

        ### Block 3: Deconvolve, Residual, Merge
        (fname_templates, fname_spike_train, fname_templates_up,
         fname_spike_train_up, fname_residual, residual_dtype) = single_block(
             os.path.join(BLK_FOLDER, 'final_deconv'), standardized_path,
             standardized_params, fname_templates, run_block)

        ## save the final templates and spike train
        fname_templates_final = os.path.join(BLK_FOLDER, 'templates.npy')
        fname_spike_train_final = os.path.join(BLK_FOLDER, 'spike_train.npy')
        # tranpose axes
        templates = np.load(fname_templates).transpose(1, 2, 0)
        # align spike time to the beginning
        spike_train = np.load(fname_spike_train)
        spike_train[:, 0] -= CONFIG.spike_size // 2
        np.save(fname_templates_final, templates)
        np.save(fname_spike_train_final, spike_train)

        total_time = time.time() - start

    logger.info('Finished YASS execution. Total time: {}'.format(
        human_readable_time(total_time)))
    logger.info('Final Templates Location: ' + fname_templates_final)
    logger.info('Final Spike Train Location: ' + fname_spike_train_final)
示例#4
0
def run(config, logger_level='INFO', clean=False, output_dir='tmp/',
        complete=False, set_zero_seed=False):
    """Run YASS built-in pipeline

    Parameters
    ----------
    config: str or mapping (such as dictionary)
        Path to YASS configuration file or mapping object

    logger_level: str
        Logger level

    clean: bool, optional
        Delete CONFIG.data.root_folder/output_dir/ before running

    output_dir: str, optional
        Output directory (relative to CONFIG.data.root_folder to store the
        output data, defaults to tmp/

    complete: bool, optional
        Generates extra files (needed to generate phy files)

    Notes
    -----
    Running the preprocessor will generate the followiing files in
    CONFIG.data.root_folder/output_directory/:

    * ``config.yaml`` - Copy of the configuration file
    * ``metadata.yaml`` - Experiment metadata
    * ``filtered.bin`` - Filtered recordings (from preprocess)
    * ``filtered.yaml`` - Filtered recordings metadata (from preprocess)
    * ``standarized.bin`` - Standarized recordings (from preprocess)
    * ``standarized.yaml`` - Standarized recordings metadata (from preprocess)
    * ``whitening.npy`` - Whitening filter (from preprocess)


    Returns
    -------
    numpy.ndarray
        Spike train
    """
    # load yass configuration parameters
    set_config(config)
    CONFIG = read_config()
    ROOT_FOLDER = CONFIG.data.root_folder
    TMP_FOLDER = path.join(ROOT_FOLDER, output_dir)

    # remove tmp folder if needed
    if os.path.exists(TMP_FOLDER) and clean:
        shutil.rmtree(TMP_FOLDER)

    # create TMP_FOLDER if needed
    if not os.path.exists(TMP_FOLDER):
        os.makedirs(TMP_FOLDER)

    # load logging config file
    logging_config = load_logging_config_file()
    logging_config['handlers']['file']['filename'] = path.join(TMP_FOLDER,
                                                               'yass.log')
    logging_config['root']['level'] = logger_level

    # configure logging
    logging.config.dictConfig(logging_config)

    # instantiate logger and start coloredlogs
    logger = logging.getLogger(__name__)
    coloredlogs.install(logger=logger)

    if set_zero_seed:
        logger.warning('Set numpy seed to zero')
        np.random.seed(0)

    # print yass version
    logger.info('YASS version: %s', yass.__version__)

    # preprocess
    start = time.time()
    (standarized_path,
     standarized_params,
     channel_index,
     whiten_filter) = (preprocess
                       .run(output_directory=output_dir,
                            if_file_exists=CONFIG.preprocess.if_file_exists))
    time_preprocess = time.time() - start

    # detect
    start = time.time()
    (score, spike_index_clear,
     spike_index_all) = detect.run(standarized_path,
                                   standarized_params,
                                   channel_index,
                                   whiten_filter,
                                   output_directory=output_dir,
                                   if_file_exists=CONFIG.detect.if_file_exists,
                                   save_results=CONFIG.detect.save_results)
    time_detect = time.time() - start

    # cluster
    start = time.time()
    spike_train_clear, tmp_loc, vbParam = cluster.run(
        score,
        spike_index_clear,
        output_directory=output_dir,
        if_file_exists=CONFIG.cluster.if_file_exists,
        save_results=CONFIG.cluster.save_results)
    time_cluster = time.time() - start

    # get templates
    start = time.time()
    (templates,
     spike_train_clear_after_templates,
     groups,
     idx_good_templates) = get_templates.run(
        spike_train_clear, tmp_loc,
        output_directory=output_dir,
        if_file_exists=CONFIG.templates.if_file_exists,
        save_results=CONFIG.templates.save_results)
    time_templates = time.time() - start

    # run deconvolution
    start = time.time()
    spike_train = deconvolute.run(spike_index_all, templates,
                                  output_directory=output_dir)
    time_deconvolution = time.time() - start

    # save metadata in tmp
    path_to_metadata = path.join(TMP_FOLDER, 'metadata.yaml')
    logging.info('Saving metadata in {}'.format(path_to_metadata))
    save_metadata(path_to_metadata)

    # save config.yaml copy in tmp/
    path_to_config_copy = path.join(TMP_FOLDER, 'config.yaml')

    if isinstance(config, Mapping):
        with open(path_to_config_copy, 'w') as f:
            yaml.dump(config, f, default_flow_style=False)
    else:
        shutil.copy2(config, path_to_config_copy)

    logging.info('Saving copy of config: {} in {}'.format(config,
                                                          path_to_config_copy))

    # TODO: complete flag saves other files needed for integrating phy
    # with yass, the integration hasn't been completed yet
    # this part loads waveforms for all spikes in the spike train and scores
    # them, this data is needed to later generate phy files
    if complete:
        STANDARIZED_PATH = path.join(TMP_FOLDER, 'standarized.bin')
        PARAMS = load_yaml(path.join(TMP_FOLDER, 'standarized.yaml'))

        # load waveforms for all spikes in the spike train
        logger.info('Loading waveforms from all spikes in the spike train...')
        explorer = RecordingExplorer(STANDARIZED_PATH,
                                     spike_size=CONFIG.spike_size,
                                     dtype=PARAMS['dtype'],
                                     n_channels=PARAMS['n_channels'],
                                     data_order=PARAMS['data_order'])
        waveforms = explorer.read_waveforms(spike_train[:, 0])

        path_to_waveforms = path.join(TMP_FOLDER, 'spike_train_waveforms.npy')
        np.save(path_to_waveforms, waveforms)
        logger.info('Saved all waveforms from the spike train in {}...'
                    .format(path_to_waveforms))

        # score all waveforms
        logger.info('Scoring waveforms from all spikes in the spike train...')
        path_to_rotation = path.join(TMP_FOLDER, 'rotation.npy')
        rotation = np.load(path_to_rotation)

        main_channels = explorer.main_channel_for_waveforms(waveforms)
        path_to_main_channels = path.join(TMP_FOLDER,
                                          'waveforms_main_channel.npy')
        np.save(path_to_main_channels, main_channels)
        logger.info('Saved all waveforms main channels in {}...'
                    .format(path_to_waveforms))

        waveforms_score = dim_red.score(waveforms, rotation, main_channels,
                                        CONFIG.neigh_channels, CONFIG.geom)
        path_to_waveforms_score = path.join(TMP_FOLDER, 'waveforms_score.npy')
        np.save(path_to_waveforms_score, waveforms_score)
        logger.info('Saved all scores in {}...'.format(path_to_waveforms))

        # score templates
        # TODO: templates should be returned in the right shape to avoid .T
        templates_ = templates.T
        main_channels_tmpls = explorer.main_channel_for_waveforms(templates_)
        path_to_templates_main_c = path.join(TMP_FOLDER,
                                             'templates_main_channel.npy')
        np.save(path_to_templates_main_c, main_channels_tmpls)
        logger.info('Saved all templates main channels in {}...'
                    .format(path_to_templates_main_c))

        templates_score = dim_red.score(templates_, rotation,
                                        main_channels_tmpls,
                                        CONFIG.neigh_channels, CONFIG.geom)
        path_to_templates_score = path.join(TMP_FOLDER, 'templates_score.npy')
        np.save(path_to_templates_score, templates_score)
        logger.info('Saved all templates scores in {}...'
                    .format(path_to_waveforms))

    logger.info('Finished YASS execution. Timing summary:')
    total = (time_preprocess + time_detect + time_cluster + time_templates
             + time_deconvolution)
    logger.info('\t Preprocess: %s (%.2f %%)',
                human_readable_time(time_preprocess),
                time_preprocess/total*100)
    logger.info('\t Detection: %s (%.2f %%)',
                human_readable_time(time_detect),
                time_detect/total*100)
    logger.info('\t Clustering: %s (%.2f %%)',
                human_readable_time(time_cluster),
                time_cluster/total*100)
    logger.info('\t Templates: %s (%.2f %%)',
                human_readable_time(time_templates),
                time_templates/total*100)
    logger.info('\t Deconvolution: %s (%.2f %%)',
                human_readable_time(time_deconvolution),
                time_deconvolution/total*100)

    return spike_train
示例#5
0
文件: batch.py 项目: kathefter/yass
    def multi_channel_apply(self,
                            function,
                            mode,
                            cleanup_function=None,
                            output_path=None,
                            from_time=None,
                            to_time=None,
                            channels='all',
                            if_file_exists='overwrite',
                            cast_dtype=None,
                            pass_batch_info=False,
                            pass_batch_results=False,
                            processes=1,
                            **kwargs):
        """
        Apply a function where each batch has observations from more than
        one channel

        Parameters
        ----------
        function: callable
            Function to be applied, first parameter passed will be a 2D numpy
            array in 'long' shape (number of observations, number of
            channels). If pass_batch_info is True, another two keyword
            parameters will be passed to function: 'idx_local' is the slice
            object with the limits of the data in [observations, channels]
            format (excluding the buffer), 'idx' is the absolute index of
            the data again in [observations, channels] format

        mode: str
            'disk' or 'memory', if 'disk', a binary file is created at the
            beginning of the operation and each partial result is saved
            (ussing numpy.ndarray.tofile function), at the end of the
            operation two files are generated: the binary file and a yaml
            file with some file parameters (useful if you want to later use
            RecordingsReader to read the file). If 'memory', partial results
            are kept in memory and returned as a list

        cleanup_function: callable, optional
            A function to be executed after `function` and before adding the
            partial result to the list of results (if `memory` mode) or to the
            biinary file (if in `disk mode`). `cleanup_function` will be called
            with the following parameters (in that order): result from applying
            `function` to the batch, slice object with the idx where the data
            is located (exludes buffer), slice object with the absolute
            location of the data and buffer size

        output_path: str, optional
            Where to save the output, required if 'disk' mode

        force_complete_channel_batch: bool, optional
            If True, every index generated will correspond to all the
            observations in a single channel, hence
            n_batches = n_selected_channels, defaults to True. If True
            from_time and to_time must be None

        from_time: int, optional
            Starting time, defaults to None

        to_time: int, optional
            Ending time, defaults to None

        channels: int, tuple or str, optional
            A tuple with the channel indexes or 'all' to traverse all channels,
            defaults to 'all'

        if_file_exists: str, optional
            One of 'overwrite', 'abort', 'skip'. If 'overwrite' it replaces the
            file if it exists, if 'abort' if raise a ValueError exception if
            the file exists, if 'skip' if skips the operation if the file
            exists. Only valid when mode = 'disk'

        cast_dtype: str, optional
            Output dtype, defaults to None which means no cast is done

        pass_batch_info: bool, optional
            Whether to call the function with batch info or just call it with
            the batch data (see description in the function) parameter

        pass_batch_results: bool, optional
            Whether to pass results from the previous batch to the next one,
            defaults to False. Only relevant when mode='memory'. If True,
            function will be called with the keyword parameter
            'previous_batch' which contains the computation for the last
            batch, it is set to None in the first batch

        **kwargs
            kwargs to pass to function

        Returns
        -------
        output_path, params (when mode is 'disk')
            Path to output binary file, Binary file params

        list (when mode is 'memory' and pass_batch_results is False)
            List where every element is the result of applying the function
            to one batch. When pass_batch_results is True, it returns the
            output of the function for the last batch

        Examples
        --------

        .. literalinclude:: ../../examples/batch/multi_channel_apply_disk.py

        .. literalinclude:: ../../examples/batch/multi_channel_apply_memory.py

        Notes
        -----
        Applying functions will incur in memory overhead, which depends
        on the function implementation, this is an important thing to consider
        if the transformation changes the data's dtype (e.g. converts int16 to
        float64), which means that a chunk of 1MB in int16 will have a size
        of 4MB in float64. Take that into account when setting max_memory

        For performance reasons, outputs data in 'samples' order.
        """
        if mode not in ['disk', 'memory']:
            raise ValueError(
                'Mode should be disk or memory, received: {}'.format(mode))

        if mode == 'disk' and output_path is None:
            raise ValueError('output_path is required in "disk" mode')

        if (mode == 'disk' and if_file_exists == 'abort'
                and os.path.exists(output_path)):
            raise ValueError('{} already exists'.format(output_path))

        self.logger.info('Applying function {}...'.format(
            function_path(function)))

        if (mode == 'disk' and if_file_exists == 'skip'
                and os.path.exists(output_path)):
            # load params...
            path_to_yaml = output_path.replace('.bin', '.yaml')

            if not os.path.exists(path_to_yaml):
                raise ValueError(
                    "if_file_exists = 'skip', but {}"
                    " is missing, aborting...".format(path_to_yaml))

            with open(path_to_yaml) as f:
                params = yaml.load(f)

            self.logger.info('{} exists, skiping...'.format(output_path))

            return output_path, params

        if mode == 'disk':

            if processes == 1:
                fn = self._multi_channel_apply_disk
            else:
                fn = partial(self._multi_channel_apply_disk_parallel,
                             processes=processes)

            start = time.time()
            res = fn(function, cleanup_function, output_path, from_time,
                     to_time, channels, cast_dtype, pass_batch_info,
                     pass_batch_results, **kwargs)
            elapsed = time.time() - start
            self.logger.info('{} took {}'.format(function_path(function),
                                                 human_readable_time(elapsed)))
            return res
        else:
            fn = self._multi_channel_apply_memory

            start = time.time()
            res = fn(function, cleanup_function, from_time, to_time, channels,
                     cast_dtype, pass_batch_info, pass_batch_results, **kwargs)
            elapsed = time.time() - start
            self.logger.info('{} took {}'.format(function_path(function),
                                                 human_readable_time(elapsed)))
            return res
示例#6
0
def run(config, logger_level='INFO', clean=False, output_dir='tmp/',
        complete=False, calculate_rf=False, visualize=False, set_zero_seed=False):
    """Run YASS built-in pipeline

    Parameters
    ----------
    config: str or mapping (such as dictionary)
        Path to YASS configuration file or mapping object

    logger_level: str
        Logger level

    clean: bool, optional
        Delete CONFIG.data.root_folder/output_dir/ before running

    output_dir: str, optional
        Output directory (if relative, it makes it relative to
        CONFIG.data.root_folder) to store the output data, defaults to tmp/.
        If absolute, it leaves it as it is.

    complete: bool, optional
        Generates extra files (needed to generate phy files)

    Notes
    -----
    Running the preprocessor will generate the followiing files in
    CONFIG.data.root_folder/output_directory/:

    * ``config.yaml`` - Copy of the configuration file
    * ``metadata.yaml`` - Experiment metadata
    * ``filtered.bin`` - Filtered recordings (from preprocess)
    * ``filtered.yaml`` - Filtered recordings metadata (from preprocess)
    * ``standardized.bin`` - Standarized recordings (from preprocess)
    * ``standardized.yaml`` - Standarized recordings metadata (from preprocess)
    * ``whitening.npy`` - Whitening filter (from preprocess)


    Returns
    -------
    numpy.ndarray
        Spike train
    """

    # load yass configuration parameters
    set_config(config, output_dir)
    CONFIG = read_config()
    TMP_FOLDER = CONFIG.path_to_output_directory

    # remove tmp folder if needed
    if os.path.exists(TMP_FOLDER) and clean:
        shutil.rmtree(TMP_FOLDER)

    # create TMP_FOLDER if needed
    if not os.path.exists(TMP_FOLDER):
        os.makedirs(TMP_FOLDER)

    # load logging config file
    logging_config = load_logging_config_file()
    logging_config['handlers']['file']['filename'] = os.path.join(
        TMP_FOLDER,'yass.log')
    logging_config['root']['level'] = logger_level

    # configure logging
    logging.config.dictConfig(logging_config)

    # instantiate logger
    logger = logging.getLogger(__name__)

    # print yass version
    logger.info('YASS version: %s', yass.__version__)

    ''' **********************************************
        ******** SET ENVIRONMENT VARIABLES ***********
        **********************************************
    '''
    os.environ["OPENBLAS_NUM_THREADS"] = "1"
    os.environ["MKL_NUM_THREADS"] = "1"
    os.environ["GIO_EXTRA_MODULES"] = "/usr/lib/x86_64-linux-gnu/gio/modules/"

    ''' **********************************************
        ************** PREPROCESS ********************
        **********************************************
    '''

    # preprocess
    start = time.time()
    (standardized_path,
     standardized_dtype) = preprocess.run(
        os.path.join(TMP_FOLDER, 'preprocess'))

    #### Block 1: Detection, Clustering, Postprocess
    #print ("CLUSTERING DEFAULT LENGTH: ", CONFIG.rec_len, " current set to 300 sec")
    (fname_templates,
     fname_spike_train) = initial_block(
        os.path.join(TMP_FOLDER, 'block_1'),
        standardized_path,
        standardized_dtype,
        run_chunk_sec = CONFIG.clustering_chunk)
        #run_chunk_sec = [0, 600*20000])
        #run_chunk_sec = [0, 300])

    print (" inpput to block2: ", fname_templates)
    
    #### Block 2: Deconv, Merge, Residuals, Clustering, Postprocess
    n_iterations = 1
    for it in range(n_iterations):
        (fname_templates,
         fname_spike_train) = iterative_block(
            os.path.join(TMP_FOLDER, 'block_{}'.format(it+2)),
            standardized_path,
            standardized_dtype,
            fname_templates,
            run_chunk_sec = CONFIG.clustering_chunk)

    ### Pre-final deconv: Deconvolve, Residual, Merge, kill low fr units
    (fname_templates,
     fname_spike_train)= pre_final_deconv(
        os.path.join(TMP_FOLDER, 'pre_final_deconv'),
        standardized_path,
        standardized_dtype,
        fname_templates,
        run_chunk_sec = CONFIG.clustering_chunk)
    
    ### Final deconv: Deconvolve, Residual, soft assignment
    (fname_templates,
     fname_spike_train,
     fname_soft_assignment)= final_deconv(
        os.path.join(TMP_FOLDER, 'final_deconv'),
        standardized_path,
        standardized_dtype,
        fname_templates,
        update_templates = True,
        run_chunk_sec = CONFIG.final_deconv_chunk)

    ## save the final templates and spike train
    fname_templates_final = os.path.join(
        TMP_FOLDER, 'templates.npy')
    fname_spike_train_final = os.path.join(
        TMP_FOLDER, 'spike_train.npy')
    fname_soft_assignment_final = os.path.join(
        TMP_FOLDER, 'soft_assignment.npy')

    # tranpose axes
    templates = np.load(fname_templates).transpose(1,2,0)
    # align spike time to the beginning
    spike_train = np.load(fname_spike_train)
    #spike_train[:,0] -= CONFIG.spike_size//2
    soft_assignment = np.load(fname_soft_assignment)
    np.save(fname_templates_final, templates)
    np.save(fname_spike_train_final, spike_train)
    np.save(fname_soft_assignment_final, soft_assignment)

    total_time = time.time() - start

    ''' **********************************************
        ************** RF / VISUALIZE ****************
        **********************************************
    '''

    if calculate_rf:
        rf.run() 

    if visualize:
        visual.run()
    
    logger.info('Finished YASS execution. Total time: {}'.format(
        human_readable_time(total_time)))
    logger.info('Final Templates Location: '+fname_templates_final)
    logger.info('Final Spike Train Location: '+fname_spike_train_final)