Example #1
0
def run_threshold(standardized_path,
                  standardized_params,
                  whiten_filter,
                  output_directory,
                  if_file_exists,
                  save_results,
                  temporal_features=3,
                  std_factor=4):
    """Run threshold detector and dimensionality reduction using PCA
    Returns
    -------
    scores
      Scores for all spikes
    spike_index_clear
      Spike indexes for clear spikes
    spike_index_all
      Spike indexes for all spikes
    """
    logger = logging.getLogger(__name__)

    logger.debug('Running threshold detector...')

    CONFIG = read_config()

    folder = Path(CONFIG.path_to_output_directory, 'detect')
    folder.mkdir(exist_ok=True)

    TMP_FOLDER = str(folder)

    # files that will be saved if enable by the if_file_exists option
    filename_index_clear = 'spike_index_clear.npy'
    filename_index_clear_pca = 'spike_index_clear_pca.npy'
    filename_scores_clear = 'scores_clear.npy'
    filename_spike_index_all = 'spike_index_all.npy'
    filename_rotation = 'rotation.npy'

    ###################
    # Spike detection #
    ###################

    # run threshold detection, save clear indexes in TMP/filename_index_clear
    clear = threshold(standardized_path,
                      standardized_params['dtype'],
                      standardized_params['n_channels'],
                      standardized_params['data_order'],
                      CONFIG.resources.max_memory,
                      CONFIG.neigh_channels,
                      CONFIG.spike_size,
                      CONFIG.spike_size + CONFIG.templates.max_shift,
                      std_factor,
                      TMP_FOLDER,
                      spike_index_clear_filename=filename_index_clear,
                      if_file_exists=if_file_exists)

    #######
    # PCA #
    #######

    # run PCA, save rotation matrix and pca scores under TMP_FOLDER
    # TODO: remove clear as input for PCA and create an independent function
    pca_scores, clear, _ = pca(standardized_path, standardized_params['dtype'],
                               standardized_params['n_channels'],
                               standardized_params['data_order'], clear,
                               CONFIG.spike_size, temporal_features,
                               CONFIG.neigh_channels, CONFIG.channel_index,
                               CONFIG.resources.max_memory, TMP_FOLDER,
                               'scores_pca.npy', filename_rotation,
                               filename_index_clear_pca, if_file_exists)

    #################
    # Whiten scores #
    #################

    # apply whitening to scores
    scores = whiten.score(pca_scores, clear[:, 1], whiten_filter)

    if save_results:
        # save spike_index_all (same as spike_index_clear for threshold
        # detector)
        path_to_spike_index_all = os.path.join(TMP_FOLDER,
                                               filename_spike_index_all)
        save_numpy_object(clear,
                          path_to_spike_index_all,
                          if_file_exists,
                          name='Spike index all')

    # FIXME: always saving scores since they are loaded by the clustering
    # step, we need to find a better way to do this, since the current
    # clustering code is going away soon this is a tmp solution
    # saves scores
    # saves whiten scores
    path_to_scores = os.path.join(TMP_FOLDER, filename_scores_clear)
    save_numpy_object(scores, path_to_scores, if_file_exists, name='scores')

    return clear  #, np.copy(clear)
Example #2
0
File: detect.py Project: Nomow/yass
def threshold(path_to_data,
              dtype,
              n_channels,
              data_order,
              max_memory,
              neighbors,
              spike_size,
              minimum_half_waveform_size,
              threshold,
              output_path=None,
              spike_index_clear_filename='spike_index_clear.npy',
              if_file_exists='skip'):
    """Threshold spike detection in batches

    Parameters
    ----------
    path_to_data: str
        Path to recordings in binary format

    dtype: str
        Recordings dtype

    n_channels: int
        Number of channels in the recordings

    data_order: str
        Recordings order, one of ('channels', 'samples'). In a dataset with k
        observations per channel and j channels: 'channels' means first k
        contiguous observations come from channel 0, then channel 1, and so
        on. 'sample' means first j contiguous data are the first observations
        from all channels, then the second observations from all channels and
        so on

    max_memory:
        Max memory to use in each batch (e.g. 100MB, 1GB)

    neighbors_matrix: numpy.ndarray (n_channels, n_channels)
        Boolean numpy 2-D array where a i, j entry is True if i is considered
        neighbor of j

    spike_size: int
        Spike size

    minimum_half_waveform_size: int
        This is used to remove spikes that are either at the beginning or end
        of the recordings and whose location does not allow to draw a
        wavefor of size at least 2 * minimum_half_waveform_size + 1

    threshold: float
        Threshold used on amplitude for detection

    output_path: str, optional
        Directory to save spike indexes, if None, results won't be stored, but
        only returned by the function

    spike_index_clear_filename: str, optional
        Filename for spike_index_clear, it is used as the filename for the
        file (relative to output_path), if None, results won't be saved, only
        returned

    if_file_exists:
        What to do if there is already a file in save_spike_index_clear
        path. One of 'overwrite', 'abort', 'skip'. If 'overwrite' it replaces
        he file if it exists, if 'abort' if raise a ValueError exception if
        the file exists, if 'skip' it skips the operation if
        save_spike_index_clear and save_spike_index_collision the file exist
        and loads them from disk, if any of the files is missing they are
        computed

    Returns
    -------
    spike_index_clear: numpy.ndarray (n_clear_spikes, 2)
        2D array with indexes for clear spikes, first column contains the
        spike location in the recording and the second the main channel
        (channel whose amplitude is maximum)

    spike_index_collision: numpy.ndarray (0, 2)
        Empty array, collision is not implemented in the threshold detector
    """
    # instatiate batch processor
    bp = BatchProcessor(path_to_data,
                        dtype,
                        n_channels,
                        data_order,
                        max_memory,
                        buffer_size=spike_size)

    # run threshold detector
    spikes = bp.multi_channel_apply(_threshold,
                                    mode='memory',
                                    cleanup_function=fix_indexes,
                                    neighbors=neighbors,
                                    spike_size=spike_size,
                                    threshold=threshold)

    # no collision detection implemented, all spikes are marked as clear
    spike_index_clear = np.vstack(spikes)

    # remove spikes whose location won't let us draw a complete waveform
    logger.info('Removing clear indexes outside the allowed range to '
                'draw a complete waveform...')
    spike_index_clear, _ = (remove_incomplete_waveforms(
        spike_index_clear, minimum_half_waveform_size, bp.reader.observations))

    if output_path and spike_index_clear_filename:
        path = os.path.join(output_path, spike_index_clear_filename)
        save_numpy_object(spike_index_clear,
                          path,
                          if_file_exists=if_file_exists,
                          name='Spike index clear')

    return spike_index_clear
Example #3
0
def matrix(path_to_data,
           dtype,
           n_channels,
           data_order,
           channel_index,
           spike_size,
           max_memory,
           output_path,
           output_filename='whitening.npy',
           if_file_exists='skip'):
    """Compute whitening filter using the first batch of the data

    Parameters
    ----------
    path_to_data: str
        Path to recordings in binary format

    dtype: str
        Recordings dtype

    n_channels: int
        Number of channels in the recordings

    data_order: str
        Recordings order, one of ('channels', 'samples'). In a dataset with k
        observations per channel and j channels: 'channels' means first k
        contiguous observations come from channel 0, then channel 1, and so
        on. 'sample' means first j contiguous data are the first observations
        from all channels, then the second observations from all channels and
        so on

    channel_index: np.array
        A matrix of size [n_channels, n_nieghbors], showing neighboring channel
        information

    spike_size: int
        Spike size

    max_memory: str
        Max memory to use in each batch (e.g. 100MB, 1GB)

    output_path: str
        Where to store the whitenint gilter

    output_filename: str, optional
        Filename for the output data, defaults to whitening.npy

    if_file_exists: str, optional
        One of 'overwrite', 'abort', 'skip'. If 'overwrite' it replaces the
        whitening filter if it exists, if 'abort' if raise a ValueError
        exception if the file exists, if 'skip' if skips the operation if the
        file exists

    Returns
    -------
    standarized_path: str
        Path to standarized recordings

    standarized_params: dict
        A dictionary with the parameters for the standarized recordings
        (dtype, n_channels, data_order)
    """
    logger = logging.getLogger(__name__)

    # compute Q (using the first batchfor whitening
    logger.info('Computing whitening matrix...')

    bp = BatchProcessor(path_to_data, dtype, n_channels, data_order,
                        max_memory)

    batches = bp.multi_channel()
    first_batch = next(batches)
    whiten_filter = _matrix(first_batch, channel_index, spike_size)

    path_to_whitening_matrix = Path(output_path, output_filename)
    save_numpy_object(whiten_filter,
                      path_to_whitening_matrix,
                      if_file_exists='overwrite',
                      name='whitening filter')

    return whiten_filter
def pca(path_to_data,
        dtype,
        n_channels,
        data_order,
        spike_index,
        spike_size,
        temporal_features,
        neighbors_matrix,
        channel_index,
        max_memory,
        output_path=None,
        scores_filename='scores.npy',
        rotation_matrix_filename='rotation.npy',
        spike_index_clear_filename='spike_index_clear_pca.npy',
        if_file_exists='skip'):
    """Apply PCA in batches

    Parameters
    ----------
    path_to_data: str
        Path to recordings in binary format

    dtype: str
        Recordings dtype

    n_channels: int
        Number of channels in the recordings

    data_order: str
        Recordings order, one of ('channels', 'samples'). In a dataset with k
        observations per channel and j channels: 'channels' means first k
        contiguous observations come from channel 0, then channel 1, and so
        on. 'sample' means first j contiguous data are the first observations
        from all channels, then the second observations from all channels and
        so on

    spike_index: numpy.ndarray
        A 2D numpy array, first column is spike time, second column is main
        channel (the channel where spike has the biggest amplitude)

    spike_size: int
        Spike size

    temporal_features: numpy.ndarray
        Number of output features

    neighbors_matrix: numpy.ndarray (n_channels, n_channels)
        Boolean numpy 2-D array where a i, j entry is True if i is considered
        neighbor of j

    channel_index: np.array (n_channels, n_neigh)
        Each row indexes its neighboring channels.
        For example, channel_index[c] is the index of
        neighboring channels (including itself)
        If any value is equal to n_channels, it is nothing but
        a space holder in a case that a channel has less than
        n_neigh neighboring channels


    max_memory:
        Max memory to use in each batch (e.g. 100MB, 1GB)

    output_path: str, optional
        Directory to store the scores and rotation matrix, if None, previous
        results on disk are ignored, operations are computed and results
        aren't saved to disk

    scores_filename: str, optional
        File name for rotation matrix if False, does not save data

    rotation_matrix_filename: str, optional
        File name for scores if False, does not save data

    spike_index_clear_filename: str, optional
        File name for spike index clear

    if_file_exists:
        What to do if there is already a file in the rotation matrix and/or
        scores location. One of 'overwrite', 'abort', 'skip'. If 'overwrite'
        it replaces the file if it exists, if 'abort' if raise a ValueError
        exception if the file exists, if 'skip' if skips the operation if the
        file exists

    Returns
    -------
    scores: numpy.ndarray
        Numpy 3D array  of size (n_waveforms, n_reduced_features,
        n_neighboring_channels) Scores for every waveform, second dimension in
        the array is reduced from n_temporal_features to n_reduced_features,
        third dimension depends on the number of  neighboring channels

    rotation_matrix: numpy.ndarray
        3D array (window_size, n_features, n_channels)
    """

    ###########################
    # compute rotation matrix #
    ###########################

    bp = BatchProcessor(path_to_data,
                        dtype,
                        n_channels,
                        data_order,
                        max_memory,
                        buffer_size=spike_size)

    # compute PCA sufficient statistics
    logger.info('Computing PCA sufficient statistics...')
    stats = bp.multi_channel_apply(suff_stat,
                                   mode='memory',
                                   spike_index=spike_index,
                                   spike_size=spike_size)
    suff_stats = reduce(lambda x, y: np.add(x, y), [e[0] for e in stats])
    spikes_per_channel = reduce(lambda x, y: np.add(x, y),
                                [e[1] for e in stats])

    # compute PCA projection matrix
    logger.info('Computing PCA projection matrix...')
    rotation = project(suff_stats, spikes_per_channel, temporal_features,
                       neighbors_matrix)

    #####################################
    # waveform dimensionality reduction #
    #####################################

    logger.info('Reducing spikes dimensionality with PCA matrix...')
    res = bp.multi_channel_apply(score,
                                 mode='memory',
                                 pass_batch_info=True,
                                 rot=rotation,
                                 channel_index=channel_index,
                                 spike_index=spike_index)

    scores = np.concatenate([element[0] for element in res], axis=0)
    spike_index = np.concatenate([element[1] for element in res], axis=0)

    # save scores
    if output_path and scores_filename:
        path_to_score = Path(output_path) / scores_filename
        save_numpy_object(scores,
                          path_to_score,
                          if_file_exists=if_file_exists,
                          name='scores')

    if output_path and spike_index_clear_filename:
        path_to_spike_index = Path(output_path) / spike_index_clear_filename
        save_numpy_object(spike_index,
                          path_to_spike_index,
                          if_file_exists=if_file_exists,
                          name='Spike index PCA')

    if output_path and rotation_matrix_filename:
        path_to_rotation = Path(output_path) / rotation_matrix_filename
        save_numpy_object(rotation,
                          path_to_rotation,
                          if_file_exists=if_file_exists,
                          name='rotation matrix')

    return scores, spike_index, rotation
Example #5
0
def run(if_file_exists='skip'):
    """Preprocess pipeline: filtering, standarization and whitening filter

    This step (optionally) performs filtering on the data, standarizes it
    and computes a whitening filter. Filtering and standarized data are
    processed in chunks and written to disk.

    Parameters
    ----------
    if_file_exists: str, optional
        One of 'overwrite', 'abort', 'skip'. Control de behavior for every
        generated file. If 'overwrite' it replaces the files if any exist,
        if 'abort' it raises a ValueError exception if any file exists,
        if 'skip' it skips the operation (and loads the files) if any of them
        exist

    Returns
    -------
    standarized_path: str
        Path to standarized data binary file

    standarized_params: str
        Path to standarized data parameters

    channel_index: numpy.ndarray
        Channel indexes

    whiten_filter: numpy.ndarray
        Whiten matrix

    Notes
    -----
    Running the preprocessor will generate the followiing files in
    CONFIG.data.root_folder/output_directory/:

    * ``filtered.bin`` - Filtered recordings
    * ``filtered.yaml`` - Filtered recordings metadata
    * ``standarized.bin`` - Standarized recordings
    * ``standarized.yaml`` - Standarized recordings metadata
    * ``whitening.npy`` - Whitening filter

    Everything is run on CPU.

    Examples
    --------

    .. literalinclude:: ../../examples/pipeline/preprocess.py
    """

    logger = logging.getLogger(__name__)

    CONFIG = read_config()
    OUTPUT_DTYPE = CONFIG.preprocess.dtype
    output_directory = os.path.join(CONFIG.path_to_output_directory,
                                    'preprocess')

    logger.info(
        'Output dtype for transformed data will be {}'.format(OUTPUT_DTYPE))

    if not os.path.exists(output_directory):
        logger.info('Creating temporary folder: {}'.format(output_directory))
        os.makedirs(output_directory)
    else:
        logger.info('Temporary folder {} already exists, output will be '
                    'stored there'.format(output_directory))

    params = dict(dtype=CONFIG.recordings.dtype,
                  n_channels=CONFIG.recordings.n_channels,
                  data_order=CONFIG.recordings.order)

    # Generate params:
    standarized_path = os.path.join(output_directory, "standarized.bin")
    standarized_params = params
    standarized_params['dtype'] = 'float32'

    # Check if data already saved to disk and skip:
    if if_file_exists == 'skip':
        if os.path.exists(standarized_path):

            channel_index = make_channel_index(CONFIG.neigh_channels,
                                               CONFIG.geom, 2)

            # Cat: this is redundant, should save to disk/not recompute
            whiten_filter = whiten.matrix(standarized_path,
                                          standarized_params['dtype'],
                                          standarized_params['n_channels'],
                                          standarized_params['data_order'],
                                          channel_index,
                                          CONFIG.spike_size,
                                          CONFIG.resources.max_memory,
                                          output_directory,
                                          output_filename='whitening.npy',
                                          if_file_exists=if_file_exists)

            path_to_channel_index = os.path.join(output_directory,
                                                 "channel_index.npy")

            return str(standarized_path), standarized_params, whiten_filter

    # read config params
    multi_processing = CONFIG.resources.multi_processing
    n_processors = CONFIG.resources.n_processors
    n_sec_chunk = CONFIG.resources.n_sec_chunk
    n_channels = CONFIG.recordings.n_channels
    sampling_rate = CONFIG.recordings.sampling_rate

    # Read filter params
    low_frequency = CONFIG.preprocess.filter.low_pass_freq
    high_factor = CONFIG.preprocess.filter.high_factor
    order = CONFIG.preprocess.filter.order
    buffer_size = 200

    # compute len of recording
    filename_dat = os.path.join(CONFIG.data.root_folder,
                                CONFIG.data.recordings)
    fp = np.memmap(filename_dat, dtype='int16', mode='r')
    fp_len = fp.shape[0]

    # compute batch indexes
    indexes = np.arange(0, fp_len / n_channels, sampling_rate * n_sec_chunk)
    if indexes[-1] != fp_len / n_channels:
        indexes = np.hstack((indexes, fp_len / n_channels))

    idx_list = []
    for k in range(len(indexes) - 1):
        idx_list.append([
            indexes[k], indexes[k + 1], buffer_size,
            indexes[k + 1] - indexes[k] + buffer_size
        ])

    idx_list = np.int64(np.vstack(idx_list))
    proc_indexes = np.arange(len(idx_list))

    logger.info("# of chunks: %i", len(idx_list))

    # Make directory to hold filtered batch files:
    filtered_location = os.path.join(output_directory, "filtered_files")
    logger.info(filtered_location)
    if not os.path.exists(filtered_location):
        os.makedirs(filtered_location)

    # filter and standardize in one step
    if multi_processing:
        parmap.map(filter_standardize,
                   list(zip(idx_list, proc_indexes)),
                   low_frequency,
                   high_factor,
                   order,
                   sampling_rate,
                   buffer_size,
                   filename_dat,
                   n_channels,
                   output_directory,
                   processes=n_processors,
                   pm_pbar=True)
    else:
        for k in range(len(idx_list)):
            filter_standardize([idx_list[k], k], low_frequency, high_factor,
                               order, sampling_rate, buffer_size, filename_dat,
                               n_channels, output_directory)

    # Merge the chunk filtered files and delete the individual chunks
    merge_filtered_files(output_directory)

    # save yaml file with params
    path_to_yaml = standarized_path.replace('.bin', '.yaml')

    params = dict(dtype=standarized_params['dtype'],
                  n_channels=standarized_params['n_channels'],
                  data_order=standarized_params['data_order'])

    with open(path_to_yaml, 'w') as f:
        logger.info('Saving params...')
        yaml.dump(params, f)

    # TODO: this shoulnd't be done here, it would be better to compute
    # this when initializing the config object and then access it from there
    channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom, 2)

    # logger.info CONFIG.resources.max_memory
    # quit()
    # Cat: TODO: need to make this much smaller in size, don't need such
    # large batches
    # OLD CODE: compute whiten filter using batch processor
    # TODO: remove whiten_filter out of output argument

    whiten_filter = whiten.matrix(
        standarized_path,
        standarized_params['dtype'],
        standarized_params['n_channels'],
        standarized_params['data_order'],
        channel_index,
        CONFIG.spike_size,
        # CONFIG.resources.max_memory,
        '50MB',
        output_directory,
        output_filename='whitening.npy',
        if_file_exists=if_file_exists)

    path_to_channel_index = os.path.join(output_directory, 'channel_index.npy')
    save_numpy_object(channel_index,
                      path_to_channel_index,
                      if_file_exists=if_file_exists,
                      name='Channel index')

    return str(standarized_path), standarized_params, whiten_filter
Example #6
0
def run_threshold(standarized_path, standarized_params, whiten_filter,
                  output_directory, if_file_exists, save_results):
    """Run threshold detector and dimensionality reduction using PCA


    Returns
    -------
    scores
      Scores for all spikes

    spike_index_clear
      Spike indexes for clear spikes

    spike_index_all
      Spike indexes for all spikes
    """
    logger = logging.getLogger(__name__)

    logger.debug('Running threshold detector...')

    CONFIG = read_config()

    folder = Path(CONFIG.data.root_folder, output_directory, 'detect')
    folder.mkdir(exist_ok=True)

    # Set TMP_FOLDER to None if not save_results, this will disable
    # saving results in every function below
    TMP_FOLDER = (str(folder) if save_results else None)

    # files that will be saved if enable by the if_file_exists option
    filename_index_clear = 'spike_index_clear.npy'
    filename_index_clear_pca = 'spike_index_clear_pca.npy'
    filename_scores_clear = 'scores_clear.npy'
    filename_spike_index_all = 'spike_index_all.npy'
    filename_rotation = 'rotation.npy'

    ###################
    # Spike detection #
    ###################

    # run threshold detection, save clear indexes in TMP/filename_index_clear
    clear = threshold(standarized_path,
                      standarized_params['dtype'],
                      standarized_params['n_channels'],
                      standarized_params['data_order'],
                      CONFIG.resources.max_memory,
                      CONFIG.neigh_channels,
                      CONFIG.spike_size,
                      CONFIG.spike_size + CONFIG.templates.max_shift,
                      CONFIG.detect.threshold_detector.std_factor,
                      TMP_FOLDER,
                      spike_index_clear_filename=filename_index_clear,
                      if_file_exists=if_file_exists)

    #######
    # PCA #
    #######

    # run PCA, save rotation matrix and pca scores under TMP_FOLDER
    # TODO: remove clear as input for PCA and create an independent function
    pca_scores, clear, _ = pca(
        standarized_path, standarized_params['dtype'],
        standarized_params['n_channels'], standarized_params['data_order'],
        clear, CONFIG.spike_size, CONFIG.detect.temporal_features,
        CONFIG.neigh_channels, CONFIG.channel_index,
        CONFIG.resources.max_memory, TMP_FOLDER, 'scores_pca.npy',
        filename_rotation, filename_index_clear_pca, if_file_exists)

    #################
    # Whiten scores #
    #################

    # apply whitening to scores
    scores_clear = whiten.score(pca_scores, clear[:, 1], whiten_filter)

    # TODO: this shouldn't be here
    # transform scores to location + shape feature space
    if CONFIG.cluster.method == 'location':
        scores = get_locations_features_threshold(scores_clear, clear[:, 1],
                                                  CONFIG.channel_index,
                                                  CONFIG.geom)

    if TMP_FOLDER is not None:
        # saves whiten scores
        path_to_scores = os.path.join(TMP_FOLDER, filename_scores_clear)
        save_numpy_object(scores,
                          path_to_scores,
                          if_file_exists,
                          name='scores')

        # save spike_index_all (same as spike_index_clear for threshold
        # detector)
        path_to_spike_index_all = os.path.join(TMP_FOLDER,
                                               filename_spike_index_all)
        save_numpy_object(clear,
                          path_to_spike_index_all,
                          if_file_exists,
                          name='Spike index all')

    return scores, clear, np.copy(clear)
Example #7
0
def run(output_directory='tmp/', if_file_exists='skip'):
    """Preprocess pipeline: filtering, standarization and whitening filter

    This step (optionally) performs filtering on the data, standarizes it
    and computes a whitening filter. Filtering and standarized data are
    processed in chunks and written to disk.

    Parameters
    ----------
    output_directory: str, optional
        Location to store results, relative to CONFIG.data.root_folder,
        defaults to tmp/. See list of files in Notes section below.

    if_file_exists: str, optional
        One of 'overwrite', 'abort', 'skip'. Control de behavior for every
        generated file. If 'overwrite' it replaces the files if any exist,
        if 'abort' it raises a ValueError exception if any file exists,
        if 'skip' it skips the operation (and loads the files) if any of them
        exist

    Returns
    -------
    standarized_path: str
        Path to standarized data binary file

    standarized_params: str
        Path to standarized data parameters

    channel_index: numpy.ndarray
        Channel indexes

    whiten_filter: numpy.ndarray
        Whiten matrix

    Notes
    -----
    Running the preprocessor will generate the followiing files in
    CONFIG.data.root_folder/output_directory/:

    * ``preprocess/filtered.bin`` - Filtered recordings
    * ``preprocess/filtered.yaml`` - Filtered recordings metadata
    * ``preprocess/standarized.bin`` - Standarized recordings
    * ``preprocess/standarized.yaml`` - Standarized recordings metadata
    * ``preprocess/whitening.npy`` - Whitening filter

    Everything is run on CPU.

    Examples
    --------

    .. literalinclude:: ../../examples/pipeline/preprocess.py
    """

    logger = logging.getLogger(__name__)

    CONFIG = read_config()
    OUTPUT_DTYPE = CONFIG.preprocess.dtype
    PROCESSES = CONFIG.resources.processes

    logger.info(
        'Output dtype for transformed data will be {}'.format(OUTPUT_DTYPE))

    TMP = Path(CONFIG.data.root_folder, output_directory, 'preprocess/')
    TMP.mkdir(parents=True, exist_ok=True)
    TMP = str(TMP)

    path = os.path.join(CONFIG.data.root_folder, CONFIG.data.recordings)
    params = dict(dtype=CONFIG.recordings.dtype,
                  n_channels=CONFIG.recordings.n_channels,
                  data_order=CONFIG.recordings.order)

    # filter and standarize
    if CONFIG.preprocess.apply_filter:
        filter_params = CONFIG.preprocess.filter

        (standarized_path,
         standarized_params) = butterworth(path,
                                           params['dtype'],
                                           params['n_channels'],
                                           params['data_order'],
                                           filter_params.low_pass_freq,
                                           filter_params.high_factor,
                                           filter_params.order,
                                           CONFIG.recordings.sampling_rate,
                                           CONFIG.resources.max_memory,
                                           TMP,
                                           OUTPUT_DTYPE,
                                           standarize=True,
                                           output_filename='standarized.bin',
                                           if_file_exists=if_file_exists,
                                           processes=PROCESSES)
    # just standarize
    else:
        (standarized_path,
         standarized_params) = standarize(path,
                                          params['dtype'],
                                          params['n_channels'],
                                          params['data_order'],
                                          CONFIG.recordings.sampling_rate,
                                          CONFIG.resources.max_memory,
                                          TMP,
                                          OUTPUT_DTYPE,
                                          output_filename='standarized.bin',
                                          if_file_exists=if_file_exists,
                                          processes=PROCESSES)

    # TODO: this shoulnd't be done here, it would be better to compute
    # this when initializing the config object and then access it from there
    channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom, 2)

    # TODO: remove whiten_filter out of output argument
    whiten_filter = whiten.matrix(standarized_path,
                                  standarized_params['dtype'],
                                  standarized_params['n_channels'],
                                  standarized_params['data_order'],
                                  channel_index,
                                  CONFIG.spike_size,
                                  CONFIG.resources.max_memory,
                                  TMP,
                                  output_filename='whitening.npy',
                                  if_file_exists=if_file_exists)

    path_to_channel_index = os.path.join(TMP, 'channel_index.npy')
    save_numpy_object(channel_index,
                      path_to_channel_index,
                      if_file_exists=if_file_exists,
                      name='Channel index')

    return (str(standarized_path), standarized_params, channel_index,
            whiten_filter)
def pca(path_to_data,
        dtype,
        n_channels,
        data_order,
        recordings,
        spike_index,
        spike_size,
        temporal_features,
        neighbors_matrix,
        channel_index,
        max_memory,
        gmm_params,
        output_path=None,
        scores_filename='scores.npy',
        rotation_matrix_filename='rotation.npy',
        spike_index_clear_filename='spike_index_clear_pca.npy',
        if_file_exists='skip'):
    """Apply PCA in batches

    Parameters
    ----------
    path_to_data: str
        Path to recordings in binary format

    dtype: str
        Recordings dtype

    n_channels: int
        Number of channels in the recordings

    data_order: str
        Recordings order, one of ('channels', 'samples'). In a dataset with k
        observations per channel and j channels: 'channels' means first k
        contiguous observations come from channel 0, then channel 1, and so
        on. 'sample' means first j contiguous data are the first observations
        from all channels, then the second observations from all channels and
        so on

    recordings: np.ndarray (n_observations, n_channels)
        Multi-channel recordings

    spike_index: numpy.ndarray
        A 2D numpy array, first column is spike time, second column is main
        channel (the channel where spike has the biggest amplitude)

    spike_size: int
        Spike size

    temporal_features: numpy.ndarray
        Number of output features

    neighbors_matrix: numpy.ndarray (n_channels, n_channels)
        Boolean numpy 2-D array where a i, j entry is True if i is considered
        neighbor of j

    channel_index: np.array (n_channels, n_neigh)
        Each row indexes its neighboring channels.
        For example, channel_index[c] is the index of
        neighboring channels (including itself)
        If any value is equal to n_channels, it is nothing but
        a space holder in a case that a channel has less than
        n_neigh neighboring channels


    max_memory:
        Max memory to use in each batch (e.g. 100MB, 1GB)

    gmm_params:
        Dictionary with the parameters of the Gaussian mixture model
        
    output_path: str, optional
        Directory to store the scores and rotation matrix, if None, previous
        results on disk are ignored, operations are computed and results
        aren't saved to disk

    scores_filename: str, optional
        File name for rotation matrix if False, does not save data

    rotation_matrix_filename: str, optional
        File name for scores if False, does not save data

    spike_index_clear_filename: str, optional
        File name for spike index clear

    if_file_exists:
        What to do if there is already a file in the rotation matrix and/or
        scores location. One of 'overwrite', 'abort', 'skip'. If 'overwrite'
        it replaces the file if it exists, if 'abort' if raise a ValueError
        exception if the file exists, if 'skip' if skips the operation if the
        file exists

    Returns
    -------
    scores: numpy.ndarray
        Numpy 3D array  of size (n_waveforms, n_reduced_features,
        n_neighboring_channels) Scores for every waveform, second dimension in
        the array is reduced from n_temporal_features to n_reduced_features,
        third dimension depends on the number of  neighboring channels

    rotation_matrix: numpy.ndarray
        3D array (window_size, n_features, n_channels)
    """

    ###########################
    # compute rotation matrix #
    ###########################

    bp = BatchProcessor(path_to_data,
                        dtype,
                        n_channels,
                        data_order,
                        max_memory,
                        buffer_size=spike_size)

    # compute WPCA
    WAVE, FEATURE, CH = 0, 1, 2

    logger.info('Preforming WPCA')

    logger.info('Computing Wavelets ...')
    feature = bp.multi_channel_apply(wavedec,
                                     mode='memory',
                                     pass_batch_info=True,
                                     spike_index=spike_index,
                                     spike_size=spike_size,
                                     wvtype='haar')

    features = reduce(lambda x, y: np.concatenate((x, y)),
                      [f for f in feature])

    logger.info('Computing weights..')

    # Weighting the features using metric defined in gmtype
    weights = gmm_weight(features, gmm_params, spike_index)
    wfeatures = features * weights

    n_features = wfeatures.shape[FEATURE]
    wfeatures_lin = np.reshape(
        wfeatures, (wfeatures.shape[WAVE] * n_features, wfeatures.shape[CH]))
    feature_index = np.arange(0, wfeatures.shape[WAVE] * n_features,
                              n_features)

    TMP_FOLDER, _ = os.path.split(path_to_data)
    feature_path = os.path.join(TMP_FOLDER, 'features.bin')
    feature_params = writefile(wfeatures_lin, feature_path)

    bp_feat = BatchProcessor(feature_path,
                             feature_params['dtype'],
                             feature_params['n_channels'],
                             feature_params['data_order'],
                             max_memory,
                             buffer_size=n_features)

    # compute PCA sufficient statistics from extracted features

    logger.info('Computing PCA sufficient statistics...')
    stats = bp_feat.multi_channel_apply(suff_stat_features,
                                        mode='memory',
                                        pass_batch_info=True,
                                        spike_index=spike_index,
                                        spike_size=spike_size,
                                        feature_index=feature_index,
                                        feature_size=n_features)

    suff_stats = reduce(lambda x, y: np.add(x, y), [e[0] for e in stats])
    spikes_per_channel = reduce(lambda x, y: np.add(x, y),
                                [e[1] for e in stats])

    # compute PCA projection matrix
    logger.info('Computing PCA projection matrix...')
    rotation = project(suff_stats, spikes_per_channel, temporal_features,
                       neighbors_matrix)

    #####################################
    # waveform dimensionality reduction #
    #####################################

    logger.info('Reducing spikes dimensionality with PCA matrix...')

    # using a new Batch to read feature file
    res = bp_feat.multi_channel_apply(score_features,
                                      mode='memory',
                                      pass_batch_info=True,
                                      rot=rotation,
                                      channel_index=channel_index,
                                      spike_index=spike_index,
                                      feature_index=feature_index)

    scores = np.concatenate([element[0] for element in res], axis=0)
    spike_index = np.concatenate([element[1] for element in res], axis=0)
    feature_index = np.concatenate([element[2] for element in res], axis=0)

    # renormalizing PC projections to similar unitary variance
    scores = st.zscore(scores, axis=0)

    # save scores
    if output_path and scores_filename:
        path_to_score = Path(output_path) / scores_filename
        save_numpy_object(scores,
                          path_to_score,
                          if_file_exists=if_file_exists,
                          name='scores')

    if output_path and spike_index_clear_filename:
        path_to_spike_index = Path(output_path) / spike_index_clear_filename
        save_numpy_object(spike_index,
                          path_to_spike_index,
                          if_file_exists=if_file_exists,
                          name='Spike index PCA')

    if output_path and rotation_matrix_filename:
        path_to_rotation = Path(output_path) / rotation_matrix_filename
        save_numpy_object(rotation,
                          path_to_rotation,
                          if_file_exists=if_file_exists,
                          name='rotation matrix')

    return scores, spike_index, rotation