예제 #1
0
def test_can_use_detect_and_triage_after_reload(path_to_tests,
                                                path_to_sample_pipeline_folder,
                                                tmp_folder,
                                                path_to_standarized_data):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    (x_detect, y_detect, x_triage, y_triage, x_ae,
     y_ae) = make_training_data(CONFIG, spike_train, chosen_templates,
                                min_amplitude, n_spikes,
                                path_to_sample_pipeline_folder)

    _, waveform_length, n_neighbors = x_detect.shape

    path_to_model = path.join(tmp_folder, 'detect-net.ckpt')

    detector = NeuralNetDetector(path_to_model,
                                 filters,
                                 waveform_length,
                                 n_neighbors,
                                 threshold=0.5,
                                 channel_index=CONFIG.channel_index,
                                 n_iter=10)

    detector.fit(x_detect, y_detect)

    detector = NeuralNetDetector.load(path_to_model,
                                      threshold=0.5,
                                      channel_index=CONFIG.channel_index)

    triage = NeuralNetTriage(path_to_model,
                             filters,
                             waveform_length,
                             n_neighbors,
                             threshold=0.5,
                             n_iter=10)

    triage.fit(x_detect, y_detect)

    triage = NeuralNetTriage.load(path_to_model, threshold=0.5)

    data = RecordingExplorer(path_to_standarized_data).reader.data

    output_names = ('spike_index', 'waveform', 'probability')

    (spike_index, waveform,
     proba) = detector.predict(data, output_names=output_names)

    triage.predict(waveform[:, :, :n_neighbors])
예제 #2
0
def test_can_train_triage(path_to_tests, path_to_sample_pipeline_folder,
                          tmp_folder):
    yass.set_config(path.join(path_to_tests, 'config_nnet.yaml'))
    CONFIG = yass.read_config()

    (x_detect, y_detect, x_triage, y_triage, x_ae,
     y_ae) = make_training_data(CONFIG, spike_train, chosen_templates,
                                min_amplitude, n_spikes,
                                path_to_sample_pipeline_folder)

    _, waveform_length, n_neighbors = x_triage.shape

    path_to_model = path.join(tmp_folder, 'triage-net.ckpt')

    triage = NeuralNetTriage(path_to_model,
                             filters,
                             waveform_length,
                             n_neighbors,
                             threshold=0.5,
                             n_iter=10)

    triage.fit(x_detect, y_detect)
예제 #3
0
파일: prepare.py 프로젝트: kathefter/yass
def prepare_nn(channel_index, whiten_filter, threshold_detect,
               threshold_triage, detector_filename, autoencoder_filename,
               triage_filename):
    """Prepare neural net tensors in advance. This is to effciently run
    neural net with batch processor as we don't have to recreate tf
    tensors in every batch

    Parameters
    ----------
    channel_index: np.array (n_channels, n_neigh)
        Each row indexes its neighboring channels.
        For example, channel_index[c] is the index of
        neighboring channels (including itself)
        If any value is equal to n_channels, it is nothing but
        a space holder in a case that a channel has less than
        n_neigh neighboring channels

    whiten_filter: numpy.ndarray (n_channels, n_neigh, n_neigh)
        whitening matrix such that whiten_filter[c] is the whitening
        filter of channel c and its neighboring channel determined from
        channel_index.

    threshold_detect: int
        threshold for neural net detection

    threshold_triage: int
        threshold for neural net triage

    detector_filename: str
        location of trained neural net detectior

    autoencoder_filename: str
        location of trained neural net autoencoder

    triage_filename: str
        location of trained neural net triage

    Returns
    -------
    x_tf: tf.tensors (n_observations, n_channels)
        placeholder of recording for running tensorflow

    output_tf: tuple of tf.tensors
        a tuple of tensorflow tensors that produce score, spike_index_clear,
        and spike_index_collision

    NND: class
        an instance of class, NeuralNetDetector

    NNT: class
        an instance of class, NeuralNetTriage
    """

    # placeholder for input recording
    x_tf = tf.placeholder("float", [None, None])

    # load Neural Net's
    NND = NeuralNetDetector(detector_filename)
    NNAE = AutoEncoder(autoencoder_filename)
    NNT = NeuralNetTriage(triage_filename)

    # make spike_index tensorflow tensor
    spike_index_tf_all = NND.make_detection_tf_tensors(x_tf, channel_index,
                                                       threshold_detect)

    # remove edge spike time
    spike_index_tf = remove_edge_spikes(x_tf, spike_index_tf_all,
                                        NND.filters_dict['size'])

    # make waveform tensorflow tensor
    waveform_tf = make_waveform_tf_tensor(x_tf, spike_index_tf, channel_index,
                                          NND.filters_dict['size'])

    # make score tensorflow tensor from waveform
    score_tf = NNAE.make_score_tf_tensor(waveform_tf)

    # run neural net triage
    nneigh = NND.filters_dict['n_neighbors']
    idx_clean = NNT.triage_wf(waveform_tf[:, :, :nneigh], threshold_triage)

    # gather all output tensors
    output_tf = (score_tf, spike_index_tf, idx_clean)

    return x_tf, output_tf, NND, NNAE, NNT
예제 #4
0
파일: detect.py 프로젝트: hooshmandshr/yass
def nn_detection(recordings, neighbors, geom, temporal_features,
                 temporal_window, th_detect, th_triage, detector_filename,
                 autoencoder_filename, triage_filename):
    """Detect spikes using a neural network

    Parameters
    ----------
    recordings: numpy.ndarray (n_observations, n_channels)
        Neural recordings

    neighbors: numpy.ndarray (n_channels, n_channels)
        Channels neighbors matric

    geom: numpy.ndarray (n_channels, 2)
        Cartesian coordinates for the channels

    temporal_features: int
        ?

    temporal_window: int
        ?

    th_detect: float?
        Spike threshold [improve this explanation]

    th_triage: float?
        Triage threshold [improve this explanation]

    detector_filename: str
        Path to neural network detector

    autoencoder_filename: str
        Path to neural network autoencoder

    triage_filename: str
        Path to triage neural network

    Returns
    -------
    clear_scores: numpy.ndarray (n_spikes, n_features, n_channels)
        3D array with the scores for the clear spikes, first simension is
        the number of spikes, second is the nymber of features and third the
        number of channels

    spike_index_clear: numpy.ndarray (n_clear_spikes, 2)
        2D array with indexes for clear spikes, first column contains the
        spike location in the recording and the second the main channel
        (channel whose amplitude is maximum)

    spike_index_collision: numpy.ndarray (n_collided_spikes, 2)
        2D array with indexes for collided spikes, first column contains the
        spike location in the recording and the second the main channel
        (channel whose amplitude is maximum)
    """
    nnd = NeuralNetDetector(detector_filename, autoencoder_filename)
    nnt = NeuralNetTriage(triage_filename)

    T, C = recordings.shape

    a, b = neighbors.shape

    if a != b:
        raise ValueError('neighbors is not a square matrix, verify')

    if a != C:
        raise ValueError(
            'Number of channels in recording are {} but the '
            'neighbors matrix has {} elements, they must match'.format(C, a))

    # neighboring channel info
    nneigh = np.max(np.sum(neighbors, 0))
    c_idx = np.ones((C, nneigh), 'int32') * C

    for c in range(C):
        ch_idx, temp = order_channels_by_distance(c,
                                                  np.where(neighbors[c])[0],
                                                  geom)
        c_idx[c, :ch_idx.shape[0]] = ch_idx

    # input
    x_tf = tf.placeholder("float", [T, C])

    # detect spike index
    local_max_idx_tf = nnd.get_spikes(x_tf, T, nneigh, c_idx, temporal_window,
                                      th_detect)

    # get score train
    score_train_tf = nnd.get_score_train(x_tf)

    # get energy for detected index
    energy_tf = tf.reduce_sum(tf.square(score_train_tf), axis=2)
    energy_val_tf = tf.gather_nd(energy_tf, local_max_idx_tf)

    # get triage probability
    triage_prob_tf = nnt.triage_prob(x_tf, T, nneigh, c_idx)

    # gather all results above
    result = (local_max_idx_tf, score_train_tf, energy_val_tf, triage_prob_tf)

    # remove duplicates
    energy_train_tf = tf.placeholder("float", [T, C])
    spike_index_tf = remove_duplicate_spikes_by_energy(energy_train_tf, T,
                                                       c_idx, temporal_window)

    # get score
    score_train_placeholder = tf.placeholder("float",
                                             [T, C, temporal_features])
    spike_index_clear_tf = tf.placeholder("int64", [None, 2])
    score_tf = get_score(score_train_placeholder, spike_index_clear_tf, T,
                         temporal_features, c_idx)

    ###############################
    # get values of above tensors #
    ###############################

    with tf.Session() as sess:

        nnd.saver.restore(sess, nnd.path_to_detector_model)
        nnd.saver_ae.restore(sess, nnd.path_to_ae_model)
        nnt.saver.restore(sess, nnt.path_to_triage_model)

        local_max_idx, score_train, energy_val, triage_prob = sess.run(
            result, feed_dict={x_tf: recordings})

        energy_train = np.zeros((T, C))
        energy_train[local_max_idx[:, 0], local_max_idx[:, 1]] = energy_val
        spike_index = sess.run(spike_index_tf,
                               feed_dict={energy_train_tf: energy_train})

        idx_clean = triage_prob[spike_index[:, 0], spike_index[:,
                                                               1]] > th_triage

        spike_index_clear = spike_index[idx_clean]
        spike_index_collision = spike_index[~idx_clean]

        score = sess.run(score_tf,
                         feed_dict={
                             score_train_placeholder: score_train,
                             spike_index_clear_tf: spike_index_clear
                         })

    return score, spike_index_clear, spike_index_collision