Exemplo n.º 1
0
def get_o_layer(standarized_path,
                standarized_params,
                output_directory='tmp/',
                output_dtype='float32',
                output_filename='o_layer.bin',
                if_file_exists='skip',
                save_partial_results=False):
    """Get the output of NN detector instead of outputting the spike index
    """

    CONFIG = read_config()
    channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom, 1)

    x_tf = tf.placeholder("float", [None, None])

    # load Neural Net's
    detection_fname = CONFIG.detect.neural_network_detector.filename
    detection_th = CONFIG.detect.neural_network_detector.threshold_spike
    NND = NeuralNetDetector(detection_fname)
    o_layer_tf = NND.make_o_layer_tf_tensors(x_tf, channel_index, detection_th)

    bp = BatchProcessor(standarized_path,
                        standarized_params['dtype'],
                        standarized_params['n_channels'],
                        standarized_params['data_format'],
                        CONFIG.resources.max_memory,
                        buffer_size=CONFIG.spike_size)

    TMP = os.path.join(CONFIG.data.root_folder, output_directory)
    _output_path = os.path.join(TMP, output_filename)
    (o_path, o_params) = bp.multi_channel_apply(_get_o_layer,
                                                mode='disk',
                                                cleanup_function=fix_indexes,
                                                output_path=_output_path,
                                                cast_dtype=output_dtype,
                                                x_tf=x_tf,
                                                o_layer_tf=o_layer_tf,
                                                NND=NND)

    return o_path, o_params
Exemplo n.º 2
0
def test_can_pass_information_between_batches(path_to_data):
    bp = BatchProcessor(path_to_data,
                        dtype='int64',
                        n_channels=2,
                        data_order='samples',
                        max_memory='160B')

    def col_sums(data, previous_batch):
        current = np.sum(data, axis=0)

        if previous_batch is None:
            return current
        else:
            return np.array([
                previous_batch[0] + current[0], previous_batch[1] + current[1]
            ])

    res = bp.multi_channel_apply(col_sums,
                                 mode='memory',
                                 channels='all',
                                 pass_batch_results=True)

    assert res[0] == 4950 and res[1] == 4950
Exemplo n.º 3
0
def _threshold_detection(standarized_path, standarized_params, n_observations,
                         output_directory):
    """Run threshold detector and dimensionality reduction using PCA
    """
    logger = logging.getLogger(__name__)

    CONFIG = read_config()
    OUTPUT_DTYPE = CONFIG.preprocess.dtype
    TMP_FOLDER = os.path.join(CONFIG.data.root_folder, output_directory)

    ###############
    # Whiten data #
    ###############

    # compute Q for whitening
    logger.info('Computing whitening matrix...')
    bp = BatchProcessor(standarized_path, standarized_params['dtype'],
                        standarized_params['n_channels'],
                        standarized_params['data_format'],
                        CONFIG.resources.max_memory)
    batches = bp.multi_channel()
    first_batch, _, _ = next(batches)
    Q = whiten.matrix(first_batch, CONFIG.neighChannels, CONFIG.spikeSize)

    path_to_whitening_matrix = os.path.join(TMP_FOLDER, 'whitening.npy')
    np.save(path_to_whitening_matrix, Q)
    logger.info(
        'Saved whitening matrix in {}'.format(path_to_whitening_matrix))

    # apply whitening to every batch
    (whitened_path, whitened_params) = bp.multi_channel_apply(
        np.matmul,
        mode='disk',
        output_path=os.path.join(TMP_FOLDER, 'whitened.bin'),
        if_file_exists='skip',
        cast_dtype=OUTPUT_DTYPE,
        b=Q)

    ###################
    # Spike detection #
    ###################

    path_to_spike_index_clear = os.path.join(TMP_FOLDER,
                                             'spike_index_clear.npy')

    bp = BatchProcessor(standarized_path,
                        standarized_params['dtype'],
                        standarized_params['n_channels'],
                        standarized_params['data_format'],
                        CONFIG.resources.max_memory,
                        buffer_size=0)

    # clear spikes
    if os.path.exists(path_to_spike_index_clear):
        # if it exists, load it...
        logger.info('Found file in {}, loading it...'.format(
            path_to_spike_index_clear))
        spike_index_clear = np.load(path_to_spike_index_clear)
    else:
        # if it doesn't, detect spikes...
        logger.info('Did not find file in {}, finding spikes using threshold'
                    ' detector...'.format(path_to_spike_index_clear))

        # apply threshold detector on standarized data
        spikes = bp.multi_channel_apply(detect.threshold,
                                        mode='memory',
                                        cleanup_function=detect.fix_indexes,
                                        neighbors=CONFIG.neighChannels,
                                        spike_size=CONFIG.spikeSize,
                                        std_factor=CONFIG.stdFactor)
        spike_index_clear = np.vstack(spikes)

        logger.info('Removing clear indexes outside the allowed range to '
                    'draw a complete waveform...')
        spike_index_clear, _ = (detect.remove_incomplete_waveforms(
            spike_index_clear, CONFIG.spikeSize + CONFIG.templatesMaxShift,
            n_observations))

        logger.info('Saving spikes in {}...'.format(path_to_spike_index_clear))
        np.save(path_to_spike_index_clear, spike_index_clear)

    path_to_spike_index_collision = os.path.join(TMP_FOLDER,
                                                 'spike_index_collision.npy')

    # collided spikes
    if os.path.exists(path_to_spike_index_collision):
        # if it exists, load it...
        logger.info('Found collided spikes in {}, loading them...'.format(
            path_to_spike_index_collision))
        spike_index_collision = np.load(path_to_spike_index_collision)

        if spike_index_collision.shape[0] != 0:
            raise ValueError('Found non-empty collision spike index in {}, '
                             'but threshold detector is selected, collision '
                             'detection is not implemented for threshold '
                             'detector so array must have dimensios (0, 2) '
                             'but had ({}, {})'.format(
                                 path_to_spike_index_collision,
                                 *spike_index_collision.shape))
    else:
        # triage is not implemented on threshold detector, return empty array
        logger.info('Creating empty array for'
                    ' collided spikes (collision detection is not implemented'
                    ' with threshold detector. Saving them in {}'.format(
                        path_to_spike_index_collision))
        spike_index_collision = np.zeros((0, 2), 'int32')
        np.save(path_to_spike_index_collision, spike_index_collision)

    #######################
    # Waveform extraction #
    #######################

    # load and dump waveforms from clear spikes
    path_to_waveforms_clear = os.path.join(TMP_FOLDER, 'waveforms_clear.npy')

    if os.path.exists(path_to_waveforms_clear):
        logger.info('Found clear waveforms in {}, loading them...'.format(
            path_to_waveforms_clear))
        waveforms_clear = np.load(path_to_waveforms_clear)
    else:
        logger.info(
            'Did not find clear waveforms in {}, reading them from {}'.format(
                path_to_waveforms_clear, standarized_path))
        explorer = RecordingExplorer(standarized_path,
                                     spike_size=CONFIG.spikeSize)
        waveforms_clear = explorer.read_waveforms(spike_index_clear[:, 0])
        np.save(path_to_waveforms_clear, waveforms_clear)
        logger.info('Saved waveform from clear spikes in: {}'.format(
            path_to_waveforms_clear))

    #########################
    # PCA - rotation matrix #
    #########################

    # compute per-batch sufficient statistics for PCA on standarized data
    logger.info('Computing PCA sufficient statistics...')
    stats = bp.multi_channel_apply(dim_red.suff_stat,
                                   mode='memory',
                                   spike_index=spike_index_clear,
                                   spike_size=CONFIG.spikeSize)

    suff_stats = reduce(lambda x, y: np.add(x, y), [e[0] for e in stats])

    spikes_per_channel = reduce(lambda x, y: np.add(x, y),
                                [e[1] for e in stats])

    # compute rotation matrix
    logger.info('Computing PCA projection matrix...')
    rotation = dim_red.project(suff_stats, spikes_per_channel,
                               CONFIG.spikes.temporal_features,
                               CONFIG.neighChannels)
    path_to_rotation = os.path.join(TMP_FOLDER, 'rotation.npy')
    np.save(path_to_rotation, rotation)
    logger.info('Saved rotation matrix in {}...'.format(path_to_rotation))

    main_channel = spike_index_clear[:, 1]
    ###########################################
    # PCA - waveform dimensionality reduction #
    ###########################################
    if CONFIG.clustering.clustering_method == 'location':
        logger.info('Denoising...')
        path_to_denoised_waveforms = os.path.join(TMP_FOLDER,
                                                  'denoised_waveforms.npy')
        if os.path.exists(path_to_denoised_waveforms):
            logger.info(
                'Found denoised waveforms in {}, loading them...'.format(
                    path_to_denoised_waveforms))
            denoised_waveforms = np.load(path_to_denoised_waveforms)
        else:
            logger.info(
                'Did not find denoised waveforms in {}, evaluating them'
                'from {}'.format(path_to_denoised_waveforms,
                                 path_to_waveforms_clear))
            waveforms_clear = np.load(path_to_waveforms_clear)
            denoised_waveforms = dim_red.denoise(waveforms_clear, rotation,
                                                 CONFIG)
            logger.info('Saving denoised waveforms to {}'.format(
                path_to_denoised_waveforms))
            np.save(path_to_denoised_waveforms, denoised_waveforms)

        isolated_index, x, y = get_isolated_spikes_and_locations(
            denoised_waveforms, main_channel, CONFIG)
        x = (x - np.mean(x)) / np.std(x)
        y = (y - np.mean(y)) / np.std(y)
        corrupted_index = np.logical_not(
            np.in1d(np.arange(spike_index_clear.shape[0]), isolated_index))
        spike_index_collision = np.concatenate(
            [spike_index_collision, spike_index_clear[corrupted_index]],
            axis=0)
        spike_index_clear = spike_index_clear[isolated_index]
        waveforms_clear = waveforms_clear[isolated_index]

        #################################################
        # Dimensionality reduction (Isolated Waveforms) #
        #################################################

        scores = dim_red.main_channel_scores(waveforms_clear, rotation,
                                             spike_index_clear, CONFIG)
        scores = (scores - np.mean(scores, axis=0)) / np.std(scores)
        scores = np.concatenate([
            x[:, np.newaxis, np.newaxis], y[:, np.newaxis, np.newaxis],
            scores[:, :, np.newaxis]
        ],
                                axis=1)
    else:
        logger.info('Reducing spikes dimensionality with PCA matrix...')
        scores = dim_red.score(waveforms_clear, rotation, spike_index_clear[:,
                                                                            1],
                               CONFIG.neighChannels, CONFIG.geom)

        # save scores
    path_to_score = os.path.join(TMP_FOLDER, 'score_clear.npy')
    np.save(path_to_score, scores)
    logger.info('Saved spike scores in {}...'.format(path_to_score))

    return scores, spike_index_clear, spike_index_collision
Exemplo n.º 4
0
Arquivo: detect.py Projeto: Nomow/yass
def threshold(path_to_data,
              dtype,
              n_channels,
              data_order,
              max_memory,
              neighbors,
              spike_size,
              minimum_half_waveform_size,
              threshold,
              output_path=None,
              spike_index_clear_filename='spike_index_clear.npy',
              if_file_exists='skip'):
    """Threshold spike detection in batches

    Parameters
    ----------
    path_to_data: str
        Path to recordings in binary format

    dtype: str
        Recordings dtype

    n_channels: int
        Number of channels in the recordings

    data_order: str
        Recordings order, one of ('channels', 'samples'). In a dataset with k
        observations per channel and j channels: 'channels' means first k
        contiguous observations come from channel 0, then channel 1, and so
        on. 'sample' means first j contiguous data are the first observations
        from all channels, then the second observations from all channels and
        so on

    max_memory:
        Max memory to use in each batch (e.g. 100MB, 1GB)

    neighbors_matrix: numpy.ndarray (n_channels, n_channels)
        Boolean numpy 2-D array where a i, j entry is True if i is considered
        neighbor of j

    spike_size: int
        Spike size

    minimum_half_waveform_size: int
        This is used to remove spikes that are either at the beginning or end
        of the recordings and whose location does not allow to draw a
        wavefor of size at least 2 * minimum_half_waveform_size + 1

    threshold: float
        Threshold used on amplitude for detection

    output_path: str, optional
        Directory to save spike indexes, if None, results won't be stored, but
        only returned by the function

    spike_index_clear_filename: str, optional
        Filename for spike_index_clear, it is used as the filename for the
        file (relative to output_path), if None, results won't be saved, only
        returned

    if_file_exists:
        What to do if there is already a file in save_spike_index_clear
        path. One of 'overwrite', 'abort', 'skip'. If 'overwrite' it replaces
        he file if it exists, if 'abort' if raise a ValueError exception if
        the file exists, if 'skip' it skips the operation if
        save_spike_index_clear and save_spike_index_collision the file exist
        and loads them from disk, if any of the files is missing they are
        computed

    Returns
    -------
    spike_index_clear: numpy.ndarray (n_clear_spikes, 2)
        2D array with indexes for clear spikes, first column contains the
        spike location in the recording and the second the main channel
        (channel whose amplitude is maximum)

    spike_index_collision: numpy.ndarray (0, 2)
        Empty array, collision is not implemented in the threshold detector
    """
    # instatiate batch processor
    bp = BatchProcessor(path_to_data,
                        dtype,
                        n_channels,
                        data_order,
                        max_memory,
                        buffer_size=spike_size)

    # run threshold detector
    spikes = bp.multi_channel_apply(_threshold,
                                    mode='memory',
                                    cleanup_function=fix_indexes,
                                    neighbors=neighbors,
                                    spike_size=spike_size,
                                    threshold=threshold)

    # no collision detection implemented, all spikes are marked as clear
    spike_index_clear = np.vstack(spikes)

    # remove spikes whose location won't let us draw a complete waveform
    logger.info('Removing clear indexes outside the allowed range to '
                'draw a complete waveform...')
    spike_index_clear, _ = (remove_incomplete_waveforms(
        spike_index_clear, minimum_half_waveform_size, bp.reader.observations))

    if output_path and spike_index_clear_filename:
        path = os.path.join(output_path, spike_index_clear_filename)
        save_numpy_object(spike_index_clear,
                          path,
                          if_file_exists=if_file_exists,
                          name='Spike index clear')

    return spike_index_clear
Exemplo n.º 5
0
big_long[:, 1] = x_long

bp_long = BatchProcessor(path_to_long,
                         dtype='int64',
                         n_channels=50,
                         data_format='long',
                         max_memory='500MB')

path = bp_long.single_channel_apply(dummy, path_to_out)

out = RecordingsReader(path)
out

bp_wide = BatchProcessor(path_to_wide,
                         dtype='int64',
                         n_channels=50,
                         data_format='wide',
                         max_memory='500MB')

path = bp_wide.single_channel_apply(dummy, path_to_out)
out = RecordingsReader(path)
out

path = bp_long.multi_channel_apply(dummy, path_to_out)
out = RecordingsReader(path)
out

path = bp_wide.multi_channel_apply(dummy, path_to_out)
out = RecordingsReader(path)
out
Exemplo n.º 6
0
def pca(path_to_data,
        dtype,
        n_channels,
        data_order,
        spike_index,
        spike_size,
        temporal_features,
        neighbors_matrix,
        channel_index,
        max_memory,
        output_path=None,
        scores_filename='scores.npy',
        rotation_matrix_filename='rotation.npy',
        spike_index_clear_filename='spike_index_clear_pca.npy',
        if_file_exists='skip'):
    """Apply PCA in batches

    Parameters
    ----------
    path_to_data: str
        Path to recordings in binary format

    dtype: str
        Recordings dtype

    n_channels: int
        Number of channels in the recordings

    data_order: str
        Recordings order, one of ('channels', 'samples'). In a dataset with k
        observations per channel and j channels: 'channels' means first k
        contiguous observations come from channel 0, then channel 1, and so
        on. 'sample' means first j contiguous data are the first observations
        from all channels, then the second observations from all channels and
        so on

    spike_index: numpy.ndarray
        A 2D numpy array, first column is spike time, second column is main
        channel (the channel where spike has the biggest amplitude)

    spike_size: int
        Spike size

    temporal_features: numpy.ndarray
        Number of output features

    neighbors_matrix: numpy.ndarray (n_channels, n_channels)
        Boolean numpy 2-D array where a i, j entry is True if i is considered
        neighbor of j

    channel_index: np.array (n_channels, n_neigh)
        Each row indexes its neighboring channels.
        For example, channel_index[c] is the index of
        neighboring channels (including itself)
        If any value is equal to n_channels, it is nothing but
        a space holder in a case that a channel has less than
        n_neigh neighboring channels


    max_memory:
        Max memory to use in each batch (e.g. 100MB, 1GB)

    output_path: str, optional
        Directory to store the scores and rotation matrix, if None, previous
        results on disk are ignored, operations are computed and results
        aren't saved to disk

    scores_filename: str, optional
        File name for rotation matrix if False, does not save data

    rotation_matrix_filename: str, optional
        File name for scores if False, does not save data

    spike_index_clear_filename: str, optional
        File name for spike index clear

    if_file_exists:
        What to do if there is already a file in the rotation matrix and/or
        scores location. One of 'overwrite', 'abort', 'skip'. If 'overwrite'
        it replaces the file if it exists, if 'abort' if raise a ValueError
        exception if the file exists, if 'skip' if skips the operation if the
        file exists

    Returns
    -------
    scores: numpy.ndarray
        Numpy 3D array  of size (n_waveforms, n_reduced_features,
        n_neighboring_channels) Scores for every waveform, second dimension in
        the array is reduced from n_temporal_features to n_reduced_features,
        third dimension depends on the number of  neighboring channels

    rotation_matrix: numpy.ndarray
        3D array (window_size, n_features, n_channels)
    """

    ###########################
    # compute rotation matrix #
    ###########################

    bp = BatchProcessor(path_to_data,
                        dtype,
                        n_channels,
                        data_order,
                        max_memory,
                        buffer_size=spike_size)

    # compute PCA sufficient statistics
    logger.info('Computing PCA sufficient statistics...')
    stats = bp.multi_channel_apply(suff_stat,
                                   mode='memory',
                                   spike_index=spike_index,
                                   spike_size=spike_size)
    suff_stats = reduce(lambda x, y: np.add(x, y), [e[0] for e in stats])
    spikes_per_channel = reduce(lambda x, y: np.add(x, y),
                                [e[1] for e in stats])

    # compute PCA projection matrix
    logger.info('Computing PCA projection matrix...')
    rotation = project(suff_stats, spikes_per_channel, temporal_features,
                       neighbors_matrix)

    #####################################
    # waveform dimensionality reduction #
    #####################################

    logger.info('Reducing spikes dimensionality with PCA matrix...')
    res = bp.multi_channel_apply(score,
                                 mode='memory',
                                 pass_batch_info=True,
                                 rot=rotation,
                                 channel_index=channel_index,
                                 spike_index=spike_index)

    scores = np.concatenate([element[0] for element in res], axis=0)
    spike_index = np.concatenate([element[1] for element in res], axis=0)

    # save scores
    if output_path and scores_filename:
        path_to_score = Path(output_path) / scores_filename
        save_numpy_object(scores,
                          path_to_score,
                          if_file_exists=if_file_exists,
                          name='scores')

    if output_path and spike_index_clear_filename:
        path_to_spike_index = Path(output_path) / spike_index_clear_filename
        save_numpy_object(spike_index,
                          path_to_spike_index,
                          if_file_exists=if_file_exists,
                          name='Spike index PCA')

    if output_path and rotation_matrix_filename:
        path_to_rotation = Path(output_path) / rotation_matrix_filename
        save_numpy_object(rotation,
                          path_to_rotation,
                          if_file_exists=if_file_exists,
                          name='rotation matrix')

    return scores, spike_index, rotation
Exemplo n.º 7
0
def test_splitting_in_batches_does_not_affect(path_to_config,
                                              path_to_sample_pipeline_folder,
                                              make_tmp_folder,
                                              path_to_standardized_data):
    yass.set_config(path_to_config, make_tmp_folder)
    CONFIG = yass.read_config()

    PATH_TO_DATA = path_to_standardized_data

    with open(path.join(path_to_sample_pipeline_folder, 'preprocess',
                        'standardized.yaml')) as f:
        PARAMS = yaml.load(f)

    channel_index = make_channel_index(CONFIG.neigh_channels,
                                       CONFIG.geom)

    detection_th = CONFIG.detect.neural_network_detector.threshold_spike
    triage_th = CONFIG.detect.neural_network_triage.threshold_collision
    detection_fname = CONFIG.detect.neural_network_detector.filename
    ae_fname = CONFIG.detect.neural_network_autoencoder.filename
    triage_fname = CONFIG.detect.neural_network_triage.filename

    # instantiate neural networks
    NND = NeuralNetDetector.load(detection_fname, detection_th,
                                 channel_index)
    triage = KerasModel(triage_fname,
                        allow_longer_waveform_length=True,
                        allow_more_channels=True)
    NNAE = AutoEncoder.load(ae_fname, input_tensor=NND.waveform_tf)

    bp = BatchProcessor(PATH_TO_DATA, PARAMS['dtype'], PARAMS['n_channels'],
                        PARAMS['data_order'], '100KB',
                        buffer_size=CONFIG.spike_size)

    out = ('spike_index', 'waveform')
    fn = neuralnetwork.apply.fix_indexes_spike_index

    # detector
    with tf.Session() as sess:
        # get values of above tensors
        NND.restore(sess)

        res = bp.multi_channel_apply(NND.predict_recording,
                                     mode='memory',
                                     sess=sess,
                                     output_names=out,
                                     cleanup_function=fn)

    spike_index_new = np.concatenate([element[0] for element in res], axis=0)
    wfs = np.concatenate([element[1] for element in res], axis=0)

    idx_clean = triage.predict_with_threshold(wfs, triage_th)
    score = NNAE.predict(wfs)
    rot = NNAE.load_rotation()
    neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

    (score_clear_new,
        spike_index_clear_new) = post_processing(score,
                                                 spike_index_new,
                                                 idx_clean,
                                                 rot,
                                                 neighbors)

    with tf.Session() as sess:
        # get values of above tensors
        NND.restore(sess)

        res = bp.multi_channel_apply(NND.predict_recording,
                                     mode='memory',
                                     sess=sess,
                                     output_names=('spike_index',
                                                   'waveform'),
                                     cleanup_function=fn)

    spike_index_batch, wfs = zip(*res)

    spike_index_batch = np.concatenate(spike_index_batch, axis=0)
    wfs = np.concatenate(wfs, axis=0)

    idx_clean = triage.predict_with_threshold(x=wfs,
                                              threshold=triage_th)

    score = NNAE.predict(wfs)
    rot = NNAE.load_rotation()
    neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

    (score_clear_batch,
        spike_index_clear_batch) = post_processing(score,
                                                   spike_index_batch,
                                                   idx_clean,
                                                   rot,
                                                   neighbors)
Exemplo n.º 8
0
def test_splitting_in_batches_does_not_affect_result(path_to_tests):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    PATH_TO_DATA = path.join(path_to_tests, 'data/standarized.bin')

    data = RecordingsReader(PATH_TO_DATA, loader='array').data

    with open(path.join(path_to_tests, 'data/standarized.yaml')) as f:
        PARAMS = yaml.load(f)

    channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom)

    whiten_filter = np.tile(
        np.eye(channel_index.shape[1], dtype='float32')[np.newaxis, :, :],
        [channel_index.shape[0], 1, 1])

    detection_th = CONFIG.detect.neural_network_detector.threshold_spike
    triage_th = CONFIG.detect.neural_network_triage.threshold_collision
    detection_fname = CONFIG.detect.neural_network_detector.filename
    ae_fname = CONFIG.detect.neural_network_autoencoder.filename
    triage_fname = CONFIG.detect.neural_network_triage.filename
    (x_tf, output_tf, NND, NNAE, NNT) = neuralnetwork.prepare_nn(
        channel_index,
        whiten_filter,
        detection_th,
        triage_th,
        detection_fname,
        ae_fname,
        triage_fname,
    )

    # run all at once
    with tf.Session() as sess:
        # get values of above tensors
        NND.saver.restore(sess, NND.path_to_detector_model)
        NNAE.saver_ae.restore(sess, NNAE.path_to_ae_model)
        NNT.saver.restore(sess, NNT.path_to_triage_model)
        rot = NNAE.load_rotation()
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        (scores, clear, collision) = neuralnetwork.run_detect_triage_featurize(
            data, sess, x_tf, output_tf, neighbors, rot)

    # run in batches - buffer size makes sure we can detect spikes if they
    # appear at the end of any batch
    bp = BatchProcessor(PATH_TO_DATA,
                        PARAMS['dtype'],
                        PARAMS['n_channels'],
                        PARAMS['data_order'],
                        '100KB',
                        buffer_size=CONFIG.spike_size)

    with tf.Session() as sess:
        # get values of above tensors
        NND.saver.restore(sess, NND.path_to_detector_model)
        NNAE.saver_ae.restore(sess, NNAE.path_to_ae_model)
        NNT.saver.restore(sess, NNT.path_to_triage_model)

        rot = NNAE.load_rotation()
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        res = bp.multi_channel_apply(
            neuralnetwork.run_detect_triage_featurize,
            mode='memory',
            cleanup_function=neuralnetwork.fix_indexes,
            sess=sess,
            x_tf=x_tf,
            output_tf=output_tf,
            rot=rot,
            neighbors=neighbors)

    scores_batch = np.concatenate([element[0] for element in res], axis=0)
    clear_batch = np.concatenate([element[1] for element in res], axis=0)
    collision_batch = np.concatenate([element[2] for element in res], axis=0)

    np.testing.assert_array_equal(clear_batch, clear)
    np.testing.assert_array_equal(collision_batch, collision)
    np.testing.assert_array_equal(scores_batch, scores)
Exemplo n.º 9
0

# create batch processor for the data
bp = BatchProcessor(path_to_neuropixel_data,
                    dtype='int16',
                    n_channels=385,
                    data_format='wide',
                    max_memory='500MB')

# appply a multi channel transformation, each batch will be a temporal
# subset with observations from all selected n_channels, the size
# of the subset is calculated depending on max_memory. Each batch is
# processed and when done, results are save to disk, the next batch is
# then loaded and so on
bp.multi_channel_apply(sum_one,
                       mode='disk',
                       output_path=path_to_modified_data,
                       channels=[0, 1, 2])

# let's visualize the results
raw = RecordingsReader(path_to_neuropixel_data,
                       dtype='int16',
                       n_channels=385,
                       data_format='wide')

# you do not need to specify the format since multi_channel_apply
# saves a yaml file with such parameters
filtered = RecordingsReader(path_to_modified_data)

fig, (ax1, ax2) = plt.subplots(2, 1)
ax1.plot(raw[:2000, 0])
ax2.plot(filtered[:2000, 0])
Exemplo n.º 10
0
def butterworth(path_to_data,
                dtype,
                n_channels,
                data_order,
                low_frequency,
                high_factor,
                order,
                sampling_frequency,
                max_memory,
                output_path,
                output_dtype,
                standarize=False,
                output_filename='filtered.bin',
                if_file_exists='skip',
                processes='max'):
    """Filter (butterworth) recordings in batches

    Parameters
    ----------
    path_to_data: str
        Path to recordings in binary format

    dtype: str
        Recordings dtype

    n_channels: int
        Number of channels in the recordings

    data_order: str
        Recordings order, one of ('channels', 'samples'). In a dataset with k
        observations per channel and j channels: 'channels' means first k
        contiguous observations come from channel 0, then channel 1, and so
        on. 'sample' means first j contiguous data are the first observations
        from all channels, then the second observations from all channels and
        so on

    low_frequency: int
        Low pass frequency (Hz)

    high_factor: float
        High pass factor (proportion of sampling rate)

    order: int
        Order of Butterworth filter

    sampling_frequency: int
        Recordings sampling frequency in Hz

    max_memory: str
        Max memory to use in each batch (e.g. 100MB, 1GB)

    output_path: str
        Folder to store the filtered recordings

    output_filename: str, optional
        How to name the file, defaults to filtered.bin

    output_dtype: str
        dtype for filtered data

    standarize: bool
        Whether to standarize the data after the filtering step

    if_file_exists: str, optional
        One of 'overwrite', 'abort', 'skip'. If 'overwrite' it replaces the
        file if it exists, if 'abort' if raise a ValueError exception if
        the file exists, if 'skip' if skips the operation if the file
        exists

    processes: str or int, optional
        Number of processes to use, if 'max', it uses all cores in the machine
        if a number, it uses that number of cores

    Returns
    -------
    standarized_path: str
        Location to filtered recordings

    standarized_params: dict
        A dictionary with the parameters for the filtered recordings
        (dtype, n_channels, data_order)
    """
    processes = multiprocess.cpu_count() if processes == 'max' else processes

    # init batch processor
    bp = BatchProcessor(path_to_data,
                        dtype,
                        n_channels,
                        data_order,
                        max_memory,
                        buffer_size=200)

    if standarize:
        bp_ = BatchProcessor(path_to_data,
                             dtype,
                             n_channels,
                             data_order,
                             max_memory,
                             buffer_size=0)

        filtering = partial(_butterworth,
                            low_frequency=low_frequency,
                            high_factor=high_factor,
                            order=order,
                            sampling_frequency=sampling_frequency)

        # if standarize, estimate sd from first batch and use
        # _butterworth_scale function, pass filtering to estimate sd from the
        # filtered data
        sd = standard_deviation(bp_,
                                sampling_frequency,
                                preprocess_fn=filtering)
        fn = partial(_butterworth_scale, denominator=sd)
        # add name to the partial object, since it is not added...
        fn.__name__ = _butterworth_scale.__name__
    else:
        # otherwise use _butterworth function
        fn = _butterworth

    _output_path = os.path.join(output_path, output_filename)

    (path,
     params) = bp.multi_channel_apply(fn,
                                      mode='disk',
                                      cleanup_function=fix_indexes,
                                      output_path=_output_path,
                                      if_file_exists=if_file_exists,
                                      cast_dtype=output_dtype,
                                      low_frequency=low_frequency,
                                      high_factor=high_factor,
                                      order=order,
                                      sampling_frequency=sampling_frequency,
                                      processes=processes)

    return path, params
Exemplo n.º 11
0
path_to_neuropixel_data = (os.path.expanduser('~/data/ucl-neuropixel'
                           '/rawDataSample.bin'))
path_to_filtered_data = (os.path.expanduser('~/data/ucl-neuropixel'
                         '/tmp/filtered_multi.bin'))

# create batch processor for the data
bp = BatchProcessor(path_to_neuropixel_data,
                    dtype='int16', n_channels=385, data_format='wide',
                    max_memory='500MB')

# appply a multi channel transformation, each batch will be a temporal
# subset with observations from all selected n_channels, the size
# of the subset is calculated depending on max_memory
bp.multi_channel_apply(butterworth, path_to_filtered_data,
                       channels='all',
                       low_freq=300, high_factor=0.1,
                       order=3, sampling_freq=30000)

# let's visualize the results
raw = RecordingsReader(path_to_neuropixel_data, dtype='int16',
                       n_channels=385, data_format='wide')

# you do not need to specify the format since multi_channel_apply
# saves a yaml file with such parameters
filtered = RecordingsReader(path_to_filtered_data)

fig, (ax1, ax2) = plt.subplots(2, 1)
ax1.plot(raw[:2000, 0])
ax2.plot(filtered[:2000, 0])
plt.show()
def pca(path_to_data,
        dtype,
        n_channels,
        data_order,
        recordings,
        spike_index,
        spike_size,
        temporal_features,
        neighbors_matrix,
        channel_index,
        max_memory,
        gmm_params,
        output_path=None,
        scores_filename='scores.npy',
        rotation_matrix_filename='rotation.npy',
        spike_index_clear_filename='spike_index_clear_pca.npy',
        if_file_exists='skip'):
    """Apply PCA in batches

    Parameters
    ----------
    path_to_data: str
        Path to recordings in binary format

    dtype: str
        Recordings dtype

    n_channels: int
        Number of channels in the recordings

    data_order: str
        Recordings order, one of ('channels', 'samples'). In a dataset with k
        observations per channel and j channels: 'channels' means first k
        contiguous observations come from channel 0, then channel 1, and so
        on. 'sample' means first j contiguous data are the first observations
        from all channels, then the second observations from all channels and
        so on

    recordings: np.ndarray (n_observations, n_channels)
        Multi-channel recordings

    spike_index: numpy.ndarray
        A 2D numpy array, first column is spike time, second column is main
        channel (the channel where spike has the biggest amplitude)

    spike_size: int
        Spike size

    temporal_features: numpy.ndarray
        Number of output features

    neighbors_matrix: numpy.ndarray (n_channels, n_channels)
        Boolean numpy 2-D array where a i, j entry is True if i is considered
        neighbor of j

    channel_index: np.array (n_channels, n_neigh)
        Each row indexes its neighboring channels.
        For example, channel_index[c] is the index of
        neighboring channels (including itself)
        If any value is equal to n_channels, it is nothing but
        a space holder in a case that a channel has less than
        n_neigh neighboring channels


    max_memory:
        Max memory to use in each batch (e.g. 100MB, 1GB)

    gmm_params:
        Dictionary with the parameters of the Gaussian mixture model
        
    output_path: str, optional
        Directory to store the scores and rotation matrix, if None, previous
        results on disk are ignored, operations are computed and results
        aren't saved to disk

    scores_filename: str, optional
        File name for rotation matrix if False, does not save data

    rotation_matrix_filename: str, optional
        File name for scores if False, does not save data

    spike_index_clear_filename: str, optional
        File name for spike index clear

    if_file_exists:
        What to do if there is already a file in the rotation matrix and/or
        scores location. One of 'overwrite', 'abort', 'skip'. If 'overwrite'
        it replaces the file if it exists, if 'abort' if raise a ValueError
        exception if the file exists, if 'skip' if skips the operation if the
        file exists

    Returns
    -------
    scores: numpy.ndarray
        Numpy 3D array  of size (n_waveforms, n_reduced_features,
        n_neighboring_channels) Scores for every waveform, second dimension in
        the array is reduced from n_temporal_features to n_reduced_features,
        third dimension depends on the number of  neighboring channels

    rotation_matrix: numpy.ndarray
        3D array (window_size, n_features, n_channels)
    """

    ###########################
    # compute rotation matrix #
    ###########################

    bp = BatchProcessor(path_to_data,
                        dtype,
                        n_channels,
                        data_order,
                        max_memory,
                        buffer_size=spike_size)

    # compute WPCA
    WAVE, FEATURE, CH = 0, 1, 2

    logger.info('Preforming WPCA')

    logger.info('Computing Wavelets ...')
    feature = bp.multi_channel_apply(wavedec,
                                     mode='memory',
                                     pass_batch_info=True,
                                     spike_index=spike_index,
                                     spike_size=spike_size,
                                     wvtype='haar')

    features = reduce(lambda x, y: np.concatenate((x, y)),
                      [f for f in feature])

    logger.info('Computing weights..')

    # Weighting the features using metric defined in gmtype
    weights = gmm_weight(features, gmm_params, spike_index)
    wfeatures = features * weights

    n_features = wfeatures.shape[FEATURE]
    wfeatures_lin = np.reshape(
        wfeatures, (wfeatures.shape[WAVE] * n_features, wfeatures.shape[CH]))
    feature_index = np.arange(0, wfeatures.shape[WAVE] * n_features,
                              n_features)

    TMP_FOLDER, _ = os.path.split(path_to_data)
    feature_path = os.path.join(TMP_FOLDER, 'features.bin')
    feature_params = writefile(wfeatures_lin, feature_path)

    bp_feat = BatchProcessor(feature_path,
                             feature_params['dtype'],
                             feature_params['n_channels'],
                             feature_params['data_order'],
                             max_memory,
                             buffer_size=n_features)

    # compute PCA sufficient statistics from extracted features

    logger.info('Computing PCA sufficient statistics...')
    stats = bp_feat.multi_channel_apply(suff_stat_features,
                                        mode='memory',
                                        pass_batch_info=True,
                                        spike_index=spike_index,
                                        spike_size=spike_size,
                                        feature_index=feature_index,
                                        feature_size=n_features)

    suff_stats = reduce(lambda x, y: np.add(x, y), [e[0] for e in stats])
    spikes_per_channel = reduce(lambda x, y: np.add(x, y),
                                [e[1] for e in stats])

    # compute PCA projection matrix
    logger.info('Computing PCA projection matrix...')
    rotation = project(suff_stats, spikes_per_channel, temporal_features,
                       neighbors_matrix)

    #####################################
    # waveform dimensionality reduction #
    #####################################

    logger.info('Reducing spikes dimensionality with PCA matrix...')

    # using a new Batch to read feature file
    res = bp_feat.multi_channel_apply(score_features,
                                      mode='memory',
                                      pass_batch_info=True,
                                      rot=rotation,
                                      channel_index=channel_index,
                                      spike_index=spike_index,
                                      feature_index=feature_index)

    scores = np.concatenate([element[0] for element in res], axis=0)
    spike_index = np.concatenate([element[1] for element in res], axis=0)
    feature_index = np.concatenate([element[2] for element in res], axis=0)

    # renormalizing PC projections to similar unitary variance
    scores = st.zscore(scores, axis=0)

    # save scores
    if output_path and scores_filename:
        path_to_score = Path(output_path) / scores_filename
        save_numpy_object(scores,
                          path_to_score,
                          if_file_exists=if_file_exists,
                          name='scores')

    if output_path and spike_index_clear_filename:
        path_to_spike_index = Path(output_path) / spike_index_clear_filename
        save_numpy_object(spike_index,
                          path_to_spike_index,
                          if_file_exists=if_file_exists,
                          name='Spike index PCA')

    if output_path and rotation_matrix_filename:
        path_to_rotation = Path(output_path) / rotation_matrix_filename
        save_numpy_object(rotation,
                          path_to_rotation,
                          if_file_exists=if_file_exists,
                          name='rotation matrix')

    return scores, spike_index, rotation
Exemplo n.º 13
0
    """Add one to every element in the batch
    """
    return np.max(batch, axis=0)


# create batch processor for the data
bp = BatchProcessor(path_to_neuropixel_data,
                    dtype='int16', n_channels=385, data_format='wide',
                    max_memory='10MB')

# appply a multi channel transformation, each batch will be a temporal
# subset with observations from all selected n_channels, the size
# of the subset is calculated depending on max_memory. Results
# from every batch are returned in a list
res = bp.multi_channel_apply(max_in_channel,
                             mode='memory',
                             channels=[0, 1, 2])


# we have one element per batch
len(res)


# output for the first batch
res[0]


# stack results from every batch
arr = np.stack(res, axis=0)

Exemplo n.º 14
0
def run(output_directory='tmp/'):
    """Execute preprocessing pipeline

    Parameters
    ----------
    output_directory: str, optional
      Location to store partial results, relative to CONFIG.data.root_folder,
      defaults to tmp/

    Returns
    -------
    clear_scores: numpy.ndarray (n_spikes, n_features, n_channels)
        3D array with the scores for the clear spikes, first simension is
        the number of spikes, second is the nymber of features and third the
        number of channels

    spike_index_clear: numpy.ndarray (n_clear_spikes, 2)
        2D array with indexes for clear spikes, first column contains the
        spike location in the recording and the second the main channel
        (channel whose amplitude is maximum)

    spike_index_collision: numpy.ndarray (n_collided_spikes, 2)
        2D array with indexes for collided spikes, first column contains the
        spike location in the recording and the second the main channel
        (channel whose amplitude is maximum)

    Notes
    -----
    Running the preprocessor will generate the followiing files in
    CONFIG.data.root_folder/output_directory/:

    * ``config.yaml`` - Copy of the configuration file
    * ``metadata.yaml`` - Experiment metadata
    * ``filtered.bin`` - Filtered recordings
    * ``filtered.yaml`` - Filtered recordings metadata
    * ``standarized.bin`` - Standarized recordings
    * ``standarized.yaml`` - Standarized recordings metadata
    * ``whitened.bin`` - Whitened recordings
    * ``whitened.yaml`` - Whitened recordings metadata
    * ``rotation.npy`` - Rotation matrix for dimensionality reduction
    * ``spike_index_clear.npy`` - Same as spike_index_clear returned
    * ``spike_index_collision.npy`` - Same as spike_index_collision returned
    * ``score_clear.npy`` - Scores for clear spikes
    * ``waveforms_clear.npy`` - Waveforms for clear spikes

    Examples
    --------

    .. literalinclude:: ../examples/preprocess.py
    """

    logger = logging.getLogger(__name__)

    CONFIG = read_config()

    OUTPUT_DTYPE = CONFIG.preprocess.dtype

    logger.info(
        'Output dtype for transformed data will be {}'.format(OUTPUT_DTYPE))

    TMP = os.path.join(CONFIG.data.root_folder, output_directory)

    if not os.path.exists(TMP):
        logger.info('Creating temporary folder: {}'.format(TMP))
        os.makedirs(TMP)
    else:
        logger.info('Temporary folder {} already exists, output will be '
                    'stored there'.format(TMP))

    path = os.path.join(CONFIG.data.root_folder, CONFIG.data.recordings)
    dtype = CONFIG.recordings.dtype

    # initialize pipeline object, one batch per channel
    pipeline = BatchPipeline(path, dtype, CONFIG.recordings.n_channels,
                             CONFIG.recordings.format,
                             CONFIG.resources.max_memory, TMP)

    # add filter transformation if necessary
    if CONFIG.preprocess.filter:
        filter_op = Transform(butterworth,
                              'filtered.bin',
                              mode='single_channel_one_batch',
                              keep=True,
                              if_file_exists='skip',
                              cast_dtype=OUTPUT_DTYPE,
                              low_freq=CONFIG.filter.low_pass_freq,
                              high_factor=CONFIG.filter.high_factor,
                              order=CONFIG.filter.order,
                              sampling_freq=CONFIG.recordings.sampling_rate)

        pipeline.add([filter_op])

    (filtered_path, ), (filtered_params, ) = pipeline.run()

    # standarize
    bp = BatchProcessor(filtered_path, filtered_params['dtype'],
                        filtered_params['n_channels'],
                        filtered_params['data_format'],
                        CONFIG.resources.max_memory)
    batches = bp.multi_channel()
    first_batch, _, _ = next(batches)
    sd = standard_deviation(first_batch, CONFIG.recordings.sampling_rate)

    (standarized_path, standarized_params) = bp.multi_channel_apply(
        standarize,
        mode='disk',
        output_path=os.path.join(TMP, 'standarized.bin'),
        if_file_exists='skip',
        cast_dtype=OUTPUT_DTYPE,
        sd=sd)

    standarized = RecordingsReader(standarized_path)
    n_observations = standarized.observations

    if CONFIG.spikes.detection == 'threshold':
        return _threshold_detection(standarized_path, standarized_params,
                                    n_observations, output_directory)
    elif CONFIG.spikes.detection == 'nn':
        return _neural_network_detection(standarized_path, standarized_params,
                                         n_observations, output_directory)
Exemplo n.º 15
0
def standarize(path_to_data, dtype, n_channels, data_order,
               sampling_frequency, max_memory, output_path,
               output_dtype, output_filename='standarized.bin',
               if_file_exists='skip', processes='max'):
    """
    Standarize recordings in batches and write results to disk. Standard
    deviation is estimated using the first batch

    Parameters
    ----------
    path_to_data: str
        Path to recordings in binary format

    dtype: str
        Recordings dtype

    n_channels: int
        Number of channels in the recordings

    data_order: str
        Recordings order, one of ('channels', 'samples'). In a dataset with k
        observations per channel and j channels: 'channels' means first k
        contiguous observations come from channel 0, then channel 1, and so
        on. 'sample' means first j contiguous data are the first observations
        from all channels, then the second observations from all channels and
        so on

    sampling_frequency: int
        Recordings sampling frequency in Hz

    max_memory: str
        Max memory to use in each batch (e.g. 100MB, 1GB)

    output_path: str
        Where to store the standarized recordings

    output_dtype: str
        dtype  for standarized data

    output_filename: str, optional
        Filename for the output data, defaults to whitened.bin

    if_file_exists: str, optional
        One of 'overwrite', 'abort', 'skip'. If 'overwrite' it replaces the
        standarized data if it exists, if 'abort' if raise a ValueError
        exception if the file exists, if 'skip' if skips the operation if the
        file exists

    processes: str or int, optional
        Number of processes to use, if 'max', it uses all cores in the machine
        if a number, it uses that number of cores

    Returns
    -------
    standarized_path: str
        Path to standarized recordings

    standarized_params: dict
        A dictionary with the parameters for the standarized recordings
        (dtype, n_channels, data_order)
    """
    processes = multiprocess.cpu_count() if processes == 'max' else processes
    _output_path = os.path.join(output_path, output_filename)

    # init batch processor
    bp = BatchProcessor(path_to_data, dtype, n_channels, data_order,
                        max_memory)

    sd = standard_deviation(bp, sampling_frequency)

    def divide(rec):
        return np.divide(rec, sd)

    # apply transformation
    (standarized_path,
     standarized_params) = bp.multi_channel_apply(divide,
                                                  mode='disk',
                                                  output_path=_output_path,
                                                  cast_dtype=output_dtype,
                                                  processes=processes)

    return standarized_path, standarized_params
Exemplo n.º 16
0
def _neural_network_detection(standarized_path, standarized_params,
                              n_observations, output_directory):
    """Run neural network detection and autoencoder dimensionality reduction
    """
    logger = logging.getLogger(__name__)

    CONFIG = read_config()
    OUTPUT_DTYPE = CONFIG.preprocess.dtype
    TMP_FOLDER = os.path.join(CONFIG.data.root_folder, output_directory)

    # detect spikes
    bp = BatchProcessor(standarized_path,
                        standarized_params['dtype'],
                        standarized_params['n_channels'],
                        standarized_params['data_format'],
                        CONFIG.resources.max_memory,
                        buffer_size=0)

    # check if all scores, clear and collision spikes exist..
    path_to_score = os.path.join(TMP_FOLDER, 'score_clear.npy')
    path_to_spike_index_clear = os.path.join(TMP_FOLDER,
                                             'spike_index_clear.npy')
    path_to_spike_index_collision = os.path.join(TMP_FOLDER,
                                                 'spike_index_collision.npy')

    if all([
            os.path.exists(path_to_score),
            os.path.exists(path_to_spike_index_clear),
            os.path.exists(path_to_spike_index_collision)
    ]):
        logger.info('Loading "{}", "{}" and "{}"'.format(
            path_to_score, path_to_spike_index_clear,
            path_to_spike_index_collision))

        scores = np.load(path_to_score)
        clear = np.load(path_to_spike_index_clear)
        collision = np.load(path_to_spike_index_collision)

    else:
        logger.info('One or more of "{}", "{}" or "{}" files were missing, '
                    'computing...'.format(path_to_score,
                                          path_to_spike_index_clear,
                                          path_to_spike_index_collision))

        # apply threshold detector on standarized data
        autoencoder_filename = CONFIG.neural_network_autoencoder.filename
        mc = bp.multi_channel_apply
        res = mc(
            neuralnetwork.nn_detection,
            mode='memory',
            cleanup_function=neuralnetwork.fix_indexes,
            neighbors=CONFIG.neighChannels,
            geom=CONFIG.geom,
            temporal_features=CONFIG.spikes.temporal_features,
            # FIXME: what is this?
            temporal_window=3,
            th_detect=CONFIG.neural_network_detector.threshold_spike,
            th_triage=CONFIG.neural_network_triage.threshold_collision,
            detector_filename=CONFIG.neural_network_detector.filename,
            autoencoder_filename=autoencoder_filename,
            triage_filename=CONFIG.neural_network_triage.filename)

        # save clear spikes
        clear = np.concatenate([element[1] for element in res], axis=0)
        logger.info('Removing clear indexes outside the allowed range to '
                    'draw a complete waveform...')
        clear, idx = detect.remove_incomplete_waveforms(
            clear, CONFIG.spikeSize + CONFIG.templatesMaxShift, n_observations)
        np.save(path_to_spike_index_clear, clear)
        logger.info('Saved spike index clear in {}...'.format(
            path_to_spike_index_clear))

        # save collided spikes
        collision = np.concatenate([element[2] for element in res], axis=0)
        logger.info('Removing collision indexes outside the allowed range to '
                    'draw a complete waveform...')
        collision, _ = detect.remove_incomplete_waveforms(
            collision, CONFIG.spikeSize + CONFIG.templatesMaxShift,
            n_observations)
        np.save(path_to_spike_index_collision, collision)
        logger.info('Saved spike index collision in {}...'.format(
            path_to_spike_index_collision))

        if CONFIG.clustering.clustering_method == 'location':
            #######################
            # Waveform extraction #
            #######################

            # TODO: what should the behaviour be for spike indexes that are
            # when starting/ending the recordings and it is not possible to
            # draw a complete waveform?
            logger.info('Computing whitening matrix...')
            bp = BatchProcessor(standarized_path, standarized_params['dtype'],
                                standarized_params['n_channels'],
                                standarized_params['data_format'],
                                CONFIG.resources.max_memory)
            batches = bp.multi_channel()
            first_batch, _, _ = next(batches)
            Q = whiten.matrix(first_batch, CONFIG.neighChannels,
                              CONFIG.spikeSize)

            path_to_whitening_matrix = os.path.join(TMP_FOLDER,
                                                    'whitening.npy')
            np.save(path_to_whitening_matrix, Q)
            logger.info('Saved whitening matrix in {}'.format(
                path_to_whitening_matrix))

            # apply whitening to every batch
            (whitened_path, whitened_params) = bp.multi_channel_apply(
                np.matmul,
                mode='disk',
                output_path=os.path.join(TMP_FOLDER, 'whitened.bin'),
                if_file_exists='skip',
                cast_dtype=OUTPUT_DTYPE,
                b=Q)

            main_channel = clear[:, 1]

            # load and dump waveforms from clear spikes

            path_to_waveforms_clear = os.path.join(TMP_FOLDER,
                                                   'waveforms_clear.npy')

            if os.path.exists(path_to_waveforms_clear):
                logger.info(
                    'Found clear waveforms in {}, loading them...'.format(
                        path_to_waveforms_clear))
                waveforms_clear = np.load(path_to_waveforms_clear)
            else:
                logger.info(
                    'Did not find clear waveforms in {}, reading them from {}'.
                    format(path_to_waveforms_clear, whitened_path))
                explorer = RecordingExplorer(whitened_path,
                                             spike_size=CONFIG.spikeSize)
                waveforms_clear = explorer.read_waveforms(clear[:, 0], 'all')
                np.save(path_to_waveforms_clear, waveforms_clear)
                logger.info('Saved waveform from clear spikes in: {}'.format(
                    path_to_waveforms_clear))

            main_channel = clear[:, 1]

            # save rotation
            detector_filename = CONFIG.neural_network_detector.filename
            autoencoder_filename = CONFIG.neural_network_autoencoder.filename
            rotation = neuralnetwork.load_rotation(detector_filename,
                                                   autoencoder_filename)
            path_to_rotation = os.path.join(TMP_FOLDER, 'rotation.npy')
            logger.info("rotation_matrix_shape = {}".format(rotation.shape))
            np.save(path_to_rotation, rotation)
            logger.info(
                'Saved rotation matrix in {}...'.format(path_to_rotation))

            logger.info('Denoising...')
            path_to_denoised_waveforms = os.path.join(
                TMP_FOLDER, 'denoised_waveforms.npy')
            if os.path.exists(path_to_denoised_waveforms):
                logger.info(
                    'Found denoised waveforms in {}, loading them...'.format(
                        path_to_denoised_waveforms))
                denoised_waveforms = np.load(path_to_denoised_waveforms)
            else:
                logger.info(
                    'Did not find denoised waveforms in {}, evaluating them'
                    'from {}'.format(path_to_denoised_waveforms,
                                     path_to_waveforms_clear))
                waveforms_clear = np.load(path_to_waveforms_clear)
                denoised_waveforms = dim_red.denoise(waveforms_clear, rotation,
                                                     CONFIG)
                logger.info('Saving denoised waveforms to {}'.format(
                    path_to_denoised_waveforms))
                np.save(path_to_denoised_waveforms, denoised_waveforms)

            isolated_index, x, y = get_isolated_spikes_and_locations(
                denoised_waveforms, main_channel, CONFIG)
            x = (x - np.mean(x)) / np.std(x)
            y = (y - np.mean(y)) / np.std(y)
            corrupted_index = np.logical_not(
                np.in1d(np.arange(clear.shape[0]), isolated_index))
            collision = np.concatenate([collision, clear[corrupted_index]],
                                       axis=0)
            clear = clear[isolated_index]
            waveforms_clear = waveforms_clear[isolated_index]
            #################################################
            # Dimensionality reduction (Isolated Waveforms) #
            #################################################

            scores = dim_red.main_channel_scores(waveforms_clear, rotation,
                                                 clear, CONFIG)
            scores = (scores - np.mean(scores, axis=0)) / np.std(scores)
            scores = np.concatenate([
                x[:, np.newaxis, np.newaxis], y[:, np.newaxis, np.newaxis],
                scores[:, :, np.newaxis]
            ],
                                    axis=1)

        else:

            # save scores
            scores = np.concatenate([element[0] for element in res], axis=0)

            logger.info(
                'Removing scores for indexes outside the allowed range to '
                'draw a complete waveform...')
            scores = scores[idx]

            # compute Q for whitening
            logger.info('Computing whitening matrix...')
            bp = BatchProcessor(standarized_path, standarized_params['dtype'],
                                standarized_params['n_channels'],
                                standarized_params['data_format'],
                                CONFIG.resources.max_memory)
            batches = bp.multi_channel()
            first_batch, _, _ = next(batches)
            Q = whiten.matrix_localized(first_batch, CONFIG.neighChannels,
                                        CONFIG.geom, CONFIG.spikeSize)

            path_to_whitening_matrix = os.path.join(TMP_FOLDER,
                                                    'whitening.npy')
            np.save(path_to_whitening_matrix, Q)
            logger.info('Saved whitening matrix in {}'.format(
                path_to_whitening_matrix))

            scores = whiten.score(scores, clear[:, 1], Q)

            np.save(path_to_score, scores)
            logger.info('Saved spike scores in {}...'.format(path_to_score))

            # save rotation
            detector_filename = CONFIG.neural_network_detector.filename
            autoencoder_filename = CONFIG.neural_network_autoencoder.filename
            rotation = neuralnetwork.load_rotation(detector_filename,
                                                   autoencoder_filename)
            path_to_rotation = os.path.join(TMP_FOLDER, 'rotation.npy')
            np.save(path_to_rotation, rotation)
            logger.info(
                'Saved rotation matrix in {}...'.format(path_to_rotation))

        np.save(path_to_score, scores)
        logger.info('Saved spike scores in {}...'.format(path_to_score))
    return scores, clear, collision
Exemplo n.º 17
0
def test_splitting_in_batches_does_not_affect(path_to_tests,
                                              path_to_standarized_data,
                                              path_to_sample_pipeline_folder):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    PATH_TO_DATA = path_to_standarized_data

    data = RecordingsReader(PATH_TO_DATA, loader='array').data

    with open(
            path.join(path_to_sample_pipeline_folder, 'preprocess',
                      'standarized.yaml')) as f:
        PARAMS = yaml.load(f)

    channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom)

    detection_th = CONFIG.detect.neural_network_detector.threshold_spike
    triage_th = CONFIG.detect.neural_network_triage.threshold_collision
    detection_fname = CONFIG.detect.neural_network_detector.filename
    ae_fname = CONFIG.detect.neural_network_autoencoder.filename
    triage_fname = CONFIG.detect.neural_network_triage.filename

    # instantiate neural networks
    NND = NeuralNetDetector.load(detection_fname, detection_th, channel_index)
    NNT = NeuralNetTriage.load(triage_fname,
                               triage_th,
                               input_tensor=NND.waveform_tf)
    NNAE = AutoEncoder(ae_fname, input_tensor=NND.waveform_tf)

    output_tf = (NNAE.score_tf, NND.spike_index_tf, NNT.idx_clean)

    # run all at once
    with tf.Session() as sess:
        # get values of above tensors
        NND.restore(sess)
        NNAE.restore(sess)
        NNT.restore(sess)

        rot = NNAE.load_rotation()
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        (scores, clear, collision) = neuralnetwork.run_detect_triage_featurize(
            data, sess, NND.x_tf, output_tf, neighbors, rot)

    # run in batches - buffer size makes sure we can detect spikes if they
    # appear at the end of any batch
    bp = BatchProcessor(PATH_TO_DATA,
                        PARAMS['dtype'],
                        PARAMS['n_channels'],
                        PARAMS['data_order'],
                        '100KB',
                        buffer_size=CONFIG.spike_size)

    with tf.Session() as sess:
        # get values of above tensors
        NND.restore(sess)
        NNAE.restore(sess)
        NNT.restore(sess)

        rot = NNAE.load_rotation()
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        res = bp.multi_channel_apply(
            neuralnetwork.run_detect_triage_featurize,
            mode='memory',
            cleanup_function=neuralnetwork.fix_indexes,
            sess=sess,
            x_tf=NND.x_tf,
            output_tf=output_tf,
            rot=rot,
            neighbors=neighbors)

    scores_batch = np.concatenate([element[0] for element in res], axis=0)
    clear_batch = np.concatenate([element[1] for element in res], axis=0)
    collision_batch = np.concatenate([element[2] for element in res], axis=0)

    np.testing.assert_array_equal(clear_batch, clear)
    np.testing.assert_array_equal(collision_batch, collision)
    np.testing.assert_array_equal(scores_batch, scores)