def run(fname_recording, recording_dtype, fname_spike_train, output_directory): """ """ logger = logging.getLogger(__name__) CONFIG = read_config() # make output directory if not exist if not os.path.exists(output_directory): os.mkdir(output_directory) # get reader reader = READER(fname_recording, recording_dtype, CONFIG) reader.spike_size = CONFIG.spike_size_nn # get noise covariance logger.info('Compute Noise Covaraince') save_dir = os.path.join(output_directory, 'noise_cov') chunk = [0, np.min((5*60*reader.sampling_rate, reader.end))] fname_spatial_sig, fname_temporal_sig = get_noise_covariance( reader, save_dir, CONFIG, chunk) # get processed templates logger.info('Crop Templates') save_dir = os.path.join(output_directory, 'templates') fname_templates_snippets = get_templates_on_local_channels( reader, save_dir, fname_spike_train, CONFIG) # denoise templates fname_templates_denoised = denoise_templates( fname_templates_snippets, save_dir) # make training data logger.info('Make Training Data') DetectTD = Detection_Training_Data( fname_templates_denoised, fname_spatial_sig, fname_temporal_sig) DenoTD = Denoising_Training_Data( fname_templates_denoised, fname_spatial_sig, fname_temporal_sig) return DetectTD, DenoTD
def residual_ONgpu(recordings_filename, recording_dtype, CONFIG, fname_shifts, fname_templates, output_directory, dtype_out, fname_out, fname_spike_train, update_templates, run_chunk_sec): # get data reader if run_chunk_sec == 'full': chunk_sec = None else: chunk_sec = run_chunk_sec reader = READER(recordings_filename, recording_dtype, CONFIG, CONFIG.resources.n_sec_chunk_gpu_deconv, chunk_sec=chunk_sec) if False: RESIDUAL_GPU3(reader, recordings_filename, recording_dtype, CONFIG, fname_shifts, fname_templates, output_directory, dtype_out, fname_out, fname_spike_train, update_templates) else: RESIDUAL_GPU2(reader, recordings_filename, recording_dtype, CONFIG, fname_shifts, fname_templates, output_directory, dtype_out, fname_out, fname_spike_train, update_templates)
def residual_ONcpu(fname_templates, fname_spike_train, output_directory, recordings_filename, recording_dtype, dtype_out, fname_out, run_chunk_sec, CONFIG): # get data reader if run_chunk_sec == 'full': chunk_sec = None else: chunk_sec = run_chunk_sec reader = READER(recordings_filename, recording_dtype, CONFIG, CONFIG.resources.n_sec_chunk, chunk_sec=chunk_sec) # get residual object residual_object = RESIDUAL(fname_templates, fname_spike_train, reader, fname_out, dtype_out) # compute residual seg_dir = os.path.join(output_directory, 'segs') residual_object.compute_residual(seg_dir, CONFIG.resources.multi_processing, CONFIG.resources.n_processors) # concatenate all segments residual_object.save_residual()
def run_neural_network(standardized_path, standardized_dtype, output_directory, run_chunk_sec='full'): """Run neural network detection """ logger = logging.getLogger(__name__) CONFIG = read_config() # load NN detector detector = Detect(CONFIG.neuralnetwork.detect.n_filters, CONFIG.spike_size_nn, CONFIG.channel_index) detector.load(CONFIG.neuralnetwork.detect.filename) # load NN denoiser denoiser = Denoise(CONFIG.neuralnetwork.denoise.n_filters, CONFIG.neuralnetwork.denoise.filter_sizes, CONFIG.spike_size_nn) denoiser.load(CONFIG.neuralnetwork.denoise.filename) # get data reader batch_length = CONFIG.resources.n_sec_chunk * CONFIG.resources.n_processors n_sec_chunk = CONFIG.resources.n_sec_chunk_gpu_detect print(" batch length to (sec): ", batch_length, " (longer increase speed a bit)") print(" length of each seg (sec): ", n_sec_chunk) buffer = CONFIG.spike_size_nn if run_chunk_sec == 'full': chunk_sec = None else: chunk_sec = run_chunk_sec reader = READER(standardized_path, standardized_dtype, CONFIG, batch_length, buffer, chunk_sec) # neighboring channels channel_index_dedup = make_channel_index(CONFIG.neigh_channels, CONFIG.geom, steps=2) # threshold for neuralnet detection detect_threshold = CONFIG.detect.threshold # loop over each chunk batch_ids = np.arange(reader.n_batches) batch_ids_split = np.split(batch_ids, len(CONFIG.torch_devices)) processes = [] for ii, device in enumerate(CONFIG.torch_devices): p = mp.Process(target=run_nn_detction_batch, args=(batch_ids_split[ii], output_directory, reader, n_sec_chunk, detector, denoiser, channel_index_dedup, detect_threshold, device)) p.start() processes.append(p) for p in processes: p.join()
def get_plot_ptps(save_dir, fname_raw, fname_residual, fname_spike_train, fname_scales, fname_shifts, templates_dir, ptp_threshold, n_col, CONFIG, units_in=None, fname_drifts_gt=None, n_nearby_units=3): reader_raw = READER(fname_raw, 'float32', CONFIG) reader_resid = READER(fname_residual, 'float32', CONFIG) update_time = CONFIG.deconvolution.template_update_time # load initial templates init_templates = np.load( os.path.join(templates_dir, 'templates_{}sec.npy').format(0)) n_units = init_templates.shape[0] meta_data_dir = os.path.join(save_dir, 'meta_data') if not os.path.exists(meta_data_dir): os.makedirs(meta_data_dir) figs_dir = os.path.join(save_dir, 'figs') if not os.path.exists(figs_dir): os.makedirs(figs_dir) if units_in is None: units_in = np.arange(n_units) units_in = units_in[units_in < n_units] get_plot_ptps_parallel(units_in, reader_raw, reader_resid, fname_spike_train, fname_scales, fname_shifts, templates_dir, meta_data_dir, figs_dir, update_time, ptp_threshold, n_col, fname_drifts_gt, n_nearby_units)
def run_template_update(output_directory, fname_templates, fname_spike_train, fname_shifts, fname_scales, fname_residual, residual_dtype, residual_offset=0, update_weight=50, units_to_update=None): fname_templates_out = os.path.join(output_directory, 'templates.npy') if not os.path.exists(fname_templates_out): print('updating templates') CONFIG = read_config() # output folder if not os.path.exists(output_directory): os.makedirs(output_directory) # reader if CONFIG.deconvolution.deconv_gpu: n_sec_chunk = CONFIG.resources.n_sec_chunk_gpu_deconv else: n_sec_chunk = CONFIG.resources.n_sec_chunk reader_residual = READER(fname_residual, residual_dtype, CONFIG, n_sec_chunk, offset=residual_offset) # residual obj that can shift templates in gpu residual_comp = RESIDUAL_GPU2(None, CONFIG, None, None, None, None, None, None, None, None, True) residual_comp.load_templates(fname_templates) residual_comp.make_bsplines_parallel() avg_min_max_vals, weights = get_avg_min_max_vals( fname_templates, fname_spike_train, fname_shifts, fname_scales, reader_residual, residual_comp, units_to_update) templates_updated = update_templates(fname_templates, weights, avg_min_max_vals, update_weight, units_to_update) np.save(fname_templates_out, templates_updated) return fname_templates_out
def run_voltage_treshold(standardized_path, standardized_dtype, output_directory, run_chunk_sec='full'): """Run detection that thresholds on amplitude """ logger = logging.getLogger(__name__) CONFIG = read_config() # get data reader #n_sec_chunk = CONFIG.resources.n_sec_chunk*CONFIG.resources.n_processors batch_length = CONFIG.resources.n_sec_chunk n_sec_chunk = 0.5 print(" batch length to (sec): ", batch_length, " (longer increase speed a bit)") print(" length of each seg (sec): ", n_sec_chunk) buffer = CONFIG.spike_size if run_chunk_sec == 'full': chunk_sec = None else: chunk_sec = run_chunk_sec reader = READER(standardized_path, standardized_dtype, CONFIG, batch_length, buffer, chunk_sec) # number of processed chunks n_mini_per_big_batch = int(np.ceil(batch_length / n_sec_chunk)) total_processing = int(reader.n_batches * n_mini_per_big_batch) # neighboring channels channel_index = make_channel_index(CONFIG.neigh_channels, CONFIG.geom, steps=2) if CONFIG.resources.multi_processing: parmap.starmap(run_voltage_threshold_parallel, list(zip(np.arange(reader.n_batches))), reader, n_sec_chunk, CONFIG.detect.threshold, channel_index, output_directory, processes=CONFIG.resources.n_processors, pm_pbar=True) else: for batch_id in range(reader.n_batches): run_voltage_threshold_parallel(batch_id, reader, n_sec_chunk, CONFIG.detect.threshold, channel_index, output_directory)
def run(template_fname, spike_train_fname, shifts_fname, output_directory, residual_fname, residual_dtype): logger = logging.getLogger(__name__) CONFIG = read_config() # fname_out = os.path.join(output_directory, 'soft_assignment.npy') if os.path.exists(fname_out): return fname_out # output folder if not os.path.exists(output_directory): os.makedirs(output_directory) # reader for residual reader_resid = READER(residual_fname, residual_dtype, CONFIG, CONFIG.resources.n_sec_chunk_gpu_deconv) # load NN detector detector = Detect(CONFIG.neuralnetwork.detect.n_filters, CONFIG.spike_size_nn, CONFIG.channel_index) detector.load(CONFIG.neuralnetwork.detect.filename) detector = detector.cuda() # initialize soft assignment calculator threshold = CONFIG.deconvolution.threshold / 0.1 sna = SOFTNOISEASSIGNMENT(spike_train_fname, template_fname, shifts_fname, reader_resid, detector, CONFIG.channel_index, threshold) # compuate soft assignment probs = sna.compute_soft_assignment() np.save(fname_out, probs) return fname_out
def run(output_directory): """Preprocess pipeline: filtering, standarization and whitening filter This step (optionally) performs filtering on the data, standarizes it and computes a whitening filter. Filtering and standardized data are processed in chunks and written to disk. Parameters ---------- output_directory: str where results will be saved Returns ------- standardized_path: str Path to standardized data binary file standardized_params: str Path to standardized data parameters channel_index: numpy.ndarray Channel indexes whiten_filter: numpy.ndarray Whiten matrix Notes ----- Running the preprocessor will generate the followiing files in CONFIG.data.root_folder/output_directory/: * ``filtered.bin`` - Filtered recordings * ``filtered.yaml`` - Filtered recordings metadata * ``standardized.bin`` - Standarized recordings * ``standardized.yaml`` - Standarized recordings metadata * ``whitening.npy`` - Whitening filter Everything is run on CPU. Examples -------- .. literalinclude:: ../../examples/pipeline/preprocess.py """ # ********************************************** # *********** Initialize *********************** # ********************************************** logger = logging.getLogger(__name__) # load config CONFIG = read_config() # raw data info filename_raw = os.path.join(CONFIG.data.root_folder, CONFIG.data.recordings) dtype_raw = CONFIG.recordings.dtype n_channels = CONFIG.recordings.n_channels if not CONFIG.preprocess.apply_filter: return filename_raw, dtype_raw # if apply filter, get recording reader n_sec_chunk = CONFIG.resources.n_sec_chunk reader = READER(filename_raw, dtype_raw, CONFIG, n_sec_chunk) logger.info("# of chunks: {}".format(reader.n_batches)) # make output directory if not os.path.exists(output_directory): logger.info('Creating temporary folder: {}'.format(output_directory)) os.makedirs(output_directory) else: logger.info('Temporary folder {} already exists, output will be ' 'stored there'.format(output_directory)) # make output parameters standardized_path = os.path.join(output_directory, "standardized.bin") standardized_params = dict(dtype=CONFIG.preprocess.dtype, n_channels=n_channels) logger.info('Output dtype for transformed data will be {}'.format( CONFIG.preprocess.dtype)) # Check if data already saved to disk and skip: if os.path.exists(standardized_path): return standardized_path, standardized_params['dtype'] # ********************************************** # *********** run filter & stdarize *********** # ********************************************** # get necessary parameters low_frequency = CONFIG.preprocess.filter.low_pass_freq high_factor = CONFIG.preprocess.filter.high_factor order = CONFIG.preprocess.filter.order sampling_rate = CONFIG.recordings.sampling_rate # estimate std from a small chunk chunk_5sec = 5 * CONFIG.recordings.sampling_rate if CONFIG.rec_len < chunk_5sec: chunk_5sec = CONFIG.rec_len small_batch = reader.read_data( data_start=CONFIG.rec_len // 2 - chunk_5sec // 2, data_end=CONFIG.rec_len // 2 + chunk_5sec // 2) fname_mean_sd = os.path.join(output_directory, 'mean_and_standard_dev_value.npz') if not os.path.exists(fname_mean_sd): get_std(small_batch, sampling_rate, fname_mean_sd, CONFIG.preprocess.apply_filter, low_frequency, high_factor, order) # turn it off small_batch = None # Make directory to hold filtered batch files: filtered_location = os.path.join(output_directory, "filtered_files") if not os.path.exists(filtered_location): os.makedirs(filtered_location) # read config params multi_processing = CONFIG.resources.multi_processing if CONFIG.resources.multi_processing: n_processors = CONFIG.resources.n_processors parmap.map(filter_standardize_batch, [i for i in range(reader.n_batches)], reader, fname_mean_sd, CONFIG.preprocess.apply_filter, CONFIG.preprocess.dtype, filtered_location, low_frequency, high_factor, order, sampling_rate, processes=n_processors, pm_pbar=True) else: for batch_id in range(reader.n_batches): filter_standardize_batch( batch_id, reader, fname_mean_sd, CONFIG.preprocess.apply_filter, CONFIG.preprocess.dtype, filtered_location, low_frequency, high_factor, order, sampling_rate, ) # Merge the chunk filtered files and delete the individual chunks merge_filtered_files(filtered_location, output_directory) # save yaml file with params path_to_yaml = standardized_path.replace('.bin', '.yaml') with open(path_to_yaml, 'w') as f: logger.info('Saving params...') yaml.dump(standardized_params, f) return standardized_path, standardized_params['dtype']
def post_process(output_directory, fname_templates, fname_spike_train, fname_weights, fname_recording, recording_dtype, units_in, method, ctr): ''' Run a single post process method: strings. Options are 'low_ptp', 'duplicate', 'collision', 'high_mad', 'low_fr', 'high_fr', 'off_center', 'duplicate_l2' ''' logger = logging.getLogger(__name__) CONFIG = read_config() if method == 'low_ptp': # Cat: TODO: move parameter to CONFIG threshold = CONFIG.clean_up.min_ptp # load templates templates = np.load(fname_templates) # remove low ptp units_out = remove_small_units(templates, threshold, units_in) logger.info("{} units after removing low ptp units".format( len(units_out))) elif method == 'off_center': threshold = CONFIG.clean_up.off_center # load templates templates = np.load(fname_templates) # remove off centered units units_out = remove_off_centered_units(templates, threshold, units_in) logger.info("{} units after removing off centered units".format( len(units_out))) elif method == 'duplicate': # tmp saving dir save_dir = os.path.join(output_directory, 'duplicates_{}'.format(ctr)) # remove duplicates units_out = remove_duplicates(fname_templates, fname_weights, save_dir, CONFIG, units_in, CONFIG.resources.multi_processing, CONFIG.resources.n_processors) logger.info("{} units after removing duplicate units".format( len(units_out))) elif method == 'duplicate_l2': # tmp saving dir save_dir = os.path.join(output_directory, 'duplicates_l2_{}'.format(ctr)) # remove duplicates n_spikes_big = 100 min_ptp = 2 units_out = duplicate_l2(fname_templates, fname_spike_train, CONFIG.neigh_channels, save_dir, n_spikes_big, min_ptp, units_in) logger.info("{} units after removing L2 duplicate units".format( len(units_out))) elif method == 'collision': # save folder save_dir = os.path.join(output_directory, 'collision_{}'.format(ctr)) # find collision units and remove units_out = remove_collision(fname_templates, save_dir, CONFIG, units_in, CONFIG.resources.multi_processing, CONFIG.resources.n_processors) logger.info("{} units after removing collision units".format( len(units_out))) elif method == 'high_mad': # get data reader reader = READER(fname_recording, recording_dtype, CONFIG) # save folder save_dir = os.path.join(output_directory, 'mad_{}'.format(ctr)) # neighboring channels neigh_channels = n_steps_neigh_channels(CONFIG.neigh_channels, 2) max_violations = CONFIG.clean_up.mad.max_violations min_var_gap = CONFIG.clean_up.mad.min_var_gap # find high mad units and remove units_out = remove_high_mad(fname_templates, fname_spike_train, fname_weights, reader, neigh_channels, save_dir, min_var_gap, max_violations, units_in, CONFIG.resources.multi_processing, CONFIG.resources.n_processors) logger.info("{} units after removing high mad units".format( len(units_out))) elif method == 'low_fr': threshold = CONFIG.clean_up.min_fr # length of recording in seconds rec_len = np.load(fname_spike_train)[:, 0].ptp() rec_len_sec = float(rec_len) / CONFIG.recordings.sampling_rate # load templates weights = np.load(fname_weights) # remove low ptp units_out = remove_low_fr_units(weights, rec_len_sec, threshold, units_in) logger.info("{} units after removing low fr units".format( len(units_out))) elif method == 'high_fr': # TODO: move parameter to config? threshold = 70 # length of recording in seconds rec_len = np.load(fname_spike_train)[:, 0].ptp() rec_len_sec = float(rec_len) / CONFIG.recordings.sampling_rate # load templates weights = np.load(fname_weights) # remove low ptp units_out = remove_high_fr_units(weights, rec_len_sec, threshold, units_in) logger.info("{} units after removing high fr units".format( len(units_out))) else: units_out = np.copy(units_in) logger.info("Method not recognized. Nothing removed") return units_out
def run_post_deconv_split(output_directory, fname_templates, fname_spike_train, fname_shifts, fname_scales, fname_raw, raw_dtype, fname_residual, residual_dtype, residual_offset=0, initial_batch=False): CONFIG = read_config() # output folder if not os.path.exists(output_directory): os.makedirs(output_directory) # reader if CONFIG.deconvolution.deconv_gpu: n_sec_chunk = CONFIG.resources.n_sec_chunk_gpu_deconv else: n_sec_chunk = CONFIG.resources.n_sec_chunk reader_residual = READER(fname_residual, residual_dtype, CONFIG, n_sec_chunk, offset=residual_offset) reader_raw = READER(fname_raw, raw_dtype, CONFIG) # load input data templates = np.load(fname_templates) spike_train = np.load(fname_spike_train) shifts = np.load(fname_shifts) scales = np.load(fname_scales) # get cleaned ptp fname_cleaned_ptp = os.path.join(output_directory, 'cleaned_ptp.npy') fname_spike_times = os.path.join(output_directory, 'spike_times.npy') fname_shifts = os.path.join(output_directory, 'shifts_list.npy') fname_scales = os.path.join(output_directory, 'scales_list.npy') fname_vis_chans = os.path.join(output_directory, 'vis_chans.npy') if os.path.exists(fname_cleaned_ptp) and os.path.exists(fname_spike_times): cleaned_ptp = np.load(fname_cleaned_ptp, allow_pickle=True) spike_times_list = np.load(fname_spike_times, allow_pickle=True) shifts_list = np.load(fname_shifts, allow_pickle=True) scales_list = np.load(fname_scales, allow_pickle=True) vis_chans = np.load(fname_vis_chans, allow_pickle=True) else: print('get cleaned ptp') (cleaned_ptp, spike_times_list, shifts_list, scales_list, vis_chans) = get_cleaned_ptp( templates, spike_train, shifts, scales, reader_residual, fname_templates, CONFIG) np.save(fname_shifts, shifts_list, allow_pickle=True) np.save(fname_scales, scales_list, allow_pickle=True) np.save(fname_vis_chans, vis_chans, allow_pickle=True) np.save(fname_cleaned_ptp, cleaned_ptp, allow_pickle=True) np.save(fname_spike_times, spike_times_list, allow_pickle=True) # split units fname_templates_updated = os.path.join( output_directory, 'templated_updated.npy') fname_spike_train_updated = os.path.join( output_directory, 'spike_train_updated.npy') fname_shifts_updated = os.path.join( output_directory, 'shifts_updated.npy') fname_scales_updated = os.path.join( output_directory, 'scales_updated.npy') if not(os.path.exists(fname_templates_updated) and os.path.exists(fname_spike_train_updated)): print('run split') if initial_batch: update_original_templates = True min_fr_accept = 0 min_ptp_accept = 0 min_fraction_accept = 0 else: update_original_templates = False min_fr_accept = 0.5 min_ptp_accept = 10000 min_fraction_accept = 0.15 (templates_updated, spike_train_updated, shifts_updated, scales_updated) = run_split( cleaned_ptp, spike_times_list, shifts_list, scales_list, vis_chans, templates, spike_train, reader_raw, CONFIG, update_original_templates=update_original_templates, min_ptp_accept=min_ptp_accept, min_fr_accept=min_fr_accept, min_fraction_accept=min_fraction_accept) # denoise split units n_splits = templates_updated.shape[0] - templates.shape[0] vis_threshold_strong = 1. vis_threshold_weak = 0.5 rank = 5 pad_len = int(1.5 * CONFIG.recordings.sampling_rate / 1000.) jitter_len = pad_len split_templates_denoised = shift_svd_denoise( templates_updated[-n_splits:], CONFIG, vis_threshold_strong, vis_threshold_weak, rank, pad_len, jitter_len) templates_updated[-n_splits:] = split_templates_denoised ## add new templates and spike train to the existing one #templates_updated = np.concatenate((templates, new_temps), axis=0) #spike_train_new[:, 1] += templates.shape[0] #spike_train_updated = np.concatenate((spike_train, spike_train_new), axis=0) idx_sort = np.argsort(spike_train_updated[:, 0]) spike_train_updated = spike_train_updated[idx_sort] shifts_updated = shifts_updated[idx_sort] scales_updated = scales_updated[idx_sort] np.save(fname_templates_updated, templates_updated) np.save(fname_spike_train_updated, spike_train_updated) np.save(fname_shifts_updated, shifts_updated) np.save(fname_scales_updated, scales_updated) # can be used to find gpu memory not freed # import gc #n_objects = 0 #for obj in gc.get_objects(): # try: # if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)): # print(obj, type(obj), obj.size()) # # n_objects += 1 # except: # pass #print(n_objects) #units_to_process = np.arange(templates.shape[0], templates_updated.shape[0]) #else: # templates_updated = np.load(fname_templates_updated) # units_to_process = np.arange(templates.shape[0], templates_updated.shape[0]) ## kill duplicate templates #methods = ['low_ptp', 'duplicate'] #(fname_templates_out, fname_spike_train_out, # _, _, _) = postprocess.run( # methods, # os.path.join(output_directory, # 'duplicate_remove'), # None, # None, # fname_templates_updated, # fname_spike_train_updated, # None, # None, # None, # None, # units_to_process) #return fname_templates_out, fname_spike_train_out return (fname_templates_updated, fname_spike_train_updated, fname_shifts_updated, fname_scales_updated)
def run(template_fname, spike_train_fname, shifts_fname, scales_fname, output_directory, residual_fname, residual_dtype, residual_offset=0, compute_noise_soft=True, compute_template_soft=True, update_templates=False, similar_array=None): logger = logging.getLogger(__name__) CONFIG = read_config() # fname_noise_soft = os.path.join(output_directory, 'noise_soft_assignment.npy') fname_template_soft = os.path.join(output_directory, 'template_soft_assignment.npz') # output folder if not os.path.exists(output_directory): os.makedirs(output_directory) # reader for residual reader_resid = READER(residual_fname, residual_dtype, CONFIG, CONFIG.resources.n_sec_chunk_gpu_deconv, offset=residual_offset) ######################## # Noise soft assignment# ######################## if compute_noise_soft and (not os.path.exists(fname_noise_soft)): if CONFIG.neuralnetwork.apply_nn: # load NN detector detector = Detect(CONFIG.neuralnetwork.detect.n_filters, CONFIG.spike_size_nn, CONFIG.channel_index, CONFIG) detector.load(CONFIG.neuralnetwork.detect.filename) detector = detector.cuda() # initialize soft assignment calculator threshold = CONFIG.deconvolution.threshold / 0.1 # HACK now.. it needs a proper fix later if update_templates: template_fname_ = os.path.join(template_fname, 'templates_init.npy') else: template_fname_ = template_fname sna = SOFTNOISEASSIGNMENT(spike_train_fname, template_fname_, shifts_fname, scales_fname, reader_resid, detector, CONFIG.channel_index, threshold) # compuate soft assignment probs_noise = sna.compute_soft_assignment() np.save(fname_noise_soft, probs_noise) del sna del detector torch.cuda.empty_cache() else: spike_train = np.load(spike_train_fname) np.save(fname_noise_soft, np.ones(len(spike_train))) ########################### # Template soft assignment# ########################### if compute_template_soft and (not os.path.exists(fname_template_soft)): # get whitening filters fname_spatial_cov = os.path.join(output_directory, 'spatial_cov.npy') fname_temporal_cov = os.path.join(output_directory, 'temporal_cov.npy') if not (os.path.exists(fname_spatial_cov) and os.path.exists(fname_temporal_cov)): spatial_cov, temporal_cov = get_noise_covariance( reader_resid, CONFIG) np.save(fname_spatial_cov, spatial_cov) np.save(fname_temporal_cov, temporal_cov) else: spatial_cov = np.load(fname_spatial_cov) temporal_cov = np.load(fname_temporal_cov) window_size = 51 n_chans = 10 reader_resid = READER(residual_fname, residual_dtype, CONFIG, CONFIG.resources.n_sec_chunk_gpu_deconv, offset=residual_offset) TAO = TEMPLATE_ASSIGN_OBJECT( fname_spike_train=spike_train_fname, fname_templates=template_fname, fname_shifts=shifts_fname, reader_residual=reader_resid, spat_cov=spatial_cov, temp_cov=temporal_cov, channel_idx=CONFIG.channel_index, geom=CONFIG.geom, large_unit_threshold=100000, n_chans=n_chans, rec_chans=CONFIG.channel_index.shape[0], sim_units=3, temp_thresh=5, lik_window=window_size, similar_array=similar_array, update_templates=update_templates, template_update_time=CONFIG.deconvolution.template_update_time) probs_templates, _, logprobs_outliers, units_assignment = TAO.run() #outlier spike times/units chi2_df = (2 * (window_size // 2) + 1) * n_chans cut_off = chi2(chi2_df).ppf(.999) #s_table = s_score(_) #s_table = s_score(probs_templates) #logprobs_outliers = logprobs_outliers/chi2_df cpu_sps = TAO.spike_train_og outliers = cpu_sps[np.where(logprobs_outliers.min(1) > cut_off)[0], :] #append log_probs to spike_times #logprobs = np.concatenate((cpu_sps,TAO.log_probs), axis = 1) # compuate soft assignment #np.save(prob_template_fname, probs_templates) #np.save(outlier_fname, outliers) #np.save(logprobs_outlier_fname, logprobs_outliers) #np.save(units_assign_fname, units_assignment) np.savez( fname_template_soft, probs_templates=probs_templates, units_assignment=units_assignment, #logprobs = _, #sihoulette_score = s_table, logprobs_outliers=logprobs_outliers, outliers=outliers) del TAO torch.cuda.empty_cache() return fname_noise_soft, fname_template_soft
def update_templates(fname_templates, fname_spike_train, recordings_filename, recording_dtype, output_directory, rate=0.002, unit_ids=None): logger = logging.getLogger(__name__) CONFIG = read_config() # output folder if not os.path.exists(output_directory): os.makedirs(output_directory) fname_templates_updated = os.path.join(output_directory, 'templates_updated.npy') if os.path.exists(fname_templates_updated): return fname_templates_updated, None reader = READER(recordings_filename, recording_dtype, CONFIG) # max channel for each unit max_channels = np.load(fname_templates).ptp(1).argmax(1) fname_templates_new = run_template_computation( fname_spike_train, reader, output_directory, max_channels=max_channels, unit_ids=unit_ids, multi_processing=CONFIG.resources.multi_processing, n_processors=CONFIG.resources.n_processors) # load templates templates_orig = np.load(fname_templates) templates_new = np.load(fname_templates_new) n_units, n_times, n_channels = templates_orig.shape n_units_new = templates_new.shape[0] if unit_ids is None: unit_ids = np.arange(n_units) # if last few units have no spikes deconvovled, the length of new templates # can be shorter. then, zero pad it if n_units_new < n_units: zero_pad = np.zeros((n_units - n_units_new, n_times, n_channels), 'float32') templates_new = np.concatenate((templates_new, zero_pad), axis=0) # number of deconvolved spikes n_spikes = np.zeros(n_units) units_unique, n_spikes_unique = np.unique(np.load(fname_spike_train)[:, 1], return_counts=True) n_spikes[units_unique] = n_spikes_unique # update rule if it will be updated weight_to_update = np.power((1 - rate), n_spikes) # only update for units in unit_ids weight = np.ones(n_units) weight[unit_ids] = weight_to_update[unit_ids] weight = weight[:, None, None] # align templates templates_orig, templates_new = align_two_set_of_templates( templates_orig, templates_new) # update and save templates_updated = weight * templates_orig + (1 - weight) * templates_new np.save(fname_templates_updated, templates_updated) # check the difference max_diff = np.zeros(n_units) max_diff[unit_ids] = np.max(np.abs(templates_new[unit_ids] - templates_orig[unit_ids]), axis=(1, 2)) max_diff = max_diff / templates_orig.ptp(1).max(1) return fname_templates_updated, max_diff
def deconv_ONcpu(fname_templates_in, output_directory, recordings_filename, recording_dtype, threshold, run_chunk_sec, save_up_data, fname_spike_train, fname_spike_train_up, fname_templates, fname_templates_up, CONFIG): logger = logging.getLogger(__name__) # parameters # TODO: read from CONFIG if threshold is None: threshold = CONFIG.deconvolution.threshold elif threshold == 'max': min_norm_2 = np.square(np.load(fname_templates_in)).sum((1, 2)).min() threshold = min_norm_2 * 0.8 conv_approx_rank = 5 upsample_max_val = 8 max_iter = 1000 if run_chunk_sec == 'full': chunk_sec = None else: chunk_sec = run_chunk_sec reader = READER(recordings_filename, recording_dtype, CONFIG, CONFIG.resources.n_sec_chunk, chunk_sec=chunk_sec) mp_object = MatchPursuit_objectiveUpsample( fname_templates=fname_templates_in, save_dir=output_directory, reader=reader, max_iter=max_iter, upsample=upsample_max_val, threshold=threshold, conv_approx_rank=conv_approx_rank, n_processors=CONFIG.resources.n_processors, multi_processing=CONFIG.resources.multi_processing) logger.info('Number of Units IN: {}'.format(mp_object.temps.shape[2])) # directory to save results for each segment seg_dir = os.path.join(output_directory, 'seg') if not os.path.exists(seg_dir): os.makedirs(seg_dir) # skip files/batches already completed; this allows more even distribution # across cores in case of restart # Cat: TODO: if cpu is still being used by endusers, may wish to implement # dynamic file assignment here to deal with slow cores etc. fnames_out = [] batch_ids = [] for batch_id in range(reader.n_batches): fname_temp = os.path.join( seg_dir, "seg_{}_deconv.npz".format(str(batch_id).zfill(6))) if os.path.exists(fname_temp): continue fnames_out.append(fname_temp) batch_ids.append(batch_id) logger.info("running deconvolution on {} batches of {} seconds".format( len(batch_ids), CONFIG.resources.n_sec_chunk)) if len(batch_ids) > 0: if CONFIG.resources.multi_processing: logger.info("running deconvolution with {} processors".format( CONFIG.resources.n_processors)) batches_in = np.array_split(batch_ids, CONFIG.resources.n_processors) fnames_in = np.array_split(fnames_out, CONFIG.resources.n_processors) parmap.starmap(mp_object.run, list(zip(batches_in, fnames_in)), processes=CONFIG.resources.n_processors, pm_pbar=True) else: logger.info("running deconvolution") for ctr in range(len(batch_ids)): mp_object.run([batch_ids[ctr]], [fnames_out[ctr]]) # collect result res = [] logger.info("gathering deconvolution results") for batch_id in range(reader.n_batches): fname_out = os.path.join( seg_dir, "seg_{}_deconv.npz".format(str(batch_id).zfill(6))) res.append(np.load(fname_out)['spike_train']) res = np.vstack(res) logger.info('Number of Spikes deconvolved: {}'.format(res.shape[0])) # save templates and upsampled templates np.save(fname_templates, np.load(fname_templates_in)) #np.save(fname_templates, # mp_object.temps.transpose(2,0,1)) # since deconv spike time is not centered, get shift for centering shift = CONFIG.spike_size // 2 # get spike train and save spike_train = np.copy(res) # map back to original id spike_train[:, 1] = np.int32(spike_train[:, 1] / mp_object.upsample_max_val) spike_train[:, 0] += shift # save np.save(fname_spike_train, spike_train) if save_up_data: # get upsampled templates and mapping for computing residual (templates_up, deconv_id_sparse_temp_map ) = mp_object.get_sparse_upsampled_templates() np.save(fname_templates_up, templates_up.transpose(2, 0, 1)) # get upsampled spike train spike_train_up = np.copy(res) spike_train_up[:, 1] = deconv_id_sparse_temp_map[spike_train_up[:, 1]] spike_train_up[:, 0] += shift np.save(fname_spike_train_up, spike_train_up)
def run(output_directory, fname_recording, recording_dtype, fname_residual=None, residual_dtype=None, fname_spike_index=None, fname_templates=None, fname_spike_train=None, fname_shifts=None, fname_scales=None, raw_data=True, full_run=False): """Spike clustering Parameters ---------- spike_index: numpy.ndarray (n_clear_spikes, 2), str or Path 2D array with indexes for spikes, first column contains the spike location in the recording and the second the main channel (channel whose amplitude is maximum). Or path to an npy file out_dir: str, optional Location to store/look for the generate spike train relative to config output directory if_file_exists: str, optional One of 'overwrite', 'abort', 'skip'. Control de behavior for the spike_train_cluster.npy. file If 'overwrite' it replaces the files if exists, if 'abort' it raises a ValueError exception if exists, if 'skip' it skips the operation if the file exists (and returns the stored file) save_results: bool, optional Whether to save spike train to disk (in CONFIG.data.root_folder/relative_to/spike_train_cluster.npy), defaults to False Returns ------- spike_train: (TODO add documentation) Examples -------- .. literalinclude:: ../../examples/pipeline/cluster.py """ logger = logging.getLogger(__name__) ######################## ### INITIALIZE ######### ######################## CONFIG = read_config() # get CONFIG2 for clustering # Cat: TODO: Edu said the CONFIG file can be passed as a dictionary CONFIG2 = make_CONFIG2(CONFIG) os.environ["CUDA_VISIBLE_DEVICES"] = str(CONFIG.resources.gpu_id) # output folder if not os.path.exists(output_directory): os.makedirs(output_directory) # data reader reader_raw = READER(fname_recording, recording_dtype, CONFIG, CONFIG.resources.n_sec_chunk_gpu_deconv, chunk_sec=CONFIG.clustering_chunk) if fname_residual is not None: reader_resid = READER(fname_residual, residual_dtype, CONFIG, CONFIG.resources.n_sec_chunk_gpu_deconv, chunk_sec=CONFIG.clustering_chunk) else: reader_resid = None # nn denoiser if CONFIG.neuralnetwork.apply_nn: # load NN denoiser denoiser = Denoise(CONFIG.neuralnetwork.denoise.n_filters, CONFIG.neuralnetwork.denoise.filter_sizes, CONFIG.spike_size_nn, CONFIG) denoiser.load(CONFIG.neuralnetwork.denoise.filename) denoiser = denoiser.cuda() else: denoiser = None # if the output exists and want to skip, just finish fname_templates_out = os.path.join(output_directory, 'templates.npy') fname_spike_train_out = os.path.join(output_directory, 'spike_train.npy') if not os.path.exists(fname_templates_out): # if clustering on clean waveforms, spike train is given # => make spike index and labels if fname_spike_index is None: savedir = os.path.join(output_directory, 'spike_index') if not os.path.exists(savedir): os.makedirs(savedir) (fname_spike_index, fname_labels_input) = make_spike_index_from_spike_train( fname_spike_train, fname_templates, savedir) else: # if we have spike_index, then we have no initial labels fname_labels_input = None ################################# #### STAGE 1: Cluster on PTP #### ################################# # keep track of input label because this is the deconv label # and it is necessary when making cleaned spikes logger.info("Split on PTP") (fname_spike_index, fname_labels, fname_labels_input) = run_split_on_ptp( os.path.join(output_directory, 'ptp_split'), fname_spike_index, CONFIG2, raw_data, fname_labels_input, fname_templates, fname_shifts, fname_scales, reader_raw, reader_resid, denoiser) ############################################ #### STAGE 2: LOCAL + DISTANT CLUSTERING ### ############################################ # load and align waveforms logger.info("load waveforms on local channels") units, fnames_input = load_waveforms( os.path.join(output_directory, 'input'), raw_data, fname_labels, fname_spike_index, fname_labels_input, fname_templates, fname_shifts, fname_scales, reader_raw, reader_resid, CONFIG2) if CONFIG.neuralnetwork.apply_nn: logger.info("NN denoise") # denoise it nn_denoise_wf(fnames_input, denoiser, CONFIG.torch_devices, CONFIG) else: logger.info("denoise") denoise_wf(fnames_input) #if raw_data: # align if raw data # no need to align for clean waveforms # because input shift is already used for alignment logger.info("align waveforms on local channels") align_waveforms(fnames_input, CONFIG2) # save location for intermediate results tmp_save_dir = os.path.join(output_directory, 'cluster_result') if not os.path.exists(tmp_save_dir): os.makedirs(tmp_save_dir) # Cat: TODO: this parallelization may not be optimally asynchronous # make arg list first args_in = [] for ctr, unit in enumerate(units): # check to see if chunk + channel already completed filename_postclustering = os.path.join( tmp_save_dir, "cluster_result_{}.npz".format(unit)) # skip if os.path.exists(filename_postclustering): continue args_in.append([ raw_data, full_run, CONFIG2, reader_raw, reader_resid, filename_postclustering, fnames_input[ctr] ]) logger.info("starting clustering") if CONFIG.resources.multi_processing: parmap.map(Cluster, args_in, processes=CONFIG.resources.n_processors, pm_pbar=True) else: with tqdm(total=len(args_in)) as pbar: for arg_in in args_in: Cluster(arg_in) pbar.update() # first gather clustering result fname_templates_out, fname_spike_train_out = gather_clustering_result( tmp_save_dir, output_directory) for fname in fnames_input: os.remove(fname) #check_long_temp = os.path.join(output_directory, 'long_template.npy') #if not os.path.exists(check_long_temp): # logger.info("get longer templates") # fname_templates_out = run_template_computation( # output_directory, # fname_spike_train_out, # reader_raw, # spike_size=CONFIG.spike_size, # multi_processing=CONFIG.resources.multi_processing, # n_processors=CONFIG.resources.n_processors) # np.save(check_long_temp, None) #check_low_fr_temp = os.path.join(output_directory, 'check_low_fr_template.npy') #if not os.path.exists(check_low_fr_temp): # if CONFIG.neuralnetwork.apply_nn: # # denoise wfs before computing templates for low fr units # logger.info("re-estimate templates of low firing rate units") # fname_templates_out = denoise_then_estimate_template( # fname_templates_out, # fname_spike_train_out, # reader_raw, # denoiser, # CONFIG, # n_max_spikes=100) # np.save(check_low_fr_temp, None) #check_sharpen = os.path.join(output_directory, 'check_sharpen.npy') #if not os.path.exists(check_sharpen): #fname_templates_aligned = os.path.join(output_directory, 'templates_aligned.npy') #if not os.path.exists(fname_templates_aligned): # logger.info("subsample template alignment") # fname_templates_out = sharpen_templates(fname_templates_out, # fname_templates_aligned) # zero-out edges #check_zero_out = os.path.join(output_directory, 'check_zero_out.npy') #if not os.path.exists(check_zero_out): # logger.info("zero out unnecessary parts") # fix_template_edges_by_file(fname_templates_out, # CONFIG.center_spike_size) # np.save(check_zero_out, None) return fname_templates_out, fname_spike_train_out
def run(fname_templates_in, output_directory, recordings_filename, recording_dtype, threshold=None, run_chunk_sec='full', save_up_data=True): """Deconvolute spikes Parameters ---------- spike_index_all: numpy.ndarray (n_data, 3) A 2D array for all potential spikes whose first column indicates the spike time and the second column the principal channels 3rd column indicates % confidence of cluster membership Note: can now have single events assigned to multiple templates templates: numpy.ndarray (n_channels, waveform_size, n_templates) A 3D array with the templates output_directory: str, optional Output directory (relative to CONFIG.data.root_folder) used to load the recordings to generate templates, defaults to tmp/ recordings_filename: str, optional Recordings filename (relative to CONFIG.data.root_folder/ output_directory) used to draw the waveforms from, defaults to standardized.bin Returns ------- spike_train: numpy.ndarray (n_clear_spikes, 2) A 2D array with the spike train, first column indicates the spike time and the second column the neuron ID Examples -------- .. literalinclude:: ../../examples/pipeline/deconvolute.py """ logger = logging.getLogger(__name__) CONFIG = read_config() CONFIG = make_CONFIG2(CONFIG) #print("... deconv using GPU device: ", torch.cuda.current_device()) # output folder if not os.path.exists(output_directory): os.makedirs(output_directory) fname_templates = os.path.join( output_directory, 'templates.npy') fname_spike_train = os.path.join( output_directory, 'spike_train.npy') fname_shifts = os.path.join( output_directory, 'shifts.npy') fname_scales = os.path.join( output_directory, 'scales.npy') if (os.path.exists(fname_templates) and os.path.exists(fname_spike_train) and os.path.exists(fname_shifts) and os.path.exists(fname_scales)): return (fname_templates, fname_spike_train, fname_shifts, fname_scales) # parameters if threshold is None: threshold = CONFIG.deconvolution.threshold elif threshold == 'low_fp': threshold = 150 if run_chunk_sec == 'full': chunk_sec = None else: chunk_sec = run_chunk_sec # reader reader = READER(recordings_filename, recording_dtype, CONFIG, CONFIG.resources.n_sec_chunk_gpu_deconv, chunk_sec=chunk_sec) # enforce broad buffer reader.buffer=1000 deconv_ONgpu(fname_templates_in, output_directory, reader, threshold, CONFIG, run_chunk_sec) return (fname_templates, fname_spike_train, fname_shifts, fname_scales)
def run(output_directory, fname_spike_train, fname_shifts, fname_scales, fname_templates, fname_soft_assignment, fname_residual, residual_dtype): logger = logging.getLogger(__name__) CONFIG = read_config() # output folder if not os.path.exists(output_directory): os.makedirs(output_directory) fname_spike_train_out = os.path.join(output_directory, 'spike_train.npy') fname_templates_out = os.path.join(output_directory, 'templates.npy') fname_soft_assignment_out = os.path.join(output_directory, 'soft_assignment.npy') fname_shifts_out = os.path.join(output_directory, 'shifts.npy') fname_scales_out = os.path.join(output_directory, 'scales.npy') if os.path.exists(fname_spike_train_out) and os.path.exists( fname_templates_out): return (fname_templates_out, fname_spike_train_out, fname_shifts_out, fname_scales_out, fname_soft_assignment_out) reader_residual = READER(fname_residual, residual_dtype, CONFIG) # get whitening filters fname_spatial_cov = os.path.join(output_directory, 'spatial_cov.npy') fname_temporal_cov = os.path.join(output_directory, 'temporal_cov.npy') if not (os.path.exists(fname_spatial_cov) and os.path.exists(fname_temporal_cov)): spatial_cov, temporal_cov = get_noise_covariance( reader_residual, CONFIG) np.save(fname_spatial_cov, spatial_cov) np.save(fname_temporal_cov, temporal_cov) else: spatial_cov = np.load(fname_spatial_cov) temporal_cov = np.load(fname_temporal_cov) # initialize merge: find candidates logger.info("finding merge candidates") tm = TemplateMerge(output_directory, reader_residual, fname_templates, fname_spike_train, fname_shifts, fname_scales, fname_soft_assignment, fname_spatial_cov, fname_temporal_cov, CONFIG.geom, CONFIG.resources.multi_processing, CONFIG.resources.n_processors) # find merge pairs logger.info("merging pairs") tm.get_merge_pairs() # update templates adn spike train accordingly logger.info("udpating templates and spike train") (templates_new, spike_train_new, shifts_new, scales_new, soft_assignment_new, merge_array) = tm.merge_units() # save results fname_merge_array = os.path.join(output_directory, 'merge_array.npy') np.save(fname_merge_array, merge_array) np.save(fname_spike_train_out, spike_train_new) np.save(fname_templates_out, templates_new) np.save(fname_shifts_out, shifts_new) np.save(fname_scales_out, scales_new) np.save(fname_soft_assignment_out, soft_assignment_new) logger.info('Number of units after merge: {}'.format( templates_new.shape[0])) return (fname_templates_out, fname_spike_train_out, fname_shifts_out, fname_scales_out, fname_soft_assignment_out)
def run_cleaned_template_computation(out_dir, fname_spike_train, fname_templates, fname_shifts, fname_scales, fname_residual_recording, dtype_residual_recording, CONFIG, unit_ids=None): logger = logging.getLogger(__name__) logger.info("computing templates from cleaned spikes") # make output folder if not os.path.exists(out_dir): os.makedirs(out_dir) #fname_templates = os.path.join(out_dir, 'templates.npy') #if os.path.exists(fname_templates): # return fname_templates # make temp folder tmp_folder = os.path.join(out_dir, 'tmp_template') if not os.path.exists(tmp_folder): os.makedirs(tmp_folder) # get number of units n_units, n_times, n_channels = np.load(fname_templates).shape if unit_ids is None: unit_ids = np.arange(n_units) reader_residual = READER(fname_residual_recording, dtype_residual_recording, CONFIG) # run computing function if CONFIG.resources.multi_processing: n_processors = CONFIG.resources.n_processors unit_ids_partition = [] for j in range(n_processors): unit_ids_partition.append(unit_ids[slice(j, len(unit_ids), n_processors)]) parmap.map(run_cleaned_template_computation_parallel, unit_ids_partition, tmp_folder, fname_spike_train, fname_templates, fname_shifts, fname_scales, reader_residual, pm_processes=n_processors, pm_pbar=True) else: run_cleaned_template_computation_parallel(unit_ids, tmp_folder, fname_spike_train, fname_templates, fname_shifts, fname_scales, reader_residual) # gather all info templates_new = np.zeros((n_units, n_times, n_channels), 'float32') for unit in unit_ids: fname_out = os.path.join(tmp_folder, 'unit_{}.npy'.format(unit)) templates_new[unit] = np.load(fname_out) fname_templates = os.path.join(out_dir, 'templates.npy') np.save(fname_templates, templates_new) return fname_templates