Exemplo n.º 1
0
def train(directory, config_train, logger_level):
    """
    Train neural networks, DIRECTORY must be a folder containing the
    output of `yass sort`, CONFIG_TRAIN must be the location of a file with
    the training parameters
    """
    logging.basicConfig(level=getattr(logging, logger_level))
    logger = logging.getLogger(__name__)

    path_to_spike_train = path.join(directory, 'spike_train.npy')
    spike_train = np.load(path_to_spike_train)

    logger.info(
        'Loaded spike train with: {:,} spikes and {:,} different IDs'.format(
            len(spike_train), len(np.unique(spike_train[:, 1]))))

    path_to_config = path.join(directory, 'config.yaml')
    CONFIG = Config.from_yaml(path_to_config)

    CONFIG_TRAIN = load_yaml(config_train)

    train_neural_networks(CONFIG,
                          CONFIG_TRAIN,
                          spike_train,
                          data_folder=directory)
Exemplo n.º 2
0
def set_config(config, output_directory):
    """Set configuration settings

    Parameters
    ----------
    config: str or mapping (such as dictionary)
        Path to a yaml file or mapping object

    output_directory: str of pathlib.Path
        output directory for the project, this makes
        Config.output_directory return
        onfig.data.root_folder / output_directory, which is a common path
        used through the pipeline. If the path is absolute, it is not modified

    Notes
    -----
    This function lets the user configure global settings used
    trhoughout the execution of the pipeline. The variables set here
    determine the behavior of some steps. For that reason once, the values
    are set, they cannot be changed, to avoid changing global configuration
    changes at runtime
    """
    # this variable is accesible in the preprocessa and process blocks to avoid
    # having to pass the configuration over and over, to prevent global state
    # issues, the config cannot be edited after loaded
    global CONFIG

    if isinstance(config, str):
        CONFIG = Config.from_yaml(config, output_directory)
    else:
        CONFIG = Config(config, output_directory)

    logger.debug('CONFIG set to: %s', CONFIG._data)

    return CONFIG
Exemplo n.º 3
0
def set_config(config):
    """Set configuration settings

    Parameters
    ----------
    config: str or mapping (such as dictionary)
        Path to a yaml file or mapping object

    Notes
    -----
    This function lets the user configure global settings used
    trhoughout the execution of the pipeline. The variables set here
    determine the behavior of some steps. For that reason once, the values
    are set, they cannot be changed, to avoid changing global configuration
    changes at runtime
    """
    # this variable is accesible in the preprocessa and process blocks to avoid
    # having to pass the configuration over and over, to prevent global state
    # issues, the config cannot be edited after loaded
    global CONFIG

    if isinstance(config, str):
        CONFIG = Config.from_yaml(config)
    else:
        CONFIG = Config(config)

    logger.debug('CONFIG set to: %s', CONFIG._data)
Exemplo n.º 4
0
def set_config(path_to_config):
    """Set configuration settings

    This function lets the user configure global settings used
    trhoughout the execution of the pipeline. The variables set here
    determine the behavior of some steps. For that reason once, the values
    are set, they cannot be changed, to avoid changing global configuration
    changes at runtime
    """
    # this variable is accesible in the preprocessa and process blocks to avoid
    # having to pass the configuration over and over, to prevent global state
    # issues, the config cannot be edited after loaded
    global CONFIG
    CONFIG = Config.from_yaml(path_to_config)
Exemplo n.º 5
0
def run(config, logger_level='INFO', clean=False, output_dir='tmp/'):
    """Run YASS built-in pipeline

    Parameters
    ----------
    config: str or mapping (such as dictionary)
        Path to YASS configuration file or mapping object

    logger_level: str
        Logger level

    clean: bool, optional
        Delete CONFIG.data.root_folder/output_dir/ before running

    output_dir: str, optional
        Output directory (if relative, it makes it relative to
        CONFIG.data.root_folder) to store the output data, defaults to tmp/.
        If absolute, it leaves it as it is.

    complete: bool, optional
        Generates extra files (needed to generate phy files)

    Notes
    -----
    Running the preprocessor will generate the followiing files in
    CONFIG.data.root_folder/output_directory/:

    * ``config.yaml`` - Copy of the configuration file
    * ``metadata.yaml`` - Experiment metadata
    * ``filtered.bin`` - Filtered recordings (from preprocess)
    * ``filtered.yaml`` - Filtered recordings metadata (from preprocess)
    * ``standardized.bin`` - Standarized recordings (from preprocess)
    * ``standardized.yaml`` - Standarized recordings metadata (from preprocess)
    * ``whitening.npy`` - Whitening filter (from preprocess)


    Returns
    -------
    numpy.ndarray
        Spike train
    """

    # load yass configuration parameters
    CONFIG = Config.from_yaml(config)
    CONFIG._data['cluster']['min_fr'] = 1
    CONFIG._data['clean_up']['mad']['min_var_gap'] = 1.5
    CONFIG._data['clean_up']['mad']['max_violations'] = 5
    CONFIG._data['neuralnetwork']['apply_nn'] = False
    CONFIG._data['detect']['threshold'] = 4

    set_config(CONFIG._data, output_dir)
    CONFIG = read_config()
    TMP_FOLDER = CONFIG.path_to_output_directory

    # remove tmp folder if needed
    if os.path.exists(TMP_FOLDER) and clean:
        shutil.rmtree(TMP_FOLDER)

    # create TMP_FOLDER if needed
    if not os.path.exists(TMP_FOLDER):
        os.makedirs(TMP_FOLDER)

    # load logging config file
    logging_config = load_logging_config_file()
    logging_config['handlers']['file']['filename'] = os.path.join(
        TMP_FOLDER, 'yass.log')
    logging_config['root']['level'] = logger_level

    # configure logging
    logging.config.dictConfig(logging_config)

    # instantiate logger
    logger = logging.getLogger(__name__)

    # print yass version
    logger.info('YASS version: %s', yass.__version__)
    ''' **********************************************
        ******** SET ENVIRONMENT VARIABLES ***********
        **********************************************
    '''
    os.environ["OPENBLAS_NUM_THREADS"] = "1"
    os.environ["MKL_NUM_THREADS"] = "1"
    os.environ["GIO_EXTRA_MODULES"] = "/usr/lib/x86_64-linux-gnu/gio/modules/"

    # TODO: if input spike train is None, run yass with threshold detector
    #if fname_spike_train is None:
    #    logger.info('Not available yet. You must input spike train')
    #    return
    ''' **********************************************
        ************** PREPROCESS ********************
        **********************************************
    '''
    # preprocess
    start = time.time()
    (standardized_path, standardized_dtype) = preprocess.run(
        os.path.join(TMP_FOLDER, 'preprocess'))

    TMP_FOLDER = os.path.join(TMP_FOLDER, 'nn_train')
    if not os.path.exists(TMP_FOLDER):
        os.makedirs(TMP_FOLDER)

    if CONFIG.neuralnetwork.training.input_spike_train_filname is None:

        # run on 10 minutes of data
        rec_len = np.min(
            (CONFIG.rec_len, CONFIG.recordings.sampling_rate * 10 * 60))
        # detect
        logger.info('DETECTION')
        spike_index_path = detect.run(standardized_path,
                                      standardized_dtype,
                                      os.path.join(TMP_FOLDER, 'detect'),
                                      run_chunk_sec=[0, rec_len])

        logger.info('CLUSTERING')

        # cluster
        raw_data = True
        full_run = False
        fname_templates, fname_spike_train = cluster.run(
            os.path.join(TMP_FOLDER, 'cluster'),
            standardized_path,
            standardized_dtype,
            fname_spike_index=spike_index_path,
            raw_data=True,
            full_run=True)

        methods = [
            'off_center', 'low_ptp', 'high_mad', 'duplicate', 'duplicate_l2'
        ]
        fname_templates, fname_spike_train = postprocess.run(
            methods, os.path.join(TMP_FOLDER,
                                  'cluster_post_process'), standardized_path,
            standardized_dtype, fname_templates, fname_spike_train)

    else:
        # if there is an input spike train, use it
        fname_spike_train = CONFIG.neuralnetwork.training.input_spike_train_filname

    # Get training data maker
    DetectTD, DenoTD = augment.run(standardized_path, standardized_dtype,
                                   fname_spike_train,
                                   os.path.join(TMP_FOLDER, 'augment'))

    # Train Detector
    detector = Detect(CONFIG.neuralnetwork.detect.n_filters,
                      CONFIG.spike_size_nn, CONFIG.channel_index).cuda()
    detector.train(os.path.join(TMP_FOLDER, 'detect.pt'), DetectTD)

    # Train Denoiser
    denoiser = Denoise(CONFIG.neuralnetwork.denoise.n_filters,
                       CONFIG.neuralnetwork.denoise.filter_sizes,
                       CONFIG.spike_size_nn).cuda()
    denoiser.train(os.path.join(TMP_FOLDER, 'denoise.pt'), DenoTD)