Exemple #1
0
def test_can_use_neural_network_detector(path_to_tests):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    data = RecordingsReader(path.join(path_to_tests, 'data/standarized.bin'),
                            loader='array').data

    channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom)

    whiten_filter = np.tile(
        np.eye(channel_index.shape[1], dtype='float32')[np.newaxis, :, :],
        [channel_index.shape[0], 1, 1])

    detection_th = CONFIG.detect.neural_network_detector.threshold_spike
    triage_th = CONFIG.detect.neural_network_triage.threshold_collision
    detection_fname = CONFIG.detect.neural_network_detector.filename
    ae_fname = CONFIG.detect.neural_network_autoencoder.filename
    triage_fname = CONFIG.detect.neural_network_triage.filename

    (x_tf, output_tf, NND, NNAE,
     NNT) = neuralnetwork.prepare_nn(channel_index, whiten_filter,
                                     detection_th, triage_th, detection_fname,
                                     ae_fname, triage_fname)

    with tf.Session() as sess:
        # get values of above tensors
        NND.saver.restore(sess, NND.path_to_detector_model)
        NNAE.saver_ae.restore(sess, NNAE.path_to_ae_model)
        NNT.saver.restore(sess, NNT.path_to_triage_model)
        rot = NNAE.load_rotation()
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        neuralnetwork.run_detect_triage_featurize(data, sess, x_tf, output_tf,
                                                  neighbors, rot)
Exemple #2
0
def test_can_use_neural_network_detector(path_to_tests,
                                         path_to_standarized_data):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    data = RecordingsReader(path_to_standarized_data, loader='array').data

    channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom)

    detection_th = CONFIG.detect.neural_network_detector.threshold_spike
    triage_th = CONFIG.detect.neural_network_triage.threshold_collision
    detection_fname = CONFIG.detect.neural_network_detector.filename
    ae_fname = CONFIG.detect.neural_network_autoencoder.filename
    triage_fname = CONFIG.detect.neural_network_triage.filename

    # instantiate neural networks
    NND = NeuralNetDetector.load(detection_fname, detection_th, channel_index)
    NNT = NeuralNetTriage.load(triage_fname,
                               triage_th,
                               input_tensor=NND.waveform_tf)
    NNAE = AutoEncoder(ae_fname, input_tensor=NND.waveform_tf)

    output_tf = (NNAE.score_tf, NND.spike_index_tf, NNT.idx_clean)

    with tf.Session() as sess:
        NND.restore(sess)
        NNAE.restore(sess)
        NNT.restore(sess)

        rot = NNAE.load_rotation()
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        neuralnetwork.run_detect_triage_featurize(data, sess, NND.x_tf,
                                                  output_tf, neighbors, rot)
Exemple #3
0
def pc_feature_ind(n_spikes, n_templates, n_channels, geom, neigh_channels,
                   spike_train, templates):
    """
    pc_feature_ind.npy - [nTemplates, nPCFeatures] uint32 matrix specifying
    which pcFeatures are included in the pc_features matrix.
    """

    # get main channel for each template
    templates_mainc = np.argmax(np.max(templates, axis=1), axis=0)

    # main channel for each spike based on templates_mainc
    spikes_mainc = np.zeros(n_spikes, 'int32')

    for j in range(n_spikes):
        spikes_mainc[j] = templates_mainc[spike_train[j, 1]]

    # number of neighbors to consider
    neighbors = n_steps_neigh_channels(neigh_channels, 2)
    nneigh = np.max(np.sum(neighbors, 0))

    # ordered neighboring channels w.r.t. each channel
    c_idx = np.zeros((n_channels, nneigh), 'int32')

    for c in range(n_channels):
        c_idx[c] = (np.argsort(np.sum(np.square(geom - geom[c]), axis=1))
                    [:nneigh])

    pc_feature_ind = np.zeros((n_templates, nneigh), 'int32')

    for k in range(n_templates):
        pc_feature_ind[k] = c_idx[templates_mainc[k]]

    return pc_feature_ind
Exemple #4
0
def matrix(ts, neighbors, spike_size):
    """
    Compute spatial whitening matrix for time series, used only in threshold
    detection

    Parameters
    ----------
    ts: numpy.ndarray (n_observations, n_channels)
        Recordings
    neighbors: numpy.ndarray (n_channels, n_channels)
        Boolean numpy 2-D array where a i, j entry is True if i is considered
        neighbor of j
    spike_size: int
        Spike size

    Returns
    -------
    numpy.ndarray (n_channels, n_channels)
        whitening matrix
    """
    # get all necessary parameters from param
    [T, C] = ts.shape
    R = spike_size * 2 + 1
    th = 4
    neighChannels = n_steps_neigh_channels(neighbors, steps=2)

    chanRange = np.arange(0, C)
    spikes_rec = np.ones(ts.shape)

    for i in range(0, C):
        idxCrossing = np.where(ts[:, i] < -th)[0]
        idxCrossing = idxCrossing[np.logical_and(idxCrossing >= (R + 1),
                                                 idxCrossing <= (T - R - 1))]
        spike_time = idxCrossing[np.logical_and(
            ts[idxCrossing, i] <= ts[idxCrossing - 1, i],
            ts[idxCrossing, i] <= ts[idxCrossing + 1, i])]

        # the portion of recording where spikes present is set to nan
        for j in np.arange(-spike_size, spike_size + 1):
            spikes_rec[spike_time + j, i] = 0

    blanked_rec = ts * spikes_rec
    M = np.matmul(blanked_rec.transpose(), blanked_rec) / \
        np.matmul(spikes_rec.transpose(), spikes_rec)
    invhalf_var = np.diag(np.power(np.diag(M), -0.5))
    M = np.matmul(np.matmul(invhalf_var, M), invhalf_var)
    Q = np.zeros((C, C))

    for c in range(0, C):
        ch_idx = chanRange[neighChannels[c, :]]
        V, D, _ = np.linalg.svd(M[ch_idx, :][:, ch_idx])
        eps = 1e-6
        Epsilon = np.diag(1 / np.power((D + eps), 0.5))
        Q_small = np.matmul(np.matmul(V, Epsilon), V.transpose())
        Q[c, ch_idx] = Q_small[ch_idx == c, :]

    return Q.transpose()
def whitening(ts, neighbors, spike_size):
    """Spatial whitening filter for time series
    Parameters
    ----------
    ts: np.array
        T x C numpy array, where T is the number of time samples and
        C is the number of channels
    """
    # get all necessary parameters from param
    [T, C] = ts.shape
    R = spike_size*2 + 1
    th = 4
    neighChannels = n_steps_neigh_channels(neighbors, steps=2)

    chanRange = np.arange(0, C)
    # timeRange = np.arange(0, T)
    # masked recording
    spikes_rec = np.ones(ts.shape)

    for i in range(0, C):
        # idxCrossing = timeRange[ts[:, i] < -th[i]]
        idxCrossing = np.where(ts[:, i] < -th)[0]
        idxCrossing = idxCrossing[np.logical_and(
            idxCrossing >= (R+1), idxCrossing <= (T-R-1))]
        spike_time = idxCrossing[np.logical_and(ts[idxCrossing, i] <=
                                                ts[idxCrossing-1, i],
                                                ts[idxCrossing, i] <=
                                                ts[idxCrossing+1, i])]

        # the portion of recording where spikes present is set to nan
        for j in np.arange(-spike_size, spike_size+1):
            spikes_rec[spike_time + j, i] = 0

    blanked_rec = ts*spikes_rec
    M = np.matmul(blanked_rec.transpose(), blanked_rec) / \
        np.matmul(spikes_rec.transpose(), spikes_rec)
    invhalf_var = np.diag(np.power(np.diag(M), -0.5))
    M = np.matmul(np.matmul(invhalf_var, M), invhalf_var)
    Q = np.zeros((C, C))

    for c in range(0, C):
        ch_idx = chanRange[neighChannels[c, :]]
        V, D, _ = np.linalg.svd(M[ch_idx, :][:, ch_idx])
        Epsilon = np.diag(1/np.power((D), 0.5))
        Q_small = np.matmul(np.matmul(V, Epsilon), V.transpose())
        Q[c, ch_idx] = Q_small[ch_idx == c, :]

    return np.matmul(ts, Q.transpose())
Exemple #6
0
def covariance(recordings, temporal_size, neigbor_steps):
    """Compute noise spatial and temporal covariance

    Parameters
    ----------
    recordings: matrix
        Multi-cannel recordings (n observations x n channels)
    temporal_size:

    neigbor_steps: int
        Number of steps from the multi-channel geometry to consider two
        channels as neighors
    """
    CONFIG = read_config()

    # get the neighbor channels at a max "neigbor_steps" steps
    neigh_channels = n_steps_neigh_channels(CONFIG.neighChannels,
                                            neigbor_steps)

    # sum neighor flags for every channel, this gives the number of neighbors
    # per channel, then find the channel with the most neighbors
    # TODO: why are we selecting this one?
    channel = np.argmax(np.sum(neigh_channels, 0))

    # get the neighbor channels for "channel"
    (neighbords_idx, ) = np.where(neigh_channels[channel])

    # order neighbors by distance
    neighbords_idx, temp = order_channels_by_distance(channel, neighbords_idx,
                                                      CONFIG.geom)

    # from the multi-channel recordings, get the neighbor channels
    # (this includes the channel with the most neighbors itself)
    rec = recordings[:, neighbords_idx]

    # filter recording
    if CONFIG.preprocess.filter == 1:
        rec = butterworth(rec, CONFIG.filter.low_pass_freq,
                          CONFIG.filter.high_factor, CONFIG.filter.order,
                          CONFIG.recordings.sampling_rate)

    # standardize recording
    sd_ = standarize.sd(rec, CONFIG.recordings.sampling_rate)
    rec = standarize.standarize(rec, sd_)

    # compute and return spatial and temporal covariance
    return util.covariance(rec, temporal_size, neigbor_steps, CONFIG.spikeSize)
Exemple #7
0
def template_features(n_spikes, n_channels, n_templates, spike_train,
                      templates_main_channel, neigh_channels,
                      geom, templates_low_dim, template_feature_ind_,
                      waveforms_score):
    """
    template_features.npy - [nSpikes, nTempFeatures] single matrix giving the
    magnitude of the projection of each spike onto nTempFeatures other
    features. Which other features is specified in template_feature_ind.npy
    """
    # TODO: fix this, assume you can receive the waveforms from the spike
    # train as a parameter and templates, main channel for templates templates
    # score and waveforms score for every waveform in the spike train, for
    # scoring other waveforms, use function in dimensionality_reduction.score

    k_neigh = np.min((5, n_templates))

    template_features_ = np.zeros((n_spikes, k_neigh))

    spikes_mainc = np.zeros(n_spikes, 'int32')

    for j in range(n_spikes):
        spikes_mainc[j] = templates_main_channel[spike_train[j, 1]]

    # number of neighbors to consider
    neighbors = n_steps_neigh_channels(neigh_channels, 2)
    nneigh = np.max(np.sum(neighbors, 0))

    # ordered neighboring channels w.r.t. each channel
    c_idx = np.zeros((n_channels, nneigh), 'int32')

    for c in range(n_channels):
        c_idx[c] = (np.argsort(np.sum(np.square(geom - geom[c]), axis=1))
                    [:nneigh])

    for j in range(n_spikes):

        ch_idx = c_idx[spikes_mainc[j]]
        kk = spike_train[j, 1]

        for k in range(k_neigh):
            template_features_[j] = np.sum(
                np.multiply(waveforms_score[j].T,
                            templates_low_dim[ch_idx]
                            [:, :, template_feature_ind_[kk, k]]))

    return template_features_
Exemple #8
0
def run_deduplication(batch_files_dir, output_directory):

    CONFIG = read_config()

    neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)
    w = 5

    batch_ids = list(np.arange(len(os.listdir(batch_files_dir))))
    if CONFIG.resources.multi_processing:
        #if False:
        parmap.map(run_deduplication_batch_simple,
                   batch_ids,
                   batch_files_dir,
                   output_directory,
                   neighbors,
                   w,
                   processes=CONFIG.resources.n_processors,
                   pm_pbar=True)
    else:
        for batch_id in batch_ids:
            run_deduplication_batch_simple(batch_id, batch_files_dir,
                                           output_directory, neighbors, w)
Exemple #9
0
def test_can_compute_n_steps_neighbors(data_info, path_to_geometry):
    geometry = parse(path_to_geometry, data_info['n_channels'])
    neighbors = find_channel_neighbors(geometry, radius=70)
    n_steps_neigh_channels(neighbors, steps=2)
Exemple #10
0
def run_neural_network(standardized_path, standardized_params, whiten_filter,
                       output_directory, if_file_exists, save_results):
    """Run neural network detection and autoencoder dimensionality reduction

    Returns
    -------
    scores
      Scores for all spikes

    spike_index_clear
      Spike indexes for clear spikes

    spike_index_all
      Spike indexes for all spikes
    """
    logger = logging.getLogger(__name__)

    CONFIG = read_config()
    TMP_FOLDER = CONFIG.path_to_output_directory

    # check if all scores, clear and collision spikes exist..
    path_to_score = os.path.join(TMP_FOLDER, 'scores_clear.npy')
    path_to_spike_index_clear = os.path.join(TMP_FOLDER,
                                             'spike_index_clear.npy')
    path_to_spike_index_all = os.path.join(TMP_FOLDER, 'spike_index_all.npy')
    path_to_rotation = os.path.join(TMP_FOLDER, 'rotation.npy')

    path_to_standardized = os.path.join(TMP_FOLDER, 'preprocess',
                                        'standarized.bin')

    paths = [path_to_score, path_to_spike_index_clear, path_to_spike_index_all]
    exists = [os.path.exists(p) for p in paths]

    if (if_file_exists == 'overwrite' or not any(exists)):

        max_memory = (CONFIG.resources.max_memory_gpu
                      if GPU_ENABLED else CONFIG.resources.max_memory)

        # make tensorflow tensors and neural net classes
        detection_th = CONFIG.detect.neural_network_detector.threshold_spike
        triage_th = CONFIG.detect.neural_network_triage.threshold_collision
        detection_fname = CONFIG.detect.neural_network_detector.filename
        ae_fname = CONFIG.detect.neural_network_autoencoder.filename
        triage_fname = CONFIG.detect.neural_network_triage.filename
        n_channels = CONFIG.recordings.n_channels

        # open tensorflow for every chunk
        NND = NeuralNetDetector.load(detection_fname, detection_th,
                                     CONFIG.channel_index)
        NNAE = AutoEncoder.load(ae_fname, input_tensor=NND.waveform_tf)

        # run nn preprocess batch-wsie
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        # compute len of recording
        filename_dat = os.path.join(CONFIG.data.root_folder,
                                    CONFIG.data.recordings)
        fp = np.memmap(filename_dat, dtype='int16', mode='r')
        fp_len = fp.shape[0] / n_channels

        # compute batch indexes
        buffer_size = 200  # Cat: to set this in CONFIG file
        sampling_rate = CONFIG.recordings.sampling_rate
        #n_sec_chunk = CONFIG.resources.n_sec_chunk

        # Cat: TODO: Set a different size chunk for clustering vs. detection
        n_sec_chunk = 60

        # take chunks
        indexes = np.arange(0, fp_len, sampling_rate * n_sec_chunk)

        # add last bit of recording if it's shorter
        if indexes[-1] != fp_len:
            indexes = np.hstack((indexes, fp_len))

        idx_list = []
        for k in range(len(indexes) - 1):
            idx_list.append([
                indexes[k], indexes[k + 1], buffer_size,
                indexes[k + 1] - indexes[k] + buffer_size
            ])

        idx_list = np.int64(np.vstack(idx_list))[:20]

        #idx_list = idx_list

        logger.info("# of chunks: %i", len(idx_list))

        logger.info(idx_list)

        # run tensorflow
        processing_ctr = 0
        #chunk_ctr = 0

        # chunks to cycle over are 10 x as much as initial chosen data
        total_processing = len(idx_list) * n_sec_chunk

        # keep tensorflow open
        # save iteratively
        fname_detection = os.path.join(CONFIG.path_to_output_directory,
                                       'detect')
        if not os.path.exists(fname_detection):
            os.mkdir(fname_detection)

        # set tensorflow verbosity level
        tf.logging.set_verbosity(tf.logging.ERROR)

        # open etsnrflow session
        with tf.Session() as sess:
            #K.set_session(sess)
            NND.restore(sess)

            #triage = KerasModel(triage_fname,
            #                    allow_longer_waveform_length=True,
            #                    allow_more_channels=True)

            # read chunks of data first:
            # read chunk of raw standardized data
            # Cat: TODO: don't save to lists, might want to use numpy arrays directl
            #print (os.path.join(fname_detection,"detect_"+
            #                      str(chunk_ctr).zfill(5)+'.npz'))

            # loop over 10sec or 60 sec chunks
            for chunk_ctr, idx in enumerate(idx_list):
                if os.path.exists(
                        os.path.join(
                            fname_detection,
                            "detect_" + str(chunk_ctr).zfill(5) + '.npz')):
                    continue

                # reset lists
                spike_index_list = []
                #idx_clean_list = []
                energy_list = []
                TC_list = []
                offset_list = []

                # load chunk of data
                standardized_recording = binary_reader(idx, buffer_size,
                                                       path_to_standardized,
                                                       n_channels,
                                                       CONFIG.data.root_folder)

                # run detection on smaller chunks of data, e.g. 1 sec
                # Cat: TODO: add last bit at end in case short
                indexes = np.arange(0, standardized_recording.shape[0],
                                    sampling_rate)

                # run tensorflow over 1sec chunks in general
                for ctr, index in enumerate(indexes[:-1]):

                    # save absolute index of each subchunk
                    offset_list.append(idx[0] + indexes[ctr])

                    data_temp = standardized_recording[indexes[ctr]:indexes[ctr
                                                                            +
                                                                            1]]

                    # store size of recordings in case at end of dataset.
                    TC_list.append(data_temp.shape)

                    # run detect nn
                    res = NND.predict_recording(data_temp,
                                                sess=sess,
                                                output_names=('spike_index',
                                                              'waveform'))
                    spike_index, wfs = res

                    #idx_clean = (triage
                    #             .predict_with_threshold(x=wfs,
                    #                                     threshold=triage_th))

                    score = NNAE.predict(wfs, sess)
                    rot = NNAE.load_rotation(sess)
                    neighbors = n_steps_neigh_channels(CONFIG.neigh_channels,
                                                       2)

                    # idx_clean is the indexes of clear spikes in all_spikes
                    spike_index_list.append(spike_index)
                    #idx_clean_list.append(idx_clean)

                    # run AE nn; required for remove_axon function
                    # Cat: TODO: Do we really need this: can we get energy list faster?
                    #rot = NNAE.load_rotation()
                    energy_ = np.ptp(np.matmul(score[:, :, 0], rot.T), axis=1)
                    energy_list.append(energy_)

                    logger.info('processed chunk: %s/%s,  # spikes: %s',
                                str(processing_ctr), str(total_processing),
                                spike_index.shape)

                    processing_ctr += 1

                # save chunk of data in case crashes occur
                np.savez(os.path.join(fname_detection,
                                      "detect_" + str(chunk_ctr).zfill(5)),
                         spike_index_list=spike_index_list,
                         energy_list=energy_list,
                         TC_list=TC_list,
                         offset_list=offset_list)

        # load all saved data;
        spike_index_list = []
        energy_list = []
        TC_list = []
        offset_list = []
        for ctr, idx in enumerate(idx_list):
            data = np.load(fname_detection + '/detect_' + str(ctr).zfill(5) +
                           '.npz')
            spike_index_list.extend(data['spike_index_list'])
            energy_list.extend(data['energy_list'])
            TC_list.extend(data['TC_list'])
            offset_list.extend(data['offset_list'])

        # save all detected spikes pre axon_kill
        spike_index_all_pre_deduplication = fix_indexes_firstbatch_3(
            spike_index_list, offset_list, buffer_size, sampling_rate)
        np.save(
            os.path.join(TMP_FOLDER, 'spike_index_all_pre_deduplication.npy'),
            spike_index_all_pre_deduplication)

        # remove axons - compute axons to be killed
        logger.info(' removing axons in parallel')
        multi_procesing = 1
        if CONFIG.resources.multi_processing:
            keep = parmap.map(deduplicate,
                              list(
                                  zip(spike_index_list, energy_list, TC_list,
                                      np.arange(len(energy_list)))),
                              neighbors,
                              processes=CONFIG.resources.n_processors,
                              pm_pbar=True)
        else:
            keep = []
            for k in range(len(energy_list)):
                keep.append(
                    deduplicate(
                        (spike_index_list[k], energy_list[k], TC_list[k], k),
                        neighbors))

        # Cat: TODO Note that we're killing spike_index_all as well.
        # remove axons from clear spikes - keep only non-killed+clean events
        spike_index_all_postkill = []
        for k in range(len(spike_index_list)):
            spike_index_all_postkill.append(spike_index_list[k][keep[k][0]])

        logger.info(' fixing indexes from batching')
        spike_index_all = fix_indexes_firstbatch_3(spike_index_all_postkill,
                                                   offset_list, buffer_size,
                                                   sampling_rate)

        # get and clean all spikes
        spikes_all = spike_index_all

        #logger.info('Removing all indexes outside the allowed range to '
        #            'draw a complete waveform...')
        _n_observations = fp_len
        spikes_all, _ = detect.remove_incomplete_waveforms(
            spikes_all, CONFIG.spike_size + CONFIG.templates.max_shift,
            _n_observations)

        np.save(os.path.join(TMP_FOLDER, 'spike_index_all.npy'), spikes_all)

    else:
        spikes_all = np.load(os.path.join(TMP_FOLDER, 'spike_index_all.npy'))

    return spikes_all
Exemple #11
0
def test_splitting_in_batches_does_not_affect(path_to_tests,
                                              path_to_standarized_data,
                                              path_to_sample_pipeline_folder):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    PATH_TO_DATA = path_to_standarized_data

    data = RecordingsReader(PATH_TO_DATA, loader='array').data

    with open(
            path.join(path_to_sample_pipeline_folder, 'preprocess',
                      'standarized.yaml')) as f:
        PARAMS = yaml.load(f)

    channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom)

    detection_th = CONFIG.detect.neural_network_detector.threshold_spike
    triage_th = CONFIG.detect.neural_network_triage.threshold_collision
    detection_fname = CONFIG.detect.neural_network_detector.filename
    ae_fname = CONFIG.detect.neural_network_autoencoder.filename
    triage_fname = CONFIG.detect.neural_network_triage.filename

    # instantiate neural networks
    NND = NeuralNetDetector.load(detection_fname, detection_th, channel_index)
    NNT = NeuralNetTriage.load(triage_fname,
                               triage_th,
                               input_tensor=NND.waveform_tf)
    NNAE = AutoEncoder(ae_fname, input_tensor=NND.waveform_tf)

    output_tf = (NNAE.score_tf, NND.spike_index_tf, NNT.idx_clean)

    # run all at once
    with tf.Session() as sess:
        # get values of above tensors
        NND.restore(sess)
        NNAE.restore(sess)
        NNT.restore(sess)

        rot = NNAE.load_rotation()
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        (scores, clear, collision) = neuralnetwork.run_detect_triage_featurize(
            data, sess, NND.x_tf, output_tf, neighbors, rot)

    # run in batches - buffer size makes sure we can detect spikes if they
    # appear at the end of any batch
    bp = BatchProcessor(PATH_TO_DATA,
                        PARAMS['dtype'],
                        PARAMS['n_channels'],
                        PARAMS['data_order'],
                        '100KB',
                        buffer_size=CONFIG.spike_size)

    with tf.Session() as sess:
        # get values of above tensors
        NND.restore(sess)
        NNAE.restore(sess)
        NNT.restore(sess)

        rot = NNAE.load_rotation()
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        res = bp.multi_channel_apply(
            neuralnetwork.run_detect_triage_featurize,
            mode='memory',
            cleanup_function=neuralnetwork.fix_indexes,
            sess=sess,
            x_tf=NND.x_tf,
            output_tf=output_tf,
            rot=rot,
            neighbors=neighbors)

    scores_batch = np.concatenate([element[0] for element in res], axis=0)
    clear_batch = np.concatenate([element[1] for element in res], axis=0)
    collision_batch = np.concatenate([element[2] for element in res], axis=0)

    np.testing.assert_array_equal(clear_batch, clear)
    np.testing.assert_array_equal(collision_batch, collision)
    np.testing.assert_array_equal(scores_batch, scores)
Exemple #12
0
def _threshold(rec, neighbors, spike_size, threshold):
    """Run Threshold-based spike detection

    Parameters
    ----------
    rec: np.ndarray (n_observations, n_channels)
        numpy 2-D array with the recordings, first dimension must be
        n_observations and second n_channels

    neighbors: np.ndarray (n_channels, n_channels)
        Boolean numpy 2-D array where a i, j entry is True if i is considered
        neighbor of j

    spike_size: int
        Spike size

    threshold: float
        Threshold used on amplitude for detection

    Notes
    -----
    any values below -std_factor is considered as a spike
    and its location is saved and returned

    Returns
    -------
    index: np.ndarray (number of spikes, 2)
        First column is spike time, second column is main channel (the channel
        where spike has the biggest amplitude)
    """
    T, C = rec.shape
    R = spike_size
    th = threshold
    neigh_channels_big = n_steps_neigh_channels(neighbors, steps=2)

    index = np.zeros((0, 2), 'int32')

    for c in range(C):
        # For each channel, mark down location where it crosses the threshold
        idx = np.logical_and(
            rec[:, c] < -th, np.r_[True, rec[1:, c] < rec[:-1, c]]
            & np.r_[rec[:-1, c] < rec[1:, c], True])
        nc = np.sum(idx)

        if nc > 0:
            # location where it crosses the threshold
            spt_c = np.where(idx)[0]

            # remove an index if it is too close to the edge of the recording
            spt_c = spt_c[np.logical_and(spt_c > 2 * R, spt_c < T - 2 * R)]
            nc = spt_c.shape[0]

            # get neighboring channels
            ch_idx = np.where(neigh_channels_big[c])[0]
            c_main = np.where(ch_idx == c)[0]

            # look at temporal spatial window around the spike location
            # if the spike being looked at has the biggest amplitude than
            # it spatial and temporal window, keep it.
            # Otherwise, disregard it
            idx_keep = np.zeros(nc, 'bool')

            for j in range(nc):
                # get waveforms
                wf_temp = rec[spt_c[j] + np.arange(-2 * R, 2 * R + 1)][:,
                                                                       ch_idx]

                # location with the biggest amplitude: (t_min, c_min)
                c_min = np.argmin(np.amin(wf_temp, axis=0))
                t_min = np.argmin(wf_temp[:, c_min])

                if t_min == 2 * R and c_min == c_main:
                    idx_keep[j] = 1

            to_keep = spt_c[idx_keep]
            new_spikes = np.zeros((len(to_keep), 2), 'int32')
            new_spikes[:, 0] = to_keep
            new_spikes[:, 1] = c
            index = np.append(index, new_spikes, axis=0)

    return index
Exemple #13
0
def post_process(output_directory, fname_templates, fname_spike_train,
                 fname_weights, fname_recording, recording_dtype, units_in,
                 method, ctr):
    ''' 
    Run a single post process
    method: strings.
        Options are 'low_ptp', 'duplicate', 'collision',
        'high_mad', 'low_fr', 'high_fr', 'off_center',
        'duplicate_l2'
    '''

    logger = logging.getLogger(__name__)

    CONFIG = read_config()

    if method == 'low_ptp':

        # Cat: TODO: move parameter to CONFIG
        threshold = CONFIG.clean_up.min_ptp

        # load templates
        templates = np.load(fname_templates)

        # remove low ptp
        units_out = remove_small_units(templates, threshold, units_in)

        logger.info("{} units after removing low ptp units".format(
            len(units_out)))

    elif method == 'off_center':

        threshold = CONFIG.clean_up.off_center

        # load templates
        templates = np.load(fname_templates)

        # remove off centered units
        units_out = remove_off_centered_units(templates, threshold, units_in)

        logger.info("{} units after removing off centered units".format(
            len(units_out)))

    elif method == 'duplicate':

        # tmp saving dir
        save_dir = os.path.join(output_directory, 'duplicates_{}'.format(ctr))

        # remove duplicates
        units_out = remove_duplicates(fname_templates, fname_weights, save_dir,
                                      CONFIG, units_in,
                                      CONFIG.resources.multi_processing,
                                      CONFIG.resources.n_processors)

        logger.info("{} units after removing duplicate units".format(
            len(units_out)))

    elif method == 'duplicate_l2':

        # tmp saving dir
        save_dir = os.path.join(output_directory,
                                'duplicates_l2_{}'.format(ctr))

        # remove duplicates
        n_spikes_big = 100
        min_ptp = 2
        units_out = duplicate_l2(fname_templates, fname_spike_train,
                                 CONFIG.neigh_channels, save_dir, n_spikes_big,
                                 min_ptp, units_in)

        logger.info("{} units after removing L2 duplicate units".format(
            len(units_out)))

    elif method == 'collision':
        # save folder
        save_dir = os.path.join(output_directory, 'collision_{}'.format(ctr))

        # find collision units and remove
        units_out = remove_collision(fname_templates, save_dir, CONFIG,
                                     units_in,
                                     CONFIG.resources.multi_processing,
                                     CONFIG.resources.n_processors)

        logger.info("{} units after removing collision units".format(
            len(units_out)))

    elif method == 'high_mad':

        # get data reader
        reader = READER(fname_recording, recording_dtype, CONFIG)

        # save folder
        save_dir = os.path.join(output_directory, 'mad_{}'.format(ctr))

        # neighboring channels
        neigh_channels = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        max_violations = CONFIG.clean_up.mad.max_violations
        min_var_gap = CONFIG.clean_up.mad.min_var_gap

        # find high mad units and remove
        units_out = remove_high_mad(fname_templates, fname_spike_train,
                                    fname_weights, reader, neigh_channels,
                                    save_dir, min_var_gap, max_violations,
                                    units_in,
                                    CONFIG.resources.multi_processing,
                                    CONFIG.resources.n_processors)

        logger.info("{} units after removing high mad units".format(
            len(units_out)))

    elif method == 'low_fr':

        threshold = CONFIG.clean_up.min_fr

        # length of recording in seconds
        rec_len = np.load(fname_spike_train)[:, 0].ptp()
        rec_len_sec = float(rec_len) / CONFIG.recordings.sampling_rate

        # load templates
        weights = np.load(fname_weights)

        # remove low ptp
        units_out = remove_low_fr_units(weights, rec_len_sec, threshold,
                                        units_in)

        logger.info("{} units after removing low fr units".format(
            len(units_out)))

    elif method == 'high_fr':

        # TODO: move parameter to config?
        threshold = 70

        # length of recording in seconds
        rec_len = np.load(fname_spike_train)[:, 0].ptp()
        rec_len_sec = float(rec_len) / CONFIG.recordings.sampling_rate

        # load templates
        weights = np.load(fname_weights)

        # remove low ptp
        units_out = remove_high_fr_units(weights, rec_len_sec, threshold,
                                         units_in)

        logger.info("{} units after removing high fr units".format(
            len(units_out)))

    else:
        units_out = np.copy(units_in)
        logger.info("Method not recognized. Nothing removed")

    return units_out
def test_can_compute_n_steps_neighbors(path_to_geometry):
    geometry = parse(path_to_geometry, n_channels)
    neighbors = find_channel_neighbors(geometry, radius=70)
    n_steps_neigh_channels(neighbors, steps=2)
Exemple #15
0
def test_splitting_in_batches_does_not_affect_result(path_to_tests):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    PATH_TO_DATA = path.join(path_to_tests, 'data/standarized.bin')

    data = RecordingsReader(PATH_TO_DATA, loader='array').data

    with open(path.join(path_to_tests, 'data/standarized.yaml')) as f:
        PARAMS = yaml.load(f)

    channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom)

    whiten_filter = np.tile(
        np.eye(channel_index.shape[1], dtype='float32')[np.newaxis, :, :],
        [channel_index.shape[0], 1, 1])

    detection_th = CONFIG.detect.neural_network_detector.threshold_spike
    triage_th = CONFIG.detect.neural_network_triage.threshold_collision
    detection_fname = CONFIG.detect.neural_network_detector.filename
    ae_fname = CONFIG.detect.neural_network_autoencoder.filename
    triage_fname = CONFIG.detect.neural_network_triage.filename
    (x_tf, output_tf, NND, NNAE, NNT) = neuralnetwork.prepare_nn(
        channel_index,
        whiten_filter,
        detection_th,
        triage_th,
        detection_fname,
        ae_fname,
        triage_fname,
    )

    # run all at once
    with tf.Session() as sess:
        # get values of above tensors
        NND.saver.restore(sess, NND.path_to_detector_model)
        NNAE.saver_ae.restore(sess, NNAE.path_to_ae_model)
        NNT.saver.restore(sess, NNT.path_to_triage_model)
        rot = NNAE.load_rotation()
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        (scores, clear, collision) = neuralnetwork.run_detect_triage_featurize(
            data, sess, x_tf, output_tf, neighbors, rot)

    # run in batches - buffer size makes sure we can detect spikes if they
    # appear at the end of any batch
    bp = BatchProcessor(PATH_TO_DATA,
                        PARAMS['dtype'],
                        PARAMS['n_channels'],
                        PARAMS['data_order'],
                        '100KB',
                        buffer_size=CONFIG.spike_size)

    with tf.Session() as sess:
        # get values of above tensors
        NND.saver.restore(sess, NND.path_to_detector_model)
        NNAE.saver_ae.restore(sess, NNAE.path_to_ae_model)
        NNT.saver.restore(sess, NNT.path_to_triage_model)

        rot = NNAE.load_rotation()
        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

        res = bp.multi_channel_apply(
            neuralnetwork.run_detect_triage_featurize,
            mode='memory',
            cleanup_function=neuralnetwork.fix_indexes,
            sess=sess,
            x_tf=x_tf,
            output_tf=output_tf,
            rot=rot,
            neighbors=neighbors)

    scores_batch = np.concatenate([element[0] for element in res], axis=0)
    clear_batch = np.concatenate([element[1] for element in res], axis=0)
    collision_batch = np.concatenate([element[2] for element in res], axis=0)

    np.testing.assert_array_equal(clear_batch, clear)
    np.testing.assert_array_equal(collision_batch, collision)
    np.testing.assert_array_equal(scores_batch, scores)
Exemple #16
0
    def fullMPMU(self):

        start_time = dt.datetime.now()

        neighchan = n_steps_neigh_channels(self.config.neighChannels, steps=3)
        C = self.config.recordings.n_channels
        R = self.config.spikeSize
        shift = 3  # int(R/2)
        K = self.templates.shape[2]
        nrank = self.config.deconvolution.rank
        lam = self.config.deconvolution.lam
        Th = self.config.deconvolution.threshold

        # used to have 3
        iter_max = 1

        amps = np.max(np.abs(self.templates), axis=0)
        amps_max = np.max(amps, axis=0)

        templatesMask = np.zeros((K, C), 'bool')

        for k in range(K):
            templatesMask[k] = amps[:, k] > amps_max[k] * 0.5

        W_all, U_all, mu_all = decompose_dWU(self.templates, nrank)

        spiketime_all = np.zeros(0, 'int32')
        assignment_all = np.zeros(0, 'int32')

        for c in range(C):
            nmax = 1000000
            ids = np.zeros(nmax, 'int32')
            sts = np.zeros(nmax, 'int32')
            ns = np.zeros(nmax, 'int32')

            idx_c = self.spike_index[:, 1] == c
            nc = np.sum(idx_c)

            if nc > 0:
                spt_c = self.spike_index[idx_c, 0]
                ch_idx = np.where(neighchan[c])[0]

                # this line is in the old pipeline
                ch_idx = np.arange(C)

                k_idx = np.where(templatesMask[:, c])[0]
                tt = self.templates[:, ch_idx][:, :, k_idx]
                Kc = k_idx.shape[0]

                if nc > 0 and Kc > 0:
                    mu = np.reshape(mu_all[k_idx], [1, 1, Kc])

                    # commented out since it is unused
                    # lam1 = lam/np.square(mu)

                    wf = np.zeros((nc, 2 * (R + shift) + 1, ch_idx.shape[0]))

                    for j in range(nc):
                        wf[j] = self.wrec[spt_c[j] +
                                          np.arange(-(R + shift), R + shift +
                                                    1)][:, ch_idx]

                    n = np.arange(nc)
                    i0 = 0
                    it = 0

                    while it < iter_max:
                        nc = n.shape[0]
                        wf_projs = np.zeros(
                            (nc, 2 * (R + shift) + 1, nrank, Kc))

                        for k in range(Kc):
                            wf_projs[:, :, :, k] = np.reshape(
                                np.matmul(
                                    np.reshape(wf[n], [-1, ch_idx.shape[0]]),
                                    U_all[ch_idx, k_idx[k]]), [nc, -1, nrank])

                        obj = np.zeros((nc, 2 * shift + 1, Kc))

                        for j in range(2 * shift + 1):
                            obj[:, j, :] = np.sum(
                                (wf_projs[:, j:(j + 2 * R + 1)] *
                                 np.transpose(W_all[:, k_idx],
                                              [0, 2, 1])[np.newaxis, :]),
                                axis=(1, 2))

                        # this block is commented out in the old pipeline
                        # Ci = obj+(mu*lam1)
                        # Ci = np.square(Ci)/(1+lam1)
                        # Ci = Ci - lam1*np.square(mu)

                        # this block is in the old pipeline
                        scale = np.abs((obj - mu) / np.sqrt(mu / lam)) - 3
                        scale = np.minimum(np.maximum(scale, 0), 1)
                        scale[scale < 0] = 0
                        scale[scale > 1] = 1
                        Ci = np.multiply(np.square(obj), (1 - scale))

                        mX = np.max(Ci, axis=1)
                        st = np.argmax(Ci, axis=1)
                        idd = np.argmax(mX, axis=1)
                        st = st[np.arange(st.shape[0]), idd]

                        idx_keep = np.max(mX, axis=1) > Th * Th
                        st = st[idx_keep]
                        idd = idd[idx_keep]
                        n = n[idx_keep]

                        n_detected = np.sum(idx_keep)

                        if it > 0:
                            idx_keep2 = np.zeros(n_detected, 'bool')

                            for j in range(n_detected):

                                if (np.sum(ids[:i0][ns[:i0] == n[j]] == idd[j])
                                        == 0):
                                    idx_keep2[j] = 1

                            st = st[idx_keep2]
                            idd = idd[idx_keep2]
                            n = n[idx_keep2]
                            n_detected = np.sum(idx_keep2)

                        if not st.any():
                            it = iter_max
                        else:
                            sts[i0:(i0 + n_detected)] = st
                            ids[i0:(i0 + n_detected)] = idd
                            ns[i0:(i0 + n_detected)] = n
                            i0 = i0 + n_detected
                            it += 1

                            if it < iter_max:
                                for j in range(st.shape[0]):
                                    wf[n[j],
                                       st[j]:(st[j] + 2 * R + 1)] -= tt[:, :,
                                                                        idd[j]]

                    ids = k_idx[ids[:i0]]
                    sts = sts[:i0] - shift - 1 + spt_c[ns[:i0]]

                    spiketime_all = np.concatenate((spiketime_all, sts))
                    assignment_all = np.concatenate((assignment_all, ids))

        current_time = dt.datetime.now()
        self.logger.info("Deconvolution done in {0} seconds.".format(
            (current_time - start_time).seconds))

        return np.concatenate(
            (spiketime_all[:, np.newaxis], assignment_all[:, np.newaxis]),
            axis=1)
Exemple #17
0
def test_splitting_in_batches_does_not_affect(path_to_config,
                                              path_to_sample_pipeline_folder,
                                              make_tmp_folder,
                                              path_to_standardized_data):
    yass.set_config(path_to_config, make_tmp_folder)
    CONFIG = yass.read_config()

    PATH_TO_DATA = path_to_standardized_data

    with open(path.join(path_to_sample_pipeline_folder, 'preprocess',
                        'standardized.yaml')) as f:
        PARAMS = yaml.load(f)

    channel_index = make_channel_index(CONFIG.neigh_channels,
                                       CONFIG.geom)

    detection_th = CONFIG.detect.neural_network_detector.threshold_spike
    triage_th = CONFIG.detect.neural_network_triage.threshold_collision
    detection_fname = CONFIG.detect.neural_network_detector.filename
    ae_fname = CONFIG.detect.neural_network_autoencoder.filename
    triage_fname = CONFIG.detect.neural_network_triage.filename

    # instantiate neural networks
    NND = NeuralNetDetector.load(detection_fname, detection_th,
                                 channel_index)
    triage = KerasModel(triage_fname,
                        allow_longer_waveform_length=True,
                        allow_more_channels=True)
    NNAE = AutoEncoder.load(ae_fname, input_tensor=NND.waveform_tf)

    bp = BatchProcessor(PATH_TO_DATA, PARAMS['dtype'], PARAMS['n_channels'],
                        PARAMS['data_order'], '100KB',
                        buffer_size=CONFIG.spike_size)

    out = ('spike_index', 'waveform')
    fn = neuralnetwork.apply.fix_indexes_spike_index

    # detector
    with tf.Session() as sess:
        # get values of above tensors
        NND.restore(sess)

        res = bp.multi_channel_apply(NND.predict_recording,
                                     mode='memory',
                                     sess=sess,
                                     output_names=out,
                                     cleanup_function=fn)

    spike_index_new = np.concatenate([element[0] for element in res], axis=0)
    wfs = np.concatenate([element[1] for element in res], axis=0)

    idx_clean = triage.predict_with_threshold(wfs, triage_th)
    score = NNAE.predict(wfs)
    rot = NNAE.load_rotation()
    neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

    (score_clear_new,
        spike_index_clear_new) = post_processing(score,
                                                 spike_index_new,
                                                 idx_clean,
                                                 rot,
                                                 neighbors)

    with tf.Session() as sess:
        # get values of above tensors
        NND.restore(sess)

        res = bp.multi_channel_apply(NND.predict_recording,
                                     mode='memory',
                                     sess=sess,
                                     output_names=('spike_index',
                                                   'waveform'),
                                     cleanup_function=fn)

    spike_index_batch, wfs = zip(*res)

    spike_index_batch = np.concatenate(spike_index_batch, axis=0)
    wfs = np.concatenate(wfs, axis=0)

    idx_clean = triage.predict_with_threshold(x=wfs,
                                              threshold=triage_th)

    score = NNAE.predict(wfs)
    rot = NNAE.load_rotation()
    neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)

    (score_clear_batch,
        spike_index_clear_batch) = post_processing(score,
                                                   spike_index_batch,
                                                   idx_clean,
                                                   rot,
                                                   neighbors)
Exemple #18
0
def denoise_then_estimate_template(fname_template,
                                   fname_spike_train,
                                   reader,
                                   denoiser,
                                   CONFIG,
                                   n_max_spikes=100):

    templates = np.load(fname_template)
    spike_train = np.load(fname_spike_train)

    n_units, n_times, n_chans = templates.shape

    ptps = templates.ptp(1)
    mcs = ptps.argmax(1)

    templates_mc = np.zeros((n_units, n_times))
    for j in range(n_units):
        templates_mc[j] = templates[j, :, mcs[j]]
    min_time_all = int(np.median(np.argmin(templates_mc, 1)))

    n_spikes = np.zeros(n_units)
    a, b = np.unique(spike_train[:, 1], return_counts=True)
    n_spikes[a] = b

    neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)
    r2 = CONFIG.spike_size_nn // 2

    k_idx = np.where(n_spikes < n_max_spikes)[0]
    #k_idx = np.arange(n_units)
    for k in k_idx:
        # step 1: get min points that are valid (can be connected from the max channel)
        min_times, min_channels = argrelmin(templates[k], axis=0, order=5)
        min_val = templates[k][min_times, min_channels]

        th = np.max((-0.5, np.min(templates[k])))
        min_times = min_times[min_val <= th]
        min_channels = min_channels[min_val <= th]

        if np.sum(min_channels == mcs[k]) > 1:
            idx = np.where(min_channels == mcs[k])[0]
            idx = idx[np.argmin(np.abs(min_times[idx] - min_time_all))]
        else:
            idx = np.where(min_channels == mcs[k])[0][0]

        keep = connecting_points(min_times, min_channels, idx, neighbors)

        # step 2: get longer waveforms
        chans_in = min_channels[keep]
        times_in = min_times[keep] - min_time_all
        R2 = np.max(np.abs(times_in)) + r2
        times_in = R2 + times_in

        spt_ = spike_train[spike_train[:, 1] == k, 0]
        wfs_long = reader.read_waveforms(spt_,
                                         n_times=R2 * 2 + 1,
                                         channels=chans_in)[0]

        # step 3: cut it such that the size works for nn denoiser and then denoise
        wfs = np.zeros((len(wfs_long), 2 * r2 + 1, len(chans_in)))
        for j in range(len(chans_in)):
            wfs[:, :, j] = wfs_long[:, :,
                                    j][:,
                                       times_in[j] - r2:times_in[j] + r2 + 1]

        wfs_reshaped = wfs.transpose(0, 2, 1).reshape(-1, 2 * r2 + 1)
        wfs_denoised = denoiser(torch.from_numpy(
            wfs_reshaped).float().cuda())[0].data.cpu().numpy().reshape(
                wfs.shape[0], len(chans_in), -1).transpose(0, 2, 1)

        # step 4: put back
        temp_cut = wfs_denoised.mean(0)
        temp = np.zeros((R2 * 2 + 1, n_chans))
        for j in range(len(chans_in)):
            temp[times_in[j] - r2:times_in[j] + r2 + 1,
                 chans_in[j]] = temp_cut[:, j]

        if R2 * 2 + 1 > n_times:
            templates[k] = temp[R2 - n_times // 2:R2 + n_times // 2 + 1]
        else:
            templates[k, (n_times // 2) - R2:(n_times // 2) + R2 + 1] = temp
            templates[k, :(n_times // 2) - R2] = 0
            templates[k, (n_times // 2) + R2 + 1:] = 0

    #fname_templates_denoised = os.path.join(save_dir, 'templates_denoised.npy')
    #np.save(fname_templates_denoised, templates)

    np.save(fname_template, templates)

    return fname_template
Exemple #19
0
def run_neural_network(standarized_path, standarized_params, whiten_filter,
                       output_directory, if_file_exists, save_results):
    """Run neural network detection and autoencoder dimensionality reduction

    Returns
    -------
    scores
      Scores for all spikes

    spike_index_clear
      Spike indexes for clear spikes

    spike_index_all
      Spike indexes for all spikes
    """
    logger = logging.getLogger(__name__)

    CONFIG = read_config()

    folder = Path(CONFIG.data.root_folder, output_directory, 'detect')
    folder.mkdir(exist_ok=True)

    TMP_FOLDER = str(folder)

    # check if all scores, clear and collision spikes exist..
    path_to_score = os.path.join(TMP_FOLDER, 'scores_clear.npy')
    path_to_spike_index_clear = os.path.join(TMP_FOLDER,
                                             'spike_index_clear.npy')
    path_to_spike_index_all = os.path.join(TMP_FOLDER, 'spike_index_all.npy')
    path_to_rotation = os.path.join(TMP_FOLDER, 'rotation.npy')

    paths = [path_to_score, path_to_spike_index_clear, path_to_spike_index_all]
    exists = [os.path.exists(p) for p in paths]

    if (if_file_exists == 'overwrite'
            or if_file_exists == 'abort' and not any(exists)
            or if_file_exists == 'skip' and not all(exists)):
        max_memory = (CONFIG.resources.max_memory_gpu
                      if GPU_ENABLED else CONFIG.resources.max_memory)

        # instantiate batch processor
        bp = BatchProcessor(standarized_path,
                            standarized_params['dtype'],
                            standarized_params['n_channels'],
                            standarized_params['data_order'],
                            max_memory,
                            buffer_size=CONFIG.spike_size)

        # load parameters
        detection_th = CONFIG.detect.neural_network_detector.threshold_spike
        triage_th = CONFIG.detect.neural_network_triage.threshold_collision
        detection_fname = CONFIG.detect.neural_network_detector.filename
        ae_fname = CONFIG.detect.neural_network_autoencoder.filename
        triage_fname = CONFIG.detect.neural_network_triage.filename

        # instantiate neural networks
        NND = NeuralNetDetector.load(detection_fname, detection_th,
                                     CONFIG.channel_index)
        NNT = NeuralNetTriage.load(triage_fname,
                                   triage_th,
                                   input_tensor=NND.waveform_tf)
        NNAE = AutoEncoder(ae_fname, input_tensor=NND.waveform_tf)

        neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2)
        rotation = NNAE.load_rotation()

        # gather all output tensors
        output_tf = (NNAE.score_tf, NND.spike_index_tf, NNT.idx_clean)

        # run detection
        with tf.Session() as sess:

            # get values of above tensors
            NND.restore(sess)
            NNAE.restore(sess)
            NNT.restore(sess)

            mc = bp.multi_channel_apply
            res = mc(neuralnetwork.run_detect_triage_featurize,
                     mode='memory',
                     cleanup_function=neuralnetwork.fix_indexes,
                     sess=sess,
                     x_tf=NND.x_tf,
                     output_tf=output_tf,
                     rot=rotation,
                     neighbors=neighbors)

        # get clear spikes
        clear = np.concatenate([element[1] for element in res], axis=0)
        logger.info('Removing clear indexes outside the allowed range to '
                    'draw a complete waveform...')
        clear, idx = detect.remove_incomplete_waveforms(
            clear, CONFIG.spike_size + CONFIG.templates.max_shift,
            bp.reader._n_observations)

        # get all spikes
        spikes_all = np.concatenate([element[2] for element in res], axis=0)
        logger.info('Removing indexes outside the allowed range to '
                    'draw a complete waveform...')
        spikes_all, _ = detect.remove_incomplete_waveforms(
            spikes_all, CONFIG.spike_size + CONFIG.templates.max_shift,
            bp.reader._n_observations)

        # get scores
        scores = np.concatenate([element[0] for element in res], axis=0)
        logger.info('Removing scores for indexes outside the allowed range to '
                    'draw a complete waveform...')
        scores = scores[idx]

        # transform scores to location + shape feature space
        # TODO: move this to another place

        if CONFIG.cluster.method == 'location':
            threshold = 2
            scores = get_locations_features(scores, rotation, clear[:, 1],
                                            CONFIG.channel_index, CONFIG.geom,
                                            threshold)
            idx_nan = np.where(np.isnan(np.sum(scores, axis=(1, 2))))[0]
            scores = np.delete(scores, idx_nan, 0)
            clear = np.delete(clear, idx_nan, 0)

        # save partial results if required
        if save_results:
            # save clear spikes
            np.save(path_to_spike_index_clear, clear)
            logger.info('Saved spike index clear in {}...'.format(
                path_to_spike_index_clear))

            # save all ppikes
            np.save(path_to_spike_index_all, spikes_all)
            logger.info('Saved spike index all in {}...'.format(
                path_to_spike_index_all))

            # save rotation
            np.save(path_to_rotation, rotation)
            logger.info(
                'Saved rotation matrix in {}...'.format(path_to_rotation))

            # saves scores
            np.save(path_to_score, scores)
            logger.info('Saved spike scores in {}...'.format(path_to_score))

    elif if_file_exists == 'abort' and any(exists):
        conflict = [p for p, e in zip(paths, exists) if e]
        message = reduce(lambda x, y: str(x) + ', ' + str(y), conflict)
        raise ValueError('if_file_exists was set to abort, the '
                         'program halted since the following files '
                         'already exist: {}'.format(message))
    elif if_file_exists == 'skip' and all(exists):
        logger.warning('Skipped execution. All output files exist'
                       ', loading them...')
        scores = np.load(path_to_score)
        clear = np.load(path_to_spike_index_clear)
        spikes_all = np.load(path_to_spike_index_all)

    else:
        raise ValueError(
            'Invalid value for if_file_exists {}'
            'must be one of overwrite, abort or skip'.format(if_file_exists))

    return scores, clear, spikes_all
Exemple #20
0
def threshold(rec, neighbors, spike_size, std_factor):
    """Threshold-based spike detection

    Parameters
    ----------
    rec: np.ndarray (n_observations, n_channels)
        numpy 2-D array with the recordings, first dimension must be
        n_observations and second n_channels
    neighbors: np.ndarray (n_channels, n_channels)
        Boolean numpy 2-D array where a i, j entry is True if i is considered
        neighbor of j
    spike_size: int
        Spike size
    std_factor: float?
        ?

    Notes
    -----
    [Add brief description of the method]

    Returns
    -------
    index: np.ndarray (number of spikes, 2)
        First column is spike time, second column is main channel (the channel
        where spike has the biggest amplitude)
    """
    T, C = rec.shape
    R = spike_size
    th = std_factor
    neighChannels_big = n_steps_neigh_channels(neighbors, steps=2)

    # FIXME: is this a safe thing to do?
    index = np.zeros((1000000, 2), 'int32')
    count = 0

    for c in range(C):
        idx = np.logical_and(
            rec[:, c] < -th, np.r_[True, rec[1:, c] < rec[:-1, c]]
            & np.r_[rec[:-1, c] < rec[1:, c], True])
        nc = np.sum(idx)

        if nc > 0:
            spt_c = np.where(idx)[0]
            spt_c = spt_c[np.logical_and(spt_c > 2 * R, spt_c < T - 2 * R)]
            nc = spt_c.shape[0]
            ch_idx = np.where(neighChannels_big[c])[0]
            c_main = np.where(ch_idx == c)[0]
            idx_keep = np.zeros(nc, 'bool')

            for j in range(nc):
                wf_temp = rec[spt_c[j] + np.arange(-2 * R, 2 * R + 1)][:,
                                                                       ch_idx]
                c_min = np.argmin(np.amin(wf_temp, axis=0))
                t_min = np.argmin(wf_temp[:, c_min])
                if t_min == 2 * R and c_min == c_main:
                    idx_keep[j] = 1

            nc = np.sum(idx_keep)
            index[count:(count + nc), 0] = spt_c[idx_keep]
            index[count:(count + nc), 1] = c
            count += nc

    return index[:count]