def test_can_use_neural_network_detector(path_to_tests, path_to_standarized_data): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() data = RecordingsReader(path_to_standarized_data, loader='array').data channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom) detection_th = CONFIG.detect.neural_network_detector.threshold_spike triage_th = CONFIG.detect.neural_network_triage.threshold_collision detection_fname = CONFIG.detect.neural_network_detector.filename ae_fname = CONFIG.detect.neural_network_autoencoder.filename triage_fname = CONFIG.detect.neural_network_triage.filename # instantiate neural networks NND = NeuralNetDetector.load(detection_fname, detection_th, channel_index) NNT = NeuralNetTriage.load(triage_fname, triage_th, input_tensor=NND.waveform_tf) NNAE = AutoEncoder(ae_fname, input_tensor=NND.waveform_tf) output_tf = (NNAE.score_tf, NND.spike_index_tf, NNT.idx_clean) with tf.Session() as sess: NND.restore(sess) NNAE.restore(sess) NNT.restore(sess) rot = NNAE.load_rotation() neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) neuralnetwork.run_detect_triage_featurize(data, sess, NND.x_tf, output_tf, neighbors, rot)
def test_can_use_detect_and_triage_after_reload(path_to_tests, path_to_sample_pipeline_folder, tmp_folder, path_to_standarized_data): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() (x_detect, y_detect, x_triage, y_triage, x_ae, y_ae) = make_training_data(CONFIG, spike_train, chosen_templates, min_amplitude, n_spikes, path_to_sample_pipeline_folder) _, waveform_length, n_neighbors = x_detect.shape path_to_model = path.join(tmp_folder, 'detect-net.ckpt') detector = NeuralNetDetector(path_to_model, filters, waveform_length, n_neighbors, threshold=0.5, channel_index=CONFIG.channel_index, n_iter=10) detector.fit(x_detect, y_detect) detector = NeuralNetDetector.load(path_to_model, threshold=0.5, channel_index=CONFIG.channel_index) triage = NeuralNetTriage(path_to_model, filters, waveform_length, n_neighbors, threshold=0.5, n_iter=10) triage.fit(x_detect, y_detect) triage = NeuralNetTriage.load(path_to_model, threshold=0.5) data = RecordingExplorer(path_to_standarized_data).reader.data output_names = ('spike_index', 'waveform', 'probability') (spike_index, waveform, proba) = detector.predict(data, output_names=output_names) triage.predict(waveform[:, :, :n_neighbors])
def test_can_reload_triage(path_to_tests, path_to_sample_pipeline_folder, tmp_folder): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() (x_detect, y_detect, x_triage, y_triage, x_ae, y_ae) = make_training_data(CONFIG, spike_train, chosen_templates, min_amplitude, n_spikes, path_to_sample_pipeline_folder) _, waveform_length, n_neighbors = x_triage.shape path_to_model = path.join(tmp_folder, 'triage-net.ckpt') triage = NeuralNetTriage(path_to_model, filters, waveform_length, n_neighbors, threshold=0.5, n_iter=10) triage.fit(x_detect, y_detect) NeuralNetTriage.load(path_to_model, threshold=0.5)
def run_neural_network(standarized_path, standarized_params, whiten_filter, output_directory, if_file_exists, save_results): """Run neural network detection and autoencoder dimensionality reduction Returns ------- scores Scores for all spikes spike_index_clear Spike indexes for clear spikes spike_index_all Spike indexes for all spikes """ logger = logging.getLogger(__name__) CONFIG = read_config() folder = Path(CONFIG.data.root_folder, output_directory, 'detect') folder.mkdir(exist_ok=True) TMP_FOLDER = str(folder) # check if all scores, clear and collision spikes exist.. path_to_score = os.path.join(TMP_FOLDER, 'scores_clear.npy') path_to_spike_index_clear = os.path.join(TMP_FOLDER, 'spike_index_clear.npy') path_to_spike_index_all = os.path.join(TMP_FOLDER, 'spike_index_all.npy') path_to_rotation = os.path.join(TMP_FOLDER, 'rotation.npy') paths = [path_to_score, path_to_spike_index_clear, path_to_spike_index_all] exists = [os.path.exists(p) for p in paths] if (if_file_exists == 'overwrite' or if_file_exists == 'abort' and not any(exists) or if_file_exists == 'skip' and not all(exists)): max_memory = (CONFIG.resources.max_memory_gpu if GPU_ENABLED else CONFIG.resources.max_memory) # instantiate batch processor bp = BatchProcessor(standarized_path, standarized_params['dtype'], standarized_params['n_channels'], standarized_params['data_order'], max_memory, buffer_size=CONFIG.spike_size) # load parameters detection_th = CONFIG.detect.neural_network_detector.threshold_spike triage_th = CONFIG.detect.neural_network_triage.threshold_collision detection_fname = CONFIG.detect.neural_network_detector.filename ae_fname = CONFIG.detect.neural_network_autoencoder.filename triage_fname = CONFIG.detect.neural_network_triage.filename # instantiate neural networks NND = NeuralNetDetector.load(detection_fname, detection_th, CONFIG.channel_index) NNT = NeuralNetTriage.load(triage_fname, triage_th, input_tensor=NND.waveform_tf) NNAE = AutoEncoder(ae_fname, input_tensor=NND.waveform_tf) neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) rotation = NNAE.load_rotation() # gather all output tensors output_tf = (NNAE.score_tf, NND.spike_index_tf, NNT.idx_clean) # run detection with tf.Session() as sess: # get values of above tensors NND.restore(sess) NNAE.restore(sess) NNT.restore(sess) mc = bp.multi_channel_apply res = mc(neuralnetwork.run_detect_triage_featurize, mode='memory', cleanup_function=neuralnetwork.fix_indexes, sess=sess, x_tf=NND.x_tf, output_tf=output_tf, rot=rotation, neighbors=neighbors) # get clear spikes clear = np.concatenate([element[1] for element in res], axis=0) logger.info('Removing clear indexes outside the allowed range to ' 'draw a complete waveform...') clear, idx = detect.remove_incomplete_waveforms( clear, CONFIG.spike_size + CONFIG.templates.max_shift, bp.reader._n_observations) # get all spikes spikes_all = np.concatenate([element[2] for element in res], axis=0) logger.info('Removing indexes outside the allowed range to ' 'draw a complete waveform...') spikes_all, _ = detect.remove_incomplete_waveforms( spikes_all, CONFIG.spike_size + CONFIG.templates.max_shift, bp.reader._n_observations) # get scores scores = np.concatenate([element[0] for element in res], axis=0) logger.info('Removing scores for indexes outside the allowed range to ' 'draw a complete waveform...') scores = scores[idx] # transform scores to location + shape feature space # TODO: move this to another place if CONFIG.cluster.method == 'location': threshold = 2 scores = get_locations_features(scores, rotation, clear[:, 1], CONFIG.channel_index, CONFIG.geom, threshold) idx_nan = np.where(np.isnan(np.sum(scores, axis=(1, 2))))[0] scores = np.delete(scores, idx_nan, 0) clear = np.delete(clear, idx_nan, 0) # save partial results if required if save_results: # save clear spikes np.save(path_to_spike_index_clear, clear) logger.info('Saved spike index clear in {}...'.format( path_to_spike_index_clear)) # save all ppikes np.save(path_to_spike_index_all, spikes_all) logger.info('Saved spike index all in {}...'.format( path_to_spike_index_all)) # save rotation np.save(path_to_rotation, rotation) logger.info( 'Saved rotation matrix in {}...'.format(path_to_rotation)) # saves scores np.save(path_to_score, scores) logger.info('Saved spike scores in {}...'.format(path_to_score)) elif if_file_exists == 'abort' and any(exists): conflict = [p for p, e in zip(paths, exists) if e] message = reduce(lambda x, y: str(x) + ', ' + str(y), conflict) raise ValueError('if_file_exists was set to abort, the ' 'program halted since the following files ' 'already exist: {}'.format(message)) elif if_file_exists == 'skip' and all(exists): logger.warning('Skipped execution. All output files exist' ', loading them...') scores = np.load(path_to_score) clear = np.load(path_to_spike_index_clear) spikes_all = np.load(path_to_spike_index_all) else: raise ValueError( 'Invalid value for if_file_exists {}' 'must be one of overwrite, abort or skip'.format(if_file_exists)) return scores, clear, spikes_all
def test_splitting_in_batches_does_not_affect(path_to_tests, path_to_standarized_data, path_to_sample_pipeline_folder): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() PATH_TO_DATA = path_to_standarized_data data = RecordingsReader(PATH_TO_DATA, loader='array').data with open( path.join(path_to_sample_pipeline_folder, 'preprocess', 'standarized.yaml')) as f: PARAMS = yaml.load(f) channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom) detection_th = CONFIG.detect.neural_network_detector.threshold_spike triage_th = CONFIG.detect.neural_network_triage.threshold_collision detection_fname = CONFIG.detect.neural_network_detector.filename ae_fname = CONFIG.detect.neural_network_autoencoder.filename triage_fname = CONFIG.detect.neural_network_triage.filename # instantiate neural networks NND = NeuralNetDetector.load(detection_fname, detection_th, channel_index) NNT = NeuralNetTriage.load(triage_fname, triage_th, input_tensor=NND.waveform_tf) NNAE = AutoEncoder(ae_fname, input_tensor=NND.waveform_tf) output_tf = (NNAE.score_tf, NND.spike_index_tf, NNT.idx_clean) # run all at once with tf.Session() as sess: # get values of above tensors NND.restore(sess) NNAE.restore(sess) NNT.restore(sess) rot = NNAE.load_rotation() neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) (scores, clear, collision) = neuralnetwork.run_detect_triage_featurize( data, sess, NND.x_tf, output_tf, neighbors, rot) # run in batches - buffer size makes sure we can detect spikes if they # appear at the end of any batch bp = BatchProcessor(PATH_TO_DATA, PARAMS['dtype'], PARAMS['n_channels'], PARAMS['data_order'], '100KB', buffer_size=CONFIG.spike_size) with tf.Session() as sess: # get values of above tensors NND.restore(sess) NNAE.restore(sess) NNT.restore(sess) rot = NNAE.load_rotation() neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) res = bp.multi_channel_apply( neuralnetwork.run_detect_triage_featurize, mode='memory', cleanup_function=neuralnetwork.fix_indexes, sess=sess, x_tf=NND.x_tf, output_tf=output_tf, rot=rot, neighbors=neighbors) scores_batch = np.concatenate([element[0] for element in res], axis=0) clear_batch = np.concatenate([element[1] for element in res], axis=0) collision_batch = np.concatenate([element[2] for element in res], axis=0) np.testing.assert_array_equal(clear_batch, clear) np.testing.assert_array_equal(collision_batch, collision) np.testing.assert_array_equal(scores_batch, scores)