def get_o_layer(standarized_path, standarized_params, output_directory='tmp/', output_dtype='float32', output_filename='o_layer.bin', if_file_exists='skip', save_partial_results=False): """Get the output of NN detector instead of outputting the spike index """ CONFIG = read_config() channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom, 1) x_tf = tf.placeholder("float", [None, None]) # load Neural Net's detection_fname = CONFIG.detect.neural_network_detector.filename detection_th = CONFIG.detect.neural_network_detector.threshold_spike NND = NeuralNetDetector(detection_fname) o_layer_tf = NND.make_o_layer_tf_tensors(x_tf, channel_index, detection_th) bp = BatchProcessor(standarized_path, standarized_params['dtype'], standarized_params['n_channels'], standarized_params['data_format'], CONFIG.resources.max_memory, buffer_size=CONFIG.spike_size) TMP = os.path.join(CONFIG.data.root_folder, output_directory) _output_path = os.path.join(TMP, output_filename) (o_path, o_params) = bp.multi_channel_apply(_get_o_layer, mode='disk', cleanup_function=fix_indexes, output_path=_output_path, cast_dtype=output_dtype, x_tf=x_tf, o_layer_tf=o_layer_tf, NND=NND) return o_path, o_params
def test_can_pass_information_between_batches(path_to_data): bp = BatchProcessor(path_to_data, dtype='int64', n_channels=2, data_order='samples', max_memory='160B') def col_sums(data, previous_batch): current = np.sum(data, axis=0) if previous_batch is None: return current else: return np.array([ previous_batch[0] + current[0], previous_batch[1] + current[1] ]) res = bp.multi_channel_apply(col_sums, mode='memory', channels='all', pass_batch_results=True) assert res[0] == 4950 and res[1] == 4950
def _threshold_detection(standarized_path, standarized_params, n_observations, output_directory): """Run threshold detector and dimensionality reduction using PCA """ logger = logging.getLogger(__name__) CONFIG = read_config() OUTPUT_DTYPE = CONFIG.preprocess.dtype TMP_FOLDER = os.path.join(CONFIG.data.root_folder, output_directory) ############### # Whiten data # ############### # compute Q for whitening logger.info('Computing whitening matrix...') bp = BatchProcessor(standarized_path, standarized_params['dtype'], standarized_params['n_channels'], standarized_params['data_format'], CONFIG.resources.max_memory) batches = bp.multi_channel() first_batch, _, _ = next(batches) Q = whiten.matrix(first_batch, CONFIG.neighChannels, CONFIG.spikeSize) path_to_whitening_matrix = os.path.join(TMP_FOLDER, 'whitening.npy') np.save(path_to_whitening_matrix, Q) logger.info( 'Saved whitening matrix in {}'.format(path_to_whitening_matrix)) # apply whitening to every batch (whitened_path, whitened_params) = bp.multi_channel_apply( np.matmul, mode='disk', output_path=os.path.join(TMP_FOLDER, 'whitened.bin'), if_file_exists='skip', cast_dtype=OUTPUT_DTYPE, b=Q) ################### # Spike detection # ################### path_to_spike_index_clear = os.path.join(TMP_FOLDER, 'spike_index_clear.npy') bp = BatchProcessor(standarized_path, standarized_params['dtype'], standarized_params['n_channels'], standarized_params['data_format'], CONFIG.resources.max_memory, buffer_size=0) # clear spikes if os.path.exists(path_to_spike_index_clear): # if it exists, load it... logger.info('Found file in {}, loading it...'.format( path_to_spike_index_clear)) spike_index_clear = np.load(path_to_spike_index_clear) else: # if it doesn't, detect spikes... logger.info('Did not find file in {}, finding spikes using threshold' ' detector...'.format(path_to_spike_index_clear)) # apply threshold detector on standarized data spikes = bp.multi_channel_apply(detect.threshold, mode='memory', cleanup_function=detect.fix_indexes, neighbors=CONFIG.neighChannels, spike_size=CONFIG.spikeSize, std_factor=CONFIG.stdFactor) spike_index_clear = np.vstack(spikes) logger.info('Removing clear indexes outside the allowed range to ' 'draw a complete waveform...') spike_index_clear, _ = (detect.remove_incomplete_waveforms( spike_index_clear, CONFIG.spikeSize + CONFIG.templatesMaxShift, n_observations)) logger.info('Saving spikes in {}...'.format(path_to_spike_index_clear)) np.save(path_to_spike_index_clear, spike_index_clear) path_to_spike_index_collision = os.path.join(TMP_FOLDER, 'spike_index_collision.npy') # collided spikes if os.path.exists(path_to_spike_index_collision): # if it exists, load it... logger.info('Found collided spikes in {}, loading them...'.format( path_to_spike_index_collision)) spike_index_collision = np.load(path_to_spike_index_collision) if spike_index_collision.shape[0] != 0: raise ValueError('Found non-empty collision spike index in {}, ' 'but threshold detector is selected, collision ' 'detection is not implemented for threshold ' 'detector so array must have dimensios (0, 2) ' 'but had ({}, {})'.format( path_to_spike_index_collision, *spike_index_collision.shape)) else: # triage is not implemented on threshold detector, return empty array logger.info('Creating empty array for' ' collided spikes (collision detection is not implemented' ' with threshold detector. Saving them in {}'.format( path_to_spike_index_collision)) spike_index_collision = np.zeros((0, 2), 'int32') np.save(path_to_spike_index_collision, spike_index_collision) ####################### # Waveform extraction # ####################### # load and dump waveforms from clear spikes path_to_waveforms_clear = os.path.join(TMP_FOLDER, 'waveforms_clear.npy') if os.path.exists(path_to_waveforms_clear): logger.info('Found clear waveforms in {}, loading them...'.format( path_to_waveforms_clear)) waveforms_clear = np.load(path_to_waveforms_clear) else: logger.info( 'Did not find clear waveforms in {}, reading them from {}'.format( path_to_waveforms_clear, standarized_path)) explorer = RecordingExplorer(standarized_path, spike_size=CONFIG.spikeSize) waveforms_clear = explorer.read_waveforms(spike_index_clear[:, 0]) np.save(path_to_waveforms_clear, waveforms_clear) logger.info('Saved waveform from clear spikes in: {}'.format( path_to_waveforms_clear)) ######################### # PCA - rotation matrix # ######################### # compute per-batch sufficient statistics for PCA on standarized data logger.info('Computing PCA sufficient statistics...') stats = bp.multi_channel_apply(dim_red.suff_stat, mode='memory', spike_index=spike_index_clear, spike_size=CONFIG.spikeSize) suff_stats = reduce(lambda x, y: np.add(x, y), [e[0] for e in stats]) spikes_per_channel = reduce(lambda x, y: np.add(x, y), [e[1] for e in stats]) # compute rotation matrix logger.info('Computing PCA projection matrix...') rotation = dim_red.project(suff_stats, spikes_per_channel, CONFIG.spikes.temporal_features, CONFIG.neighChannels) path_to_rotation = os.path.join(TMP_FOLDER, 'rotation.npy') np.save(path_to_rotation, rotation) logger.info('Saved rotation matrix in {}...'.format(path_to_rotation)) main_channel = spike_index_clear[:, 1] ########################################### # PCA - waveform dimensionality reduction # ########################################### if CONFIG.clustering.clustering_method == 'location': logger.info('Denoising...') path_to_denoised_waveforms = os.path.join(TMP_FOLDER, 'denoised_waveforms.npy') if os.path.exists(path_to_denoised_waveforms): logger.info( 'Found denoised waveforms in {}, loading them...'.format( path_to_denoised_waveforms)) denoised_waveforms = np.load(path_to_denoised_waveforms) else: logger.info( 'Did not find denoised waveforms in {}, evaluating them' 'from {}'.format(path_to_denoised_waveforms, path_to_waveforms_clear)) waveforms_clear = np.load(path_to_waveforms_clear) denoised_waveforms = dim_red.denoise(waveforms_clear, rotation, CONFIG) logger.info('Saving denoised waveforms to {}'.format( path_to_denoised_waveforms)) np.save(path_to_denoised_waveforms, denoised_waveforms) isolated_index, x, y = get_isolated_spikes_and_locations( denoised_waveforms, main_channel, CONFIG) x = (x - np.mean(x)) / np.std(x) y = (y - np.mean(y)) / np.std(y) corrupted_index = np.logical_not( np.in1d(np.arange(spike_index_clear.shape[0]), isolated_index)) spike_index_collision = np.concatenate( [spike_index_collision, spike_index_clear[corrupted_index]], axis=0) spike_index_clear = spike_index_clear[isolated_index] waveforms_clear = waveforms_clear[isolated_index] ################################################# # Dimensionality reduction (Isolated Waveforms) # ################################################# scores = dim_red.main_channel_scores(waveforms_clear, rotation, spike_index_clear, CONFIG) scores = (scores - np.mean(scores, axis=0)) / np.std(scores) scores = np.concatenate([ x[:, np.newaxis, np.newaxis], y[:, np.newaxis, np.newaxis], scores[:, :, np.newaxis] ], axis=1) else: logger.info('Reducing spikes dimensionality with PCA matrix...') scores = dim_red.score(waveforms_clear, rotation, spike_index_clear[:, 1], CONFIG.neighChannels, CONFIG.geom) # save scores path_to_score = os.path.join(TMP_FOLDER, 'score_clear.npy') np.save(path_to_score, scores) logger.info('Saved spike scores in {}...'.format(path_to_score)) return scores, spike_index_clear, spike_index_collision
def threshold(path_to_data, dtype, n_channels, data_order, max_memory, neighbors, spike_size, minimum_half_waveform_size, threshold, output_path=None, spike_index_clear_filename='spike_index_clear.npy', if_file_exists='skip'): """Threshold spike detection in batches Parameters ---------- path_to_data: str Path to recordings in binary format dtype: str Recordings dtype n_channels: int Number of channels in the recordings data_order: str Recordings order, one of ('channels', 'samples'). In a dataset with k observations per channel and j channels: 'channels' means first k contiguous observations come from channel 0, then channel 1, and so on. 'sample' means first j contiguous data are the first observations from all channels, then the second observations from all channels and so on max_memory: Max memory to use in each batch (e.g. 100MB, 1GB) neighbors_matrix: numpy.ndarray (n_channels, n_channels) Boolean numpy 2-D array where a i, j entry is True if i is considered neighbor of j spike_size: int Spike size minimum_half_waveform_size: int This is used to remove spikes that are either at the beginning or end of the recordings and whose location does not allow to draw a wavefor of size at least 2 * minimum_half_waveform_size + 1 threshold: float Threshold used on amplitude for detection output_path: str, optional Directory to save spike indexes, if None, results won't be stored, but only returned by the function spike_index_clear_filename: str, optional Filename for spike_index_clear, it is used as the filename for the file (relative to output_path), if None, results won't be saved, only returned if_file_exists: What to do if there is already a file in save_spike_index_clear path. One of 'overwrite', 'abort', 'skip'. If 'overwrite' it replaces he file if it exists, if 'abort' if raise a ValueError exception if the file exists, if 'skip' it skips the operation if save_spike_index_clear and save_spike_index_collision the file exist and loads them from disk, if any of the files is missing they are computed Returns ------- spike_index_clear: numpy.ndarray (n_clear_spikes, 2) 2D array with indexes for clear spikes, first column contains the spike location in the recording and the second the main channel (channel whose amplitude is maximum) spike_index_collision: numpy.ndarray (0, 2) Empty array, collision is not implemented in the threshold detector """ # instatiate batch processor bp = BatchProcessor(path_to_data, dtype, n_channels, data_order, max_memory, buffer_size=spike_size) # run threshold detector spikes = bp.multi_channel_apply(_threshold, mode='memory', cleanup_function=fix_indexes, neighbors=neighbors, spike_size=spike_size, threshold=threshold) # no collision detection implemented, all spikes are marked as clear spike_index_clear = np.vstack(spikes) # remove spikes whose location won't let us draw a complete waveform logger.info('Removing clear indexes outside the allowed range to ' 'draw a complete waveform...') spike_index_clear, _ = (remove_incomplete_waveforms( spike_index_clear, minimum_half_waveform_size, bp.reader.observations)) if output_path and spike_index_clear_filename: path = os.path.join(output_path, spike_index_clear_filename) save_numpy_object(spike_index_clear, path, if_file_exists=if_file_exists, name='Spike index clear') return spike_index_clear
big_long[:, 1] = x_long bp_long = BatchProcessor(path_to_long, dtype='int64', n_channels=50, data_format='long', max_memory='500MB') path = bp_long.single_channel_apply(dummy, path_to_out) out = RecordingsReader(path) out bp_wide = BatchProcessor(path_to_wide, dtype='int64', n_channels=50, data_format='wide', max_memory='500MB') path = bp_wide.single_channel_apply(dummy, path_to_out) out = RecordingsReader(path) out path = bp_long.multi_channel_apply(dummy, path_to_out) out = RecordingsReader(path) out path = bp_wide.multi_channel_apply(dummy, path_to_out) out = RecordingsReader(path) out
def pca(path_to_data, dtype, n_channels, data_order, spike_index, spike_size, temporal_features, neighbors_matrix, channel_index, max_memory, output_path=None, scores_filename='scores.npy', rotation_matrix_filename='rotation.npy', spike_index_clear_filename='spike_index_clear_pca.npy', if_file_exists='skip'): """Apply PCA in batches Parameters ---------- path_to_data: str Path to recordings in binary format dtype: str Recordings dtype n_channels: int Number of channels in the recordings data_order: str Recordings order, one of ('channels', 'samples'). In a dataset with k observations per channel and j channels: 'channels' means first k contiguous observations come from channel 0, then channel 1, and so on. 'sample' means first j contiguous data are the first observations from all channels, then the second observations from all channels and so on spike_index: numpy.ndarray A 2D numpy array, first column is spike time, second column is main channel (the channel where spike has the biggest amplitude) spike_size: int Spike size temporal_features: numpy.ndarray Number of output features neighbors_matrix: numpy.ndarray (n_channels, n_channels) Boolean numpy 2-D array where a i, j entry is True if i is considered neighbor of j channel_index: np.array (n_channels, n_neigh) Each row indexes its neighboring channels. For example, channel_index[c] is the index of neighboring channels (including itself) If any value is equal to n_channels, it is nothing but a space holder in a case that a channel has less than n_neigh neighboring channels max_memory: Max memory to use in each batch (e.g. 100MB, 1GB) output_path: str, optional Directory to store the scores and rotation matrix, if None, previous results on disk are ignored, operations are computed and results aren't saved to disk scores_filename: str, optional File name for rotation matrix if False, does not save data rotation_matrix_filename: str, optional File name for scores if False, does not save data spike_index_clear_filename: str, optional File name for spike index clear if_file_exists: What to do if there is already a file in the rotation matrix and/or scores location. One of 'overwrite', 'abort', 'skip'. If 'overwrite' it replaces the file if it exists, if 'abort' if raise a ValueError exception if the file exists, if 'skip' if skips the operation if the file exists Returns ------- scores: numpy.ndarray Numpy 3D array of size (n_waveforms, n_reduced_features, n_neighboring_channels) Scores for every waveform, second dimension in the array is reduced from n_temporal_features to n_reduced_features, third dimension depends on the number of neighboring channels rotation_matrix: numpy.ndarray 3D array (window_size, n_features, n_channels) """ ########################### # compute rotation matrix # ########################### bp = BatchProcessor(path_to_data, dtype, n_channels, data_order, max_memory, buffer_size=spike_size) # compute PCA sufficient statistics logger.info('Computing PCA sufficient statistics...') stats = bp.multi_channel_apply(suff_stat, mode='memory', spike_index=spike_index, spike_size=spike_size) suff_stats = reduce(lambda x, y: np.add(x, y), [e[0] for e in stats]) spikes_per_channel = reduce(lambda x, y: np.add(x, y), [e[1] for e in stats]) # compute PCA projection matrix logger.info('Computing PCA projection matrix...') rotation = project(suff_stats, spikes_per_channel, temporal_features, neighbors_matrix) ##################################### # waveform dimensionality reduction # ##################################### logger.info('Reducing spikes dimensionality with PCA matrix...') res = bp.multi_channel_apply(score, mode='memory', pass_batch_info=True, rot=rotation, channel_index=channel_index, spike_index=spike_index) scores = np.concatenate([element[0] for element in res], axis=0) spike_index = np.concatenate([element[1] for element in res], axis=0) # save scores if output_path and scores_filename: path_to_score = Path(output_path) / scores_filename save_numpy_object(scores, path_to_score, if_file_exists=if_file_exists, name='scores') if output_path and spike_index_clear_filename: path_to_spike_index = Path(output_path) / spike_index_clear_filename save_numpy_object(spike_index, path_to_spike_index, if_file_exists=if_file_exists, name='Spike index PCA') if output_path and rotation_matrix_filename: path_to_rotation = Path(output_path) / rotation_matrix_filename save_numpy_object(rotation, path_to_rotation, if_file_exists=if_file_exists, name='rotation matrix') return scores, spike_index, rotation
def test_splitting_in_batches_does_not_affect(path_to_config, path_to_sample_pipeline_folder, make_tmp_folder, path_to_standardized_data): yass.set_config(path_to_config, make_tmp_folder) CONFIG = yass.read_config() PATH_TO_DATA = path_to_standardized_data with open(path.join(path_to_sample_pipeline_folder, 'preprocess', 'standardized.yaml')) as f: PARAMS = yaml.load(f) channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom) detection_th = CONFIG.detect.neural_network_detector.threshold_spike triage_th = CONFIG.detect.neural_network_triage.threshold_collision detection_fname = CONFIG.detect.neural_network_detector.filename ae_fname = CONFIG.detect.neural_network_autoencoder.filename triage_fname = CONFIG.detect.neural_network_triage.filename # instantiate neural networks NND = NeuralNetDetector.load(detection_fname, detection_th, channel_index) triage = KerasModel(triage_fname, allow_longer_waveform_length=True, allow_more_channels=True) NNAE = AutoEncoder.load(ae_fname, input_tensor=NND.waveform_tf) bp = BatchProcessor(PATH_TO_DATA, PARAMS['dtype'], PARAMS['n_channels'], PARAMS['data_order'], '100KB', buffer_size=CONFIG.spike_size) out = ('spike_index', 'waveform') fn = neuralnetwork.apply.fix_indexes_spike_index # detector with tf.Session() as sess: # get values of above tensors NND.restore(sess) res = bp.multi_channel_apply(NND.predict_recording, mode='memory', sess=sess, output_names=out, cleanup_function=fn) spike_index_new = np.concatenate([element[0] for element in res], axis=0) wfs = np.concatenate([element[1] for element in res], axis=0) idx_clean = triage.predict_with_threshold(wfs, triage_th) score = NNAE.predict(wfs) rot = NNAE.load_rotation() neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) (score_clear_new, spike_index_clear_new) = post_processing(score, spike_index_new, idx_clean, rot, neighbors) with tf.Session() as sess: # get values of above tensors NND.restore(sess) res = bp.multi_channel_apply(NND.predict_recording, mode='memory', sess=sess, output_names=('spike_index', 'waveform'), cleanup_function=fn) spike_index_batch, wfs = zip(*res) spike_index_batch = np.concatenate(spike_index_batch, axis=0) wfs = np.concatenate(wfs, axis=0) idx_clean = triage.predict_with_threshold(x=wfs, threshold=triage_th) score = NNAE.predict(wfs) rot = NNAE.load_rotation() neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) (score_clear_batch, spike_index_clear_batch) = post_processing(score, spike_index_batch, idx_clean, rot, neighbors)
def test_splitting_in_batches_does_not_affect_result(path_to_tests): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() PATH_TO_DATA = path.join(path_to_tests, 'data/standarized.bin') data = RecordingsReader(PATH_TO_DATA, loader='array').data with open(path.join(path_to_tests, 'data/standarized.yaml')) as f: PARAMS = yaml.load(f) channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom) whiten_filter = np.tile( np.eye(channel_index.shape[1], dtype='float32')[np.newaxis, :, :], [channel_index.shape[0], 1, 1]) detection_th = CONFIG.detect.neural_network_detector.threshold_spike triage_th = CONFIG.detect.neural_network_triage.threshold_collision detection_fname = CONFIG.detect.neural_network_detector.filename ae_fname = CONFIG.detect.neural_network_autoencoder.filename triage_fname = CONFIG.detect.neural_network_triage.filename (x_tf, output_tf, NND, NNAE, NNT) = neuralnetwork.prepare_nn( channel_index, whiten_filter, detection_th, triage_th, detection_fname, ae_fname, triage_fname, ) # run all at once with tf.Session() as sess: # get values of above tensors NND.saver.restore(sess, NND.path_to_detector_model) NNAE.saver_ae.restore(sess, NNAE.path_to_ae_model) NNT.saver.restore(sess, NNT.path_to_triage_model) rot = NNAE.load_rotation() neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) (scores, clear, collision) = neuralnetwork.run_detect_triage_featurize( data, sess, x_tf, output_tf, neighbors, rot) # run in batches - buffer size makes sure we can detect spikes if they # appear at the end of any batch bp = BatchProcessor(PATH_TO_DATA, PARAMS['dtype'], PARAMS['n_channels'], PARAMS['data_order'], '100KB', buffer_size=CONFIG.spike_size) with tf.Session() as sess: # get values of above tensors NND.saver.restore(sess, NND.path_to_detector_model) NNAE.saver_ae.restore(sess, NNAE.path_to_ae_model) NNT.saver.restore(sess, NNT.path_to_triage_model) rot = NNAE.load_rotation() neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) res = bp.multi_channel_apply( neuralnetwork.run_detect_triage_featurize, mode='memory', cleanup_function=neuralnetwork.fix_indexes, sess=sess, x_tf=x_tf, output_tf=output_tf, rot=rot, neighbors=neighbors) scores_batch = np.concatenate([element[0] for element in res], axis=0) clear_batch = np.concatenate([element[1] for element in res], axis=0) collision_batch = np.concatenate([element[2] for element in res], axis=0) np.testing.assert_array_equal(clear_batch, clear) np.testing.assert_array_equal(collision_batch, collision) np.testing.assert_array_equal(scores_batch, scores)
# create batch processor for the data bp = BatchProcessor(path_to_neuropixel_data, dtype='int16', n_channels=385, data_format='wide', max_memory='500MB') # appply a multi channel transformation, each batch will be a temporal # subset with observations from all selected n_channels, the size # of the subset is calculated depending on max_memory. Each batch is # processed and when done, results are save to disk, the next batch is # then loaded and so on bp.multi_channel_apply(sum_one, mode='disk', output_path=path_to_modified_data, channels=[0, 1, 2]) # let's visualize the results raw = RecordingsReader(path_to_neuropixel_data, dtype='int16', n_channels=385, data_format='wide') # you do not need to specify the format since multi_channel_apply # saves a yaml file with such parameters filtered = RecordingsReader(path_to_modified_data) fig, (ax1, ax2) = plt.subplots(2, 1) ax1.plot(raw[:2000, 0]) ax2.plot(filtered[:2000, 0])
def butterworth(path_to_data, dtype, n_channels, data_order, low_frequency, high_factor, order, sampling_frequency, max_memory, output_path, output_dtype, standarize=False, output_filename='filtered.bin', if_file_exists='skip', processes='max'): """Filter (butterworth) recordings in batches Parameters ---------- path_to_data: str Path to recordings in binary format dtype: str Recordings dtype n_channels: int Number of channels in the recordings data_order: str Recordings order, one of ('channels', 'samples'). In a dataset with k observations per channel and j channels: 'channels' means first k contiguous observations come from channel 0, then channel 1, and so on. 'sample' means first j contiguous data are the first observations from all channels, then the second observations from all channels and so on low_frequency: int Low pass frequency (Hz) high_factor: float High pass factor (proportion of sampling rate) order: int Order of Butterworth filter sampling_frequency: int Recordings sampling frequency in Hz max_memory: str Max memory to use in each batch (e.g. 100MB, 1GB) output_path: str Folder to store the filtered recordings output_filename: str, optional How to name the file, defaults to filtered.bin output_dtype: str dtype for filtered data standarize: bool Whether to standarize the data after the filtering step if_file_exists: str, optional One of 'overwrite', 'abort', 'skip'. If 'overwrite' it replaces the file if it exists, if 'abort' if raise a ValueError exception if the file exists, if 'skip' if skips the operation if the file exists processes: str or int, optional Number of processes to use, if 'max', it uses all cores in the machine if a number, it uses that number of cores Returns ------- standarized_path: str Location to filtered recordings standarized_params: dict A dictionary with the parameters for the filtered recordings (dtype, n_channels, data_order) """ processes = multiprocess.cpu_count() if processes == 'max' else processes # init batch processor bp = BatchProcessor(path_to_data, dtype, n_channels, data_order, max_memory, buffer_size=200) if standarize: bp_ = BatchProcessor(path_to_data, dtype, n_channels, data_order, max_memory, buffer_size=0) filtering = partial(_butterworth, low_frequency=low_frequency, high_factor=high_factor, order=order, sampling_frequency=sampling_frequency) # if standarize, estimate sd from first batch and use # _butterworth_scale function, pass filtering to estimate sd from the # filtered data sd = standard_deviation(bp_, sampling_frequency, preprocess_fn=filtering) fn = partial(_butterworth_scale, denominator=sd) # add name to the partial object, since it is not added... fn.__name__ = _butterworth_scale.__name__ else: # otherwise use _butterworth function fn = _butterworth _output_path = os.path.join(output_path, output_filename) (path, params) = bp.multi_channel_apply(fn, mode='disk', cleanup_function=fix_indexes, output_path=_output_path, if_file_exists=if_file_exists, cast_dtype=output_dtype, low_frequency=low_frequency, high_factor=high_factor, order=order, sampling_frequency=sampling_frequency, processes=processes) return path, params
path_to_neuropixel_data = (os.path.expanduser('~/data/ucl-neuropixel' '/rawDataSample.bin')) path_to_filtered_data = (os.path.expanduser('~/data/ucl-neuropixel' '/tmp/filtered_multi.bin')) # create batch processor for the data bp = BatchProcessor(path_to_neuropixel_data, dtype='int16', n_channels=385, data_format='wide', max_memory='500MB') # appply a multi channel transformation, each batch will be a temporal # subset with observations from all selected n_channels, the size # of the subset is calculated depending on max_memory bp.multi_channel_apply(butterworth, path_to_filtered_data, channels='all', low_freq=300, high_factor=0.1, order=3, sampling_freq=30000) # let's visualize the results raw = RecordingsReader(path_to_neuropixel_data, dtype='int16', n_channels=385, data_format='wide') # you do not need to specify the format since multi_channel_apply # saves a yaml file with such parameters filtered = RecordingsReader(path_to_filtered_data) fig, (ax1, ax2) = plt.subplots(2, 1) ax1.plot(raw[:2000, 0]) ax2.plot(filtered[:2000, 0]) plt.show()
def pca(path_to_data, dtype, n_channels, data_order, recordings, spike_index, spike_size, temporal_features, neighbors_matrix, channel_index, max_memory, gmm_params, output_path=None, scores_filename='scores.npy', rotation_matrix_filename='rotation.npy', spike_index_clear_filename='spike_index_clear_pca.npy', if_file_exists='skip'): """Apply PCA in batches Parameters ---------- path_to_data: str Path to recordings in binary format dtype: str Recordings dtype n_channels: int Number of channels in the recordings data_order: str Recordings order, one of ('channels', 'samples'). In a dataset with k observations per channel and j channels: 'channels' means first k contiguous observations come from channel 0, then channel 1, and so on. 'sample' means first j contiguous data are the first observations from all channels, then the second observations from all channels and so on recordings: np.ndarray (n_observations, n_channels) Multi-channel recordings spike_index: numpy.ndarray A 2D numpy array, first column is spike time, second column is main channel (the channel where spike has the biggest amplitude) spike_size: int Spike size temporal_features: numpy.ndarray Number of output features neighbors_matrix: numpy.ndarray (n_channels, n_channels) Boolean numpy 2-D array where a i, j entry is True if i is considered neighbor of j channel_index: np.array (n_channels, n_neigh) Each row indexes its neighboring channels. For example, channel_index[c] is the index of neighboring channels (including itself) If any value is equal to n_channels, it is nothing but a space holder in a case that a channel has less than n_neigh neighboring channels max_memory: Max memory to use in each batch (e.g. 100MB, 1GB) gmm_params: Dictionary with the parameters of the Gaussian mixture model output_path: str, optional Directory to store the scores and rotation matrix, if None, previous results on disk are ignored, operations are computed and results aren't saved to disk scores_filename: str, optional File name for rotation matrix if False, does not save data rotation_matrix_filename: str, optional File name for scores if False, does not save data spike_index_clear_filename: str, optional File name for spike index clear if_file_exists: What to do if there is already a file in the rotation matrix and/or scores location. One of 'overwrite', 'abort', 'skip'. If 'overwrite' it replaces the file if it exists, if 'abort' if raise a ValueError exception if the file exists, if 'skip' if skips the operation if the file exists Returns ------- scores: numpy.ndarray Numpy 3D array of size (n_waveforms, n_reduced_features, n_neighboring_channels) Scores for every waveform, second dimension in the array is reduced from n_temporal_features to n_reduced_features, third dimension depends on the number of neighboring channels rotation_matrix: numpy.ndarray 3D array (window_size, n_features, n_channels) """ ########################### # compute rotation matrix # ########################### bp = BatchProcessor(path_to_data, dtype, n_channels, data_order, max_memory, buffer_size=spike_size) # compute WPCA WAVE, FEATURE, CH = 0, 1, 2 logger.info('Preforming WPCA') logger.info('Computing Wavelets ...') feature = bp.multi_channel_apply(wavedec, mode='memory', pass_batch_info=True, spike_index=spike_index, spike_size=spike_size, wvtype='haar') features = reduce(lambda x, y: np.concatenate((x, y)), [f for f in feature]) logger.info('Computing weights..') # Weighting the features using metric defined in gmtype weights = gmm_weight(features, gmm_params, spike_index) wfeatures = features * weights n_features = wfeatures.shape[FEATURE] wfeatures_lin = np.reshape( wfeatures, (wfeatures.shape[WAVE] * n_features, wfeatures.shape[CH])) feature_index = np.arange(0, wfeatures.shape[WAVE] * n_features, n_features) TMP_FOLDER, _ = os.path.split(path_to_data) feature_path = os.path.join(TMP_FOLDER, 'features.bin') feature_params = writefile(wfeatures_lin, feature_path) bp_feat = BatchProcessor(feature_path, feature_params['dtype'], feature_params['n_channels'], feature_params['data_order'], max_memory, buffer_size=n_features) # compute PCA sufficient statistics from extracted features logger.info('Computing PCA sufficient statistics...') stats = bp_feat.multi_channel_apply(suff_stat_features, mode='memory', pass_batch_info=True, spike_index=spike_index, spike_size=spike_size, feature_index=feature_index, feature_size=n_features) suff_stats = reduce(lambda x, y: np.add(x, y), [e[0] for e in stats]) spikes_per_channel = reduce(lambda x, y: np.add(x, y), [e[1] for e in stats]) # compute PCA projection matrix logger.info('Computing PCA projection matrix...') rotation = project(suff_stats, spikes_per_channel, temporal_features, neighbors_matrix) ##################################### # waveform dimensionality reduction # ##################################### logger.info('Reducing spikes dimensionality with PCA matrix...') # using a new Batch to read feature file res = bp_feat.multi_channel_apply(score_features, mode='memory', pass_batch_info=True, rot=rotation, channel_index=channel_index, spike_index=spike_index, feature_index=feature_index) scores = np.concatenate([element[0] for element in res], axis=0) spike_index = np.concatenate([element[1] for element in res], axis=0) feature_index = np.concatenate([element[2] for element in res], axis=0) # renormalizing PC projections to similar unitary variance scores = st.zscore(scores, axis=0) # save scores if output_path and scores_filename: path_to_score = Path(output_path) / scores_filename save_numpy_object(scores, path_to_score, if_file_exists=if_file_exists, name='scores') if output_path and spike_index_clear_filename: path_to_spike_index = Path(output_path) / spike_index_clear_filename save_numpy_object(spike_index, path_to_spike_index, if_file_exists=if_file_exists, name='Spike index PCA') if output_path and rotation_matrix_filename: path_to_rotation = Path(output_path) / rotation_matrix_filename save_numpy_object(rotation, path_to_rotation, if_file_exists=if_file_exists, name='rotation matrix') return scores, spike_index, rotation
"""Add one to every element in the batch """ return np.max(batch, axis=0) # create batch processor for the data bp = BatchProcessor(path_to_neuropixel_data, dtype='int16', n_channels=385, data_format='wide', max_memory='10MB') # appply a multi channel transformation, each batch will be a temporal # subset with observations from all selected n_channels, the size # of the subset is calculated depending on max_memory. Results # from every batch are returned in a list res = bp.multi_channel_apply(max_in_channel, mode='memory', channels=[0, 1, 2]) # we have one element per batch len(res) # output for the first batch res[0] # stack results from every batch arr = np.stack(res, axis=0)
def run(output_directory='tmp/'): """Execute preprocessing pipeline Parameters ---------- output_directory: str, optional Location to store partial results, relative to CONFIG.data.root_folder, defaults to tmp/ Returns ------- clear_scores: numpy.ndarray (n_spikes, n_features, n_channels) 3D array with the scores for the clear spikes, first simension is the number of spikes, second is the nymber of features and third the number of channels spike_index_clear: numpy.ndarray (n_clear_spikes, 2) 2D array with indexes for clear spikes, first column contains the spike location in the recording and the second the main channel (channel whose amplitude is maximum) spike_index_collision: numpy.ndarray (n_collided_spikes, 2) 2D array with indexes for collided spikes, first column contains the spike location in the recording and the second the main channel (channel whose amplitude is maximum) Notes ----- Running the preprocessor will generate the followiing files in CONFIG.data.root_folder/output_directory/: * ``config.yaml`` - Copy of the configuration file * ``metadata.yaml`` - Experiment metadata * ``filtered.bin`` - Filtered recordings * ``filtered.yaml`` - Filtered recordings metadata * ``standarized.bin`` - Standarized recordings * ``standarized.yaml`` - Standarized recordings metadata * ``whitened.bin`` - Whitened recordings * ``whitened.yaml`` - Whitened recordings metadata * ``rotation.npy`` - Rotation matrix for dimensionality reduction * ``spike_index_clear.npy`` - Same as spike_index_clear returned * ``spike_index_collision.npy`` - Same as spike_index_collision returned * ``score_clear.npy`` - Scores for clear spikes * ``waveforms_clear.npy`` - Waveforms for clear spikes Examples -------- .. literalinclude:: ../examples/preprocess.py """ logger = logging.getLogger(__name__) CONFIG = read_config() OUTPUT_DTYPE = CONFIG.preprocess.dtype logger.info( 'Output dtype for transformed data will be {}'.format(OUTPUT_DTYPE)) TMP = os.path.join(CONFIG.data.root_folder, output_directory) if not os.path.exists(TMP): logger.info('Creating temporary folder: {}'.format(TMP)) os.makedirs(TMP) else: logger.info('Temporary folder {} already exists, output will be ' 'stored there'.format(TMP)) path = os.path.join(CONFIG.data.root_folder, CONFIG.data.recordings) dtype = CONFIG.recordings.dtype # initialize pipeline object, one batch per channel pipeline = BatchPipeline(path, dtype, CONFIG.recordings.n_channels, CONFIG.recordings.format, CONFIG.resources.max_memory, TMP) # add filter transformation if necessary if CONFIG.preprocess.filter: filter_op = Transform(butterworth, 'filtered.bin', mode='single_channel_one_batch', keep=True, if_file_exists='skip', cast_dtype=OUTPUT_DTYPE, low_freq=CONFIG.filter.low_pass_freq, high_factor=CONFIG.filter.high_factor, order=CONFIG.filter.order, sampling_freq=CONFIG.recordings.sampling_rate) pipeline.add([filter_op]) (filtered_path, ), (filtered_params, ) = pipeline.run() # standarize bp = BatchProcessor(filtered_path, filtered_params['dtype'], filtered_params['n_channels'], filtered_params['data_format'], CONFIG.resources.max_memory) batches = bp.multi_channel() first_batch, _, _ = next(batches) sd = standard_deviation(first_batch, CONFIG.recordings.sampling_rate) (standarized_path, standarized_params) = bp.multi_channel_apply( standarize, mode='disk', output_path=os.path.join(TMP, 'standarized.bin'), if_file_exists='skip', cast_dtype=OUTPUT_DTYPE, sd=sd) standarized = RecordingsReader(standarized_path) n_observations = standarized.observations if CONFIG.spikes.detection == 'threshold': return _threshold_detection(standarized_path, standarized_params, n_observations, output_directory) elif CONFIG.spikes.detection == 'nn': return _neural_network_detection(standarized_path, standarized_params, n_observations, output_directory)
def standarize(path_to_data, dtype, n_channels, data_order, sampling_frequency, max_memory, output_path, output_dtype, output_filename='standarized.bin', if_file_exists='skip', processes='max'): """ Standarize recordings in batches and write results to disk. Standard deviation is estimated using the first batch Parameters ---------- path_to_data: str Path to recordings in binary format dtype: str Recordings dtype n_channels: int Number of channels in the recordings data_order: str Recordings order, one of ('channels', 'samples'). In a dataset with k observations per channel and j channels: 'channels' means first k contiguous observations come from channel 0, then channel 1, and so on. 'sample' means first j contiguous data are the first observations from all channels, then the second observations from all channels and so on sampling_frequency: int Recordings sampling frequency in Hz max_memory: str Max memory to use in each batch (e.g. 100MB, 1GB) output_path: str Where to store the standarized recordings output_dtype: str dtype for standarized data output_filename: str, optional Filename for the output data, defaults to whitened.bin if_file_exists: str, optional One of 'overwrite', 'abort', 'skip'. If 'overwrite' it replaces the standarized data if it exists, if 'abort' if raise a ValueError exception if the file exists, if 'skip' if skips the operation if the file exists processes: str or int, optional Number of processes to use, if 'max', it uses all cores in the machine if a number, it uses that number of cores Returns ------- standarized_path: str Path to standarized recordings standarized_params: dict A dictionary with the parameters for the standarized recordings (dtype, n_channels, data_order) """ processes = multiprocess.cpu_count() if processes == 'max' else processes _output_path = os.path.join(output_path, output_filename) # init batch processor bp = BatchProcessor(path_to_data, dtype, n_channels, data_order, max_memory) sd = standard_deviation(bp, sampling_frequency) def divide(rec): return np.divide(rec, sd) # apply transformation (standarized_path, standarized_params) = bp.multi_channel_apply(divide, mode='disk', output_path=_output_path, cast_dtype=output_dtype, processes=processes) return standarized_path, standarized_params
def _neural_network_detection(standarized_path, standarized_params, n_observations, output_directory): """Run neural network detection and autoencoder dimensionality reduction """ logger = logging.getLogger(__name__) CONFIG = read_config() OUTPUT_DTYPE = CONFIG.preprocess.dtype TMP_FOLDER = os.path.join(CONFIG.data.root_folder, output_directory) # detect spikes bp = BatchProcessor(standarized_path, standarized_params['dtype'], standarized_params['n_channels'], standarized_params['data_format'], CONFIG.resources.max_memory, buffer_size=0) # check if all scores, clear and collision spikes exist.. path_to_score = os.path.join(TMP_FOLDER, 'score_clear.npy') path_to_spike_index_clear = os.path.join(TMP_FOLDER, 'spike_index_clear.npy') path_to_spike_index_collision = os.path.join(TMP_FOLDER, 'spike_index_collision.npy') if all([ os.path.exists(path_to_score), os.path.exists(path_to_spike_index_clear), os.path.exists(path_to_spike_index_collision) ]): logger.info('Loading "{}", "{}" and "{}"'.format( path_to_score, path_to_spike_index_clear, path_to_spike_index_collision)) scores = np.load(path_to_score) clear = np.load(path_to_spike_index_clear) collision = np.load(path_to_spike_index_collision) else: logger.info('One or more of "{}", "{}" or "{}" files were missing, ' 'computing...'.format(path_to_score, path_to_spike_index_clear, path_to_spike_index_collision)) # apply threshold detector on standarized data autoencoder_filename = CONFIG.neural_network_autoencoder.filename mc = bp.multi_channel_apply res = mc( neuralnetwork.nn_detection, mode='memory', cleanup_function=neuralnetwork.fix_indexes, neighbors=CONFIG.neighChannels, geom=CONFIG.geom, temporal_features=CONFIG.spikes.temporal_features, # FIXME: what is this? temporal_window=3, th_detect=CONFIG.neural_network_detector.threshold_spike, th_triage=CONFIG.neural_network_triage.threshold_collision, detector_filename=CONFIG.neural_network_detector.filename, autoencoder_filename=autoencoder_filename, triage_filename=CONFIG.neural_network_triage.filename) # save clear spikes clear = np.concatenate([element[1] for element in res], axis=0) logger.info('Removing clear indexes outside the allowed range to ' 'draw a complete waveform...') clear, idx = detect.remove_incomplete_waveforms( clear, CONFIG.spikeSize + CONFIG.templatesMaxShift, n_observations) np.save(path_to_spike_index_clear, clear) logger.info('Saved spike index clear in {}...'.format( path_to_spike_index_clear)) # save collided spikes collision = np.concatenate([element[2] for element in res], axis=0) logger.info('Removing collision indexes outside the allowed range to ' 'draw a complete waveform...') collision, _ = detect.remove_incomplete_waveforms( collision, CONFIG.spikeSize + CONFIG.templatesMaxShift, n_observations) np.save(path_to_spike_index_collision, collision) logger.info('Saved spike index collision in {}...'.format( path_to_spike_index_collision)) if CONFIG.clustering.clustering_method == 'location': ####################### # Waveform extraction # ####################### # TODO: what should the behaviour be for spike indexes that are # when starting/ending the recordings and it is not possible to # draw a complete waveform? logger.info('Computing whitening matrix...') bp = BatchProcessor(standarized_path, standarized_params['dtype'], standarized_params['n_channels'], standarized_params['data_format'], CONFIG.resources.max_memory) batches = bp.multi_channel() first_batch, _, _ = next(batches) Q = whiten.matrix(first_batch, CONFIG.neighChannels, CONFIG.spikeSize) path_to_whitening_matrix = os.path.join(TMP_FOLDER, 'whitening.npy') np.save(path_to_whitening_matrix, Q) logger.info('Saved whitening matrix in {}'.format( path_to_whitening_matrix)) # apply whitening to every batch (whitened_path, whitened_params) = bp.multi_channel_apply( np.matmul, mode='disk', output_path=os.path.join(TMP_FOLDER, 'whitened.bin'), if_file_exists='skip', cast_dtype=OUTPUT_DTYPE, b=Q) main_channel = clear[:, 1] # load and dump waveforms from clear spikes path_to_waveforms_clear = os.path.join(TMP_FOLDER, 'waveforms_clear.npy') if os.path.exists(path_to_waveforms_clear): logger.info( 'Found clear waveforms in {}, loading them...'.format( path_to_waveforms_clear)) waveforms_clear = np.load(path_to_waveforms_clear) else: logger.info( 'Did not find clear waveforms in {}, reading them from {}'. format(path_to_waveforms_clear, whitened_path)) explorer = RecordingExplorer(whitened_path, spike_size=CONFIG.spikeSize) waveforms_clear = explorer.read_waveforms(clear[:, 0], 'all') np.save(path_to_waveforms_clear, waveforms_clear) logger.info('Saved waveform from clear spikes in: {}'.format( path_to_waveforms_clear)) main_channel = clear[:, 1] # save rotation detector_filename = CONFIG.neural_network_detector.filename autoencoder_filename = CONFIG.neural_network_autoencoder.filename rotation = neuralnetwork.load_rotation(detector_filename, autoencoder_filename) path_to_rotation = os.path.join(TMP_FOLDER, 'rotation.npy') logger.info("rotation_matrix_shape = {}".format(rotation.shape)) np.save(path_to_rotation, rotation) logger.info( 'Saved rotation matrix in {}...'.format(path_to_rotation)) logger.info('Denoising...') path_to_denoised_waveforms = os.path.join( TMP_FOLDER, 'denoised_waveforms.npy') if os.path.exists(path_to_denoised_waveforms): logger.info( 'Found denoised waveforms in {}, loading them...'.format( path_to_denoised_waveforms)) denoised_waveforms = np.load(path_to_denoised_waveforms) else: logger.info( 'Did not find denoised waveforms in {}, evaluating them' 'from {}'.format(path_to_denoised_waveforms, path_to_waveforms_clear)) waveforms_clear = np.load(path_to_waveforms_clear) denoised_waveforms = dim_red.denoise(waveforms_clear, rotation, CONFIG) logger.info('Saving denoised waveforms to {}'.format( path_to_denoised_waveforms)) np.save(path_to_denoised_waveforms, denoised_waveforms) isolated_index, x, y = get_isolated_spikes_and_locations( denoised_waveforms, main_channel, CONFIG) x = (x - np.mean(x)) / np.std(x) y = (y - np.mean(y)) / np.std(y) corrupted_index = np.logical_not( np.in1d(np.arange(clear.shape[0]), isolated_index)) collision = np.concatenate([collision, clear[corrupted_index]], axis=0) clear = clear[isolated_index] waveforms_clear = waveforms_clear[isolated_index] ################################################# # Dimensionality reduction (Isolated Waveforms) # ################################################# scores = dim_red.main_channel_scores(waveforms_clear, rotation, clear, CONFIG) scores = (scores - np.mean(scores, axis=0)) / np.std(scores) scores = np.concatenate([ x[:, np.newaxis, np.newaxis], y[:, np.newaxis, np.newaxis], scores[:, :, np.newaxis] ], axis=1) else: # save scores scores = np.concatenate([element[0] for element in res], axis=0) logger.info( 'Removing scores for indexes outside the allowed range to ' 'draw a complete waveform...') scores = scores[idx] # compute Q for whitening logger.info('Computing whitening matrix...') bp = BatchProcessor(standarized_path, standarized_params['dtype'], standarized_params['n_channels'], standarized_params['data_format'], CONFIG.resources.max_memory) batches = bp.multi_channel() first_batch, _, _ = next(batches) Q = whiten.matrix_localized(first_batch, CONFIG.neighChannels, CONFIG.geom, CONFIG.spikeSize) path_to_whitening_matrix = os.path.join(TMP_FOLDER, 'whitening.npy') np.save(path_to_whitening_matrix, Q) logger.info('Saved whitening matrix in {}'.format( path_to_whitening_matrix)) scores = whiten.score(scores, clear[:, 1], Q) np.save(path_to_score, scores) logger.info('Saved spike scores in {}...'.format(path_to_score)) # save rotation detector_filename = CONFIG.neural_network_detector.filename autoencoder_filename = CONFIG.neural_network_autoencoder.filename rotation = neuralnetwork.load_rotation(detector_filename, autoencoder_filename) path_to_rotation = os.path.join(TMP_FOLDER, 'rotation.npy') np.save(path_to_rotation, rotation) logger.info( 'Saved rotation matrix in {}...'.format(path_to_rotation)) np.save(path_to_score, scores) logger.info('Saved spike scores in {}...'.format(path_to_score)) return scores, clear, collision
def test_splitting_in_batches_does_not_affect(path_to_tests, path_to_standarized_data, path_to_sample_pipeline_folder): yass.set_config(path.join(path_to_tests, 'config_nnet.yaml')) CONFIG = yass.read_config() PATH_TO_DATA = path_to_standarized_data data = RecordingsReader(PATH_TO_DATA, loader='array').data with open( path.join(path_to_sample_pipeline_folder, 'preprocess', 'standarized.yaml')) as f: PARAMS = yaml.load(f) channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom) detection_th = CONFIG.detect.neural_network_detector.threshold_spike triage_th = CONFIG.detect.neural_network_triage.threshold_collision detection_fname = CONFIG.detect.neural_network_detector.filename ae_fname = CONFIG.detect.neural_network_autoencoder.filename triage_fname = CONFIG.detect.neural_network_triage.filename # instantiate neural networks NND = NeuralNetDetector.load(detection_fname, detection_th, channel_index) NNT = NeuralNetTriage.load(triage_fname, triage_th, input_tensor=NND.waveform_tf) NNAE = AutoEncoder(ae_fname, input_tensor=NND.waveform_tf) output_tf = (NNAE.score_tf, NND.spike_index_tf, NNT.idx_clean) # run all at once with tf.Session() as sess: # get values of above tensors NND.restore(sess) NNAE.restore(sess) NNT.restore(sess) rot = NNAE.load_rotation() neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) (scores, clear, collision) = neuralnetwork.run_detect_triage_featurize( data, sess, NND.x_tf, output_tf, neighbors, rot) # run in batches - buffer size makes sure we can detect spikes if they # appear at the end of any batch bp = BatchProcessor(PATH_TO_DATA, PARAMS['dtype'], PARAMS['n_channels'], PARAMS['data_order'], '100KB', buffer_size=CONFIG.spike_size) with tf.Session() as sess: # get values of above tensors NND.restore(sess) NNAE.restore(sess) NNT.restore(sess) rot = NNAE.load_rotation() neighbors = n_steps_neigh_channels(CONFIG.neigh_channels, 2) res = bp.multi_channel_apply( neuralnetwork.run_detect_triage_featurize, mode='memory', cleanup_function=neuralnetwork.fix_indexes, sess=sess, x_tf=NND.x_tf, output_tf=output_tf, rot=rot, neighbors=neighbors) scores_batch = np.concatenate([element[0] for element in res], axis=0) clear_batch = np.concatenate([element[1] for element in res], axis=0) collision_batch = np.concatenate([element[2] for element in res], axis=0) np.testing.assert_array_equal(clear_batch, clear) np.testing.assert_array_equal(collision_batch, collision) np.testing.assert_array_equal(scores_batch, scores)