def train(CONFIG, CONFIG_TRAIN, spike_train, data_folder): """ Train neural network Parameters ---------- CONFIG YASS configuration file CONFIG_TRAIN YASS Neural Network configuration file spike_train: numpy.ndarray Spike train, first column is spike index and second is main channel """ logger = logging.getLogger(__name__) chosen_templates = CONFIG_TRAIN['templates']['ids'] min_amp = CONFIG_TRAIN['templates']['minimum_amplitude'] nspikes = CONFIG_TRAIN['training']['n_spikes'] n_filters_detect = CONFIG_TRAIN['network_detector']['n_filters'] n_iter = CONFIG_TRAIN['training']['n_iterations'] n_batch = CONFIG_TRAIN['training']['n_batch'] l2_reg_scale = CONFIG_TRAIN['training']['l2_regularization_scale'] train_step_size = CONFIG_TRAIN['training']['step_size'] detectnet_name = './' + CONFIG_TRAIN['network_detector']['name'] + '.ckpt' n_filters_triage = CONFIG_TRAIN['network_triage']['n_filters'] triagenet_name = './' + CONFIG_TRAIN['network_triage']['name'] + '.ckpt' n_features = CONFIG_TRAIN['network_autoencoder']['n_features'] ae_name = './' + CONFIG_TRAIN['network_autoencoder']['name'] + '.ckpt' # generate training data for detection, triage and autoencoder logger.info('Generating training data...') (x_detect, y_detect, x_triage, y_triage, x_ae, y_ae) = make_training_data(CONFIG, spike_train, chosen_templates, min_amp, nspikes, data_folder=data_folder) # train detector NeuralNetDetector.train(x_detect, y_detect, n_filters_detect, n_iter, n_batch, l2_reg_scale, train_step_size, detectnet_name) # train triage NeuralNetTriage.train(x_triage, y_triage, n_filters_triage, n_iter, n_batch, l2_reg_scale, train_step_size, triagenet_name) # train autoencoder AutoEncoder.train(x_ae, y_ae, n_features, n_iter, n_batch, train_step_size, ae_name)
def test_can_use_neural_network_detector(path_to_tests, path_to_standarized_data): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() data = RecordingsReader(path_to_standarized_data, loader='array').data channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom) detection_th = CONFIG.detect.neural_network_detector.threshold_spike triage_th = CONFIG.detect.neural_network_triage.threshold_collision detection_fname = CONFIG.detect.neural_network_detector.filename ae_fname = CONFIG.detect.neural_network_autoencoder.filename triage_fname = CONFIG.detect.neural_network_triage.filename # instantiate neural networks NND = NeuralNetDetector.load(detection_fname, detection_th, channel_index) NNT = NeuralNetTriage.load(triage_fname, triage_th, input_tensor=NND.waveform_tf) NNAE = AutoEncoder(ae_fname, input_tensor=NND.waveform_tf) output_tf = (NNAE.score_tf, NND.spike_index_tf, NNT.idx_clean) with tf.Session() as sess: NND.restore(sess) NNAE.restore(sess) NNT.restore(sess) rot = NNAE.load_rotation() neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) neuralnetwork.run_detect_triage_featurize(data, sess, NND.x_tf, output_tf, neighbors, rot)
def test_can_use_detector_and_triage_after_fit(path_to_tests, path_to_sample_pipeline_folder, tmp_folder, path_to_standarized_data): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() (x_detect, y_detect, x_triage, y_triage, x_ae, y_ae) = make_training_data(CONFIG, spike_train, chosen_templates, min_amplitude, n_spikes, path_to_sample_pipeline_folder) _, waveform_length, n_neighbors = x_detect.shape path_to_model = path.join(tmp_folder, 'detect-net.ckpt') detector = NeuralNetDetector(path_to_model, filters, waveform_length, n_neighbors, threshold=0.5, channel_index=CONFIG.channel_index, n_iter=10) detector.fit(x_detect, y_detect) triage = NeuralNetTriage(path_to_model, filters, waveform_length, n_neighbors, threshold=0.5, n_iter=10) triage.fit(x_detect, y_detect) data = RecordingExplorer(path_to_standarized_data).reader.data output_names = ('spike_index', 'waveform', 'probability') (spike_index, waveform, proba) = detector.predict(data, output_names=output_names) triage.predict(waveform[:, :, :n_neighbors])
def get_o_layer(standarized_path, standarized_params, output_directory='tmp/', output_dtype='float32', output_filename='o_layer.bin', if_file_exists='skip', save_partial_results=False): """Get the output of NN detector instead of outputting the spike index """ CONFIG = read_config() channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom, 1) x_tf = tf.placeholder("float", [None, None]) # load Neural Net's detection_fname = CONFIG.detect.neural_network_detector.filename detection_th = CONFIG.detect.neural_network_detector.threshold_spike NND = NeuralNetDetector(detection_fname) o_layer_tf = NND.make_o_layer_tf_tensors(x_tf, channel_index, detection_th) bp = BatchProcessor(standarized_path, standarized_params['dtype'], standarized_params['n_channels'], standarized_params['data_format'], CONFIG.resources.max_memory, buffer_size=CONFIG.spike_size) TMP = os.path.join(CONFIG.data.root_folder, output_directory) _output_path = os.path.join(TMP, output_filename) (o_path, o_params) = bp.multi_channel_apply(_get_o_layer, mode='disk', cleanup_function=fix_indexes, output_path=_output_path, cast_dtype=output_dtype, x_tf=x_tf, o_layer_tf=o_layer_tf, NND=NND) return o_path, o_params
def test_can_reload_detector(path_to_tests, path_to_sample_pipeline_folder, tmp_folder): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() (x_detect, y_detect, x_triage, y_triage, x_ae, y_ae) = make_training_data(CONFIG, spike_train, chosen_templates, min_amplitude, n_spikes, path_to_sample_pipeline_folder) _, waveform_length, n_neighbors = x_detect.shape path_to_model = path.join(tmp_folder, 'detect-net.ckpt') detector = NeuralNetDetector(path_to_model, filters, waveform_length, n_neighbors, threshold=0.5, channel_index=CONFIG.channel_index, n_iter=10) detector.fit(x_detect, y_detect) NeuralNetDetector.load(path_to_model, threshold=0.5, channel_index=CONFIG.channel_index)
def load_rotation(detector_filename, autoencoder_filename): """ Load neural network rotation matrix """ # FIXME: this function should not ask for detector_filename, it is not # needed nnd = NeuralNetDetector(detector_filename, autoencoder_filename) with tf.Session() as sess: nnd.saver_ae.restore(sess, nnd.path_to_ae_model) rotation = sess.run(nnd.W_ae) return rotation
def test_can_train_detector(path_to_config, path_to_sample_pipeline_folder, make_tmp_folder): yass.set_config(path_to_config, make_tmp_folder) CONFIG = yass.read_config() spike_train = np.load(path.join(path_to_sample_pipeline_folder, 'spike_train.npy')) chosen_templates = np.unique(spike_train[:, 1]) min_amplitude = 4 max_amplitude = 60 n_spikes_to_make = 100 templates = make.load_templates(path_to_sample_pipeline_folder, spike_train, CONFIG, chosen_templates) path_to_standardized = path.join(path_to_sample_pipeline_folder, 'preprocess', 'standarized.bin') (x_detect, y_detect, x_triage, y_triage, x_ae, y_ae) = make.training_data(CONFIG, templates, min_amplitude, max_amplitude, n_spikes_to_make, path_to_standardized) _, waveform_length, n_neighbors = x_detect.shape path_to_model = path.join(make_tmp_folder, 'detect-net.ckpt') detector = NeuralNetDetector(path_to_model, [8, 4], waveform_length, n_neighbors, threshold=0.5, channel_index=CONFIG.channel_index, n_iter=10) detector.fit(x_detect, y_detect)
def run_neural_network(standardized_path, standardized_params, whiten_filter, output_directory, if_file_exists, save_results): """Run neural network detection and autoencoder dimensionality reduction Returns ------- scores Scores for all spikes spike_index_clear Spike indexes for clear spikes spike_index_all Spike indexes for all spikes """ logger = logging.getLogger(__name__) CONFIG = read_config() TMP_FOLDER = CONFIG.path_to_output_directory # check if all scores, clear and collision spikes exist.. path_to_score = os.path.join(TMP_FOLDER, 'scores_clear.npy') path_to_spike_index_clear = os.path.join(TMP_FOLDER, 'spike_index_clear.npy') path_to_spike_index_all = os.path.join(TMP_FOLDER, 'spike_index_all.npy') path_to_rotation = os.path.join(TMP_FOLDER, 'rotation.npy') path_to_standardized = os.path.join(TMP_FOLDER, 'preprocess', 'standarized.bin') paths = [path_to_score, path_to_spike_index_clear, path_to_spike_index_all] exists = [os.path.exists(p) for p in paths] if (if_file_exists == 'overwrite' or not any(exists)): max_memory = (CONFIG.resources.max_memory_gpu if GPU_ENABLED else CONFIG.resources.max_memory) # make tensorflow tensors and neural net classes detection_th = CONFIG.detect.neural_network_detector.threshold_spike triage_th = CONFIG.detect.neural_network_triage.threshold_collision detection_fname = CONFIG.detect.neural_network_detector.filename ae_fname = CONFIG.detect.neural_network_autoencoder.filename triage_fname = CONFIG.detect.neural_network_triage.filename n_channels = CONFIG.recordings.n_channels # open tensorflow for every chunk NND = NeuralNetDetector.load(detection_fname, detection_th, CONFIG.channel_index) NNAE = AutoEncoder.load(ae_fname, input_tensor=NND.waveform_tf) # run nn preprocess batch-wsie neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) # compute len of recording filename_dat = os.path.join(CONFIG.data.root_folder, CONFIG.data.recordings) fp = np.memmap(filename_dat, dtype='int16', mode='r') fp_len = fp.shape[0] / n_channels # compute batch indexes buffer_size = 200 # Cat: to set this in CONFIG file sampling_rate = CONFIG.recordings.sampling_rate #n_sec_chunk = CONFIG.resources.n_sec_chunk # Cat: TODO: Set a different size chunk for clustering vs. detection n_sec_chunk = 60 # take chunks indexes = np.arange(0, fp_len, sampling_rate * n_sec_chunk) # add last bit of recording if it's shorter if indexes[-1] != fp_len: indexes = np.hstack((indexes, fp_len)) idx_list = [] for k in range(len(indexes) - 1): idx_list.append([ indexes[k], indexes[k + 1], buffer_size, indexes[k + 1] - indexes[k] + buffer_size ]) idx_list = np.int64(np.vstack(idx_list))[:20] #idx_list = idx_list logger.info("# of chunks: %i", len(idx_list)) logger.info(idx_list) # run tensorflow processing_ctr = 0 #chunk_ctr = 0 # chunks to cycle over are 10 x as much as initial chosen data total_processing = len(idx_list) * n_sec_chunk # keep tensorflow open # save iteratively fname_detection = os.path.join(CONFIG.path_to_output_directory, 'detect') if not os.path.exists(fname_detection): os.mkdir(fname_detection) # set tensorflow verbosity level tf.logging.set_verbosity(tf.logging.ERROR) # open etsnrflow session with tf.Session() as sess: #K.set_session(sess) NND.restore(sess) #triage = KerasModel(triage_fname, # allow_longer_waveform_length=True, # allow_more_channels=True) # read chunks of data first: # read chunk of raw standardized data # Cat: TODO: don't save to lists, might want to use numpy arrays directl #print (os.path.join(fname_detection,"detect_"+ # str(chunk_ctr).zfill(5)+'.npz')) # loop over 10sec or 60 sec chunks for chunk_ctr, idx in enumerate(idx_list): if os.path.exists( os.path.join( fname_detection, "detect_" + str(chunk_ctr).zfill(5) + '.npz')): continue # reset lists spike_index_list = [] #idx_clean_list = [] energy_list = [] TC_list = [] offset_list = [] # load chunk of data standardized_recording = binary_reader(idx, buffer_size, path_to_standardized, n_channels, CONFIG.data.root_folder) # run detection on smaller chunks of data, e.g. 1 sec # Cat: TODO: add last bit at end in case short indexes = np.arange(0, standardized_recording.shape[0], sampling_rate) # run tensorflow over 1sec chunks in general for ctr, index in enumerate(indexes[:-1]): # save absolute index of each subchunk offset_list.append(idx[0] + indexes[ctr]) data_temp = standardized_recording[indexes[ctr]:indexes[ctr + 1]] # store size of recordings in case at end of dataset. TC_list.append(data_temp.shape) # run detect nn res = NND.predict_recording(data_temp, sess=sess, output_names=('spike_index', 'waveform')) spike_index, wfs = res #idx_clean = (triage # .predict_with_threshold(x=wfs, # threshold=triage_th)) score = NNAE.predict(wfs, sess) rot = NNAE.load_rotation(sess) neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) # idx_clean is the indexes of clear spikes in all_spikes spike_index_list.append(spike_index) #idx_clean_list.append(idx_clean) # run AE nn; required for remove_axon function # Cat: TODO: Do we really need this: can we get energy list faster? #rot = NNAE.load_rotation() energy_ = np.ptp(np.matmul(score[:, :, 0], rot.T), axis=1) energy_list.append(energy_) logger.info('processed chunk: %s/%s, # spikes: %s', str(processing_ctr), str(total_processing), spike_index.shape) processing_ctr += 1 # save chunk of data in case crashes occur np.savez(os.path.join(fname_detection, "detect_" + str(chunk_ctr).zfill(5)), spike_index_list=spike_index_list, energy_list=energy_list, TC_list=TC_list, offset_list=offset_list) # load all saved data; spike_index_list = [] energy_list = [] TC_list = [] offset_list = [] for ctr, idx in enumerate(idx_list): data = np.load(fname_detection + '/detect_' + str(ctr).zfill(5) + '.npz') spike_index_list.extend(data['spike_index_list']) energy_list.extend(data['energy_list']) TC_list.extend(data['TC_list']) offset_list.extend(data['offset_list']) # save all detected spikes pre axon_kill spike_index_all_pre_deduplication = fix_indexes_firstbatch_3( spike_index_list, offset_list, buffer_size, sampling_rate) np.save( os.path.join(TMP_FOLDER, 'spike_index_all_pre_deduplication.npy'), spike_index_all_pre_deduplication) # remove axons - compute axons to be killed logger.info(' removing axons in parallel') multi_procesing = 1 if CONFIG.resources.multi_processing: keep = parmap.map(deduplicate, list( zip(spike_index_list, energy_list, TC_list, np.arange(len(energy_list)))), neighbors, processes=CONFIG.resources.n_processors, pm_pbar=True) else: keep = [] for k in range(len(energy_list)): keep.append( deduplicate( (spike_index_list[k], energy_list[k], TC_list[k], k), neighbors)) # Cat: TODO Note that we're killing spike_index_all as well. # remove axons from clear spikes - keep only non-killed+clean events spike_index_all_postkill = [] for k in range(len(spike_index_list)): spike_index_all_postkill.append(spike_index_list[k][keep[k][0]]) logger.info(' fixing indexes from batching') spike_index_all = fix_indexes_firstbatch_3(spike_index_all_postkill, offset_list, buffer_size, sampling_rate) # get and clean all spikes spikes_all = spike_index_all #logger.info('Removing all indexes outside the allowed range to ' # 'draw a complete waveform...') _n_observations = fp_len spikes_all, _ = detect.remove_incomplete_waveforms( spikes_all, CONFIG.spike_size + CONFIG.templates.max_shift, _n_observations) np.save(os.path.join(TMP_FOLDER, 'spike_index_all.npy'), spikes_all) else: spikes_all = np.load(os.path.join(TMP_FOLDER, 'spike_index_all.npy')) return spikes_all
def run_neural_network(standarized_path, standarized_params, whiten_filter, output_directory, if_file_exists, save_results): """Run neural network detection and autoencoder dimensionality reduction Returns ------- scores Scores for all spikes spike_index_clear Spike indexes for clear spikes spike_index_all Spike indexes for all spikes """ logger = logging.getLogger(__name__) CONFIG = read_config() folder = Path(CONFIG.data.root_folder, output_directory, 'detect') folder.mkdir(exist_ok=True) TMP_FOLDER = str(folder) # check if all scores, clear and collision spikes exist.. path_to_score = os.path.join(TMP_FOLDER, 'scores_clear.npy') path_to_spike_index_clear = os.path.join(TMP_FOLDER, 'spike_index_clear.npy') path_to_spike_index_all = os.path.join(TMP_FOLDER, 'spike_index_all.npy') path_to_rotation = os.path.join(TMP_FOLDER, 'rotation.npy') paths = [path_to_score, path_to_spike_index_clear, path_to_spike_index_all] exists = [os.path.exists(p) for p in paths] if (if_file_exists == 'overwrite' or if_file_exists == 'abort' and not any(exists) or if_file_exists == 'skip' and not all(exists)): max_memory = (CONFIG.resources.max_memory_gpu if GPU_ENABLED else CONFIG.resources.max_memory) # instantiate batch processor bp = BatchProcessor(standarized_path, standarized_params['dtype'], standarized_params['n_channels'], standarized_params['data_order'], max_memory, buffer_size=CONFIG.spike_size) # load parameters detection_th = CONFIG.detect.neural_network_detector.threshold_spike triage_th = CONFIG.detect.neural_network_triage.threshold_collision detection_fname = CONFIG.detect.neural_network_detector.filename ae_fname = CONFIG.detect.neural_network_autoencoder.filename triage_fname = CONFIG.detect.neural_network_triage.filename # instantiate neural networks NND = NeuralNetDetector.load(detection_fname, detection_th, CONFIG.channel_index) NNT = NeuralNetTriage.load(triage_fname, triage_th, input_tensor=NND.waveform_tf) NNAE = AutoEncoder(ae_fname, input_tensor=NND.waveform_tf) neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) rotation = NNAE.load_rotation() # gather all output tensors output_tf = (NNAE.score_tf, NND.spike_index_tf, NNT.idx_clean) # run detection with tf.Session() as sess: # get values of above tensors NND.restore(sess) NNAE.restore(sess) NNT.restore(sess) mc = bp.multi_channel_apply res = mc(neuralnetwork.run_detect_triage_featurize, mode='memory', cleanup_function=neuralnetwork.fix_indexes, sess=sess, x_tf=NND.x_tf, output_tf=output_tf, rot=rotation, neighbors=neighbors) # get clear spikes clear = np.concatenate([element[1] for element in res], axis=0) logger.info('Removing clear indexes outside the allowed range to ' 'draw a complete waveform...') clear, idx = detect.remove_incomplete_waveforms( clear, CONFIG.spike_size + CONFIG.templates.max_shift, bp.reader._n_observations) # get all spikes spikes_all = np.concatenate([element[2] for element in res], axis=0) logger.info('Removing indexes outside the allowed range to ' 'draw a complete waveform...') spikes_all, _ = detect.remove_incomplete_waveforms( spikes_all, CONFIG.spike_size + CONFIG.templates.max_shift, bp.reader._n_observations) # get scores scores = np.concatenate([element[0] for element in res], axis=0) logger.info('Removing scores for indexes outside the allowed range to ' 'draw a complete waveform...') scores = scores[idx] # transform scores to location + shape feature space # TODO: move this to another place if CONFIG.cluster.method == 'location': threshold = 2 scores = get_locations_features(scores, rotation, clear[:, 1], CONFIG.channel_index, CONFIG.geom, threshold) idx_nan = np.where(np.isnan(np.sum(scores, axis=(1, 2))))[0] scores = np.delete(scores, idx_nan, 0) clear = np.delete(clear, idx_nan, 0) # save partial results if required if save_results: # save clear spikes np.save(path_to_spike_index_clear, clear) logger.info('Saved spike index clear in {}...'.format( path_to_spike_index_clear)) # save all ppikes np.save(path_to_spike_index_all, spikes_all) logger.info('Saved spike index all in {}...'.format( path_to_spike_index_all)) # save rotation np.save(path_to_rotation, rotation) logger.info( 'Saved rotation matrix in {}...'.format(path_to_rotation)) # saves scores np.save(path_to_score, scores) logger.info('Saved spike scores in {}...'.format(path_to_score)) elif if_file_exists == 'abort' and any(exists): conflict = [p for p, e in zip(paths, exists) if e] message = reduce(lambda x, y: str(x) + ', ' + str(y), conflict) raise ValueError('if_file_exists was set to abort, the ' 'program halted since the following files ' 'already exist: {}'.format(message)) elif if_file_exists == 'skip' and all(exists): logger.warning('Skipped execution. All output files exist' ', loading them...') scores = np.load(path_to_score) clear = np.load(path_to_spike_index_clear) spikes_all = np.load(path_to_spike_index_all) else: raise ValueError( 'Invalid value for if_file_exists {}' 'must be one of overwrite, abort or skip'.format(if_file_exists)) return scores, clear, spikes_all
def prepare_nn(channel_index, whiten_filter, threshold_detect, threshold_triage, detector_filename, autoencoder_filename, triage_filename): """Prepare neural net tensors in advance. This is to effciently run neural net with batch processor as we don't have to recreate tf tensors in every batch Parameters ---------- channel_index: np.array (n_channels, n_neigh) Each row indexes its neighboring channels. For example, channel_index[c] is the index of neighboring channels (including itself) If any value is equal to n_channels, it is nothing but a space holder in a case that a channel has less than n_neigh neighboring channels whiten_filter: numpy.ndarray (n_channels, n_neigh, n_neigh) whitening matrix such that whiten_filter[c] is the whitening filter of channel c and its neighboring channel determined from channel_index. threshold_detect: int threshold for neural net detection threshold_triage: int threshold for neural net triage detector_filename: str location of trained neural net detectior autoencoder_filename: str location of trained neural net autoencoder triage_filename: str location of trained neural net triage Returns ------- x_tf: tf.tensors (n_observations, n_channels) placeholder of recording for running tensorflow output_tf: tuple of tf.tensors a tuple of tensorflow tensors that produce score, spike_index_clear, and spike_index_collision NND: class an instance of class, NeuralNetDetector NNT: class an instance of class, NeuralNetTriage """ # placeholder for input recording x_tf = tf.placeholder("float", [None, None]) # load Neural Net's NND = NeuralNetDetector(detector_filename) NNAE = AutoEncoder(autoencoder_filename) NNT = NeuralNetTriage(triage_filename) # make spike_index tensorflow tensor spike_index_tf_all = NND.make_detection_tf_tensors(x_tf, channel_index, threshold_detect) # remove edge spike time spike_index_tf = remove_edge_spikes(x_tf, spike_index_tf_all, NND.filters_dict['size']) # make waveform tensorflow tensor waveform_tf = make_waveform_tf_tensor(x_tf, spike_index_tf, channel_index, NND.filters_dict['size']) # make score tensorflow tensor from waveform score_tf = NNAE.make_score_tf_tensor(waveform_tf) # run neural net triage nneigh = NND.filters_dict['n_neighbors'] idx_clean = NNT.triage_wf(waveform_tf[:, :, :nneigh], threshold_triage) # gather all output tensors output_tf = (score_tf, spike_index_tf, idx_clean) return x_tf, output_tf, NND, NNAE, NNT
def nn_detection(recordings, neighbors, geom, temporal_features, temporal_window, th_detect, th_triage, detector_filename, autoencoder_filename, triage_filename): """Detect spikes using a neural network Parameters ---------- recordings: numpy.ndarray (n_observations, n_channels) Neural recordings neighbors: numpy.ndarray (n_channels, n_channels) Channels neighbors matric geom: numpy.ndarray (n_channels, 2) Cartesian coordinates for the channels temporal_features: int ? temporal_window: int ? th_detect: float? Spike threshold [improve this explanation] th_triage: float? Triage threshold [improve this explanation] detector_filename: str Path to neural network detector autoencoder_filename: str Path to neural network autoencoder triage_filename: str Path to triage neural network Returns ------- clear_scores: numpy.ndarray (n_spikes, n_features, n_channels) 3D array with the scores for the clear spikes, first simension is the number of spikes, second is the nymber of features and third the number of channels spike_index_clear: numpy.ndarray (n_clear_spikes, 2) 2D array with indexes for clear spikes, first column contains the spike location in the recording and the second the main channel (channel whose amplitude is maximum) spike_index_collision: numpy.ndarray (n_collided_spikes, 2) 2D array with indexes for collided spikes, first column contains the spike location in the recording and the second the main channel (channel whose amplitude is maximum) """ nnd = NeuralNetDetector(detector_filename, autoencoder_filename) nnt = NeuralNetTriage(triage_filename) T, C = recordings.shape a, b = neighbors.shape if a != b: raise ValueError('neighbors is not a square matrix, verify') if a != C: raise ValueError( 'Number of channels in recording are {} but the ' 'neighbors matrix has {} elements, they must match'.format(C, a)) # neighboring channel info nneigh = np.max(np.sum(neighbors, 0)) c_idx = np.ones((C, nneigh), 'int32') * C for c in range(C): ch_idx, temp = order_channels_by_distance(c, np.where(neighbors[c])[0], geom) c_idx[c, :ch_idx.shape[0]] = ch_idx # input x_tf = tf.placeholder("float", [T, C]) # detect spike index local_max_idx_tf = nnd.get_spikes(x_tf, T, nneigh, c_idx, temporal_window, th_detect) # get score train score_train_tf = nnd.get_score_train(x_tf) # get energy for detected index energy_tf = tf.reduce_sum(tf.square(score_train_tf), axis=2) energy_val_tf = tf.gather_nd(energy_tf, local_max_idx_tf) # get triage probability triage_prob_tf = nnt.triage_prob(x_tf, T, nneigh, c_idx) # gather all results above result = (local_max_idx_tf, score_train_tf, energy_val_tf, triage_prob_tf) # remove duplicates energy_train_tf = tf.placeholder("float", [T, C]) spike_index_tf = remove_duplicate_spikes_by_energy(energy_train_tf, T, c_idx, temporal_window) # get score score_train_placeholder = tf.placeholder("float", [T, C, temporal_features]) spike_index_clear_tf = tf.placeholder("int64", [None, 2]) score_tf = get_score(score_train_placeholder, spike_index_clear_tf, T, temporal_features, c_idx) ############################### # get values of above tensors # ############################### with tf.Session() as sess: nnd.saver.restore(sess, nnd.path_to_detector_model) nnd.saver_ae.restore(sess, nnd.path_to_ae_model) nnt.saver.restore(sess, nnt.path_to_triage_model) local_max_idx, score_train, energy_val, triage_prob = sess.run( result, feed_dict={x_tf: recordings}) energy_train = np.zeros((T, C)) energy_train[local_max_idx[:, 0], local_max_idx[:, 1]] = energy_val spike_index = sess.run(spike_index_tf, feed_dict={energy_train_tf: energy_train}) idx_clean = triage_prob[spike_index[:, 0], spike_index[:, 1]] > th_triage spike_index_clear = spike_index[idx_clean] spike_index_collision = spike_index[~idx_clean] score = sess.run(score_tf, feed_dict={ score_train_placeholder: score_train, spike_index_clear_tf: spike_index_clear }) return score, spike_index_clear, spike_index_collision
def test_splitting_in_batches_does_not_affect(path_to_config, path_to_sample_pipeline_folder, make_tmp_folder, path_to_standardized_data): yass.set_config(path_to_config, make_tmp_folder) CONFIG = yass.read_config() PATH_TO_DATA = path_to_standardized_data with open(path.join(path_to_sample_pipeline_folder, 'preprocess', 'standardized.yaml')) as f: PARAMS = yaml.load(f) channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom) detection_th = CONFIG.detect.neural_network_detector.threshold_spike triage_th = CONFIG.detect.neural_network_triage.threshold_collision detection_fname = CONFIG.detect.neural_network_detector.filename ae_fname = CONFIG.detect.neural_network_autoencoder.filename triage_fname = CONFIG.detect.neural_network_triage.filename # instantiate neural networks NND = NeuralNetDetector.load(detection_fname, detection_th, channel_index) triage = KerasModel(triage_fname, allow_longer_waveform_length=True, allow_more_channels=True) NNAE = AutoEncoder.load(ae_fname, input_tensor=NND.waveform_tf) bp = BatchProcessor(PATH_TO_DATA, PARAMS['dtype'], PARAMS['n_channels'], PARAMS['data_order'], '100KB', buffer_size=CONFIG.spike_size) out = ('spike_index', 'waveform') fn = neuralnetwork.apply.fix_indexes_spike_index # detector with tf.Session() as sess: # get values of above tensors NND.restore(sess) res = bp.multi_channel_apply(NND.predict_recording, mode='memory', sess=sess, output_names=out, cleanup_function=fn) spike_index_new = np.concatenate([element[0] for element in res], axis=0) wfs = np.concatenate([element[1] for element in res], axis=0) idx_clean = triage.predict_with_threshold(wfs, triage_th) score = NNAE.predict(wfs) rot = NNAE.load_rotation() neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) (score_clear_new, spike_index_clear_new) = post_processing(score, spike_index_new, idx_clean, rot, neighbors) with tf.Session() as sess: # get values of above tensors NND.restore(sess) res = bp.multi_channel_apply(NND.predict_recording, mode='memory', sess=sess, output_names=('spike_index', 'waveform'), cleanup_function=fn) spike_index_batch, wfs = zip(*res) spike_index_batch = np.concatenate(spike_index_batch, axis=0) wfs = np.concatenate(wfs, axis=0) idx_clean = triage.predict_with_threshold(x=wfs, threshold=triage_th) score = NNAE.predict(wfs) rot = NNAE.load_rotation() neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) (score_clear_batch, spike_index_clear_batch) = post_processing(score, spike_index_batch, idx_clean, rot, neighbors)
def test_splitting_in_batches_does_not_affect(path_to_tests, path_to_standarized_data, path_to_sample_pipeline_folder): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() PATH_TO_DATA = path_to_standarized_data data = RecordingsReader(PATH_TO_DATA, loader='array').data with open( path.join(path_to_sample_pipeline_folder, 'preprocess', 'standarized.yaml')) as f: PARAMS = yaml.load(f) channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom) detection_th = CONFIG.detect.neural_network_detector.threshold_spike triage_th = CONFIG.detect.neural_network_triage.threshold_collision detection_fname = CONFIG.detect.neural_network_detector.filename ae_fname = CONFIG.detect.neural_network_autoencoder.filename triage_fname = CONFIG.detect.neural_network_triage.filename # instantiate neural networks NND = NeuralNetDetector.load(detection_fname, detection_th, channel_index) NNT = NeuralNetTriage.load(triage_fname, triage_th, input_tensor=NND.waveform_tf) NNAE = AutoEncoder(ae_fname, input_tensor=NND.waveform_tf) output_tf = (NNAE.score_tf, NND.spike_index_tf, NNT.idx_clean) # run all at once with tf.Session() as sess: # get values of above tensors NND.restore(sess) NNAE.restore(sess) NNT.restore(sess) rot = NNAE.load_rotation() neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) (scores, clear, collision) = neuralnetwork.run_detect_triage_featurize( data, sess, NND.x_tf, output_tf, neighbors, rot) # run in batches - buffer size makes sure we can detect spikes if they # appear at the end of any batch bp = BatchProcessor(PATH_TO_DATA, PARAMS['dtype'], PARAMS['n_channels'], PARAMS['data_order'], '100KB', buffer_size=CONFIG.spike_size) with tf.Session() as sess: # get values of above tensors NND.restore(sess) NNAE.restore(sess) NNT.restore(sess) rot = NNAE.load_rotation() neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) res = bp.multi_channel_apply( neuralnetwork.run_detect_triage_featurize, mode='memory', cleanup_function=neuralnetwork.fix_indexes, sess=sess, x_tf=NND.x_tf, output_tf=output_tf, rot=rot, neighbors=neighbors) scores_batch = np.concatenate([element[0] for element in res], axis=0) clear_batch = np.concatenate([element[1] for element in res], axis=0) collision_batch = np.concatenate([element[2] for element in res], axis=0) np.testing.assert_array_equal(clear_batch, clear) np.testing.assert_array_equal(collision_batch, collision) np.testing.assert_array_equal(scores_batch, scores)