def train(CONFIG, CONFIG_TRAIN, spike_train, data_folder): """ Train neural network Parameters ---------- CONFIG YASS configuration file CONFIG_TRAIN YASS Neural Network configuration file spike_train: numpy.ndarray Spike train, first column is spike index and second is main channel """ logger = logging.getLogger(__name__) chosen_templates = CONFIG_TRAIN['templates']['ids'] min_amp = CONFIG_TRAIN['templates']['minimum_amplitude'] nspikes = CONFIG_TRAIN['training']['n_spikes'] n_filters_detect = CONFIG_TRAIN['network_detector']['n_filters'] n_iter = CONFIG_TRAIN['training']['n_iterations'] n_batch = CONFIG_TRAIN['training']['n_batch'] l2_reg_scale = CONFIG_TRAIN['training']['l2_regularization_scale'] train_step_size = CONFIG_TRAIN['training']['step_size'] detectnet_name = './' + CONFIG_TRAIN['network_detector']['name'] + '.ckpt' n_filters_triage = CONFIG_TRAIN['network_triage']['n_filters'] triagenet_name = './' + CONFIG_TRAIN['network_triage']['name'] + '.ckpt' n_features = CONFIG_TRAIN['network_autoencoder']['n_features'] ae_name = './' + CONFIG_TRAIN['network_autoencoder']['name'] + '.ckpt' # generate training data for detection, triage and autoencoder logger.info('Generating training data...') (x_detect, y_detect, x_triage, y_triage, x_ae, y_ae) = make_training_data(CONFIG, spike_train, chosen_templates, min_amp, nspikes, data_folder=data_folder) # train detector NeuralNetDetector.train(x_detect, y_detect, n_filters_detect, n_iter, n_batch, l2_reg_scale, train_step_size, detectnet_name) # train triage NeuralNetTriage.train(x_triage, y_triage, n_filters_triage, n_iter, n_batch, l2_reg_scale, train_step_size, triagenet_name) # train autoencoder AutoEncoder.train(x_ae, y_ae, n_features, n_iter, n_batch, train_step_size, ae_name)
def test_can_use_neural_network_detector(path_to_tests, path_to_standarized_data): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() data = RecordingsReader(path_to_standarized_data, loader='array').data channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom) detection_th = CONFIG.detect.neural_network_detector.threshold_spike triage_th = CONFIG.detect.neural_network_triage.threshold_collision detection_fname = CONFIG.detect.neural_network_detector.filename ae_fname = CONFIG.detect.neural_network_autoencoder.filename triage_fname = CONFIG.detect.neural_network_triage.filename # instantiate neural networks NND = NeuralNetDetector.load(detection_fname, detection_th, channel_index) NNT = NeuralNetTriage.load(triage_fname, triage_th, input_tensor=NND.waveform_tf) NNAE = AutoEncoder(ae_fname, input_tensor=NND.waveform_tf) output_tf = (NNAE.score_tf, NND.spike_index_tf, NNT.idx_clean) with tf.Session() as sess: NND.restore(sess) NNAE.restore(sess) NNT.restore(sess) rot = NNAE.load_rotation() neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) neuralnetwork.run_detect_triage_featurize(data, sess, NND.x_tf, output_tf, neighbors, rot)
def test_can_train_triage(path_to_tests, path_to_sample_pipeline_folder, tmp_folder): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() (x_detect, y_detect, x_triage, y_triage, x_ae, y_ae) = make_training_data(CONFIG, spike_train, chosen_templates, min_amplitude, n_spikes, path_to_sample_pipeline_folder) _, waveform_length, n_neighbors = x_triage.shape path_to_model = path.join(tmp_folder, 'triage-net.ckpt') triage = NeuralNetTriage(path_to_model, filters, waveform_length, n_neighbors, threshold=0.5, n_iter=10) triage.fit(x_detect, y_detect)
def test_can_use_detect_and_triage_after_reload(path_to_tests, path_to_sample_pipeline_folder, tmp_folder, path_to_standarized_data): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() (x_detect, y_detect, x_triage, y_triage, x_ae, y_ae) = make_training_data(CONFIG, spike_train, chosen_templates, min_amplitude, n_spikes, path_to_sample_pipeline_folder) _, waveform_length, n_neighbors = x_detect.shape path_to_model = path.join(tmp_folder, 'detect-net.ckpt') detector = NeuralNetDetector(path_to_model, filters, waveform_length, n_neighbors, threshold=0.5, channel_index=CONFIG.channel_index, n_iter=10) detector.fit(x_detect, y_detect) detector = NeuralNetDetector.load(path_to_model, threshold=0.5, channel_index=CONFIG.channel_index) triage = NeuralNetTriage(path_to_model, filters, waveform_length, n_neighbors, threshold=0.5, n_iter=10) triage.fit(x_detect, y_detect) triage = NeuralNetTriage.load(path_to_model, threshold=0.5) data = RecordingExplorer(path_to_standarized_data).reader.data output_names = ('spike_index', 'waveform', 'probability') (spike_index, waveform, proba) = detector.predict(data, output_names=output_names) triage.predict(waveform[:, :, :n_neighbors])
def run_neural_network(standarized_path, standarized_params, whiten_filter, output_directory, if_file_exists, save_results): """Run neural network detection and autoencoder dimensionality reduction Returns ------- scores Scores for all spikes spike_index_clear Spike indexes for clear spikes spike_index_all Spike indexes for all spikes """ logger = logging.getLogger(__name__) CONFIG = read_config() folder = Path(CONFIG.data.root_folder, output_directory, 'detect') folder.mkdir(exist_ok=True) TMP_FOLDER = str(folder) # check if all scores, clear and collision spikes exist.. path_to_score = os.path.join(TMP_FOLDER, 'scores_clear.npy') path_to_spike_index_clear = os.path.join(TMP_FOLDER, 'spike_index_clear.npy') path_to_spike_index_all = os.path.join(TMP_FOLDER, 'spike_index_all.npy') path_to_rotation = os.path.join(TMP_FOLDER, 'rotation.npy') paths = [path_to_score, path_to_spike_index_clear, path_to_spike_index_all] exists = [os.path.exists(p) for p in paths] if (if_file_exists == 'overwrite' or if_file_exists == 'abort' and not any(exists) or if_file_exists == 'skip' and not all(exists)): max_memory = (CONFIG.resources.max_memory_gpu if GPU_ENABLED else CONFIG.resources.max_memory) # instantiate batch processor bp = BatchProcessor(standarized_path, standarized_params['dtype'], standarized_params['n_channels'], standarized_params['data_order'], max_memory, buffer_size=CONFIG.spike_size) # load parameters detection_th = CONFIG.detect.neural_network_detector.threshold_spike triage_th = CONFIG.detect.neural_network_triage.threshold_collision detection_fname = CONFIG.detect.neural_network_detector.filename ae_fname = CONFIG.detect.neural_network_autoencoder.filename triage_fname = CONFIG.detect.neural_network_triage.filename # instantiate neural networks NND = NeuralNetDetector.load(detection_fname, detection_th, CONFIG.channel_index) NNT = NeuralNetTriage.load(triage_fname, triage_th, input_tensor=NND.waveform_tf) NNAE = AutoEncoder(ae_fname, input_tensor=NND.waveform_tf) neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) rotation = NNAE.load_rotation() # gather all output tensors output_tf = (NNAE.score_tf, NND.spike_index_tf, NNT.idx_clean) # run detection with tf.Session() as sess: # get values of above tensors NND.restore(sess) NNAE.restore(sess) NNT.restore(sess) mc = bp.multi_channel_apply res = mc(neuralnetwork.run_detect_triage_featurize, mode='memory', cleanup_function=neuralnetwork.fix_indexes, sess=sess, x_tf=NND.x_tf, output_tf=output_tf, rot=rotation, neighbors=neighbors) # get clear spikes clear = np.concatenate([element[1] for element in res], axis=0) logger.info('Removing clear indexes outside the allowed range to ' 'draw a complete waveform...') clear, idx = detect.remove_incomplete_waveforms( clear, CONFIG.spike_size + CONFIG.templates.max_shift, bp.reader._n_observations) # get all spikes spikes_all = np.concatenate([element[2] for element in res], axis=0) logger.info('Removing indexes outside the allowed range to ' 'draw a complete waveform...') spikes_all, _ = detect.remove_incomplete_waveforms( spikes_all, CONFIG.spike_size + CONFIG.templates.max_shift, bp.reader._n_observations) # get scores scores = np.concatenate([element[0] for element in res], axis=0) logger.info('Removing scores for indexes outside the allowed range to ' 'draw a complete waveform...') scores = scores[idx] # transform scores to location + shape feature space # TODO: move this to another place if CONFIG.cluster.method == 'location': threshold = 2 scores = get_locations_features(scores, rotation, clear[:, 1], CONFIG.channel_index, CONFIG.geom, threshold) idx_nan = np.where(np.isnan(np.sum(scores, axis=(1, 2))))[0] scores = np.delete(scores, idx_nan, 0) clear = np.delete(clear, idx_nan, 0) # save partial results if required if save_results: # save clear spikes np.save(path_to_spike_index_clear, clear) logger.info('Saved spike index clear in {}...'.format( path_to_spike_index_clear)) # save all ppikes np.save(path_to_spike_index_all, spikes_all) logger.info('Saved spike index all in {}...'.format( path_to_spike_index_all)) # save rotation np.save(path_to_rotation, rotation) logger.info( 'Saved rotation matrix in {}...'.format(path_to_rotation)) # saves scores np.save(path_to_score, scores) logger.info('Saved spike scores in {}...'.format(path_to_score)) elif if_file_exists == 'abort' and any(exists): conflict = [p for p, e in zip(paths, exists) if e] message = reduce(lambda x, y: str(x) + ', ' + str(y), conflict) raise ValueError('if_file_exists was set to abort, the ' 'program halted since the following files ' 'already exist: {}'.format(message)) elif if_file_exists == 'skip' and all(exists): logger.warning('Skipped execution. All output files exist' ', loading them...') scores = np.load(path_to_score) clear = np.load(path_to_spike_index_clear) spikes_all = np.load(path_to_spike_index_all) else: raise ValueError( 'Invalid value for if_file_exists {}' 'must be one of overwrite, abort or skip'.format(if_file_exists)) return scores, clear, spikes_all
def prepare_nn(channel_index, whiten_filter, threshold_detect, threshold_triage, detector_filename, autoencoder_filename, triage_filename): """Prepare neural net tensors in advance. This is to effciently run neural net with batch processor as we don't have to recreate tf tensors in every batch Parameters ---------- channel_index: np.array (n_channels, n_neigh) Each row indexes its neighboring channels. For example, channel_index[c] is the index of neighboring channels (including itself) If any value is equal to n_channels, it is nothing but a space holder in a case that a channel has less than n_neigh neighboring channels whiten_filter: numpy.ndarray (n_channels, n_neigh, n_neigh) whitening matrix such that whiten_filter[c] is the whitening filter of channel c and its neighboring channel determined from channel_index. threshold_detect: int threshold for neural net detection threshold_triage: int threshold for neural net triage detector_filename: str location of trained neural net detectior autoencoder_filename: str location of trained neural net autoencoder triage_filename: str location of trained neural net triage Returns ------- x_tf: tf.tensors (n_observations, n_channels) placeholder of recording for running tensorflow output_tf: tuple of tf.tensors a tuple of tensorflow tensors that produce score, spike_index_clear, and spike_index_collision NND: class an instance of class, NeuralNetDetector NNT: class an instance of class, NeuralNetTriage """ # placeholder for input recording x_tf = tf.placeholder("float", [None, None]) # load Neural Net's NND = NeuralNetDetector(detector_filename) NNAE = AutoEncoder(autoencoder_filename) NNT = NeuralNetTriage(triage_filename) # make spike_index tensorflow tensor spike_index_tf_all = NND.make_detection_tf_tensors(x_tf, channel_index, threshold_detect) # remove edge spike time spike_index_tf = remove_edge_spikes(x_tf, spike_index_tf_all, NND.filters_dict['size']) # make waveform tensorflow tensor waveform_tf = make_waveform_tf_tensor(x_tf, spike_index_tf, channel_index, NND.filters_dict['size']) # make score tensorflow tensor from waveform score_tf = NNAE.make_score_tf_tensor(waveform_tf) # run neural net triage nneigh = NND.filters_dict['n_neighbors'] idx_clean = NNT.triage_wf(waveform_tf[:, :, :nneigh], threshold_triage) # gather all output tensors output_tf = (score_tf, spike_index_tf, idx_clean) return x_tf, output_tf, NND, NNAE, NNT
def nn_detection(recordings, neighbors, geom, temporal_features, temporal_window, th_detect, th_triage, detector_filename, autoencoder_filename, triage_filename): """Detect spikes using a neural network Parameters ---------- recordings: numpy.ndarray (n_observations, n_channels) Neural recordings neighbors: numpy.ndarray (n_channels, n_channels) Channels neighbors matric geom: numpy.ndarray (n_channels, 2) Cartesian coordinates for the channels temporal_features: int ? temporal_window: int ? th_detect: float? Spike threshold [improve this explanation] th_triage: float? Triage threshold [improve this explanation] detector_filename: str Path to neural network detector autoencoder_filename: str Path to neural network autoencoder triage_filename: str Path to triage neural network Returns ------- clear_scores: numpy.ndarray (n_spikes, n_features, n_channels) 3D array with the scores for the clear spikes, first simension is the number of spikes, second is the nymber of features and third the number of channels spike_index_clear: numpy.ndarray (n_clear_spikes, 2) 2D array with indexes for clear spikes, first column contains the spike location in the recording and the second the main channel (channel whose amplitude is maximum) spike_index_collision: numpy.ndarray (n_collided_spikes, 2) 2D array with indexes for collided spikes, first column contains the spike location in the recording and the second the main channel (channel whose amplitude is maximum) """ nnd = NeuralNetDetector(detector_filename, autoencoder_filename) nnt = NeuralNetTriage(triage_filename) T, C = recordings.shape a, b = neighbors.shape if a != b: raise ValueError('neighbors is not a square matrix, verify') if a != C: raise ValueError( 'Number of channels in recording are {} but the ' 'neighbors matrix has {} elements, they must match'.format(C, a)) # neighboring channel info nneigh = np.max(np.sum(neighbors, 0)) c_idx = np.ones((C, nneigh), 'int32') * C for c in range(C): ch_idx, temp = order_channels_by_distance(c, np.where(neighbors[c])[0], geom) c_idx[c, :ch_idx.shape[0]] = ch_idx # input x_tf = tf.placeholder("float", [T, C]) # detect spike index local_max_idx_tf = nnd.get_spikes(x_tf, T, nneigh, c_idx, temporal_window, th_detect) # get score train score_train_tf = nnd.get_score_train(x_tf) # get energy for detected index energy_tf = tf.reduce_sum(tf.square(score_train_tf), axis=2) energy_val_tf = tf.gather_nd(energy_tf, local_max_idx_tf) # get triage probability triage_prob_tf = nnt.triage_prob(x_tf, T, nneigh, c_idx) # gather all results above result = (local_max_idx_tf, score_train_tf, energy_val_tf, triage_prob_tf) # remove duplicates energy_train_tf = tf.placeholder("float", [T, C]) spike_index_tf = remove_duplicate_spikes_by_energy(energy_train_tf, T, c_idx, temporal_window) # get score score_train_placeholder = tf.placeholder("float", [T, C, temporal_features]) spike_index_clear_tf = tf.placeholder("int64", [None, 2]) score_tf = get_score(score_train_placeholder, spike_index_clear_tf, T, temporal_features, c_idx) ############################### # get values of above tensors # ############################### with tf.Session() as sess: nnd.saver.restore(sess, nnd.path_to_detector_model) nnd.saver_ae.restore(sess, nnd.path_to_ae_model) nnt.saver.restore(sess, nnt.path_to_triage_model) local_max_idx, score_train, energy_val, triage_prob = sess.run( result, feed_dict={x_tf: recordings}) energy_train = np.zeros((T, C)) energy_train[local_max_idx[:, 0], local_max_idx[:, 1]] = energy_val spike_index = sess.run(spike_index_tf, feed_dict={energy_train_tf: energy_train}) idx_clean = triage_prob[spike_index[:, 0], spike_index[:, 1]] > th_triage spike_index_clear = spike_index[idx_clean] spike_index_collision = spike_index[~idx_clean] score = sess.run(score_tf, feed_dict={ score_train_placeholder: score_train, spike_index_clear_tf: spike_index_clear }) return score, spike_index_clear, spike_index_collision
def test_splitting_in_batches_does_not_affect(path_to_tests, path_to_standarized_data, path_to_sample_pipeline_folder): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() PATH_TO_DATA = path_to_standarized_data data = RecordingsReader(PATH_TO_DATA, loader='array').data with open( path.join(path_to_sample_pipeline_folder, 'preprocess', 'standarized.yaml')) as f: PARAMS = yaml.load(f) channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom) detection_th = CONFIG.detect.neural_network_detector.threshold_spike triage_th = CONFIG.detect.neural_network_triage.threshold_collision detection_fname = CONFIG.detect.neural_network_detector.filename ae_fname = CONFIG.detect.neural_network_autoencoder.filename triage_fname = CONFIG.detect.neural_network_triage.filename # instantiate neural networks NND = NeuralNetDetector.load(detection_fname, detection_th, channel_index) NNT = NeuralNetTriage.load(triage_fname, triage_th, input_tensor=NND.waveform_tf) NNAE = AutoEncoder(ae_fname, input_tensor=NND.waveform_tf) output_tf = (NNAE.score_tf, NND.spike_index_tf, NNT.idx_clean) # run all at once with tf.Session() as sess: # get values of above tensors NND.restore(sess) NNAE.restore(sess) NNT.restore(sess) rot = NNAE.load_rotation() neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) (scores, clear, collision) = neuralnetwork.run_detect_triage_featurize( data, sess, NND.x_tf, output_tf, neighbors, rot) # run in batches - buffer size makes sure we can detect spikes if they # appear at the end of any batch bp = BatchProcessor(PATH_TO_DATA, PARAMS['dtype'], PARAMS['n_channels'], PARAMS['data_order'], '100KB', buffer_size=CONFIG.spike_size) with tf.Session() as sess: # get values of above tensors NND.restore(sess) NNAE.restore(sess) NNT.restore(sess) rot = NNAE.load_rotation() neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) res = bp.multi_channel_apply( neuralnetwork.run_detect_triage_featurize, mode='memory', cleanup_function=neuralnetwork.fix_indexes, sess=sess, x_tf=NND.x_tf, output_tf=output_tf, rot=rot, neighbors=neighbors) scores_batch = np.concatenate([element[0] for element in res], axis=0) clear_batch = np.concatenate([element[1] for element in res], axis=0) collision_batch = np.concatenate([element[2] for element in res], axis=0) np.testing.assert_array_equal(clear_batch, clear) np.testing.assert_array_equal(collision_batch, collision) np.testing.assert_array_equal(scores_batch, scores)