Ejemplo n.º 1
0
def test_templates_returns_expected_results(path_to_threshold_config,
                                            path_to_data_folder):
    np.random.seed(0)

    yass.set_config(path_to_threshold_config)

    (standarized_path, standarized_params, channel_index,
     whiten_filter) = preprocess.run()

    (score, spike_index_clear,
     spike_index_all) = detect.run(standarized_path,
                                   standarized_params,
                                   channel_index,
                                   whiten_filter)

    spike_train_clear, tmp_loc, vbParam = cluster.run(score, spike_index_clear)

    (templates_, spike_train,
     groups, idx_good_templates) = templates.run(spike_train_clear, tmp_loc)

    path_to_templates = path.join(path_to_data_folder,
                                  'output_reference',
                                  'templates.npy')

    ReferenceTesting.assert_array_equal(templates_, path_to_templates)

    clean_tmp()
Ejemplo n.º 2
0
def test_can_preprocess_without_filtering(path_to_threshold_config):
    CONFIG = load_yaml(path_to_threshold_config)
    CONFIG['preprocess']['apply_filter'] = False

    yass.set_config(CONFIG)

    standarized_path, standarized_params, whiten_filter = preprocess.run()
Ejemplo n.º 3
0
def test_can_preprocess_in_parallel(path_to_threshold_config):
    CONFIG = load_yaml(path_to_threshold_config)
    CONFIG['resources']['processes'] = 'max'

    yass.set_config(CONFIG)

    standarized_path, standarized_params, whiten_filter = preprocess.run()
Ejemplo n.º 4
0
def test_can_detect_with_nnet(path_to_nnet_config):
    yass.set_config(path_to_nnet_config)
    standarized_path, standarized_params, whiten_filter = preprocess.run()

    scores, clear, collision = detect.run(standarized_path, standarized_params,
                                          whiten_filter)
    clean_tmp()
Ejemplo n.º 5
0
def test_templates_loads_from_disk_if_files_exist(caplog,
                                                  path_to_threshold_config):

    yass.set_config(path_to_threshold_config)

    (standarized_path, standarized_params, channel_index,
     whiten_filter) = preprocess.run()

    (score, spike_index_clear,
     spike_index_all) = detect.run(standarized_path,
                                   standarized_params,
                                   channel_index,
                                   whiten_filter)

    spike_train_clear, tmp_loc, vbParam = cluster.run(
        score, spike_index_clear)

    # save results
    templates.run(spike_train_clear, tmp_loc, save_results=True)

    assert templates.run.executed

    # next time this should not run and just load from files
    templates.run(spike_train_clear, tmp_loc, save_results=True)

    assert not templates.run.executed
Ejemplo n.º 6
0
def test_nnet_detector_returns_expected_results(path_to_nnet_config,
                                                path_to_output_reference):
    np.random.seed(0)

    yass.set_config(path_to_nnet_config)
    (standarized_path, standarized_params, channel_index,
     whiten_filter) = preprocess.run()

    scores, clear, collision = detect.run(standarized_path, standarized_params,
                                          channel_index, whiten_filter)

    path_to_scores = path.join(path_to_output_reference,
                               'detect_nnet_scores.npy')
    path_to_clear = path.join(path_to_output_reference,
                              'detect_nnet_clear.npy')
    path_to_collision = path.join(path_to_output_reference,
                                  'detect_nnet_collision.npy')

    ReferenceTesting.assert_array_almost_equal(scores,
                                               path_to_scores,
                                               decimal=4)
    ReferenceTesting.assert_array_equal(clear, path_to_clear)
    ReferenceTesting.assert_array_equal(collision, path_to_collision)

    clean_tmp()
Ejemplo n.º 7
0
def test_templates_returns_expected_results(path_to_config,
                                            path_to_output_reference,
                                            make_tmp_folder):

    yass.set_config(path_to_config, make_tmp_folder)

    (standarized_path,
     standarized_params,
     whiten_filter) = preprocess.run()

    (spike_index_clear,
     spike_index_all) = detect.run(standarized_path,
                                   standarized_params,
                                   whiten_filter)

    (spike_train_clear,
     tmp_loc,
     vbParam) = cluster.run(spike_index_clear)

    (templates_, spike_train,
     groups,
     idx_good_templates) = templates.run(spike_train_clear, tmp_loc,
                                         save_results=True)

    path_to_templates = path.join(path_to_output_reference,
                                  'templates.npy')

    ReferenceTesting.assert_array_equal(templates_, path_to_templates)
Ejemplo n.º 8
0
def test_decovnolution(path_to_config):
    yass.set_config('tests/config_nnet.yaml')
    clear_scores, spike_index_clear, spike_index_collision = preprocess.run()
    (spike_train_clear, templates,
     spike_index_collision) = process.run(clear_scores, spike_index_clear,
                                          spike_index_collision)
    deconvolute.run(spike_train_clear, templates, spike_index_collision)
    clean_tmp()
Ejemplo n.º 9
0
def test_can_preprocess_in_parallel(path_to_config, make_tmp_folder):
    CONFIG = load_yaml(path_to_config)
    CONFIG['resources']['processes'] = 'max'

    yass.set_config(CONFIG, make_tmp_folder)

    (standardized_path, standardized_params) = preprocess.run(
        os.path.join(make_tmp_folder, 'preprocess'))
Ejemplo n.º 10
0
def test_can_preprocess_without_filtering(path_to_config, make_tmp_folder):
    CONFIG = load_yaml(path_to_config)
    CONFIG['preprocess'] = dict(apply_filter=False)

    yass.set_config(CONFIG, make_tmp_folder)

    (standardized_path, standardized_params) = preprocess.run(
        os.path.join(make_tmp_folder, 'preprocess'))
Ejemplo n.º 11
0
def test_can_detect_with_threshold(path_to_threshold_config):
    yass.set_config(path_to_threshold_config)
    (standarized_path, standarized_params, channel_index,
     whiten_filter) = preprocess.run()

    scores, clear, collision = detect.run(standarized_path, standarized_params,
                                          channel_index, whiten_filter)
    clean_tmp()
Ejemplo n.º 12
0
def test_can_preprocess_without_filtering(path_to_config,
                                          make_tmp_folder):
    CONFIG = load_yaml(path_to_config)
    CONFIG['preprocess'] = dict(apply_filter=False)

    yass.set_config(CONFIG, make_tmp_folder)

    standarized_path, standarized_params, whiten_filter = preprocess.run()
Ejemplo n.º 13
0
def test_process(path_to_config):
    yass.set_config(path_to_config)

    clear_scores, spike_index_clear, spike_index_collision = preprocess.run()

    (spike_train_clear, templates,
     spike_index_collision) = process.run(clear_scores, spike_index_clear,
                                          spike_index_collision)
Ejemplo n.º 14
0
def test_can_detect_with_nnet(path_to_config, make_tmp_folder):
    yass.set_config(path_to_config, make_tmp_folder)

    (standardized_path, standardized_params) = preprocess.run(
        os.path.join(make_tmp_folder, 'preprocess'))

    detect.run(standardized_path, standardized_params,
               os.path.join(make_tmp_folder, 'detect'))
Ejemplo n.º 15
0
def test_preprocess_saves_result_in_the_right_folder(path_to_config,
                                                     make_tmp_folder):
    yass.set_config(path_to_config, make_tmp_folder)
    standarized_path, standarized_params, _ = preprocess.run()

    expected = Path(make_tmp_folder, 'preprocess', 'standarized.bin')

    assert str(expected) == standarized_path
    assert expected.is_file()
Ejemplo n.º 16
0
def main():
    settings.run()
    start = datetime.now()
    logger = logging.getLogger(__name__)

    logger.info('Preprocessing started at second: %.2f',
                (datetime.now() - start).total_seconds())

    # preprocessing
    (standarized_path, standarized_params, channel_index,
     whiten_filter) = preprocess.run(output_directory='profiling',
                                     if_file_exists='overwrite')

    logger.info('Preprocessing finished and detection started at second: %.2f',
                (datetime.now() - start).total_seconds())

    # detection
    (score, spike_index_clear,
     spike_index_all) = detect.run(standarized_path,
                                   standarized_params,
                                   channel_index,
                                   whiten_filter,
                                   output_directory='profiling',
                                   if_file_exists='overwrite',
                                   save_results=True)

    logger.info('Detection finished and clustering started at second: %.2f',
                (datetime.now() - start).total_seconds())

    # clustering
    spike_train_clear = cluster.run(score, spike_index_clear,
                                    output_directory='profiling',
                                    if_file_exists='overwrite',
                                    save_results=True)

    logger.info('Clustering finished and templates started at second: %.2f',
                (datetime.now() - start).total_seconds())

    # templates
    the_templates = templates.run(spike_train_clear,
                                  output_directory='profiling',
                                  if_file_exists='overwrite',
                                  save_results=True)

    logger.info('templates finished and deconvolution started at second: %.2f',
                (datetime.now() - start).total_seconds())

    # deconvolution
    deconvolute.run(spike_index_all, the_templates,
                    output_directory='profiling')

    logger.info('Deconvolution finished at second: %.2f',
                (datetime.now() - start).total_seconds())
Ejemplo n.º 17
0
def test_cluster_nnet(path_to_config, make_tmp_folder):
    yass.set_config(path_to_config, make_tmp_folder)

    (standarized_path,
     standarized_params,
     whiten_filter) = preprocess.run()

    spike_index_all = detect.run(standarized_path,
                                 standarized_params,
                                 whiten_filter)

    cluster.run(None, spike_index_all)
Ejemplo n.º 18
0
def test_cluster(path_to_threshold_config):
    yass.set_config(path_to_threshold_config)

    (standarized_path, standarized_params, whiten_filter) = preprocess.run()

    (score, spike_index_clear,
     spike_index_all) = detect.run(standarized_path, standarized_params,
                                   whiten_filter)

    cluster.run(score, spike_index_clear)

    clean_tmp()
Ejemplo n.º 19
0
def test_threshold_detector_returns_expected_results(path_to_config_threshold,
                                                     path_to_output_reference,
                                                     make_tmp_folder):
    util.seed(0)

    yass.set_config(path_to_config_threshold, make_tmp_folder)

    (standardized_path, standardized_params,
     whiten_filter) = preprocess.run(output_directory=make_tmp_folder)

    clear = detect.run(standardized_path, standardized_params, whiten_filter)

    path_to_clear = path.join(path_to_output_reference,
                              'detect_threshold_clear.npy')

    ReferenceTesting.assert_array_equal(clear, path_to_clear)
Ejemplo n.º 20
0
def test_decovnolution(path_to_threshold_config):
    yass.set_config('tests/config_nnet.yaml')

    (standarized_path, standarized_params, whiten_filter) = preprocess.run()

    (score, spike_index_clear,
     spike_index_all) = detect.run(standarized_path, standarized_params,
                                   whiten_filter)

    spike_train_clear, tmp_loc, vbParam = cluster.run(score, spike_index_clear)

    (templates_, spike_train, groups,
     idx_good_templates) = templates.run(spike_train_clear, tmp_loc)

    deconvolute.run(spike_index_all, templates_)

    clean_tmp()
Ejemplo n.º 21
0
def test_deconvolution(patch_triage_network, path_to_config, make_tmp_folder):
    yass.set_config(path_to_config, make_tmp_folder)

    (standardized_path, standardized_params) = preprocess.run(
        os.path.join(make_tmp_folder, 'preprocess'))

    spike_index_path = detect.run(standardized_path, standardized_params,
                                  os.path.join(make_tmp_folder, 'detect'))

    fname_templates, fname_spike_train = cluster.run(
        spike_index_path, standardized_path, standardized_params['dtype'],
        os.path.join(make_tmp_folder, 'cluster'), True, True)

    (fname_templates, fname_spike_train, fname_templates_up,
     fname_spike_train_up) = deconvolve.run(
         fname_templates, os.path.join(make_tmp_folder, 'deconv'),
         standardized_path, standardized_params['dtype'])
Ejemplo n.º 22
0
def test_templates(path_to_threshold_config):
    yass.set_config(path_to_threshold_config)

    (standarized_path, standarized_params, channel_index,
     whiten_filter) = preprocess.run()

    (score, spike_index_clear,
     spike_index_all) = detect.run(standarized_path,
                                   standarized_params,
                                   channel_index,
                                   whiten_filter)

    spike_train_clear, tmp_loc, vbParam = cluster.run(
        score, spike_index_clear)

    templates.run(spike_train_clear, tmp_loc)

    clean_tmp()
Ejemplo n.º 23
0
def test_cluster_nnet(path_to_config, make_tmp_folder):
    yass.set_config(path_to_config, make_tmp_folder)

    (standardized_path,
     standardized_params) = preprocess.run(
        os.path.join(make_tmp_folder, 'preprocess'))

    spike_index_path = detect.run(
        standardized_path, standardized_params,
        os.path.join(make_tmp_folder, 'detect'))

    cluster.run(
        spike_index_path,
        standardized_path,
        standardized_params['dtype'],
        os.path.join(make_tmp_folder, 'cluster'),
        True,
        True)
Ejemplo n.º 24
0
def test_deconvolution(patch_triage_network, path_to_config, make_tmp_folder):
    yass.set_config(path_to_config, make_tmp_folder)

    (standarized_path, standarized_params, whiten_filter) = preprocess.run()

    spike_index_all = detect.run(standarized_path, standarized_params,
                                 whiten_filter)

    cluster.run(None, spike_index_all)

    CONFIG = read_config()
    TMP_FOLDER = CONFIG.path_to_output_directory

    path_to_spike_train_cluster = path.join(TMP_FOLDER,
                                            'spike_train_cluster.npy')
    spike_train_cluster = np.load(path_to_spike_train_cluster)
    templates_cluster = np.load(path.join(TMP_FOLDER, 'templates_cluster.npy'))

    spike_train, postdeconv_templates = deconvolve.run(spike_train_cluster,
                                                       templates_cluster)
Ejemplo n.º 25
0
def test_preprocess_returns_expected_results(path_to_threshold_config,
                                             path_to_output_reference):
    yass.set_config(path_to_threshold_config)
    standarized_path, standarized_params, whiten_filter = preprocess.run()

    # load standarized data
    standarized = np.fromfile(standarized_path,
                              dtype=standarized_params['dtype'])

    path_to_standarized = path.join(path_to_output_reference,
                                    'preprocess_standarized.npy')
    path_to_whiten_filter = path.join(path_to_output_reference,
                                      'preprocess_whiten_filter.npy')

    ReferenceTesting.assert_array_almost_equal(standarized,
                                               path_to_standarized)
    ReferenceTesting.assert_array_almost_equal(whiten_filter,
                                               path_to_whiten_filter)

    clean_tmp()
Ejemplo n.º 26
0
def test_cluster_loads_from_disk_if_all_files_exist(caplog,
                                                    path_to_threshold_config):

    yass.set_config(path_to_threshold_config)

    (standarized_path, standarized_params, whiten_filter) = preprocess.run()

    (score, spike_index_clear,
     spike_index_all) = detect.run(standarized_path, standarized_params,
                                   whiten_filter)

    # save results
    cluster.run(score, spike_index_clear, save_results=True)

    assert cluster.run.executed

    # next time this should not run and just load from files
    cluster.run(score, spike_index_clear, save_results=True)

    assert not cluster.run.executed
Ejemplo n.º 27
0
def main():
    settings.run()
    start = datetime.now()
    logger = logging.getLogger(__name__)

    logger.info('Preprocessing started at second: %.2f',
                (datetime.now() - start).total_seconds())

    # preprocessing
    (standarized_path, standarized_params, channel_index,
     whiten_filter) = preprocess.run()

    logger.info('Preprocessing finished and detection started at second: %.2f',
                (datetime.now() - start).total_seconds())

    # detection
    (score, spike_index_clear,
     spike_index_all) = detect.run(standarized_path, standarized_params,
                                   channel_index, whiten_filter)

    logger.info('Detection finished and clustering started at second: %.2f',
                (datetime.now() - start).total_seconds())

    # clustering
    spike_train_clear = cluster.run(score, spike_index_clear)

    logger.info('Clustering finished and templates started at second: %.2f',
                (datetime.now() - start).total_seconds())

    # templates
    the_templates = templates.run(spike_train_clear)

    logger.info('templates finished and deconvolution started at second: %.2f',
                (datetime.now() - start).total_seconds())

    # deconvolution
    deconvolute.run(spike_index_all, the_templates)

    logger.info('Deconvolution finished at second: %.2f',
                (datetime.now() - start).total_seconds())
Ejemplo n.º 28
0
def test_cluster_returns_expected_results(path_to_threshold_config,
                                          path_to_data_folder):
    np.random.seed(0)

    yass.set_config(path_to_threshold_config)

    (standarized_path, standarized_params, whiten_filter) = preprocess.run()

    (score, spike_index_clear,
     spike_index_all) = detect.run(standarized_path, standarized_params,
                                   whiten_filter)

    spike_train, tmp_loc, vbParam = cluster.run(score, spike_index_clear)

    path_to_spike_train = path.join(path_to_data_folder, 'output_reference',
                                    'cluster_spike_train.npy')
    path_to_tmp_loc = path.join(path_to_data_folder, 'output_reference',
                                'cluster_tmp_loc.npy')

    ReferenceTesting.assert_array_equal(spike_train, path_to_spike_train)
    ReferenceTesting.assert_array_equal(tmp_loc, path_to_tmp_loc)

    clean_tmp()
Ejemplo n.º 29
0
def run(config, logger_level='INFO', clean=False, output_dir='tmp/'):
    """Run YASS built-in pipeline

    Parameters
    ----------
    config: str or mapping (such as dictionary)
        Path to YASS configuration file or mapping object

    logger_level: str
        Logger level

    clean: bool, optional
        Delete CONFIG.data.root_folder/output_dir/ before running

    output_dir: str, optional
        Output directory (if relative, it makes it relative to
        CONFIG.data.root_folder) to store the output data, defaults to tmp/.
        If absolute, it leaves it as it is.

    complete: bool, optional
        Generates extra files (needed to generate phy files)

    Notes
    -----
    Running the preprocessor will generate the followiing files in
    CONFIG.data.root_folder/output_directory/:

    * ``config.yaml`` - Copy of the configuration file
    * ``metadata.yaml`` - Experiment metadata
    * ``filtered.bin`` - Filtered recordings (from preprocess)
    * ``filtered.yaml`` - Filtered recordings metadata (from preprocess)
    * ``standardized.bin`` - Standarized recordings (from preprocess)
    * ``standardized.yaml`` - Standarized recordings metadata (from preprocess)
    * ``whitening.npy`` - Whitening filter (from preprocess)


    Returns
    -------
    numpy.ndarray
        Spike train
    """

    # load yass configuration parameters
    CONFIG = Config.from_yaml(config)
    CONFIG._data['cluster']['min_fr'] = 1
    CONFIG._data['clean_up']['mad']['min_var_gap'] = 1.5
    CONFIG._data['clean_up']['mad']['max_violations'] = 5
    CONFIG._data['neuralnetwork']['apply_nn'] = False
    CONFIG._data['detect']['threshold'] = 4

    set_config(CONFIG._data, output_dir)
    CONFIG = read_config()
    TMP_FOLDER = CONFIG.path_to_output_directory

    # remove tmp folder if needed
    if os.path.exists(TMP_FOLDER) and clean:
        shutil.rmtree(TMP_FOLDER)

    # create TMP_FOLDER if needed
    if not os.path.exists(TMP_FOLDER):
        os.makedirs(TMP_FOLDER)

    # load logging config file
    logging_config = load_logging_config_file()
    logging_config['handlers']['file']['filename'] = os.path.join(
        TMP_FOLDER, 'yass.log')
    logging_config['root']['level'] = logger_level

    # configure logging
    logging.config.dictConfig(logging_config)

    # instantiate logger
    logger = logging.getLogger(__name__)

    # print yass version
    logger.info('YASS version: %s', yass.__version__)
    ''' **********************************************
        ******** SET ENVIRONMENT VARIABLES ***********
        **********************************************
    '''
    os.environ["OPENBLAS_NUM_THREADS"] = "1"
    os.environ["MKL_NUM_THREADS"] = "1"
    os.environ["GIO_EXTRA_MODULES"] = "/usr/lib/x86_64-linux-gnu/gio/modules/"

    # TODO: if input spike train is None, run yass with threshold detector
    #if fname_spike_train is None:
    #    logger.info('Not available yet. You must input spike train')
    #    return
    ''' **********************************************
        ************** PREPROCESS ********************
        **********************************************
    '''
    # preprocess
    start = time.time()
    (standardized_path, standardized_dtype) = preprocess.run(
        os.path.join(TMP_FOLDER, 'preprocess'))

    TMP_FOLDER = os.path.join(TMP_FOLDER, 'nn_train')
    if not os.path.exists(TMP_FOLDER):
        os.makedirs(TMP_FOLDER)

    if CONFIG.neuralnetwork.training.input_spike_train_filname is None:

        # run on 10 minutes of data
        rec_len = np.min(
            (CONFIG.rec_len, CONFIG.recordings.sampling_rate * 10 * 60))
        # detect
        logger.info('DETECTION')
        spike_index_path = detect.run(standardized_path,
                                      standardized_dtype,
                                      os.path.join(TMP_FOLDER, 'detect'),
                                      run_chunk_sec=[0, rec_len])

        logger.info('CLUSTERING')

        # cluster
        raw_data = True
        full_run = False
        fname_templates, fname_spike_train = cluster.run(
            os.path.join(TMP_FOLDER, 'cluster'),
            standardized_path,
            standardized_dtype,
            fname_spike_index=spike_index_path,
            raw_data=True,
            full_run=True)

        methods = [
            'off_center', 'low_ptp', 'high_mad', 'duplicate', 'duplicate_l2'
        ]
        fname_templates, fname_spike_train = postprocess.run(
            methods, os.path.join(TMP_FOLDER,
                                  'cluster_post_process'), standardized_path,
            standardized_dtype, fname_templates, fname_spike_train)

    else:
        # if there is an input spike train, use it
        fname_spike_train = CONFIG.neuralnetwork.training.input_spike_train_filname

    # Get training data maker
    DetectTD, DenoTD = augment.run(standardized_path, standardized_dtype,
                                   fname_spike_train,
                                   os.path.join(TMP_FOLDER, 'augment'))

    # Train Detector
    detector = Detect(CONFIG.neuralnetwork.detect.n_filters,
                      CONFIG.spike_size_nn, CONFIG.channel_index).cuda()
    detector.train(os.path.join(TMP_FOLDER, 'detect.pt'), DetectTD)

    # Train Denoiser
    denoiser = Denoise(CONFIG.neuralnetwork.denoise.n_filters,
                       CONFIG.neuralnetwork.denoise.filter_sizes,
                       CONFIG.spike_size_nn).cuda()
    denoiser.train(os.path.join(TMP_FOLDER, 'denoise.pt'), DenoTD)
Ejemplo n.º 30
0
import numpy as np
import logging

import yass
from yass import preprocess
from yass import detect
from yass import cluster

np.random.seed(0)

# configure logging module to get useful information
logging.basicConfig(level=logging.INFO)

# set yass configuration parameters
yass.set_config('config_sample.yaml', 'preprocess-example/')

standarized_path, standarized_params, whiten_filter = preprocess.run()

(spike_index_clear,
 spike_index_all) = detect.run(standarized_path,
                               standarized_params,
                               whiten_filter)

spike_train_clear, tmp_loc, vbParam = cluster.run(spike_index_clear)