def test_can_use_detect_and_triage_after_reload(path_to_tests, path_to_sample_pipeline_folder, tmp_folder, path_to_standarized_data): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() (x_detect, y_detect, x_triage, y_triage, x_ae, y_ae) = make_training_data(CONFIG, spike_train, chosen_templates, min_amplitude, n_spikes, path_to_sample_pipeline_folder) _, waveform_length, n_neighbors = x_detect.shape path_to_model = path.join(tmp_folder, 'detect-net.ckpt') detector = NeuralNetDetector(path_to_model, filters, waveform_length, n_neighbors, threshold=0.5, channel_index=CONFIG.channel_index, n_iter=10) detector.fit(x_detect, y_detect) detector = NeuralNetDetector.load(path_to_model, threshold=0.5, channel_index=CONFIG.channel_index) triage = NeuralNetTriage(path_to_model, filters, waveform_length, n_neighbors, threshold=0.5, n_iter=10) triage.fit(x_detect, y_detect) triage = NeuralNetTriage.load(path_to_model, threshold=0.5) data = RecordingExplorer(path_to_standarized_data).reader.data output_names = ('spike_index', 'waveform', 'probability') (spike_index, waveform, proba) = detector.predict(data, output_names=output_names) triage.predict(waveform[:, :, :n_neighbors])
def test_can_train_triage(path_to_tests, path_to_sample_pipeline_folder, tmp_folder): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() (x_detect, y_detect, x_triage, y_triage, x_ae, y_ae) = make_training_data(CONFIG, spike_train, chosen_templates, min_amplitude, n_spikes, path_to_sample_pipeline_folder) _, waveform_length, n_neighbors = x_triage.shape path_to_model = path.join(tmp_folder, 'triage-net.ckpt') triage = NeuralNetTriage(path_to_model, filters, waveform_length, n_neighbors, threshold=0.5, n_iter=10) triage.fit(x_detect, y_detect)
def prepare_nn(channel_index, whiten_filter, threshold_detect, threshold_triage, detector_filename, autoencoder_filename, triage_filename): """Prepare neural net tensors in advance. This is to effciently run neural net with batch processor as we don't have to recreate tf tensors in every batch Parameters ---------- channel_index: np.array (n_channels, n_neigh) Each row indexes its neighboring channels. For example, channel_index[c] is the index of neighboring channels (including itself) If any value is equal to n_channels, it is nothing but a space holder in a case that a channel has less than n_neigh neighboring channels whiten_filter: numpy.ndarray (n_channels, n_neigh, n_neigh) whitening matrix such that whiten_filter[c] is the whitening filter of channel c and its neighboring channel determined from channel_index. threshold_detect: int threshold for neural net detection threshold_triage: int threshold for neural net triage detector_filename: str location of trained neural net detectior autoencoder_filename: str location of trained neural net autoencoder triage_filename: str location of trained neural net triage Returns ------- x_tf: tf.tensors (n_observations, n_channels) placeholder of recording for running tensorflow output_tf: tuple of tf.tensors a tuple of tensorflow tensors that produce score, spike_index_clear, and spike_index_collision NND: class an instance of class, NeuralNetDetector NNT: class an instance of class, NeuralNetTriage """ # placeholder for input recording x_tf = tf.placeholder("float", [None, None]) # load Neural Net's NND = NeuralNetDetector(detector_filename) NNAE = AutoEncoder(autoencoder_filename) NNT = NeuralNetTriage(triage_filename) # make spike_index tensorflow tensor spike_index_tf_all = NND.make_detection_tf_tensors(x_tf, channel_index, threshold_detect) # remove edge spike time spike_index_tf = remove_edge_spikes(x_tf, spike_index_tf_all, NND.filters_dict['size']) # make waveform tensorflow tensor waveform_tf = make_waveform_tf_tensor(x_tf, spike_index_tf, channel_index, NND.filters_dict['size']) # make score tensorflow tensor from waveform score_tf = NNAE.make_score_tf_tensor(waveform_tf) # run neural net triage nneigh = NND.filters_dict['n_neighbors'] idx_clean = NNT.triage_wf(waveform_tf[:, :, :nneigh], threshold_triage) # gather all output tensors output_tf = (score_tf, spike_index_tf, idx_clean) return x_tf, output_tf, NND, NNAE, NNT
def nn_detection(recordings, neighbors, geom, temporal_features, temporal_window, th_detect, th_triage, detector_filename, autoencoder_filename, triage_filename): """Detect spikes using a neural network Parameters ---------- recordings: numpy.ndarray (n_observations, n_channels) Neural recordings neighbors: numpy.ndarray (n_channels, n_channels) Channels neighbors matric geom: numpy.ndarray (n_channels, 2) Cartesian coordinates for the channels temporal_features: int ? temporal_window: int ? th_detect: float? Spike threshold [improve this explanation] th_triage: float? Triage threshold [improve this explanation] detector_filename: str Path to neural network detector autoencoder_filename: str Path to neural network autoencoder triage_filename: str Path to triage neural network Returns ------- clear_scores: numpy.ndarray (n_spikes, n_features, n_channels) 3D array with the scores for the clear spikes, first simension is the number of spikes, second is the nymber of features and third the number of channels spike_index_clear: numpy.ndarray (n_clear_spikes, 2) 2D array with indexes for clear spikes, first column contains the spike location in the recording and the second the main channel (channel whose amplitude is maximum) spike_index_collision: numpy.ndarray (n_collided_spikes, 2) 2D array with indexes for collided spikes, first column contains the spike location in the recording and the second the main channel (channel whose amplitude is maximum) """ nnd = NeuralNetDetector(detector_filename, autoencoder_filename) nnt = NeuralNetTriage(triage_filename) T, C = recordings.shape a, b = neighbors.shape if a != b: raise ValueError('neighbors is not a square matrix, verify') if a != C: raise ValueError( 'Number of channels in recording are {} but the ' 'neighbors matrix has {} elements, they must match'.format(C, a)) # neighboring channel info nneigh = np.max(np.sum(neighbors, 0)) c_idx = np.ones((C, nneigh), 'int32') * C for c in range(C): ch_idx, temp = order_channels_by_distance(c, np.where(neighbors[c])[0], geom) c_idx[c, :ch_idx.shape[0]] = ch_idx # input x_tf = tf.placeholder("float", [T, C]) # detect spike index local_max_idx_tf = nnd.get_spikes(x_tf, T, nneigh, c_idx, temporal_window, th_detect) # get score train score_train_tf = nnd.get_score_train(x_tf) # get energy for detected index energy_tf = tf.reduce_sum(tf.square(score_train_tf), axis=2) energy_val_tf = tf.gather_nd(energy_tf, local_max_idx_tf) # get triage probability triage_prob_tf = nnt.triage_prob(x_tf, T, nneigh, c_idx) # gather all results above result = (local_max_idx_tf, score_train_tf, energy_val_tf, triage_prob_tf) # remove duplicates energy_train_tf = tf.placeholder("float", [T, C]) spike_index_tf = remove_duplicate_spikes_by_energy(energy_train_tf, T, c_idx, temporal_window) # get score score_train_placeholder = tf.placeholder("float", [T, C, temporal_features]) spike_index_clear_tf = tf.placeholder("int64", [None, 2]) score_tf = get_score(score_train_placeholder, spike_index_clear_tf, T, temporal_features, c_idx) ############################### # get values of above tensors # ############################### with tf.Session() as sess: nnd.saver.restore(sess, nnd.path_to_detector_model) nnd.saver_ae.restore(sess, nnd.path_to_ae_model) nnt.saver.restore(sess, nnt.path_to_triage_model) local_max_idx, score_train, energy_val, triage_prob = sess.run( result, feed_dict={x_tf: recordings}) energy_train = np.zeros((T, C)) energy_train[local_max_idx[:, 0], local_max_idx[:, 1]] = energy_val spike_index = sess.run(spike_index_tf, feed_dict={energy_train_tf: energy_train}) idx_clean = triage_prob[spike_index[:, 0], spike_index[:, 1]] > th_triage spike_index_clear = spike_index[idx_clean] spike_index_collision = spike_index[~idx_clean] score = sess.run(score_tf, feed_dict={ score_train_placeholder: score_train, spike_index_clear_tf: spike_index_clear }) return score, spike_index_clear, spike_index_collision