def test_can_use_neural_network_detector(path_to_tests, path_to_standarized_data): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() data = RecordingsReader(path_to_standarized_data, loader='array').data channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom) detection_th = CONFIG.detect.neural_network_detector.threshold_spike triage_th = CONFIG.detect.neural_network_triage.threshold_collision detection_fname = CONFIG.detect.neural_network_detector.filename ae_fname = CONFIG.detect.neural_network_autoencoder.filename triage_fname = CONFIG.detect.neural_network_triage.filename # instantiate neural networks NND = NeuralNetDetector.load(detection_fname, detection_th, channel_index) NNT = NeuralNetTriage.load(triage_fname, triage_th, input_tensor=NND.waveform_tf) NNAE = AutoEncoder(ae_fname, input_tensor=NND.waveform_tf) output_tf = (NNAE.score_tf, NND.spike_index_tf, NNT.idx_clean) with tf.Session() as sess: NND.restore(sess) NNAE.restore(sess) NNT.restore(sess) rot = NNAE.load_rotation() neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) neuralnetwork.run_detect_triage_featurize(data, sess, NND.x_tf, output_tf, neighbors, rot)
def train(CONFIG, CONFIG_TRAIN, spike_train, data_folder): """ Train neural network Parameters ---------- CONFIG YASS configuration file CONFIG_TRAIN YASS Neural Network configuration file spike_train: numpy.ndarray Spike train, first column is spike index and second is main channel """ logger = logging.getLogger(__name__) chosen_templates = CONFIG_TRAIN['templates']['ids'] min_amp = CONFIG_TRAIN['templates']['minimum_amplitude'] nspikes = CONFIG_TRAIN['training']['n_spikes'] n_filters_detect = CONFIG_TRAIN['network_detector']['n_filters'] n_iter = CONFIG_TRAIN['training']['n_iterations'] n_batch = CONFIG_TRAIN['training']['n_batch'] l2_reg_scale = CONFIG_TRAIN['training']['l2_regularization_scale'] train_step_size = CONFIG_TRAIN['training']['step_size'] detectnet_name = './' + CONFIG_TRAIN['network_detector']['name'] + '.ckpt' n_filters_triage = CONFIG_TRAIN['network_triage']['n_filters'] triagenet_name = './' + CONFIG_TRAIN['network_triage']['name'] + '.ckpt' n_features = CONFIG_TRAIN['network_autoencoder']['n_features'] ae_name = './' + CONFIG_TRAIN['network_autoencoder']['name'] + '.ckpt' # generate training data for detection, triage and autoencoder logger.info('Generating training data...') (x_detect, y_detect, x_triage, y_triage, x_ae, y_ae) = make_training_data(CONFIG, spike_train, chosen_templates, min_amp, nspikes, data_folder=data_folder) # train detector NeuralNetDetector.train(x_detect, y_detect, n_filters_detect, n_iter, n_batch, l2_reg_scale, train_step_size, detectnet_name) # train triage NeuralNetTriage.train(x_triage, y_triage, n_filters_triage, n_iter, n_batch, l2_reg_scale, train_step_size, triagenet_name) # train autoencoder AutoEncoder.train(x_ae, y_ae, n_features, n_iter, n_batch, train_step_size, ae_name)
def run_neural_network(standardized_path, standardized_params, whiten_filter, output_directory, if_file_exists, save_results): """Run neural network detection and autoencoder dimensionality reduction Returns ------- scores Scores for all spikes spike_index_clear Spike indexes for clear spikes spike_index_all Spike indexes for all spikes """ logger = logging.getLogger(__name__) CONFIG = read_config() TMP_FOLDER = CONFIG.path_to_output_directory # check if all scores, clear and collision spikes exist.. path_to_score = os.path.join(TMP_FOLDER, 'scores_clear.npy') path_to_spike_index_clear = os.path.join(TMP_FOLDER, 'spike_index_clear.npy') path_to_spike_index_all = os.path.join(TMP_FOLDER, 'spike_index_all.npy') path_to_rotation = os.path.join(TMP_FOLDER, 'rotation.npy') path_to_standardized = os.path.join(TMP_FOLDER, 'preprocess', 'standarized.bin') paths = [path_to_score, path_to_spike_index_clear, path_to_spike_index_all] exists = [os.path.exists(p) for p in paths] if (if_file_exists == 'overwrite' or not any(exists)): max_memory = (CONFIG.resources.max_memory_gpu if GPU_ENABLED else CONFIG.resources.max_memory) # make tensorflow tensors and neural net classes detection_th = CONFIG.detect.neural_network_detector.threshold_spike triage_th = CONFIG.detect.neural_network_triage.threshold_collision detection_fname = CONFIG.detect.neural_network_detector.filename ae_fname = CONFIG.detect.neural_network_autoencoder.filename triage_fname = CONFIG.detect.neural_network_triage.filename n_channels = CONFIG.recordings.n_channels # open tensorflow for every chunk NND = NeuralNetDetector.load(detection_fname, detection_th, CONFIG.channel_index) NNAE = AutoEncoder.load(ae_fname, input_tensor=NND.waveform_tf) # run nn preprocess batch-wsie neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) # compute len of recording filename_dat = os.path.join(CONFIG.data.root_folder, CONFIG.data.recordings) fp = np.memmap(filename_dat, dtype='int16', mode='r') fp_len = fp.shape[0] / n_channels # compute batch indexes buffer_size = 200 # Cat: to set this in CONFIG file sampling_rate = CONFIG.recordings.sampling_rate #n_sec_chunk = CONFIG.resources.n_sec_chunk # Cat: TODO: Set a different size chunk for clustering vs. detection n_sec_chunk = 60 # take chunks indexes = np.arange(0, fp_len, sampling_rate * n_sec_chunk) # add last bit of recording if it's shorter if indexes[-1] != fp_len: indexes = np.hstack((indexes, fp_len)) idx_list = [] for k in range(len(indexes) - 1): idx_list.append([ indexes[k], indexes[k + 1], buffer_size, indexes[k + 1] - indexes[k] + buffer_size ]) idx_list = np.int64(np.vstack(idx_list))[:20] #idx_list = idx_list logger.info("# of chunks: %i", len(idx_list)) logger.info(idx_list) # run tensorflow processing_ctr = 0 #chunk_ctr = 0 # chunks to cycle over are 10 x as much as initial chosen data total_processing = len(idx_list) * n_sec_chunk # keep tensorflow open # save iteratively fname_detection = os.path.join(CONFIG.path_to_output_directory, 'detect') if not os.path.exists(fname_detection): os.mkdir(fname_detection) # set tensorflow verbosity level tf.logging.set_verbosity(tf.logging.ERROR) # open etsnrflow session with tf.Session() as sess: #K.set_session(sess) NND.restore(sess) #triage = KerasModel(triage_fname, # allow_longer_waveform_length=True, # allow_more_channels=True) # read chunks of data first: # read chunk of raw standardized data # Cat: TODO: don't save to lists, might want to use numpy arrays directl #print (os.path.join(fname_detection,"detect_"+ # str(chunk_ctr).zfill(5)+'.npz')) # loop over 10sec or 60 sec chunks for chunk_ctr, idx in enumerate(idx_list): if os.path.exists( os.path.join( fname_detection, "detect_" + str(chunk_ctr).zfill(5) + '.npz')): continue # reset lists spike_index_list = [] #idx_clean_list = [] energy_list = [] TC_list = [] offset_list = [] # load chunk of data standardized_recording = binary_reader(idx, buffer_size, path_to_standardized, n_channels, CONFIG.data.root_folder) # run detection on smaller chunks of data, e.g. 1 sec # Cat: TODO: add last bit at end in case short indexes = np.arange(0, standardized_recording.shape[0], sampling_rate) # run tensorflow over 1sec chunks in general for ctr, index in enumerate(indexes[:-1]): # save absolute index of each subchunk offset_list.append(idx[0] + indexes[ctr]) data_temp = standardized_recording[indexes[ctr]:indexes[ctr + 1]] # store size of recordings in case at end of dataset. TC_list.append(data_temp.shape) # run detect nn res = NND.predict_recording(data_temp, sess=sess, output_names=('spike_index', 'waveform')) spike_index, wfs = res #idx_clean = (triage # .predict_with_threshold(x=wfs, # threshold=triage_th)) score = NNAE.predict(wfs, sess) rot = NNAE.load_rotation(sess) neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) # idx_clean is the indexes of clear spikes in all_spikes spike_index_list.append(spike_index) #idx_clean_list.append(idx_clean) # run AE nn; required for remove_axon function # Cat: TODO: Do we really need this: can we get energy list faster? #rot = NNAE.load_rotation() energy_ = np.ptp(np.matmul(score[:, :, 0], rot.T), axis=1) energy_list.append(energy_) logger.info('processed chunk: %s/%s, # spikes: %s', str(processing_ctr), str(total_processing), spike_index.shape) processing_ctr += 1 # save chunk of data in case crashes occur np.savez(os.path.join(fname_detection, "detect_" + str(chunk_ctr).zfill(5)), spike_index_list=spike_index_list, energy_list=energy_list, TC_list=TC_list, offset_list=offset_list) # load all saved data; spike_index_list = [] energy_list = [] TC_list = [] offset_list = [] for ctr, idx in enumerate(idx_list): data = np.load(fname_detection + '/detect_' + str(ctr).zfill(5) + '.npz') spike_index_list.extend(data['spike_index_list']) energy_list.extend(data['energy_list']) TC_list.extend(data['TC_list']) offset_list.extend(data['offset_list']) # save all detected spikes pre axon_kill spike_index_all_pre_deduplication = fix_indexes_firstbatch_3( spike_index_list, offset_list, buffer_size, sampling_rate) np.save( os.path.join(TMP_FOLDER, 'spike_index_all_pre_deduplication.npy'), spike_index_all_pre_deduplication) # remove axons - compute axons to be killed logger.info(' removing axons in parallel') multi_procesing = 1 if CONFIG.resources.multi_processing: keep = parmap.map(deduplicate, list( zip(spike_index_list, energy_list, TC_list, np.arange(len(energy_list)))), neighbors, processes=CONFIG.resources.n_processors, pm_pbar=True) else: keep = [] for k in range(len(energy_list)): keep.append( deduplicate( (spike_index_list[k], energy_list[k], TC_list[k], k), neighbors)) # Cat: TODO Note that we're killing spike_index_all as well. # remove axons from clear spikes - keep only non-killed+clean events spike_index_all_postkill = [] for k in range(len(spike_index_list)): spike_index_all_postkill.append(spike_index_list[k][keep[k][0]]) logger.info(' fixing indexes from batching') spike_index_all = fix_indexes_firstbatch_3(spike_index_all_postkill, offset_list, buffer_size, sampling_rate) # get and clean all spikes spikes_all = spike_index_all #logger.info('Removing all indexes outside the allowed range to ' # 'draw a complete waveform...') _n_observations = fp_len spikes_all, _ = detect.remove_incomplete_waveforms( spikes_all, CONFIG.spike_size + CONFIG.templates.max_shift, _n_observations) np.save(os.path.join(TMP_FOLDER, 'spike_index_all.npy'), spikes_all) else: spikes_all = np.load(os.path.join(TMP_FOLDER, 'spike_index_all.npy')) return spikes_all
def run_neural_network(standarized_path, standarized_params, whiten_filter, output_directory, if_file_exists, save_results): """Run neural network detection and autoencoder dimensionality reduction Returns ------- scores Scores for all spikes spike_index_clear Spike indexes for clear spikes spike_index_all Spike indexes for all spikes """ logger = logging.getLogger(__name__) CONFIG = read_config() folder = Path(CONFIG.data.root_folder, output_directory, 'detect') folder.mkdir(exist_ok=True) TMP_FOLDER = str(folder) # check if all scores, clear and collision spikes exist.. path_to_score = os.path.join(TMP_FOLDER, 'scores_clear.npy') path_to_spike_index_clear = os.path.join(TMP_FOLDER, 'spike_index_clear.npy') path_to_spike_index_all = os.path.join(TMP_FOLDER, 'spike_index_all.npy') path_to_rotation = os.path.join(TMP_FOLDER, 'rotation.npy') paths = [path_to_score, path_to_spike_index_clear, path_to_spike_index_all] exists = [os.path.exists(p) for p in paths] if (if_file_exists == 'overwrite' or if_file_exists == 'abort' and not any(exists) or if_file_exists == 'skip' and not all(exists)): max_memory = (CONFIG.resources.max_memory_gpu if GPU_ENABLED else CONFIG.resources.max_memory) # instantiate batch processor bp = BatchProcessor(standarized_path, standarized_params['dtype'], standarized_params['n_channels'], standarized_params['data_order'], max_memory, buffer_size=CONFIG.spike_size) # load parameters detection_th = CONFIG.detect.neural_network_detector.threshold_spike triage_th = CONFIG.detect.neural_network_triage.threshold_collision detection_fname = CONFIG.detect.neural_network_detector.filename ae_fname = CONFIG.detect.neural_network_autoencoder.filename triage_fname = CONFIG.detect.neural_network_triage.filename # instantiate neural networks NND = NeuralNetDetector.load(detection_fname, detection_th, CONFIG.channel_index) NNT = NeuralNetTriage.load(triage_fname, triage_th, input_tensor=NND.waveform_tf) NNAE = AutoEncoder(ae_fname, input_tensor=NND.waveform_tf) neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) rotation = NNAE.load_rotation() # gather all output tensors output_tf = (NNAE.score_tf, NND.spike_index_tf, NNT.idx_clean) # run detection with tf.Session() as sess: # get values of above tensors NND.restore(sess) NNAE.restore(sess) NNT.restore(sess) mc = bp.multi_channel_apply res = mc(neuralnetwork.run_detect_triage_featurize, mode='memory', cleanup_function=neuralnetwork.fix_indexes, sess=sess, x_tf=NND.x_tf, output_tf=output_tf, rot=rotation, neighbors=neighbors) # get clear spikes clear = np.concatenate([element[1] for element in res], axis=0) logger.info('Removing clear indexes outside the allowed range to ' 'draw a complete waveform...') clear, idx = detect.remove_incomplete_waveforms( clear, CONFIG.spike_size + CONFIG.templates.max_shift, bp.reader._n_observations) # get all spikes spikes_all = np.concatenate([element[2] for element in res], axis=0) logger.info('Removing indexes outside the allowed range to ' 'draw a complete waveform...') spikes_all, _ = detect.remove_incomplete_waveforms( spikes_all, CONFIG.spike_size + CONFIG.templates.max_shift, bp.reader._n_observations) # get scores scores = np.concatenate([element[0] for element in res], axis=0) logger.info('Removing scores for indexes outside the allowed range to ' 'draw a complete waveform...') scores = scores[idx] # transform scores to location + shape feature space # TODO: move this to another place if CONFIG.cluster.method == 'location': threshold = 2 scores = get_locations_features(scores, rotation, clear[:, 1], CONFIG.channel_index, CONFIG.geom, threshold) idx_nan = np.where(np.isnan(np.sum(scores, axis=(1, 2))))[0] scores = np.delete(scores, idx_nan, 0) clear = np.delete(clear, idx_nan, 0) # save partial results if required if save_results: # save clear spikes np.save(path_to_spike_index_clear, clear) logger.info('Saved spike index clear in {}...'.format( path_to_spike_index_clear)) # save all ppikes np.save(path_to_spike_index_all, spikes_all) logger.info('Saved spike index all in {}...'.format( path_to_spike_index_all)) # save rotation np.save(path_to_rotation, rotation) logger.info( 'Saved rotation matrix in {}...'.format(path_to_rotation)) # saves scores np.save(path_to_score, scores) logger.info('Saved spike scores in {}...'.format(path_to_score)) elif if_file_exists == 'abort' and any(exists): conflict = [p for p, e in zip(paths, exists) if e] message = reduce(lambda x, y: str(x) + ', ' + str(y), conflict) raise ValueError('if_file_exists was set to abort, the ' 'program halted since the following files ' 'already exist: {}'.format(message)) elif if_file_exists == 'skip' and all(exists): logger.warning('Skipped execution. All output files exist' ', loading them...') scores = np.load(path_to_score) clear = np.load(path_to_spike_index_clear) spikes_all = np.load(path_to_spike_index_all) else: raise ValueError( 'Invalid value for if_file_exists {}' 'must be one of overwrite, abort or skip'.format(if_file_exists)) return scores, clear, spikes_all
def prepare_nn(channel_index, whiten_filter, threshold_detect, threshold_triage, detector_filename, autoencoder_filename, triage_filename): """Prepare neural net tensors in advance. This is to effciently run neural net with batch processor as we don't have to recreate tf tensors in every batch Parameters ---------- channel_index: np.array (n_channels, n_neigh) Each row indexes its neighboring channels. For example, channel_index[c] is the index of neighboring channels (including itself) If any value is equal to n_channels, it is nothing but a space holder in a case that a channel has less than n_neigh neighboring channels whiten_filter: numpy.ndarray (n_channels, n_neigh, n_neigh) whitening matrix such that whiten_filter[c] is the whitening filter of channel c and its neighboring channel determined from channel_index. threshold_detect: int threshold for neural net detection threshold_triage: int threshold for neural net triage detector_filename: str location of trained neural net detectior autoencoder_filename: str location of trained neural net autoencoder triage_filename: str location of trained neural net triage Returns ------- x_tf: tf.tensors (n_observations, n_channels) placeholder of recording for running tensorflow output_tf: tuple of tf.tensors a tuple of tensorflow tensors that produce score, spike_index_clear, and spike_index_collision NND: class an instance of class, NeuralNetDetector NNT: class an instance of class, NeuralNetTriage """ # placeholder for input recording x_tf = tf.placeholder("float", [None, None]) # load Neural Net's NND = NeuralNetDetector(detector_filename) NNAE = AutoEncoder(autoencoder_filename) NNT = NeuralNetTriage(triage_filename) # make spike_index tensorflow tensor spike_index_tf_all = NND.make_detection_tf_tensors(x_tf, channel_index, threshold_detect) # remove edge spike time spike_index_tf = remove_edge_spikes(x_tf, spike_index_tf_all, NND.filters_dict['size']) # make waveform tensorflow tensor waveform_tf = make_waveform_tf_tensor(x_tf, spike_index_tf, channel_index, NND.filters_dict['size']) # make score tensorflow tensor from waveform score_tf = NNAE.make_score_tf_tensor(waveform_tf) # run neural net triage nneigh = NND.filters_dict['n_neighbors'] idx_clean = NNT.triage_wf(waveform_tf[:, :, :nneigh], threshold_triage) # gather all output tensors output_tf = (score_tf, spike_index_tf, idx_clean) return x_tf, output_tf, NND, NNAE, NNT
def test_splitting_in_batches_does_not_affect(path_to_config, path_to_sample_pipeline_folder, make_tmp_folder, path_to_standardized_data): yass.set_config(path_to_config, make_tmp_folder) CONFIG = yass.read_config() PATH_TO_DATA = path_to_standardized_data with open(path.join(path_to_sample_pipeline_folder, 'preprocess', 'standardized.yaml')) as f: PARAMS = yaml.load(f) channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom) detection_th = CONFIG.detect.neural_network_detector.threshold_spike triage_th = CONFIG.detect.neural_network_triage.threshold_collision detection_fname = CONFIG.detect.neural_network_detector.filename ae_fname = CONFIG.detect.neural_network_autoencoder.filename triage_fname = CONFIG.detect.neural_network_triage.filename # instantiate neural networks NND = NeuralNetDetector.load(detection_fname, detection_th, channel_index) triage = KerasModel(triage_fname, allow_longer_waveform_length=True, allow_more_channels=True) NNAE = AutoEncoder.load(ae_fname, input_tensor=NND.waveform_tf) bp = BatchProcessor(PATH_TO_DATA, PARAMS['dtype'], PARAMS['n_channels'], PARAMS['data_order'], '100KB', buffer_size=CONFIG.spike_size) out = ('spike_index', 'waveform') fn = neuralnetwork.apply.fix_indexes_spike_index # detector with tf.Session() as sess: # get values of above tensors NND.restore(sess) res = bp.multi_channel_apply(NND.predict_recording, mode='memory', sess=sess, output_names=out, cleanup_function=fn) spike_index_new = np.concatenate([element[0] for element in res], axis=0) wfs = np.concatenate([element[1] for element in res], axis=0) idx_clean = triage.predict_with_threshold(wfs, triage_th) score = NNAE.predict(wfs) rot = NNAE.load_rotation() neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) (score_clear_new, spike_index_clear_new) = post_processing(score, spike_index_new, idx_clean, rot, neighbors) with tf.Session() as sess: # get values of above tensors NND.restore(sess) res = bp.multi_channel_apply(NND.predict_recording, mode='memory', sess=sess, output_names=('spike_index', 'waveform'), cleanup_function=fn) spike_index_batch, wfs = zip(*res) spike_index_batch = np.concatenate(spike_index_batch, axis=0) wfs = np.concatenate(wfs, axis=0) idx_clean = triage.predict_with_threshold(x=wfs, threshold=triage_th) score = NNAE.predict(wfs) rot = NNAE.load_rotation() neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) (score_clear_batch, spike_index_clear_batch) = post_processing(score, spike_index_batch, idx_clean, rot, neighbors)
def test_splitting_in_batches_does_not_affect(path_to_tests, path_to_standarized_data, path_to_sample_pipeline_folder): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() PATH_TO_DATA = path_to_standarized_data data = RecordingsReader(PATH_TO_DATA, loader='array').data with open( path.join(path_to_sample_pipeline_folder, 'preprocess', 'standarized.yaml')) as f: PARAMS = yaml.load(f) channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom) detection_th = CONFIG.detect.neural_network_detector.threshold_spike triage_th = CONFIG.detect.neural_network_triage.threshold_collision detection_fname = CONFIG.detect.neural_network_detector.filename ae_fname = CONFIG.detect.neural_network_autoencoder.filename triage_fname = CONFIG.detect.neural_network_triage.filename # instantiate neural networks NND = NeuralNetDetector.load(detection_fname, detection_th, channel_index) NNT = NeuralNetTriage.load(triage_fname, triage_th, input_tensor=NND.waveform_tf) NNAE = AutoEncoder(ae_fname, input_tensor=NND.waveform_tf) output_tf = (NNAE.score_tf, NND.spike_index_tf, NNT.idx_clean) # run all at once with tf.Session() as sess: # get values of above tensors NND.restore(sess) NNAE.restore(sess) NNT.restore(sess) rot = NNAE.load_rotation() neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) (scores, clear, collision) = neuralnetwork.run_detect_triage_featurize( data, sess, NND.x_tf, output_tf, neighbors, rot) # run in batches - buffer size makes sure we can detect spikes if they # appear at the end of any batch bp = BatchProcessor(PATH_TO_DATA, PARAMS['dtype'], PARAMS['n_channels'], PARAMS['data_order'], '100KB', buffer_size=CONFIG.spike_size) with tf.Session() as sess: # get values of above tensors NND.restore(sess) NNAE.restore(sess) NNT.restore(sess) rot = NNAE.load_rotation() neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) res = bp.multi_channel_apply( neuralnetwork.run_detect_triage_featurize, mode='memory', cleanup_function=neuralnetwork.fix_indexes, sess=sess, x_tf=NND.x_tf, output_tf=output_tf, rot=rot, neighbors=neighbors) scores_batch = np.concatenate([element[0] for element in res], axis=0) clear_batch = np.concatenate([element[1] for element in res], axis=0) collision_batch = np.concatenate([element[2] for element in res], axis=0) np.testing.assert_array_equal(clear_batch, clear) np.testing.assert_array_equal(collision_batch, collision) np.testing.assert_array_equal(scores_batch, scores)