예제 #1
0
def test_templates_loads_from_disk_if_files_exist(caplog,
                                                  path_to_threshold_config):

    yass.set_config(path_to_threshold_config)

    (standarized_path, standarized_params, channel_index,
     whiten_filter) = preprocess.run()

    (score, spike_index_clear,
     spike_index_all) = detect.run(standarized_path,
                                   standarized_params,
                                   channel_index,
                                   whiten_filter)

    spike_train_clear, tmp_loc, vbParam = cluster.run(
        score, spike_index_clear)

    # save results
    templates.run(spike_train_clear, tmp_loc, save_results=True)

    assert templates.run.executed

    # next time this should not run and just load from files
    templates.run(spike_train_clear, tmp_loc, save_results=True)

    assert not templates.run.executed
예제 #2
0
def test_templates_returns_expected_results(path_to_config,
                                            path_to_output_reference,
                                            make_tmp_folder):

    yass.set_config(path_to_config, make_tmp_folder)

    (standarized_path,
     standarized_params,
     whiten_filter) = preprocess.run()

    (spike_index_clear,
     spike_index_all) = detect.run(standarized_path,
                                   standarized_params,
                                   whiten_filter)

    (spike_train_clear,
     tmp_loc,
     vbParam) = cluster.run(spike_index_clear)

    (templates_, spike_train,
     groups,
     idx_good_templates) = templates.run(spike_train_clear, tmp_loc,
                                         save_results=True)

    path_to_templates = path.join(path_to_output_reference,
                                  'templates.npy')

    ReferenceTesting.assert_array_equal(templates_, path_to_templates)
예제 #3
0
def test_templates_returns_expected_results(path_to_threshold_config,
                                            path_to_data_folder):
    np.random.seed(0)

    yass.set_config(path_to_threshold_config)

    (standarized_path, standarized_params, channel_index,
     whiten_filter) = preprocess.run()

    (score, spike_index_clear,
     spike_index_all) = detect.run(standarized_path,
                                   standarized_params,
                                   channel_index,
                                   whiten_filter)

    spike_train_clear, tmp_loc, vbParam = cluster.run(score, spike_index_clear)

    (templates_, spike_train,
     groups, idx_good_templates) = templates.run(spike_train_clear, tmp_loc)

    path_to_templates = path.join(path_to_data_folder,
                                  'output_reference',
                                  'templates.npy')

    ReferenceTesting.assert_array_equal(templates_, path_to_templates)

    clean_tmp()
예제 #4
0
def test_templates(path_to_threshold_config):
    yass.set_config(path_to_threshold_config)

    (standarized_path, standarized_params, channel_index,
     whiten_filter) = preprocess.run()

    (score, spike_index_clear,
     spike_index_all) = detect.run(standarized_path,
                                   standarized_params,
                                   channel_index,
                                   whiten_filter)

    spike_train_clear, tmp_loc, vbParam = cluster.run(
        score, spike_index_clear)

    templates.run(spike_train_clear, tmp_loc)

    clean_tmp()
예제 #5
0
def main():
    settings.run()
    start = datetime.now()
    logger = logging.getLogger(__name__)

    logger.info('Preprocessing started at second: %.2f',
                (datetime.now() - start).total_seconds())

    # preprocessing
    (standarized_path, standarized_params, channel_index,
     whiten_filter) = preprocess.run(output_directory='profiling',
                                     if_file_exists='overwrite')

    logger.info('Preprocessing finished and detection started at second: %.2f',
                (datetime.now() - start).total_seconds())

    # detection
    (score, spike_index_clear,
     spike_index_all) = detect.run(standarized_path,
                                   standarized_params,
                                   channel_index,
                                   whiten_filter,
                                   output_directory='profiling',
                                   if_file_exists='overwrite',
                                   save_results=True)

    logger.info('Detection finished and clustering started at second: %.2f',
                (datetime.now() - start).total_seconds())

    # clustering
    spike_train_clear = cluster.run(score, spike_index_clear,
                                    output_directory='profiling',
                                    if_file_exists='overwrite',
                                    save_results=True)

    logger.info('Clustering finished and templates started at second: %.2f',
                (datetime.now() - start).total_seconds())

    # templates
    the_templates = templates.run(spike_train_clear,
                                  output_directory='profiling',
                                  if_file_exists='overwrite',
                                  save_results=True)

    logger.info('templates finished and deconvolution started at second: %.2f',
                (datetime.now() - start).total_seconds())

    # deconvolution
    deconvolute.run(spike_index_all, the_templates,
                    output_directory='profiling')

    logger.info('Deconvolution finished at second: %.2f',
                (datetime.now() - start).total_seconds())
예제 #6
0
def test_decovnolution(path_to_threshold_config):
    yass.set_config('tests/config_nnet.yaml')

    (standarized_path, standarized_params, whiten_filter) = preprocess.run()

    (score, spike_index_clear,
     spike_index_all) = detect.run(standarized_path, standarized_params,
                                   whiten_filter)

    spike_train_clear, tmp_loc, vbParam = cluster.run(score, spike_index_clear)

    (templates_, spike_train, groups,
     idx_good_templates) = templates.run(spike_train_clear, tmp_loc)

    deconvolute.run(spike_index_all, templates_)

    clean_tmp()
예제 #7
0
def main():
    settings.run()
    start = datetime.now()
    logger = logging.getLogger(__name__)

    logger.info('Preprocessing started at second: %.2f',
                (datetime.now() - start).total_seconds())

    # preprocessing
    (standarized_path, standarized_params, channel_index,
     whiten_filter) = preprocess.run()

    logger.info('Preprocessing finished and detection started at second: %.2f',
                (datetime.now() - start).total_seconds())

    # detection
    (score, spike_index_clear,
     spike_index_all) = detect.run(standarized_path, standarized_params,
                                   channel_index, whiten_filter)

    logger.info('Detection finished and clustering started at second: %.2f',
                (datetime.now() - start).total_seconds())

    # clustering
    spike_train_clear = cluster.run(score, spike_index_clear)

    logger.info('Clustering finished and templates started at second: %.2f',
                (datetime.now() - start).total_seconds())

    # templates
    the_templates = templates.run(spike_train_clear)

    logger.info('templates finished and deconvolution started at second: %.2f',
                (datetime.now() - start).total_seconds())

    # deconvolution
    deconvolute.run(spike_index_all, the_templates)

    logger.info('Deconvolution finished at second: %.2f',
                (datetime.now() - start).total_seconds())
예제 #8
0
파일: templates.py 프로젝트: Nomow/yass
import numpy as np
import logging

import yass
from yass import preprocess
from yass import detect
from yass import cluster
from yass import templates

np.random.seed(0)

# configure logging module to get useful information
logging.basicConfig(level=logging.INFO)

# set yass configuration parameters
yass.set_config('config_sample.yaml', 'templates-example')

standarized_path, standarized_params, whiten_filter = preprocess.run()

(spike_index_clear, spike_index_all) = detect.run(standarized_path,
                                                  standarized_params,
                                                  whiten_filter)

spike_train_clear, tmp_loc, vbParam = cluster.run(spike_index_clear)

(templates_, spike_train, groups,
 idx_good_templates) = templates.run(spike_train_clear, tmp_loc)
예제 #9
0
파일: pipeline.py 프로젝트: kathefter/yass
def run(config, logger_level='INFO', clean=False, output_dir='tmp/',
        complete=False, set_zero_seed=False):
    """Run YASS built-in pipeline

    Parameters
    ----------
    config: str or mapping (such as dictionary)
        Path to YASS configuration file or mapping object

    logger_level: str
        Logger level

    clean: bool, optional
        Delete CONFIG.data.root_folder/output_dir/ before running

    output_dir: str, optional
        Output directory (relative to CONFIG.data.root_folder to store the
        output data, defaults to tmp/

    complete: bool, optional
        Generates extra files (needed to generate phy files)

    Notes
    -----
    Running the preprocessor will generate the followiing files in
    CONFIG.data.root_folder/output_directory/:

    * ``config.yaml`` - Copy of the configuration file
    * ``metadata.yaml`` - Experiment metadata
    * ``filtered.bin`` - Filtered recordings (from preprocess)
    * ``filtered.yaml`` - Filtered recordings metadata (from preprocess)
    * ``standarized.bin`` - Standarized recordings (from preprocess)
    * ``standarized.yaml`` - Standarized recordings metadata (from preprocess)
    * ``whitening.npy`` - Whitening filter (from preprocess)


    Returns
    -------
    numpy.ndarray
        Spike train
    """
    # load yass configuration parameters
    set_config(config)
    CONFIG = read_config()
    ROOT_FOLDER = CONFIG.data.root_folder
    TMP_FOLDER = path.join(ROOT_FOLDER, output_dir)

    # remove tmp folder if needed
    if os.path.exists(TMP_FOLDER) and clean:
        shutil.rmtree(TMP_FOLDER)

    # create TMP_FOLDER if needed
    if not os.path.exists(TMP_FOLDER):
        os.makedirs(TMP_FOLDER)

    # load logging config file
    logging_config = load_logging_config_file()
    logging_config['handlers']['file']['filename'] = path.join(TMP_FOLDER,
                                                               'yass.log')
    logging_config['root']['level'] = logger_level

    # configure logging
    logging.config.dictConfig(logging_config)

    # instantiate logger and start coloredlogs
    logger = logging.getLogger(__name__)
    coloredlogs.install(logger=logger)

    if set_zero_seed:
        logger.warning('Set numpy seed to zero')
        np.random.seed(0)

    # print yass version
    logger.info('YASS version: %s', yass.__version__)

    # preprocess
    start = time.time()
    (standarized_path,
     standarized_params,
     channel_index,
     whiten_filter) = (preprocess
                       .run(output_directory=output_dir,
                            if_file_exists=CONFIG.preprocess.if_file_exists))
    time_preprocess = time.time() - start

    # detect
    start = time.time()
    (score, spike_index_clear,
     spike_index_all) = detect.run(standarized_path,
                                   standarized_params,
                                   channel_index,
                                   whiten_filter,
                                   output_directory=output_dir,
                                   if_file_exists=CONFIG.detect.if_file_exists,
                                   save_results=CONFIG.detect.save_results)
    time_detect = time.time() - start

    # cluster
    start = time.time()
    spike_train_clear, tmp_loc, vbParam = cluster.run(
        score,
        spike_index_clear,
        output_directory=output_dir,
        if_file_exists=CONFIG.cluster.if_file_exists,
        save_results=CONFIG.cluster.save_results)
    time_cluster = time.time() - start

    # get templates
    start = time.time()
    (templates,
     spike_train_clear_after_templates,
     groups,
     idx_good_templates) = get_templates.run(
        spike_train_clear, tmp_loc,
        output_directory=output_dir,
        if_file_exists=CONFIG.templates.if_file_exists,
        save_results=CONFIG.templates.save_results)
    time_templates = time.time() - start

    # run deconvolution
    start = time.time()
    spike_train = deconvolute.run(spike_index_all, templates,
                                  output_directory=output_dir)
    time_deconvolution = time.time() - start

    # save metadata in tmp
    path_to_metadata = path.join(TMP_FOLDER, 'metadata.yaml')
    logging.info('Saving metadata in {}'.format(path_to_metadata))
    save_metadata(path_to_metadata)

    # save config.yaml copy in tmp/
    path_to_config_copy = path.join(TMP_FOLDER, 'config.yaml')

    if isinstance(config, Mapping):
        with open(path_to_config_copy, 'w') as f:
            yaml.dump(config, f, default_flow_style=False)
    else:
        shutil.copy2(config, path_to_config_copy)

    logging.info('Saving copy of config: {} in {}'.format(config,
                                                          path_to_config_copy))

    # TODO: complete flag saves other files needed for integrating phy
    # with yass, the integration hasn't been completed yet
    # this part loads waveforms for all spikes in the spike train and scores
    # them, this data is needed to later generate phy files
    if complete:
        STANDARIZED_PATH = path.join(TMP_FOLDER, 'standarized.bin')
        PARAMS = load_yaml(path.join(TMP_FOLDER, 'standarized.yaml'))

        # load waveforms for all spikes in the spike train
        logger.info('Loading waveforms from all spikes in the spike train...')
        explorer = RecordingExplorer(STANDARIZED_PATH,
                                     spike_size=CONFIG.spike_size,
                                     dtype=PARAMS['dtype'],
                                     n_channels=PARAMS['n_channels'],
                                     data_order=PARAMS['data_order'])
        waveforms = explorer.read_waveforms(spike_train[:, 0])

        path_to_waveforms = path.join(TMP_FOLDER, 'spike_train_waveforms.npy')
        np.save(path_to_waveforms, waveforms)
        logger.info('Saved all waveforms from the spike train in {}...'
                    .format(path_to_waveforms))

        # score all waveforms
        logger.info('Scoring waveforms from all spikes in the spike train...')
        path_to_rotation = path.join(TMP_FOLDER, 'rotation.npy')
        rotation = np.load(path_to_rotation)

        main_channels = explorer.main_channel_for_waveforms(waveforms)
        path_to_main_channels = path.join(TMP_FOLDER,
                                          'waveforms_main_channel.npy')
        np.save(path_to_main_channels, main_channels)
        logger.info('Saved all waveforms main channels in {}...'
                    .format(path_to_waveforms))

        waveforms_score = dim_red.score(waveforms, rotation, main_channels,
                                        CONFIG.neigh_channels, CONFIG.geom)
        path_to_waveforms_score = path.join(TMP_FOLDER, 'waveforms_score.npy')
        np.save(path_to_waveforms_score, waveforms_score)
        logger.info('Saved all scores in {}...'.format(path_to_waveforms))

        # score templates
        # TODO: templates should be returned in the right shape to avoid .T
        templates_ = templates.T
        main_channels_tmpls = explorer.main_channel_for_waveforms(templates_)
        path_to_templates_main_c = path.join(TMP_FOLDER,
                                             'templates_main_channel.npy')
        np.save(path_to_templates_main_c, main_channels_tmpls)
        logger.info('Saved all templates main channels in {}...'
                    .format(path_to_templates_main_c))

        templates_score = dim_red.score(templates_, rotation,
                                        main_channels_tmpls,
                                        CONFIG.neigh_channels, CONFIG.geom)
        path_to_templates_score = path.join(TMP_FOLDER, 'templates_score.npy')
        np.save(path_to_templates_score, templates_score)
        logger.info('Saved all templates scores in {}...'
                    .format(path_to_waveforms))

    logger.info('Finished YASS execution. Timing summary:')
    total = (time_preprocess + time_detect + time_cluster + time_templates
             + time_deconvolution)
    logger.info('\t Preprocess: %s (%.2f %%)',
                human_readable_time(time_preprocess),
                time_preprocess/total*100)
    logger.info('\t Detection: %s (%.2f %%)',
                human_readable_time(time_detect),
                time_detect/total*100)
    logger.info('\t Clustering: %s (%.2f %%)',
                human_readable_time(time_cluster),
                time_cluster/total*100)
    logger.info('\t Templates: %s (%.2f %%)',
                human_readable_time(time_templates),
                time_templates/total*100)
    logger.info('\t Deconvolution: %s (%.2f %%)',
                human_readable_time(time_deconvolution),
                time_deconvolution/total*100)

    return spike_train
예제 #10
0
def test_nn_output(path_to_tests):
    """Test that pipeline using threshold detector returns the same results
    """
    logger = logging.getLogger(__name__)

    yass.set_config(path.join(path_to_tests, 'config_nn_49.yaml'))

    CONFIG = read_config()
    TMP = Path(CONFIG.data.root_folder, 'tmp')

    logger.info('Removing %s', TMP)
    shutil.rmtree(str(TMP))

    PATH_TO_REF = '/home/Edu/data/nnet'

    np.random.seed(0)

    # run preprocess
    (standarized_path, standarized_params, whiten_filter) = preprocess.run()

    # load preprocess output
    path_to_standarized = path.join(PATH_TO_REF, 'preprocess',
                                    'standarized.bin')
    path_to_whitening = path.join(PATH_TO_REF, 'preprocess', 'whitening.npy')

    whitening_saved = np.load(path_to_whitening)
    standarized_saved = RecordingsReader(path_to_standarized,
                                         loader='array').data
    standarized = RecordingsReader(standarized_path, loader='array').data

    # test preprocess
    np.testing.assert_array_equal(whitening_saved, whiten_filter)
    np.testing.assert_array_equal(standarized_saved, standarized)

    # run detect
    (score, spike_index_clear,
     spike_index_all) = detect.run(standarized_path, standarized_params,
                                   whiten_filter)
    # load detect output
    path_to_scores = path.join(PATH_TO_REF, 'detect', 'scores_clear.npy')
    path_to_spike_index_clear = path.join(PATH_TO_REF, 'detect',
                                          'spike_index_clear.npy')
    path_to_spike_index_all = path.join(PATH_TO_REF, 'detect',
                                        'spike_index_all.npy')

    scores_saved = np.load(path_to_scores)
    spike_index_clear_saved = np.load(path_to_spike_index_clear)
    spike_index_all_saved = np.load(path_to_spike_index_all)

    # test detect output
    np.testing.assert_array_equal(scores_saved, score)
    np.testing.assert_array_equal(spike_index_clear_saved, spike_index_clear)
    np.testing.assert_array_equal(spike_index_all_saved, spike_index_all)

    # run cluster
    (spike_train_clear, tmp_loc,
     vbParam) = cluster.run(score, spike_index_clear)

    # load cluster output
    path_to_spike_train_cluster = path.join(PATH_TO_REF, 'cluster',
                                            'spike_train_cluster.npy')
    spike_train_cluster_saved = np.load(path_to_spike_train_cluster)

    # test cluster
    #np.testing.assert_array_equal(spike_train_cluster_saved, spike_train_clear)

    # run templates
    (templates_, spike_train, groups,
     idx_good_templates) = templates.run(spike_train_clear,
                                         tmp_loc,
                                         save_results=True)

    # load templates output
    path_to_templates = path.join(PATH_TO_REF, 'templates', 'templates.npy')
    templates_saved = np.load(path_to_templates)

    # test templates
    np.testing.assert_array_almost_equal(templates_saved,
                                         templates_,
                                         decimal=4)

    # run deconvolution
    spike_train = deconvolute.run(spike_index_all, templates_)

    # load deconvolution output
    path_to_spike_train = path.join(PATH_TO_REF, 'spike_train.npy')
    spike_train_saved = np.load(path_to_spike_train)

    # test deconvolution
    np.testing.assert_array_equal(spike_train_saved, spike_train)