def main(argv=None): if argv is None: argv = sys.argv[1:] header = get_colored_header() header += '''Utility to split results obtained when using N streams into N individual results files, one per stream ''' parser = argparse.ArgumentParser( description=header, formatter_class=argparse.RawTextHelpFormatter) parser.add_argument('datafile', help='data file') parser.add_argument('-e', '--extension', help='extension to consider for slicing results', default='') if len(argv) == 0: parser.print_help() sys.exit() args = parser.parse_args(argv) filename = os.path.abspath(args.datafile) extension = args.extension if extension != '': extension = '-' + extension params = CircusParser(filename) if os.path.exists(params.logfile): os.remove(params.logfile) _ = init_logging(params.logfile) logger = logging.getLogger(__name__) file_out_suff = params.get('data', 'file_out_suff') if params.get('data', 'stream_mode') in ['None', 'none']: print_and_log(['No streams in the datafile!'], 'error', logger) sys.exit(1) data_file = params.get_data_file() result = circus.shared.files.get_results(params, extension=extension) times = [] for source in data_file._sources: times += [[source.t_start, source.t_stop]] sub_results = slice_result(result, times) for count, result in enumerate(sub_results): keys = ['spiketimes', 'amplitudes'] mydata = h5py.File(file_out_suff + '.result%s_%d.hdf5' % (extension, count), 'w', libver='earliest') for key in keys: mydata.create_group(key) for temp in result[key].keys(): tmp_path = '%s/%s' % (key, temp) mydata.create_dataset(tmp_path, data=result[key][temp]) mydata.close()
def main(params, nb_cpu, nb_gpu, use_gpu, extension): _ = init_logging(params.logfile) logger = logging.getLogger('circus.merging') file_out_suff = params.get('data', 'file_out_suff') erase_all = params.getboolean('merging', 'erase_all') extension_in = extension extension_out = '-merged' # Erase previous results (if user agrees). if comm.rank == 0: existing_file_paths = [ file_path for file_path in [ file_out_suff + ".%s%s.hdf5" % (file_id, extension_out) for file_id in ['templates', 'clusters', 'result'] ] if os.path.isfile(file_path) ] existing_directory_path = [ directory_path for directory_path in [file_out_suff + "%s.GUI" % extension_out] if os.path.isdir(directory_path) ] if len(existing_file_paths) > 0 or len(existing_directory_path) > 0: if not erase_all: erase = query_yes_no( "Merging already done! Do you want to erase previous merging results?", default=None) else: erase = True if erase: for path in existing_file_paths: os.remove(path) if comm.rank == 0: print_and_log(["Removed file %s" % path], 'debug', logger) for path in existing_directory_path: shutil.rmtree(path) if comm.rank == 0: print_and_log(["Removed directory %s" % path], 'debug', logger) comm.Barrier() if comm.rank == 0 and params.getfloat('merging', 'auto_mode') == 0: app = QApplication([]) try: pylab.style.use('ggplot') except Exception: pass else: app = None if comm.rank == 0: print_and_log(['Launching the merging GUI...'], 'debug', logger) _ = gui.MergeWindow(params, app, extension_in, extension_out) sys.exit(app.exec_())
def main(params, nb_cpu, nb_gpu, use_gpu): ################################################################# # params = detect_memory(params) _ = init_logging(params.logfile) SHARED_MEMORY = get_shared_memory_flag(params) logger = logging.getLogger('circus.fitting') data_file = params.data_file N_e = params.getint('data', 'N_e') N_total = params.nb_channels N_t = params.getint('detection', 'N_t') template_shift = params.getint('detection', 'template_shift') file_out = params.get('data', 'file_out') file_out_suff = params.get('data', 'file_out_suff') sign_peaks = params.get('detection', 'peaks') dist_peaks = params.getint('detection', 'dist_peaks') matched_filter = params.getboolean('detection', 'matched-filter') spike_thresh = params.getfloat('detection', 'spike_thresh') spike_width = params.getfloat('detection', 'spike_width') do_temporal_whitening = params.getboolean('whitening', 'temporal') do_spatial_whitening = params.getboolean('whitening', 'spatial') chunk_size = detect_memory(params) gpu_only = params.getboolean('fitting', 'gpu_only') nodes, edges = get_nodes_and_edges(params) tmp_limits = params.get('fitting', 'amp_limits').replace('(', '').replace(')', '').split(',') tmp_limits = map(float, tmp_limits) amp_auto = params.getboolean('fitting', 'amp_auto') nb_chances = params.getint('fitting', 'nb_chances') max_chunk = params.getfloat('fitting', 'max_chunk') noise_thr = params.getfloat('clustering', 'noise_thr') collect_all = params.getboolean('fitting', 'collect_all') ignore_dead_times = params.getboolean('triggers', 'ignore_times') inv_nodes = numpy.zeros(N_total, dtype=numpy.int32) inv_nodes[nodes] = numpy.arange(len(nodes)) data_file.open() ################################################################# if use_gpu: import cudamat as cmt # # Need to properly handle multi GPU per MPI nodes? if nb_gpu > nb_cpu: gpu_id = int(comm.rank // nb_cpu) else: gpu_id = 0 cmt.cuda_set_device(gpu_id) cmt.init() cmt.cuda_sync_threads() if matched_filter: if sign_peaks in ['negative', 'both']: waveform_neg = io.load_data(params, 'waveform')[::-1] waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) * len(waveform_neg)) matched_tresholds_neg = io.load_data(params, 'matched-thresholds') if sign_peaks in ['positive', 'both']: waveform_pos = io.load_data(params, 'waveform-pos')[::-1] waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) * len(waveform_pos)) matched_tresholds_pos = io.load_data(params, 'matched-thresholds-pos') if ignore_dead_times: all_dead_times = get_dead_times(params) thresholds = io.load_data(params, 'thresholds') comm.Barrier() if comm.rank == 0: print_and_log(["Extracting MUA activity..."], 'default', logger) purge(file_out_suff, '.data') if do_spatial_whitening: spatial_whitening = io.load_data(params, 'spatial_whitening') else: spatial_whitening = None # default assignment (PyCharm code inspection) if do_temporal_whitening: temporal_whitening = io.load_data(params, 'temporal_whitening') else: temporal_whitening = None # default assignment (PyCharm code inspection) nb_chunks, last_chunk_len = data_file.analyze(chunk_size) processed_chunks = int(min(nb_chunks, max_chunk)) comm.Barrier() spiketimes_file = open(file_out_suff + '.mua-%d.data' % comm.rank, 'wb') comm.Barrier() electrodes_file = open(file_out_suff + '.elec-%d.data' % comm.rank, 'wb') comm.Barrier() amplitudes_file = open(file_out_suff + '.amp-%d.data' % comm.rank, 'wb') comm.Barrier() if use_gpu and do_spatial_whitening: spatial_whitening = cmt.CUDAMatrix(spatial_whitening, copy_on_host=False) to_explore = range(comm.rank, processed_chunks, comm.size) if comm.rank == 0: to_explore = get_tqdm_progressbar(to_explore) for gcount, gidx in enumerate(to_explore): # print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..." # # We need to deal with the borders by taking chunks of size [0, chunck_size + template_shift]. is_first = data_file.is_first_chunk(gidx, nb_chunks) is_last = data_file.is_last_chunk(gidx, nb_chunks) if is_last: padding = (-dist_peaks, 0) elif is_first: padding = (0, dist_peaks) else: padding = (-dist_peaks, dist_peaks) result = {'spiketimes': [], 'amplitudes': [], 'templates': []} local_chunk, t_offset = data_file.get_data(gidx, chunk_size, padding, nodes=nodes) len_chunk = len(local_chunk) if do_spatial_whitening: if use_gpu: local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False) local_chunk = local_chunk.dot(spatial_whitening).asarray() else: local_chunk = numpy.dot(local_chunk, spatial_whitening) if do_temporal_whitening: local_chunk = scipy.ndimage.filters.convolve1d(local_chunk, temporal_whitening, axis=0, mode='constant') # print "Extracting the peaks..." local_peaktimes = [numpy.zeros(0, dtype=numpy.uint32)] local_elecs = [numpy.zeros(0, dtype=numpy.uint32)] local_amps = [numpy.zeros(0, dtype=numpy.float32)] if matched_filter: if sign_peaks in ['positive', 'both']: filter_chunk = scipy.ndimage.filters.convolve1d( local_chunk, waveform_pos, axis=0, mode='constant') for i in range(N_e): peaktimes = scipy.signal.find_peaks( filter_chunk[:, i], height=matched_tresholds_pos[i], width=spike_width, distance=dist_peaks, wlen=N_t)[0] local_peaktimes.append(peaktimes) local_elecs.append( i * numpy.ones(len(peaktimes), dtype='uint32')) local_amps.append(filter_chunk[peaktimes, i]) if sign_peaks in ['negative', 'both']: filter_chunk = scipy.ndimage.filters.convolve1d( local_chunk, waveform_neg, axis=0, mode='constant') for i in range(N_e): peaktimes = scipy.signal.find_peaks( filter_chunk[:, i], height=matched_tresholds_neg[i], width=spike_width, distance=dist_peaks, wlen=N_t)[0] local_peaktimes.append(peaktimes) local_elecs.append( i * numpy.ones(len(peaktimes), dtype='uint32')) local_amps.append(filter_chunk[peaktimes, i]) else: for i in range(N_e): if sign_peaks == 'negative': peaktimes = scipy.signal.find_peaks(-local_chunk[:, i], height=thresholds[i], width=spike_width, distance=dist_peaks, wlen=N_t)[0] elif sign_peaks == 'positive': peaktimes = scipy.signal.find_peaks(local_chunk[:, i], height=thresholds[i], width=spike_width, distance=dist_peaks, wlen=N_t)[0] elif sign_peaks == 'both': peaktimes = scipy.signal.find_peaks(numpy.abs( local_chunk[:, i]), height=thresholds[i], width=spike_width, distance=dist_peaks, wlen=N_t)[0] local_peaktimes.append(peaktimes) local_elecs.append(i * numpy.ones(len(peaktimes), dtype='uint32')) local_amps.append(local_chunk[peaktimes, i]) local_peaktimes = numpy.concatenate(local_peaktimes) local_elecs = numpy.concatenate(local_elecs) local_amps = numpy.concatenate(local_amps) g_offset = t_offset + padding[0] if ignore_dead_times: dead_indices = numpy.searchsorted( all_dead_times, [t_offset, t_offset + chunk_size]) if dead_indices[0] != dead_indices[1]: is_included = numpy.in1d( local_peaktimes + g_offset, all_dead_times[dead_indices[0]:dead_indices[1]]) local_peaktimes = local_peaktimes[~is_included] local_elecs = local_elecs[~is_included] local_amps = local_amps[~is_included] # print "Removing the useless borders..." local_borders = (dist_peaks, len_chunk - dist_peaks) idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes < local_borders[1]) local_peaktimes = numpy.compress(idx, local_peaktimes) + g_offset local_elecs = numpy.compress(idx, local_elecs) local_amps = numpy.compress(idx, local_amps) spiketimes_file.write(local_peaktimes.astype(numpy.uint32).tostring()) electrodes_file.write(local_elecs.tostring()) amplitudes_file.write(local_amps.tostring()) sys.stderr.flush() spiketimes_file.flush() os.fsync(spiketimes_file.fileno()) spiketimes_file.close() electrodes_file.flush() os.fsync(electrodes_file.fileno()) electrodes_file.close() amplitudes_file.flush() os.fsync(amplitudes_file.fileno()) amplitudes_file.close() comm.Barrier() if comm.rank == 0: io.collect_mua(comm.size, params, erase=True) data_file.close()
def main(params, nb_cpu, nb_gpu, use_gpu, extension): assert comm.rank == 0 input_extension = extension logger = init_logging(params.logfile) logger = logging.getLogger('circus.deconverting') # Retrieve parameters. input_path = params.get('data', 'file_out_suff') + input_extension + '.GUI' output_path = params.get('data', 'file_out_suff') output_extension = '-deconverted' clusters_path = output_path + '.clusters{}.hdf5'.format(output_extension) templates_path = output_path + '.templates{}.hdf5'.format(output_extension) result_path = output_path + '.result{}.hdf5'.format(output_extension) # Check if input path exists. if not os.path.isdir(input_path): print_and_log([ "Can't find directory: {}".format(input_path), "You must first export your results into the phy format (use the converting method)." ], 'error', logger) sys.exit(0) # Check if results are already present. if os.path.isfile(clusters_path) \ and os.path.isfile(templates_path) \ and os.path.isfile(result_path): print_and_log([ "Phy results already imported.", "Delete the following files to run the deconversion again:", " - {}".format(result_path), " - {}".format(templates_path), " - {}".format(clusters_path), ], 'error', logger) if query_yes_no("Do you want to delete these files?"): os.remove(result_path) os.remove(templates_path) os.remove(clusters_path) else: sys.exit(0) # Check if results are partially present. if os.path.isfile(clusters_path) \ or os.path.isfile(templates_path) \ or os.path.isfile(result_path): print_and_log([ "Phy results partially imported.", "Delete the following files to be able to run the deconversion:", ] + [ " - {}".format(path) for path in [result_path, templates_path, clusters_path] if os.path.isfile(path) ], 'error', logger) if query_yes_no("Do you want to delete these files?"): for path in [result_path, templates_path, clusters_path]: if os.path.isfile(path): os.remove(path) else: sys.exit(0) # Log introduction message. message = "Importing data from the phy GUI with {} CPUs...".format(nb_cpu) print_and_log([message], 'info', logger) # Read phy results. phy_results = load_phy_results(input_path) spike_templates = phy_results['spike_templates'] spike_clusters = phy_results['spike_clusters'] cluster_group = phy_results['cluster_group'] templates = phy_results['templates'] # print_and_log(["{}".format(phy_results)], 'debug', logger) # Read spyking-circus results. templates_input_path = output_path + ".templates{}.hdf5".format( input_extension) templates_input_file = h5py.File(templates_input_path, mode='r', libver='earliest') overlaps = templates_input_file.get('maxoverlap')[:] try: lag = templates_input_file.get('maxlag')[:] except TypeError: # i.e. 'maxlag' not in HDF5 file. lag = np.zeros(overlaps.shape, dtype=np.int32) shape = templates_input_file.get('temp_shape')[:] # Set correct lag option. correct_lag = True # Determine association map between spike templates and spike clusters. templates_to_clusters = {} for spike_template, spike_cluster in zip(spike_templates, spike_clusters): templates_to_clusters[spike_template] = spike_cluster # Determine association map between spike cluster and default spike template. clusters_to_templates = {} for spike_cluster in cluster_group: clusters_to_templates[spike_cluster] = None # # TODO remove templates without spikes. # print_and_log([ # TODO remove. # "Removing templates without spikes..." # ], 'info', logger) # electrodes = io.load_data(params, 'electrodes', extension=input_extension) # TODO remove duplicate. # clusters = io.load_data(params, 'clusters', extension=input_extension) # TODO remove duplicate # for spike_template, _ in templates_to_clusters.items(): # # Retrieve the prefered electrode. # elec_ic = electrodes[spike_template] # # Retrieve template index among templates with same prefered electrodeself. # first_index = np.where(electrodes == elec_ic)[0][0] # nic = spike_template - first_index # # Retrieve the cluster label. # label = 'clusters_{}'.format(elec_ic) # # Select the points labelled by the clustering. # mask = clusters[label] > -1 # # Retrieve the labels used by the clustering. # tmp = np.unique(clusters[label][mask]) # # Retrieve the number of points labelled for both templates. # cluster_label = tmp[nic] # elements = np.where(clusters[label] == cluster_label)[0] # # ... # if len(elements) == 0: # print_and_log([ # "template {} has no spike".format(spike_template) # ], 'info', logger) # raise NotImplementedError to_merge = [] to_remove = [] # Do all the merges. old_results = io.load_data(params, 'results', extension=input_extension) electrodes = io.load_data(params, 'electrodes', extension=input_extension) clusters = io.load_data(params, 'clusters', extension=input_extension) for spike_template, spike_cluster in templates_to_clusters.items(): spike_group = cluster_group[spike_cluster] if spike_group in ['good', 'unsorted']: if clusters_to_templates[spike_cluster] is None: clusters_to_templates[spike_cluster] = spike_template else: # Retrieve pair of templates to merge. default_spike_template = clusters_to_templates[spike_cluster] one_merge = [default_spike_template, spike_template] # Retrieve the prefered electrode for both template. elec_ic1 = electrodes[one_merge[0]] elec_ic2 = electrodes[one_merge[1]] # Retrieve template index among templates with same prefered electrode for both templates. first_index1 = np.where(electrodes == elec_ic1)[0][0] first_index2 = np.where(electrodes == elec_ic2)[0][0] nic1 = one_merge[0] - first_index1 nic2 = one_merge[1] - first_index2 # Retrieve the cluster label for both templates. label1 = 'clusters_{}'.format(elec_ic1) label2 = 'clusters_{}'.format(elec_ic2) # Select the points labelled by the clustering for both templates. mask1 = clusters[label1] > -1 mask2 = clusters[label2] > -1 # Retrieve the labels used by the clustering for both templates. tmp1 = np.unique(clusters[label1][mask1]) tmp2 = np.unique(clusters[label2][mask2]) # Retrieve the number of points labelled for both templates. cluster_label1 = tmp1[nic1] cluster_label2 = tmp2[nic2] elements1 = np.where(clusters[label1] == cluster_label1)[0] elements2 = np.where(clusters[label2] == cluster_label2)[0] # Determine index to keep and index to delete. if len(elements1) > len(elements2): to_delete = one_merge[1] to_keep = one_merge[0] elec = elec_ic2 elements = elements2 else: to_delete = one_merge[0] to_keep = one_merge[1] elec = elec_ic1 elements = elements1 # print_and_log([ # "one_merge: {}".format(one_merge), # "elec_ic1, elec_ic2: {}, {}".format(elec_ic1, elec_ic2), # "nic1, nic2: {}, {}".format(nic1, nic2), # "tmp1, tmp2: {}, {}".format(tmp1, tmp2), # "cluster_label1, cluster_label2: {}, {}".format(cluster_label1, cluster_label2), # "to_keep, to_delete: {}, {}".format(to_keep, to_delete) # ] , 'debug', logger # ) # Merge templates (if necessary). if to_keep != to_delete: key1 = 'temp_{}'.format(to_keep) key2 = 'temp_{}'.format(to_delete) amplitudes1 = old_results['amplitudes'][key1] amplitudes2 = old_results['amplitudes'][key2] old_results['amplitudes'][key1] = np.vstack( (amplitudes1, amplitudes2)) spiketimes1 = old_results['spiketimes'][key1] spiketimes2 = old_results['spiketimes'][key2] if correct_lag: spiketimes2 = spiketimes2.astype(np.int64) spiketimes2 += lag[to_keep, to_delete] spiketimes2 = spiketimes2.astype(np.uint32) old_results['spiketimes'][key1] = np.concatenate( (spiketimes1, spiketimes2)) indices = np.argsort(old_results['spiketimes'][key1]) old_results['amplitudes'][key1] = old_results[ 'amplitudes'][key1][indices] old_results['spiketimes'][key1] = old_results[ 'spiketimes'][key1][indices] old_results['amplitudes'].pop(key2) old_results['spiketimes'].pop(key2) # Update internal variables. clusters_to_templates[spike_cluster] = to_keep to_merge.append((to_keep, to_delete)) elif spike_group in ['mua', 'noise']: to_remove.append(spike_template) else: message = "Unexpected group value: {}".format(spike_group) raise ValueError(message) # Remove unmentioned templates (e.g. without any fitted spike). old_templates = load_data(params, 'templates', extension=input_extension) initial_nb_templates = old_templates.shape[1] // 2 all_spike_templates = set(range(0, initial_nb_templates)) mentioned_spike_templates = set(templates_to_clusters.keys()) unmentioned_spike_templates = list(all_spike_templates - mentioned_spike_templates) # print_and_log(["unmentioned templates: {}".format(unmentioned_spike_templates)], 'info', logger) to_remove.extend(unmentioned_spike_templates) if to_merge == []: to_merge = np.zeros((0, 2), dtype=np.int) else: to_merge = np.array(to_merge) to_merge = to_merge[np.lexsort((to_merge[:, 1], to_merge[:, 0])), :] to_remove.sort() # Log some information. nb_merges = to_merge.shape[0] nb_removals = len(to_remove) final_nb_templates = initial_nb_templates - nb_merges - nb_removals print_and_log([ "Manual sorting with the Python GUI (i.e. phy):", " initial number of templates: {}".format(initial_nb_templates), " number of merges: {}".format(nb_merges), " number of removals: {}".format(nb_removals), " final number of templates: {}".format(final_nb_templates), ], 'info', logger) # Slice templates. to_keep = slice_templates(params, to_merge=to_merge, to_remove=to_remove, extension=output_extension, input_extension=extension) # print_and_log([ # "to_merge (passed to slice_templates: {}".format(to_merge), # "to_remove (passed to slice_templates: {}".format(to_remove), # "to_keep (returned be slice_templates): {}".format(to_keep), # ], 'info', logger) # Slice clusters. light = True dataname = 'clusters-light' if light else 'clusters' clusters = io.load_data(params, dataname, extension=extension) slice_clusters(params, clusters, to_merge=to_merge, to_remove=to_remove, extension=output_extension, input_extension=extension, light=light, method='new') # Finalize result. # nb_templates = templates.shape[0] # template_indices = np.arange(0, nb_templates) # to_delete = list(to_remove) # for one_merge in to_merge: # to_delete.append(one_merge[1]) # to_keep = set(np.unique(template_indices)) - set(to_delete) # to_keep = np.array(list(to_keep)) # to_keep = np.sort(to_keep) # # Would be correct if we could sort 'to_keep' in 'slice_templates'. # print_and_log([ # "to_keep: {}".format(to_keep) # ], 'idebug', logger) new_results = { 'spiketimes': {}, 'amplitudes': {}, } for k, template_index in enumerate(to_keep): old_key = 'temp_{}'.format(template_index) new_key = 'temp_{}'.format(k) new_results['spiketimes'][new_key] = old_results['spiketimes'].pop( old_key) new_results['amplitudes'][new_key] = old_results['amplitudes'].pop( old_key) # Check if the number of spikes is not equal to 0. nb_spikes = len(new_results['spiketimes'][new_key]) if nb_spikes == 0: print_and_log( ["{} - template {} has no spikes".format(k, template_index)], 'error', logger) keys = ['spiketimes', 'amplitudes'] # TODO add support for [fitting] collect_all=True (not supported). # Save new result to output file. result_output_path = output_path + ".result{}.hdf5".format( output_extension) result_output_file = h5py.File(result_output_path, mode='w', libver='earliest') for key in keys: result_output_file.create_group(key) for temp in new_results[key].keys(): tmp_path = "{}/{}".format(key, temp) result_output_file.create_dataset(tmp_path, data=new_results[key][temp]) result_output_file.close() # Add additional information to templates file. templates_output_path = output_path + ".templates{}.hdf5".format( output_extension) templates_output_file = h5py.File(templates_output_path, mode='r+', libver='earliest') new_shape = (len(to_keep), len(to_keep)) version = templates_output_file.create_dataset( 'version', data=np.array(circus.__version__.split('.'), dtype=np.int32)) maxoverlaps = templates_output_file.create_dataset('maxoverlap', shape=new_shape, dtype=np.float32) maxlag = templates_output_file.create_dataset('maxlag', shape=new_shape, dtype=np.int32) for k, index in enumerate(to_keep): maxoverlaps[k, :] = overlaps[index, to_keep] maxlag[k, :] = lag[index, to_keep] templates_output_file.close() # Log conclusion message. message = "Data from the phy GUI imported." print_and_log([message], 'info', logger) return
def main(argv=None): if argv is None: argv = sys.argv[1:] header = get_colored_header() header += '''Utility to launch the phy GUI and visualize the results. [data must be first converted with the converting mode] ''' parser = argparse.ArgumentParser( description=header, formatter_class=argparse.RawTextHelpFormatter) parser.add_argument('datafile', help='data file') parser.add_argument('-e', '--extension', help='extension to consider for visualization', default='') if len(argv) == 0: parser.print_help() sys.exit() args = parser.parse_args(argv) filename = os.path.abspath(args.datafile) extension = args.extension params = CircusParser(filename) if os.path.exists(params.logfile): os.remove(params.logfile) logger = init_logging(params.logfile) logger = logging.getLogger(__name__) if extension != '': extension = '-' + extension try: import traitlets except ImportError: print_and_log( ['The package traitlets required by phy is not installed'], 'error', logger) sys.exit(1) try: import click except ImportError: print_and_log(['The package click required by phy is not installed'], 'error', logger) sys.exit(1) try: import joblib except ImportError: print_and_log(['The package joblib required by phy is not installed'], 'error', logger) sys.exit(1) if HAVE_PHYCONTRIB: mytest = StrictVersion( phycontrib.__version__) >= StrictVersion("1.0.12") if not mytest: print_and_log( ['You need to update phy-contrib to the latest git version'], 'error', logger) sys.exit(1) print_and_log([ 'phy-contrib is deprecated, you should upgrade to phy 2.0 and phylib' ], 'info', logger) if HAVE_PHYLIB: try: import colorcet except ImportError: print_and_log( ['The package colorcet required by phy is not installed'], 'error', logger) sys.exit(1) try: import qtconsole except ImportError: print_and_log( ['The package qtconsole required by phy is not installed'], 'error', logger) sys.exit(1) if not test_patch_for_similarities(params, extension): print_and_log( ['You should re-export the data because of a fix in 0.6'], 'error', logger) continue_anyway = query_yes_no( Fore.WHITE + "Continue anyway (results may not be fully correct)?", default=None) if not continue_anyway: sys.exit(1) data_file = params.get_data_file() data_dtype = data_file.data_dtype if data_file.params.has_key('data_offset'): data_offset = data_file.data_offset else: data_offset = 0 file_format = data_file.description file_out_suff = params.get('data', 'file_out_suff') if file_format not in supported_by_phy: print_and_log([ "File format %s is not supported by phy. TraceView disabled" % file_format ], 'info', logger) if numpy.iterable(data_file.gain): print_and_log( ['Multiple gains are not supported, using a default value of 1'], 'info', logger) gain = 1 else: if data_file.gain != 1: print_and_log([ "Gain of %g is not supported by phy. Expecting a scaling mismatch" % data_file.gain ], 'info', logger) gain = data_file.gain probe = params.probe output_path = params.get('data', 'file_out_suff') + extension + '.GUI' if not os.path.exists(output_path): print_and_log( ['Data should be first exported with the converting method!'], 'error', logger) else: print_and_log(["Launching the phy GUI..."], 'info', logger) gui_params = {} if file_format in supported_by_phy: if not params.getboolean('data', 'overwrite'): gui_params['dat_path'] = r"%s" % params.get( 'data', 'data_file_no_overwrite') else: if params.get('data', 'stream_mode') == 'multi-files': data_file = params.get_data_file(source=True, has_been_created=False) gui_params['dat_path'] = [ r"%s" % f for f in data_file.get_file_names() ] else: gui_params['dat_path'] = r"%s" % params.get( 'data', 'data_file') else: gui_params['dat_path'] = 'giverandomname.dat' gui_params['n_channels_dat'] = params.nb_channels gui_params['n_features_per_channel'] = 5 gui_params['dtype'] = data_dtype gui_params['offset'] = data_offset gui_params['sample_rate'] = params.rate gui_params['dir_path'] = output_path gui_params['hp_filtered'] = True os.chdir(output_path) create_app() controller = TemplateController(**gui_params) gui = controller.create_gui() gui.show() run_app() gui.close() del gui
def main(params, nb_cpu, nb_gpu, use_gpu): ################################################################# #params = detect_memory(params) logger = init_logging(params.logfile) SHARED_MEMORY = get_shared_memory_flag(params) logger = logging.getLogger('circus.fitting') data_file = params.data_file data_file.open() N_e = params.getint('data', 'N_e') N_total = params.nb_channels N_t = params.getint('detection', 'N_t') template_shift = params.getint('detection', 'template_shift') file_out = params.get('data', 'file_out') file_out_suff = params.get('data', 'file_out_suff') sign_peaks = params.get('detection', 'peaks') matched_filter = params.getboolean('detection', 'matched-filter') spike_thresh = params.getfloat('detection', 'spike_thresh') do_temporal_whitening = params.getboolean('whitening', 'temporal') do_spatial_whitening = params.getboolean('whitening', 'spatial') chunk_size = params.getint('fitting', 'chunk_size') gpu_only = params.getboolean('fitting', 'gpu_only') nodes, edges = get_nodes_and_edges(params) tmp_limits = params.get('fitting', 'amp_limits').replace('(', '').replace(')', '').split(',') tmp_limits = map(float, tmp_limits) amp_auto = params.getboolean('fitting', 'amp_auto') nb_chances = params.getint('fitting', 'nb_chances') max_chunk = params.getfloat('fitting', 'max_chunk') noise_thr = params.getfloat('clustering', 'noise_thr') collect_all = params.getboolean('fitting', 'collect_all') ignore_dead_times = params.getboolean('triggers', 'ignore_times') inv_nodes = numpy.zeros(N_total, dtype=numpy.int32) inv_nodes[nodes] = numpy.argsort(nodes) ################################################################# if use_gpu: import cudamat as cmt ## Need to properly handle multi GPU per MPI nodes? if nb_gpu > nb_cpu: gpu_id = int(comm.rank // nb_cpu) else: gpu_id = 0 cmt.cuda_set_device(gpu_id) cmt.init() cmt.cuda_sync_threads() if SHARED_MEMORY: templates = io.load_data_memshared(params, 'templates', normalize=True, transpose=True) N_tm, x = templates.shape else: templates = io.load_data(params, 'templates') x, N_tm = templates.shape temp_2_shift = 2 * template_shift full_gpu = use_gpu and gpu_only n_tm = N_tm // 2 n_scalar = N_e * N_t temp_window = numpy.arange(-template_shift, template_shift + 1) size_window = N_e * (2 * template_shift + 1) if not amp_auto: amp_limits = numpy.zeros((n_tm, 2)) amp_limits[:, 0] = tmp_limits[0] amp_limits[:, 1] = tmp_limits[1] else: amp_limits = io.load_data(params, 'limits') norm_templates = io.load_data(params, 'norm-templates') if not SHARED_MEMORY: for idx in xrange(templates.shape[1]): myslice = numpy.arange(templates.indptr[idx], templates.indptr[idx + 1]) templates.data[myslice] /= norm_templates[idx] templates = templates.T if matched_filter: if sign_peaks in ['negative', 'both']: waveform_neg = io.load_data(params, 'waveform') waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) * len(waveform_neg)) matched_tresholds_neg = io.load_data(params, 'matched-thresholds') if sign_peaks in ['positive', 'both']: waveform_pos = io.load_data(params, 'waveform-pos') waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) * len(waveform_pos)) matched_tresholds_pos = io.load_data(params, 'matched-thresholds-pos') if ignore_dead_times: all_dead_times = get_dead_times(params) thresholds = io.load_data(params, 'thresholds') if collect_all: neighbors = {} for i in xrange(n_tm): tmp = templates[i, :].toarray().reshape(N_e, N_t) * norm_templates[i] neighbors[i] = numpy.where(numpy.sum(tmp, 1) != 0)[0] if use_gpu: templates = cmt.SparseCUDAMatrix(templates, copy_on_host=False) info_string = '' if comm.rank == 0: if use_gpu: info_string = "using %d GPUs" % (comm.size) else: info_string = "using %d CPUs" % (comm.size) comm.Barrier() c_overlap = io.get_overlaps(params, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu) over_shape = c_overlap.get('over_shape')[:] N_over = int(numpy.sqrt(over_shape[0])) S_over = over_shape[1] ## If the number of overlaps is different from templates, we need to recompute them if N_over != N_tm: if comm.rank == 0: print_and_log( ['Templates have been modified, recomputing the overlaps...'], 'default', logger) c_overlap = io.get_overlaps(params, erase=True, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu) over_shape = c_overlap.get('over_shape')[:] N_over = int(numpy.sqrt(over_shape[0])) S_over = over_shape[1] if SHARED_MEMORY: c_overs = io.load_data_memshared(params, 'overlaps') else: c_overs = io.load_data(params, 'overlaps') comm.Barrier() if n_tm == 0: if comm.rank == 0: print_and_log(["No templates present. Redo clustering?"], 'default', logger) sys.exit(0) if comm.rank == 0: print_and_log([ "Here comes the SpyKING CIRCUS %s and %d templates..." % (info_string, n_tm) ], 'default', logger) purge(file_out_suff, '.data') if do_spatial_whitening: spatial_whitening = io.load_data(params, 'spatial_whitening') if do_temporal_whitening: temporal_whitening = io.load_data(params, 'temporal_whitening') if full_gpu: try: # If memory on the GPU is large enough, we load the overlaps onto it for i in xrange(N_over): c_overs[i] = cmt.SparseCUDAMatrix(c_overs[i], copy_on_host=False) except Exception: if comm.rank == 0: print_and_log([ "Not enough memory on GPUs: GPUs are used for projection only" ], 'info', logger) for i in xrange(N_over): if c_overs.has_key(i): del c_overs[i] full_gpu = False nb_chunks, last_chunk_len = data_file.analyze(chunk_size) processed_chunks = int(min(nb_chunks, max_chunk)) comm.Barrier() spiketimes_file = open(file_out_suff + '.spiketimes-%d.data' % comm.rank, 'wb') comm.Barrier() amplitudes_file = open(file_out_suff + '.amplitudes-%d.data' % comm.rank, 'wb') comm.Barrier() templates_file = open(file_out_suff + '.templates-%d.data' % comm.rank, 'wb') comm.Barrier() if collect_all: garbage_times_file = open( file_out_suff + '.gspiketimes-%d.data' % comm.rank, 'wb') comm.Barrier() garbage_temp_file = open( file_out_suff + '.gtemplates-%d.data' % comm.rank, 'wb') comm.Barrier() if use_gpu and do_spatial_whitening: spatial_whitening = cmt.CUDAMatrix(spatial_whitening, copy_on_host=False) last_chunk_size = 0 to_explore = xrange(comm.rank, processed_chunks, comm.size) if comm.rank == 0: to_explore = get_tqdm_progressbar(to_explore) for gcount, gidx in enumerate(to_explore): #print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..." ## We need to deal with the borders by taking chunks of size [0, chunck_size+template_shift] is_first = data_file.is_first_chunk(gidx, nb_chunks) is_last = data_file.is_last_chunk(gidx, nb_chunks) if is_last: padding = (-temp_2_shift, 0) elif is_first: padding = (0, temp_2_shift) else: padding = (-temp_2_shift, temp_2_shift) result = {'spiketimes': [], 'amplitudes': [], 'templates': []} local_chunk, t_offset = data_file.get_data(gidx, chunk_size, padding, nodes=nodes) len_chunk = len(local_chunk) if do_spatial_whitening: if use_gpu: local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False) local_chunk = local_chunk.dot(spatial_whitening).asarray() else: local_chunk = numpy.dot(local_chunk, spatial_whitening) if do_temporal_whitening: local_chunk = scipy.ndimage.filters.convolve1d(local_chunk, temporal_whitening, axis=0, mode='constant') #print "Extracting the peaks..." if collect_all: all_found_spikes = {} for i in xrange(N_e): all_found_spikes[i] = [] local_peaktimes = numpy.zeros(0, dtype=numpy.uint32) if matched_filter: if sign_peaks in ['positive', 'both']: filter_chunk = scipy.ndimage.filters.convolve1d( local_chunk, waveform_pos, axis=0, mode='constant') for i in xrange(N_e): peaktimes = algo.detect_peaks(filter_chunk[:, i], matched_tresholds_pos[i]) local_peaktimes = numpy.concatenate( (local_peaktimes, peaktimes)) if collect_all: all_found_spikes[i] += peaktimes.tolist() if sign_peaks in ['negative', 'both']: filter_chunk = scipy.ndimage.filters.convolve1d( local_chunk, waveform_neg, axis=0, mode='constant') for i in xrange(N_e): peaktimes = algo.detect_peaks(filter_chunk[:, i], matched_tresholds_neg[i]) local_peaktimes = numpy.concatenate( (local_peaktimes, peaktimes)) if collect_all: all_found_spikes[i] += peaktimes.tolist() else: for i in xrange(N_e): if sign_peaks == 'negative': peaktimes = algo.detect_peaks(local_chunk[:, i], thresholds[i], valley=True) elif sign_peaks == 'positive': peaktimes = algo.detect_peaks(local_chunk[:, i], thresholds[i], valley=False) elif sign_peaks == 'both': peaktimes = algo.detect_peaks(numpy.abs(local_chunk[:, i]), thresholds[i], valley=False) local_peaktimes = numpy.concatenate( (local_peaktimes, peaktimes)) if collect_all: all_found_spikes[i] += peaktimes.tolist() local_peaktimes = numpy.unique(local_peaktimes) g_offset = t_offset + padding[0] if ignore_dead_times: dead_indices = numpy.searchsorted( all_dead_times, [t_offset, t_offset + chunk_size]) if dead_indices[0] != dead_indices[1]: local_peaktimes = numpy.array(list( set(local_peaktimes + g_offset).difference( all_dead_times[dead_indices[0]:dead_indices[1]])), dtype=numpy.uint32) - g_offset local_peaktimes = numpy.sort(local_peaktimes) #print "Removing the useless borders..." local_borders = (template_shift, len_chunk - template_shift) idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes < local_borders[1]) local_peaktimes = numpy.compress(idx, local_peaktimes) if collect_all: for i in xrange(N_e): all_found_spikes[i] = numpy.array(all_found_spikes[i], dtype=numpy.uint32) if ignore_dead_times: if dead_indices[0] != dead_indices[1]: all_found_spikes[i] = numpy.array( list( set(all_found_spikes[i] + g_offset).difference( all_dead_times[ dead_indices[0]:dead_indices[1]])), dtype=numpy.uint32) - g_offset all_found_spikes[i] = numpy.sort(all_found_spikes[i]) idx = (all_found_spikes[i] >= local_borders[0]) & ( all_found_spikes[i] < local_borders[1]) all_found_spikes[i] = numpy.compress(idx, all_found_spikes[i]) n_t = len(local_peaktimes) if full_gpu: # all_indices = cmt.CUDAMatrix(all_indices) tmp_gpu = cmt.CUDAMatrix(local_peaktimes.reshape((1, n_t)), copy_on_host=False) if n_t > 0: #print "Computing the b (should full_gpu by putting all chunks on GPU if possible?)..." if collect_all: c_local_chunk = local_chunk.copy() local_chunk = local_chunk.T.ravel() sub_mat = numpy.zeros((size_window, n_t), dtype=numpy.float32) if len_chunk != last_chunk_size: slice_indices = numpy.zeros(0, dtype=numpy.int32) for idx in xrange(N_e): slice_indices = numpy.concatenate( (slice_indices, len_chunk * idx + temp_window)) last_chunk_size = len_chunk for count, idx in enumerate(local_peaktimes): sub_mat[:, count] = numpy.take(local_chunk, slice_indices + idx) del local_chunk if use_gpu: sub_mat = cmt.CUDAMatrix(sub_mat, copy_on_host=False) b = cmt.sparse_dot(templates, sub_mat) else: b = templates.dot(sub_mat) del sub_mat local_restriction = (t_offset, t_offset + chunk_size) all_spikes = local_peaktimes + g_offset # Because for GPU, slicing by columns is more efficient, we need to transpose b #b = b.transpose() if use_gpu and not full_gpu: b = b.asarray() failure = numpy.zeros(n_t, dtype=numpy.int32) if full_gpu: mask = numpy.zeros((2 * n_tm, n_t), dtype=numpy.float32) mask[:n_tm, :] = 1 data = cmt.empty(mask.shape) patch_gpu = b.shape[1] == 1 else: mask = numpy.ones((n_tm, n_t), dtype=numpy.float32) sub_b = b[:n_tm, :] if collect_all: c_all_times = numpy.zeros((len_chunk, N_e), dtype=numpy.bool) c_min_times = numpy.maximum( numpy.arange(len_chunk) - template_shift, 0) c_max_times = numpy.minimum( numpy.arange(len_chunk) + template_shift + 1, len_chunk) for i in xrange(N_e): c_all_times[all_found_spikes[i], i] = True while (numpy.mean(failure) < nb_chances): if full_gpu: gpu_mask = cmt.CUDAMatrix(mask, copy_on_host=False) b.mult(gpu_mask, data) tmp_mat = data.max(0) argmax_bi = numpy.argsort(tmp_mat.asarray()[0, :])[::-1] del tmp_mat else: data = sub_b * mask argmax_bi = numpy.argsort(numpy.max(data, 0))[::-1] for peak_index in argmax_bi: if full_gpu: b_array = b.asarray() sub_b = b_array[:n_tm, :] peak_scalar_products = np.take(sub_b, peak_index, axis=1) best_template_index = np.argmax(peak_scalar_products, axis=0) best_template2_index = best_template_index + n_tm if full_gpu: best_amp = sub_b[best_template_index, peak_index] / n_scalar best_amp2 = b_array[best_template_index, peak_index] / n_scalar else: best_amp = sub_b[best_template_index, peak_index] / n_scalar best_amp2 = b[best_template2_index, peak_index] / n_scalar best_amp_n = best_amp / norm_templates[best_template_index] best_amp2_n = best_amp2 / norm_templates[ best_template2_index] # Verify amplitude constraint. a_min = amp_limits[best_template_index, 0] a_max = amp_limits[best_template_index, 1] if (a_min <= best_amp_n) & (best_amp_n <= a_max): # Keep the matching. peak_time_step = local_peaktimes[peak_index] data = (local_peaktimes - peak_time_step).astype( np.int32) is_neighbor = np.where(np.abs(data) <= temp_2_shift)[0] idx_neighbor = data[is_neighbor] + temp_2_shift nb_neighbors = len(is_neighbor) indices = np.zeros((S_over, nb_neighbors), dtype=np.int32) indices[idx_neighbor, np.arange(nb_neighbors)] = 1 if full_gpu: indices = cmt.CUDAMatrix(indices, copy_on_host=False) if patch_gpu: b_lines = b.get_col_slice(0, b.shape[0]) else: b_lines = b.get_col_slice( is_neighbor[0], is_neighbor[-1] + 1) tmp1 = cmt.sparse_dot(c_overs[best_template_index], indices, mult=-best_amp[keep]) tmp2 = cmt.sparse_dot( c_overs[best_template2_index], indices, mult=-best_amp2[keep]) b_lines.add(tmp1.add(tmp2)) del tmp1, tmp2 else: tmp1 = c_overs[best_template_index].multiply( -best_amp) tmp2 = c_overs[best_template2_index].multiply( -best_amp2) b[:, is_neighbor] += (tmp1 + tmp2).dot(indices) # Add matching to the result. t_spike = all_spikes[peak_index] if (t_spike >= local_restriction[0]) and ( t_spike < local_restriction[1]): #print "Accept spikes", t_spike, local_restriction, type(t_spike), t_spike > local_restriction[0], t_spike < local_restriction[1] result['spiketimes'] += [t_spike] result['amplitudes'] += [(best_amp_n, best_amp2_n)] result['templates'] += [best_template_index] # Mark current matching as tried. mask[best_template_index, peak_index] = 0 else: # Reject the matching. # Update failure counter of the peak. failure[peak_index] += 1 # If the maximal number of failures is reached then mark peak as solved (i.e. not fitted). if failure[peak_index] == nb_chances: mask[:, peak_index] = 0 else: mask[best_template_index, peak_index] = 0 spikes_to_write = numpy.array(result['spiketimes'], dtype=numpy.uint32) amplitudes_to_write = numpy.array(result['amplitudes'], dtype=numpy.float32) templates_to_write = numpy.array(result['templates'], dtype=numpy.uint32) spiketimes_file.write(spikes_to_write.tostring()) amplitudes_file.write(amplitudes_to_write.tostring()) templates_file.write(templates_to_write.tostring()) if collect_all: for temp, spike in zip(templates_to_write, spikes_to_write - g_offset): c_all_times[c_min_times[spike]:c_max_times[spike], neighbors[temp]] = False gspikes = numpy.where(numpy.sum(c_all_times, 1) > 0)[0] c_all_times = numpy.take(c_all_times, gspikes, axis=0) c_local_chunk = numpy.take(c_local_chunk, gspikes, axis=0) * c_all_times if sign_peaks == 'negative': bestlecs = numpy.argmin(c_local_chunk, 1) if matched_filter: threshs = -matched_tresholds_neg[bestlecs] else: threshs = -thresholds[bestlecs] idx = numpy.where(numpy.min(c_local_chunk, 1) < threshs)[0] elif sign_peaks == 'positive': bestlecs = numpy.argmax(c_local_chunk, 1) if matched_filter: threshs = matched_tresholds_pos[bestlecs] else: threshs = thresholds[bestlecs] idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0] elif sign_peaks == 'both': c_local_chunk = numpy.abs(c_local_chunk) bestlecs = numpy.argmax(c_local_chunk, 1) if matched_filter: threshs = numpy.minimum( matched_tresholds_neg[bestlecs], matched_tresholds_pos[bestlecs]) else: threshs = thresholds[bestlecs] idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0] gspikes = numpy.take(gspikes, idx) bestlecs = numpy.take(bestlecs, idx) gspikes_to_write = numpy.array(gspikes + g_offset, dtype=numpy.uint32) gtemplates_to_write = numpy.array(bestlecs, dtype=numpy.uint32) garbage_times_file.write(gspikes_to_write.tostring()) garbage_temp_file.write(gtemplates_to_write.tostring()) if full_gpu: del gpu_mask, b, data sys.stderr.flush() spiketimes_file.flush() os.fsync(spiketimes_file.fileno()) spiketimes_file.close() amplitudes_file.flush() os.fsync(amplitudes_file.fileno()) amplitudes_file.close() templates_file.flush() os.fsync(templates_file.fileno()) templates_file.close() if collect_all: garbage_temp_file.flush() os.fsync(garbage_temp_file.fileno()) garbage_temp_file.close() garbage_times_file.flush() os.fsync(garbage_times_file.fileno()) garbage_times_file.close() comm.Barrier() if comm.rank == 0: io.collect_data(comm.size, params, erase=True) data_file.close()
def main(params, nb_cpu, nb_gpu, use_gpu, extension): logger = init_logging(params.logfile) logger = logging.getLogger('circus.converting') data_file = params.data_file file_out_suff = params.get('data', 'file_out_suff') probe = params.probe output_path = params.get('data', 'file_out_suff') + extension + '.GUI' N_e = params.getint('data', 'N_e') N_t = params.getint('detection', 'N_t') erase_all = params.getboolean('converting', 'erase_all') export_pcs = params.get('converting', 'export_pcs') export_all = params.getboolean('converting', 'export_all') sparse_export = params.getboolean('converting', 'sparse_export') if export_all and not params.getboolean('fitting', 'collect_all'): if comm.rank == 0: print_and_log(['Export unfitted spikes only if [fitting] collect_all is True'], 'error', logger) sys.exit(0) def generate_mapping(probe): p = {} positions = [] nodes = [] for key in probe['channel_groups'].keys(): p.update(probe['channel_groups'][key]['geometry']) nodes += probe['channel_groups'][key]['channels'] positions += [p[channel] for channel in probe['channel_groups'][key]['channels']] idx = numpy.argsort(nodes) positions = numpy.array(positions)[idx] return positions def get_max_loc_channel(params): nodes, edges = get_nodes_and_edges(params) max_loc_channel = 0 for key in edges.keys(): if len(edges[key]) > max_loc_channel: max_loc_channel = len(edges[key]) return max_loc_channel def write_results(path, params, extension): result = io.get_results(params, extension) spikes = numpy.zeros(0, dtype=numpy.uint64) clusters = numpy.zeros(0, dtype=numpy.uint32) amplitudes = numpy.zeros(0, dtype=numpy.double) N_tm = len(result['spiketimes']) for key in result['spiketimes'].keys(): temp_id = int(key.split('_')[-1]) data = result['spiketimes'].pop(key).astype(numpy.uint64) spikes = numpy.concatenate((spikes, data)) data = result['amplitudes'].pop(key).astype(numpy.double) amplitudes = numpy.concatenate((amplitudes, data[:, 0])) clusters = numpy.concatenate((clusters, temp_id*numpy.ones(len(data), dtype=numpy.uint32))) if export_all: print_and_log(["Last %d templates are unfitted spikes on all electrodes" %N_e], 'info', logger) garbage = io.load_data(params, 'garbage', extension) for key in garbage['gspikes'].keys(): elec_id = int(key.split('_')[-1]) data = garbage['gspikes'].pop(key).astype(numpy.uint64) spikes = numpy.concatenate((spikes, data)) amplitudes = numpy.concatenate((amplitudes, numpy.zeros(len(data)))) clusters = numpy.concatenate((clusters, (elec_id + N_tm)*numpy.ones(len(data), dtype=numpy.uint32))) idx = numpy.argsort(spikes) numpy.save(os.path.join(output_path, 'spike_templates'), clusters[idx]) numpy.save(os.path.join(output_path, 'spike_times'), spikes[idx]) numpy.save(os.path.join(output_path, 'amplitudes'), amplitudes[idx]) return def write_templates(path, params, extension): max_loc_channel = get_max_loc_channel(params) templates = io.load_data(params, 'templates', extension) N_tm = templates.shape[1]//2 if sparse_export: n_channels_max = 0 for t in xrange(N_tm): data = numpy.sum(numpy.sum(templates[:, t].toarray().reshape(N_e, N_t), 1) != 0) if data > n_channels_max: n_channels_max = data else: n_channels_max = N_e if export_all: to_write_sparse = numpy.zeros((N_tm + N_e, N_t, n_channels_max), dtype=numpy.float32) mapping_sparse = -1 * numpy.ones((N_tm + N_e, n_channels_max), dtype=numpy.int32) else: to_write_sparse = numpy.zeros((N_tm, N_t, n_channels_max), dtype=numpy.float32) mapping_sparse = -1 * numpy.ones((N_tm, n_channels_max), dtype=numpy.int32) for t in xrange(N_tm): tmp = templates[:, t].toarray().reshape(N_e, N_t).T x, y = tmp.nonzero() nb_loc = len(numpy.unique(y)) if sparse_export: all_positions = numpy.zeros(y.max()+1, dtype=numpy.int32) all_positions[numpy.unique(y)] = numpy.arange(nb_loc, dtype=numpy.int32) pos = all_positions[y] to_write_sparse[t, x, pos] = tmp[x, y] mapping_sparse[t, numpy.arange(nb_loc)] = numpy.unique(y) else: pos = y to_write_sparse[t, x, pos] = tmp[x, y] if export_all: for t in xrange(N_tm, N_tm + N_e): mapping_sparse[t, 0] = t - N_tm numpy.save(os.path.join(output_path, 'templates'), to_write_sparse) if sparse_export: numpy.save(os.path.join(output_path, 'template_ind'), mapping_sparse) return N_tm def write_pcs(path, params, extension, N_tm, mode=0): spikes = numpy.load(os.path.join(output_path, 'spike_times.npy')) labels = numpy.load(os.path.join(output_path, 'spike_templates.npy')) max_loc_channel = get_max_loc_channel(params) nb_features = params.getint('whitening', 'output_dim') sign_peaks = params.get('detection', 'peaks') nodes, edges = get_nodes_and_edges(params) N_total = params.getint('data', 'N_total') if export_all: nb_templates = N_tm + N_e else: nb_templates = N_tm pc_features_ind = numpy.zeros((nb_templates, max_loc_channel), dtype=numpy.int32) clusters = io.load_data(params, 'clusters', extension) best_elec = clusters['electrodes'] if export_all: best_elec = numpy.concatenate((best_elec, numpy.arange(N_e))) inv_nodes = numpy.zeros(N_total, dtype=numpy.int32) inv_nodes[nodes] = numpy.argsort(nodes) for count, elec in enumerate(best_elec): nb_loc = len(edges[nodes[elec]]) pc_features_ind[count, numpy.arange(nb_loc)] = inv_nodes[edges[nodes[elec]]] if sign_peaks in ['negative', 'both']: basis_proj, basis_rec = io.load_data(params, 'basis') elif sign_peaks in ['positive']: basis_proj, basis_rec = io.load_data(params, 'basis-pos') to_process = numpy.arange(comm.rank, nb_templates, comm.size) all_offsets = numpy.zeros(nb_templates, dtype=numpy.int32) for target in xrange(nb_templates): if mode == 0: all_offsets[target] = len(numpy.where(labels == target)[0]) elif mode == 1: all_offsets[target] = min(500, len(numpy.where(labels == target)[0])) all_paddings = numpy.concatenate(([0] , numpy.cumsum(all_offsets))) total_pcs = numpy.sum(all_offsets) pc_file = os.path.join(output_path, 'pc_features.npy') pc_file_ids = os.path.join(output_path, 'pc_feature_spike_ids.npy') from numpy.lib.format import open_memmap if comm.rank == 0: pc_features = open_memmap(pc_file, shape=(total_pcs, nb_features, max_loc_channel), dtype=numpy.float32, mode='w+') if mode == 1: pc_ids = open_memmap(pc_file_ids, shape=(total_pcs, ), dtype=numpy.int32, mode='w+') comm.Barrier() pc_features = open_memmap(pc_file, mode='r+') if mode == 1: pc_ids = open_memmap(pc_file_ids, mode='r+') to_explore = xrange(comm.rank, nb_templates, comm.size) if comm.rank == 0: to_explore = get_tqdm_progressbar(to_explore) all_idx = numpy.zeros(0, dtype=numpy.int32) for gcount, target in enumerate(to_explore): count = all_paddings[target] if mode == 1: idx = numpy.random.permutation(numpy.where(labels == target)[0])[:500] pc_ids[count:count+len(idx)] = idx elif mode == 0: idx = numpy.where(labels == target)[0] elec = best_elec[target] indices = inv_nodes[edges[nodes[elec]]] labels_i = target*numpy.ones(len(idx)) times_i = numpy.take(spikes, idx).astype(numpy.int64) sub_data = io.get_stas(params, times_i, labels_i, elec, neighs=indices, nodes=nodes, auto_align=False) pcs = numpy.dot(sub_data, basis_proj) pcs = numpy.swapaxes(pcs, 1,2) if mode == 0: pc_features[idx, :, :len(indices)] = pcs elif mode == 1: pc_features[count:count+len(idx), :, :len(indices)] = pcs comm.Barrier() if comm.rank == 0: numpy.save(os.path.join(output_path, 'pc_feature_ind'), pc_features_ind.astype(numpy.uint32)) #n_templates, n_loc_chan do_export = True if comm.rank == 0: if os.path.exists(output_path): if not erase_all: do_export = query_yes_no(Fore.WHITE + "Export already made! Do you want to erase everything?", default=None) if do_export: if os.path.exists(os.path.abspath('.phy')): shutil.rmtree(os.path.abspath('.phy')) shutil.rmtree(output_path) if do_export == True: comm.bcast(numpy.array([1], dtype=numpy.int32), root=0) elif do_export == False: comm.bcast(numpy.array([0], dtype=numpy.int32), root=0) else: do_export = bool(comm.bcast(numpy.array([0], dtype=numpy.int32), root=0)) comm.Barrier() if do_export: if comm.rank == 0: os.makedirs(output_path) print_and_log(["Exporting data for the phy GUI with %d CPUs..." %nb_cpu], 'info', logger) if params.getboolean('whitening', 'spatial'): whitening_mat = io.load_data(params, 'spatial_whitening').astype(numpy.double) numpy.save(os.path.join(output_path, 'whitening_mat'), whitening_mat) numpy.save(os.path.join(output_path, 'whitening_mat_inv'), numpy.linalg.inv(whitening_mat)) else: numpy.save(os.path.join(output_path, 'whitening_mat'), numpy.eye(N_e)) numpy.save(os.path.join(output_path, 'channel_positions'), generate_mapping(probe).astype(numpy.double)) nodes, edges = get_nodes_and_edges(params) numpy.save(os.path.join(output_path, 'channel_map'), nodes.astype(numpy.int32)) write_results(output_path, params, extension) N_tm = write_templates(output_path, params, extension) apply_patch_for_similarities(params, extension) template_file = h5py.File(file_out_suff + '.templates%s.hdf5' %extension, 'r', libver='earliest') similarities = template_file.get('maxoverlap')[:] template_file.close() norm = N_e*N_t if export_all: to_write = numpy.zeros((N_tm + N_e, N_tm + N_e), dtype=numpy.single) to_write[:N_tm, :N_tm] = (similarities[:N_tm, :N_tm]/norm).astype(numpy.single) else: to_write = (similarities[:N_tm, :N_tm]/norm).astype(numpy.single) numpy.save(os.path.join(output_path, 'similar_templates'), to_write) comm.bcast(numpy.array([N_tm], dtype=numpy.int32), root=0) else: N_tm = int(comm.bcast(numpy.array([0], dtype=numpy.int32), root=0)) comm.Barrier() make_pcs = 2 if comm.rank == 0: if export_pcs == 'prompt': key = '' while key not in ['a', 's', 'n']: print(Fore.WHITE + "Do you want SpyKING CIRCUS to export PCs? (a)ll / (s)ome / (n)o") key = raw_input('') else: key = export_pcs if key == 'a': make_pcs = 0 comm.bcast(numpy.array([0], dtype=numpy.int32), root=0) elif key == 's': make_pcs = 1 comm.bcast(numpy.array([1], dtype=numpy.int32), root=0) elif key == 'n': comm.bcast(numpy.array([2], dtype=numpy.int32), root=0) if os.path.exists(os.path.join(output_path, 'pc_features.npy')): os.remove(os.path.join(output_path, 'pc_features.npy')) if os.path.exists(os.path.join(output_path, 'pc_feature_ind.npy')): os.remove(os.path.join(output_path, 'pc_feature_ind.npy')) else: make_pcs = comm.bcast(numpy.array([0], dtype=numpy.int32), root=0) make_pcs = make_pcs[0] comm.Barrier() if make_pcs < 2: write_pcs(output_path, params, extension, N_tm, make_pcs)
def main(params, nb_cpu, nb_gpu, use_gpu, file_name, benchmark, sim_same_elec): """ Useful tool to create synthetic datasets for benchmarking. Arguments --------- benchmark : {'fitting', 'clustering', 'synchrony', 'pca-validation', 'smart-search', 'drifts'} """ if sim_same_elec is None: sim_same_elec = 0.8 logger = init_logging(params.logfile) logger = logging.getLogger('circus.benchmarking') numpy.random.seed(265) file_name = os.path.abspath(file_name) data_path = os.path.dirname(file_name) data_suff, ext = os.path.splitext(os.path.basename(file_name)) file_out, ext = os.path.splitext(file_name) if ext == '': ext = '.dat' file_name += ext if ext != '.dat': if comm.rank == 0: print_and_log(['Benchmarking produces raw files: select a .dat extension'], 'error', logger) sys.exit(0) if benchmark not in ['fitting', 'clustering', 'synchrony', 'smart-search', 'drifts']: if comm.rank == 0: print_and_log(['Benchmark need to be in [fitting, clustering, synchrony, smart-search, drifts]'], 'error', logger) sys.exit(0) # The extension `.p` or `.pkl` or `.pickle` seems more appropriate than `.pic`. # see: http://stackoverflow.com/questions/4530111/python-saving-objects-and-using-pickle-extension-of-filename # see: https://wiki.python.org/moin/UsingPickle def write_benchmark(filename, benchmark, cells, rates, amplitudes, sampling, probe, trends=None): """Save benchmark parameters in a file to remember them.""" import cPickle to_write = {'benchmark' : benchmark} to_write['cells'] = cells to_write['rates'] = rates to_write['probe'] = probe to_write['amplitudes'] = amplitudes to_write['sampling'] = sampling if benchmark == 'drifts': to_write['drifts'] = trends cPickle.dump(to_write, open(filename + '.pic', 'w')) # Retrieve some key parameters. templates = io.load_data(params, 'templates') N_tm = templates.shape[1] // 2 trends = None # Normalize some variables. if benchmark == 'fitting': nb_insert = 25 n_cells = numpy.random.random_integers(0, N_tm - 1, nb_insert) rate = nb_insert * [10] amplitude = numpy.linspace(0.5, 5, nb_insert) if benchmark == 'clustering': n_point = 5 n_cells = numpy.random.random_integers(0, N_tm - 1, n_point ** 2) x, y = numpy.mgrid[0:n_point, 0:n_point] rate = numpy.linspace(0.5, 20, n_point)[x.flatten()] amplitude = numpy.linspace(0.5, 5, n_point)[y.flatten()] if benchmark == 'synchrony': nb_insert = 5 corrcoef = 0.2 n_cells = nb_insert * [numpy.random.random_integers(0, N_tm - 1, 1)[0]] rate = 10. / corrcoef amplitude = 2 if benchmark == 'pca-validation': nb_insert = 10 n_cells = numpy.random.random_integers(0, N_tm - 1, nb_insert) rate_min = 0.5 rate_max = 20.0 rate = rate_min + (rate_max - rate_min) * numpy.random.random_sample(nb_insert) amplitude_min = 0.5 amplitude_max = 5.0 amplitude = amplitude_min + (amplitude_max - amplitude_min) * numpy.random.random_sample(nb_insert) if benchmark == 'smart-search': nb_insert = 10 n_cells = nb_insert*[numpy.random.random_integers(0, templates.shape[1]//2-1, 1)[0]] rate = 1 + 5*numpy.arange(nb_insert) amplitude = 2 if benchmark == 'drifts': n_point = 5 n_cells = numpy.random.random_integers(0, templates.shape[1]//2-1, n_point**2) x, y = numpy.mgrid[0:n_point,0:n_point] rate = 5*numpy.ones(n_point)[x.flatten()] amplitude = numpy.linspace(0.5, 5, n_point)[y.flatten()] trends = numpy.random.randn(n_point**2) # Delete the output directory tree if this output directory exists. if comm.rank == 0: if os.path.exists(file_out): shutil.rmtree(file_out) # Check and normalize some variables. if n_cells is None: n_cells = 1 cells = [numpy.random.permutation(numpy.arange(n_cells))[0]] elif not numpy.iterable(n_cells): cells = [n_cells] n_cells = 1 else: cells = n_cells n_cells = len(cells) if numpy.iterable(rate): assert len(rate) == len(cells), "Should have the same number of rates and cells" else: rate = [rate] * len(cells) if numpy.iterable(amplitude): assert len(amplitude) == len(cells), "Should have the same number of amplitudes and cells" else: amplitude = [amplitude] * len(cells) # Retrieve some additional key parameters. #params = detect_memory(params) data_file = params.get_data_file(source=True) N_e = params.getint('data', 'N_e') N_total = params.nb_channels hdf5_compress = params.getboolean('data', 'hdf5_compress') nodes, edges = get_nodes_and_edges(params) N_t = params.getint('detection', 'N_t') inv_nodes = numpy.zeros(N_total, dtype=numpy.int32) inv_nodes[nodes] = numpy.argsort(nodes) do_temporal_whitening = params.getboolean('whitening', 'temporal') do_spatial_whitening = params.getboolean('whitening', 'spatial') N_tm_init = templates.shape[1]//2 thresholds = io.load_data(params, 'thresholds') limits = io.load_data(params, 'limits') best_elecs = io.load_data(params, 'electrodes') norms = io.load_data(params, 'norm-templates') # Create output directory if it does not exist. if comm.rank == 0: if not os.path.exists(file_out): os.makedirs(file_out) # Save benchmark parameters in a file to remember them. if comm.rank == 0: write_benchmark(file_out, benchmark, cells, rate, amplitude, params.rate, params.get('data', 'mapping'), trends) # Synchronize all the threads/processes. comm.Barrier() if do_spatial_whitening: spatial_whitening = io.load_data(params, 'spatial_whitening') if do_temporal_whitening: temporal_whitening = io.load_data(params, 'temporal_whitening') # Retrieve some additional key parameters. chunk_size = params.getint('data', 'chunk_size') scalings = [] params.set('data', 'data_file', file_name) data_file_out = params.get_data_file(is_empty=True) data_file_out.allocate(shape=data_file.shape) # Synchronize all the threads/processes. comm.Barrier() # For each wanted synthesized cell insert a generated template in the set of # existing template. for gcount, cell_id in enumerate(cells): best_elec = best_elecs[cell_id] indices = inv_nodes[edges[nodes[best_elec]]] count = 0 new_indices = [] all_elecs = numpy.random.permutation(numpy.arange(N_e)) reference = templates[:, cell_id].toarray().reshape(N_e, N_t) # Initialize the similarity (i.e. default value). similarity = 1.0 # Find the first eligible template for the wanted synthesized cell. while len(new_indices) != len(indices) or (similarity > sim_same_elec): similarity = 0 if count == len(all_elecs): if comm.rank == 0: print_and_log(["No electrode to move template %d (max similarity is %g)" %(cell_id, similarity)], 'error', logger) sys.exit(0) else: # Get the next shuffled electrode. n_elec = all_elecs[count] if benchmark not in ['synchrony', 'smart-search']: # Process if the shuffled electrode and the nearest electrode # to the synthesized cell are not identical. local_test = n_elec != best_elec else: # Process if the shuffled electrode and the nearest electrode # to the synthesized cell are identical. local_test = n_elec == best_elec if local_test: # Shuffle the neighboring electrodes whithout modifying # the nearest electrode to the synthesized cell. new_indices = inv_nodes[edges[nodes[n_elec]]] idx = numpy.where(new_indices != best_elec)[0] new_indices[idx] = numpy.random.permutation(new_indices[idx]) if len(new_indices) == len(indices): # Shuffle the templates on the neighboring electrodes. new_temp = numpy.zeros(reference.shape, dtype=numpy.float32) new_temp[new_indices, :] = reference[indices, :] # Compute the scaling factor which normalize the # shuffled template. gmin = new_temp.min() data = numpy.where(new_temp == gmin) scaling = -thresholds[data[0][0]]/gmin for i in xrange(templates.shape[1]//2): match = templates[:, i].toarray().reshape(N_e, N_t) d = numpy.corrcoef(match.flatten(), scaling * new_temp.flatten())[0, 1] if d > similarity: similarity = d else: new_indices = [] # Go to the next shuffled electrode. count += 1 #if comm.rank == 0: # print "Template", cell_id, "is shuffled from electrode", best_elec, "to", n_elec, "(max similarity is %g)" %similarity N_tm = templates.shape[1]//2 to_insert = numpy.zeros(reference.shape, dtype=numpy.float32) to_insert[new_indices] = scaling*amplitude[gcount]*templates[:, cell_id].toarray().reshape(N_e, N_t)[indices] to_insert2 = numpy.zeros(reference.shape, dtype=numpy.float32) to_insert2[new_indices] = scaling*amplitude[gcount]*templates[:, cell_id + N_tm].toarray().reshape(N_e, N_t)[indices] ## Insert the selected template. # Retrieve the number of existing templates in the dataset. N_tm = templates.shape[1]//2 # Generate the template of the synthesized cell from the selected # template, the target amplitude and the rescaling (i.e. threshold of # the target electrode). to_insert = numpy.zeros(reference.shape, dtype=numpy.float32) to_insert[new_indices] = scaling * amplitude[gcount] * templates[:, cell_id].toarray().reshape(N_e, N_t)[indices] to_insert = to_insert.flatten() to_insert2 = numpy.zeros(reference.shape, dtype=numpy.float32) to_insert2[new_indices] = scaling * amplitude[gcount] * templates[:, cell_id + N_tm].toarray().reshape(N_e, N_t)[indices] to_insert2 = to_insert2.flatten() # Compute the norm of the generated template. mynorm = numpy.sqrt(numpy.sum(to_insert ** 2) / (N_e * N_t)) mynorm2 = numpy.sqrt(numpy.sum(to_insert2 ** 2) / (N_e * N_t)) # Insert the limits of the generated template. limits = numpy.vstack((limits, limits[cell_id])) # Insert the best electrode of the generated template. best_elecs = numpy.concatenate((best_elecs, [n_elec])) # Insert the norm of the generated template (i.e. central component and # orthogonal component). norms = numpy.insert(norms, N_tm, mynorm) norms = numpy.insert(norms, 2 * N_tm + 1, mynorm2) # Insert the scaling of the generated template. scalings += [scaling] # Retrieve the data about the existing templates. templates = templates.tocoo() xdata = templates.row ydata = templates.col zdata = templates.data # Shift by one the orthogonal components of the existing templates. idx = numpy.where(ydata >= N_tm)[0] ydata[idx] += 1 # Insert the central component of the selected template. dx = to_insert.nonzero()[0].astype(numpy.int32) xdata = numpy.concatenate((xdata, dx)) ydata = numpy.concatenate((ydata, N_tm * numpy.ones(len(dx), dtype=numpy.int32))) zdata = numpy.concatenate((zdata, to_insert[dx])) # Insert the orthogonal component of the selected template. dx = to_insert2.nonzero()[0].astype(numpy.int32) xdata = numpy.concatenate((xdata, dx)) ydata = numpy.concatenate((ydata, (2 * N_tm + 1) * numpy.ones(len(dx), dtype=numpy.int32))) zdata = numpy.concatenate((zdata, to_insert2[dx])) # Recontruct the matrix of templates. templates = scipy.sparse.csc_matrix((zdata, (xdata, ydata)), shape=(N_e * N_t, 2 * (N_tm + 1))) # Remove all the expired data. if benchmark == 'pca-validation': # Remove all the expired data. N_tm_init = 0 N_tm = templates.shape[1] / 2 limits = limits[N_tm - nb_insert:, :] best_elecs = best_elecs[N_tm - nb_insert:] norms = numpy.concatenate((norms[N_tm-nb_insert:N_tm], norms[2*N_tm-nb_insert:2*N_tm])) scalings = scalings templates = templates.tocoo() xdata = templates.row ydata = templates.col zdata = templates.data idx_cen = numpy.logical_and(N_tm - nb_insert <= ydata, ydata < N_tm) idx_cen = numpy.where(idx_cen)[0] idx_ort = numpy.logical_and(2 * N_tm - nb_insert <= ydata, ydata < 2 * N_tm) idx_ort = numpy.where(idx_ort)[0] ydata[idx_cen] = ydata[idx_cen] - (N_tm - nb_insert) ydata[idx_ort] = ydata[idx_ort] - 2 * (N_tm - nb_insert) idx = numpy.concatenate((idx_cen, idx_ort)) xdata = xdata[idx] ydata = ydata[idx] zdata = zdata[idx] templates = scipy.sparse.csc_matrix((zdata, (xdata, ydata)), shape=(N_e * N_t, 2 * nb_insert)) # Retrieve the information about the organisation of the chunks of data. nb_chunks, last_chunk_len = data_file.analyze(chunk_size) # Display informations about the generated benchmark. if comm.rank == 0: print_and_log(["Generating benchmark data [%s] with %d cells" %(benchmark, n_cells)], 'info', logger) purge(file_out, '.data') template_shift = params.getint('detection', 'template_shift') all_chunks = numpy.arange(nb_chunks) to_process = all_chunks[numpy.arange(comm.rank, nb_chunks, comm.size)] loc_nb_chunks = len(to_process) numpy.random.seed(comm.rank) to_explore = xrange(comm.rank, nb_chunks, comm.size) # Initialize the progress bar about the generation of the benchmark. if comm.rank == 0: to_explore = get_tqdm_progressbar(to_explore) # Open the file for collective I/O. #g = myfile.Open(comm, file_name, MPI.MODE_RDWR) #g.Set_view(data_offset, data_mpi, data_mpi) data_file_out.open(mode='r+') # Open the thread/process' files to collect the results. spiketimes_filename = os.path.join(file_out, data_suff + '.spiketimes-%d.data' %comm.rank) spiketimes_file = open(spiketimes_filename, 'wb') amplitude_filename = os.path.join(file_out, data_suff + '.amplitudes-%d.data' %comm.rank) amplitudes_file = open(amplitude_filename, 'wb') templates_filename = os.path.join(file_out, data_suff + '.templates-%d.data' %comm.rank) templates_file = open(templates_filename, 'wb') real_amps_filename = os.path.join(file_out, data_suff + '.real_amps-%d.data' %comm.rank) real_amps_file = open(real_amps_filename, 'wb') voltages_filename = os.path.join(file_out, data_suff + '.voltages-%d.data' %comm.rank) voltages_file = open(voltages_filename, 'wb') # For each chunk of data associate to the current thread/process generate # the new chunk of data (i.e. with considering the added synthesized cells). for count, gidx in enumerate(to_explore): #if (last_chunk_len > 0) and (gidx == (nb_chunks - 1)): # chunk_len = last_chunk_len # chunk_size = last_chunk_len // N_total result = {'spiketimes' : [], 'amplitudes' : [], 'templates' : [], 'real_amps' : [], 'voltages' : []} offset = gidx * chunk_size local_chunk, t_offset = data_file.get_data(gidx, chunk_size, nodes=nodes) if benchmark == 'pca-validation': # Clear the current data chunk. local_chunk = numpy.zeros(local_chunk.shape, dtype=local_chunk.dtype) # Handle whitening if necessary. if do_spatial_whitening: local_chunk = numpy.dot(local_chunk, spatial_whitening) if do_temporal_whitening: local_chunk = scipy.ndimage.filters.convolve1d(local_chunk, temporal_whitening, axis=0, mode='constant') if benchmark is 'synchrony': # Generate some spike indices (i.e. times) at the given rate for # 'synchrony' mode. Each synthesized cell will use a subset of this # spike times. mips = numpy.random.rand(chunk_size) < rate[0] / float(params.rate) # For each synthesized cell generate its spike indices (i.e.times) and # add them to the dataset. for idx in xrange(len(cells)): if benchmark is 'synchrony': # Choose a subset of the spike indices generated before. The # size of this subset is parameterized by the target correlation # coefficients. sidx = numpy.where(mips == True)[0] spikes = numpy.zeros(chunk_size, dtype=numpy.bool) spikes[sidx[numpy.random.rand(len(sidx)) < corrcoef]] = True else: # Generate some spike indices at the given rate. spikes = numpy.random.rand(chunk_size) < rate[idx] / float(params.rate) if benchmark == 'drifts': amplitudes = numpy.ones(len(spikes)) + trends[idx]*((spikes + offset)/(5*60*float(params.rate))) else: amplitudes = numpy.ones(len(spikes)) # Padding with `False` to avoid the insertion of partial spikes at # the edges of the signal. spikes[:N_t] = False spikes[-N_t:] = False # Find the indices of the spike samples. spikes = numpy.where(spikes == True)[0] n_template = N_tm_init + idx loc_template = templates[:, n_template].toarray().reshape(N_e, N_t) first_flat = loc_template.T.flatten() norm_flat = numpy.sum(first_flat ** 2) # For each index (i.e. spike sample location) add the spike to the # chunk of data. refractory = int(5 * 1e-3 * params.rate) t_last = - refractory for scount, spike in enumerate(spikes): if (spike - t_last) > refractory: local_chunk[spike-template_shift:spike+template_shift+1, :] += amplitudes[scount]*loc_template.T amp = numpy.dot(local_chunk[spike-template_shift:spike+template_shift+1, :].flatten(), first_flat) amp /= norm_flat result['real_amps'] += [amp] result['spiketimes'] += [spike + offset] result['amplitudes'] += [(amplitudes[scount], 0)] result['templates'] += [n_template] result['voltages'] += [local_chunk[spike, best_elecs[idx]]] t_last = spike # Write the results into the thread/process' files. spikes_to_write = numpy.array(result['spiketimes'], dtype=numpy.uint32) amplitudes_to_write = numpy.array(result['amplitudes'], dtype=numpy.float32) templates_to_write = numpy.array(result['templates'], dtype=numpy.int32) real_amps_to_write = numpy.array(result['real_amps'], dtype=numpy.float32) voltages_to_write = numpy.array(result['voltages'], dtype=numpy.float32) spiketimes_file.write(spikes_to_write.tostring()) amplitudes_file.write(amplitudes_to_write.tostring()) templates_file.write(templates_to_write.tostring()) real_amps_file.write(real_amps_to_write.tostring()) voltages_file.write(voltages_to_write.tostring()) #print count, 'spikes inserted...' #new_chunk = numpy.zeros((chunk_size, N_total), dtype=numpy.float32) #new_chunk[:, nodes] = local_chunk # Overwrite the new chunk of data using explicit offset. #new_chunk = new_chunk.flatten() #g.Write_at(gidx * chunk_len, new_chunk) data_file_out.set_data(offset, local_chunk) # Update the progress bar about the generation of the benchmark. # Close the thread/process' files. spiketimes_file.flush() os.fsync(spiketimes_file.fileno()) spiketimes_file.close() amplitudes_file.flush() os.fsync(amplitudes_file.fileno()) amplitudes_file.close() templates_file.flush() os.fsync(templates_file.fileno()) templates_file.close() real_amps_file.flush() os.fsync(real_amps_file.fileno()) real_amps_file.close() voltages_file.flush() os.fsync(voltages_file.fileno()) voltages_file.close() # Close the file for collective I/O. data_file_out.close() data_file.close() # Synchronize all the threads/processes. comm.Barrier() ## Eventually, perform all the administrative tasks. ## (i.e. files and folders management). file_params = file_out + '.params' if comm.rank == 0: # Create `injected` directory if it does not exist result_path = os.path.join(file_out, 'injected') if not os.path.exists(result_path): os.makedirs(result_path) # Copy initial configuration file from `<dataset1>.params` to `<dataset2>.params`. shutil.copy2(params.get('data', 'data_file_noext') + '.params', file_params) new_params = CircusParser(file_name) # Copy initial basis file from `<dataset1>/<dataset1>.basis.hdf5` to # `<dataset2>/injected/<dataset2>.basis.hdf5. shutil.copy2(params.get('data', 'file_out') + '.basis.hdf5', os.path.join(result_path, data_suff + '.basis.hdf5')) # Save templates into `<dataset>/<dataset>.templates.hdf5`. mydata = h5py.File(os.path.join(file_out, data_suff + '.templates.hdf5'), 'w') templates = templates.tocoo() if hdf5_compress: mydata.create_dataset('temp_x', data=templates.row, compression='gzip') mydata.create_dataset('temp_y', data=templates.col, compression='gzip') mydata.create_dataset('temp_data', data=templates.data, compression='gzip') else: mydata.create_dataset('temp_x', data=templates.row) mydata.create_dataset('temp_y', data=templates.col) mydata.create_dataset('temp_data', data=templates.data) mydata.create_dataset('temp_shape', data=numpy.array([N_e, N_t, templates.shape[1]], dtype=numpy.int32)) mydata.create_dataset('limits', data=limits) mydata.create_dataset('norms', data=norms) mydata.close() # Save electrodes into `<dataset>/<dataset>.clusters.hdf5`. mydata = h5py.File(os.path.join(file_out, data_suff + '.clusters.hdf5'), 'w') mydata.create_dataset('electrodes', data=best_elecs) mydata.close() comm.Barrier() if comm.rank == 0: # Gather data from all threads/processes. f_next, extension = os.path.splitext(file_name) file_out_bis = os.path.join(f_next, os.path.basename(f_next)) #new_params.set('data', 'file_out', file_out_bis) # Output file without suffix #new_params.set('data', 'file_out_suff', file_out_bis + params.get('data', 'suffix')) new_params.get_data_file() io.collect_data(comm.size, new_params, erase=True, with_real_amps=True, with_voltages=True, benchmark=True) # Change some flags in the configuration file. new_params.write('whitening', 'temporal', 'False') # Disable temporal filtering new_params.write('whitening', 'spatial', 'False') # Disable spatial filtering new_params.write('data', 'data_dtype', 'float32') # Set type of the data to float32 new_params.write('data', 'dtype_offset', 'auto') # Set padding for data to auto # Move results from `<dataset>/<dataset>.result.hdf5` to # `<dataset>/injected/<dataset>.result.hdf5`. shutil.move(os.path.join(file_out, data_suff + '.result.hdf5'), os.path.join(result_path, data_suff + '.result.hdf5')) # Save scalings into `<dataset>/injected/<dataset>.scalings.npy`. numpy.save(os.path.join(result_path, data_suff + '.scalings'), scalings) file_name_noext, ext = os.path.splitext(file_name) # Copy basis from `<dataset>/injected/<dataset>.basis.hdf5` to # `<dataset>/<dataset>.basis.hdf5`. shutil.copy2(os.path.join(result_path, data_suff + '.basis.hdf5'), os.path.join(file_out, data_suff + '.basis.hdf5')) if benchmark not in ['fitting', 'synchrony']: # Copy templates from `<dataset>/<dataset>.templates.hdf5` to # `<dataset>/injected/<dataset>.templates.hdf5` shutil.move(os.path.join(file_out, data_suff + '.templates.hdf5'), os.path.join(result_path, data_suff + '.templates.hdf5'))
def main(params, nb_cpu, nb_gpu, use_gpu, extension): _ = init_logging(params.logfile) logger = logging.getLogger('circus.converting') data_file = params.data_file file_out_suff = params.get('data', 'file_out_suff') probe = params.probe output_path = params.get('data', 'file_out_suff') + extension + '.GUI' N_e = params.getint('data', 'N_e') prelabelling = params.getboolean('converting', 'prelabelling') N_t = params.getint('detection', 'N_t') erase_all = params.getboolean('converting', 'erase_all') export_pcs = params.get('converting', 'export_pcs') export_all = params.getboolean('converting', 'export_all') sparse_export = params.getboolean('converting', 'sparse_export') rpv_threshold = params.getfloat('converting', 'rpv_threshold') if export_all and not params.getboolean('fitting', 'collect_all'): if comm.rank == 0: print_and_log([ 'Export unfitted spikes only if [fitting] collect_all is True' ], 'error', logger) sys.exit(0) def generate_mapping(probe): p = {} positions = [] nodes = [] shanks = [] for key in list(probe['channel_groups'].keys()): p.update(probe['channel_groups'][key]['geometry']) nodes += probe['channel_groups'][key]['channels'] positions += [ p[channel] for channel in probe['channel_groups'][key]['channels'] ] shanks += [key] * len(probe['channel_groups'][key]['channels']) positions = numpy.array(positions) shanks = numpy.array(shanks) return positions, shanks def get_max_loc_channel(params, extension): if test_if_support(params, extension): supports = io.load_data(params, 'supports', extension) max_loc_channel = numpy.sum(supports, 1).max() else: nodes, edges = get_nodes_and_edges(params) max_loc_channel = 0 for key in list(edges.keys()): if len(edges[key]) > max_loc_channel: max_loc_channel = len(edges[key]) return max_loc_channel def write_results(path, params, extension): result = io.get_results(params, extension) spikes = [numpy.zeros(0, dtype=numpy.uint64)] clusters = [numpy.zeros(0, dtype=numpy.uint32)] amplitudes = [numpy.zeros(0, dtype=numpy.double)] N_tm = len(result['spiketimes']) has_purity = test_if_purity(params, extension) rpvs = [] if prelabelling: labels = [] norms = io.load_data(params, 'norm-templates', extension) norms = norms[:len(norms) // 2] if has_purity: purity = io.load_data(params, 'purity', extension) for key in list(result['spiketimes'].keys()): temp_id = int(key.split('_')[-1]) myspikes = result['spiketimes'].pop(key).astype(numpy.uint64) spikes.append(myspikes) myamplitudes = result['amplitudes'].pop(key).astype(numpy.double) amplitudes.append(myamplitudes[:, 0]) clusters.append(temp_id * numpy.ones(len(myamplitudes), dtype=numpy.uint32)) rpv = get_rpv(myspikes, params.data_file.sampling_rate) rpvs += [[temp_id, rpv]] if prelabelling: if has_purity: if rpv <= rpv_threshold: if purity[temp_id] > 0.75: labels += [[temp_id, 'good']] else: if purity[temp_id] > 0.75: labels += [[temp_id, 'mua']] else: labels += [[temp_id, 'noise']] else: median_amp = numpy.median(myamplitudes[:, 0]) std_amp = numpy.std(myamplitudes[:, 0]) if rpv <= rpv_threshold and numpy.abs(median_amp - 1) < 0.25: labels += [[temp_id, 'good']] else: if median_amp < 0.5: labels += [[temp_id, 'mua']] elif norms[temp_id] < 0.1: labels += [[temp_id, 'noise']] if export_all: print_and_log([ "Last %d templates are unfitted spikes on all electrodes" % N_e ], 'info', logger) garbage = io.load_data(params, 'garbage', extension) for key in list(garbage['gspikes'].keys()): elec_id = int(key.split('_')[-1]) data = garbage['gspikes'].pop(key).astype(numpy.uint64) spikes.append(data) amplitudes.append(numpy.ones(len(data))) clusters.append((elec_id + N_tm) * numpy.ones(len(data), dtype=numpy.uint32)) if prelabelling: f = open(os.path.join(output_path, 'cluster_group.tsv'), 'w') f.write('cluster_id\tgroup\n') for l in labels: f.write('%s\t%s\n' % (l[0], l[1])) f.close() # f = open(os.path.join(output_path, 'cluster_rpv.tsv'), 'w') # f.write('cluster_id\trpv\n') # for l in rpvs: # f.write('%s\t%s\n' % (l[0], l[1])) # f.close() spikes = numpy.concatenate(spikes).astype(numpy.uint64) amplitudes = numpy.concatenate(amplitudes).astype(numpy.double) clusters = numpy.concatenate(clusters).astype(numpy.uint32) idx = numpy.argsort(spikes) numpy.save(os.path.join(output_path, 'spike_templates'), clusters[idx]) numpy.save(os.path.join(output_path, 'spike_times'), spikes[idx]) numpy.save(os.path.join(output_path, 'amplitudes'), amplitudes[idx]) return def write_templates(path, params, extension): max_loc_channel = get_max_loc_channel(params, extension) templates = io.load_data(params, 'templates', extension) N_tm = templates.shape[1] // 2 nodes, edges = get_nodes_and_edges(params) if sparse_export: n_channels_max = 0 for t in range(N_tm): data = numpy.sum( numpy.sum(templates[:, t].toarray().reshape(N_e, N_t), 1) != 0) if data > n_channels_max: n_channels_max = data else: n_channels_max = N_e if export_all: to_write_sparse = numpy.zeros((N_tm + N_e, N_t, n_channels_max), dtype=numpy.float32) mapping_sparse = -1 * numpy.ones( (N_tm + N_e, n_channels_max), dtype=numpy.int32) else: to_write_sparse = numpy.zeros((N_tm, N_t, n_channels_max), dtype=numpy.float32) mapping_sparse = -1 * numpy.ones( (N_tm, n_channels_max), dtype=numpy.int32) has_purity = test_if_purity(params, extension) if has_purity: purity = io.load_data(params, 'purity', extension) f = open(os.path.join(output_path, 'cluster_purity.tsv'), 'w') f.write('cluster_id\tpurity\n') for i in range(N_tm): f.write('%d\t%g\n' % (i, purity[i])) f.close() for t in range(N_tm): tmp = templates[:, t].toarray().reshape(N_e, N_t).T x, y = tmp.nonzero() nb_loc = len(numpy.unique(y)) if sparse_export: all_positions = numpy.zeros(y.max() + 1, dtype=numpy.int32) all_positions[numpy.unique(y)] = numpy.arange( nb_loc, dtype=numpy.int32) pos = all_positions[y] to_write_sparse[t, x, pos] = tmp[x, y] mapping_sparse[t, numpy.arange(nb_loc)] = numpy.unique(y) else: pos = y to_write_sparse[t, x, pos] = tmp[x, y] if export_all: garbage = io.load_data(params, 'garbage', extension) for t in range(N_tm, N_tm + N_e): elec = t - N_tm spikes = garbage['gspikes'].pop('elec_%d' % elec).astype( numpy.int64) spikes = numpy.random.permutation(spikes)[:100] mapping_sparse[t, 0] = t - N_tm waveform = io.get_stas(params, times_i=spikes, labels_i=np.ones(len(spikes)), src=elec, neighs=[elec], nodes=nodes, mean_mode=True) nb_loc = 1 if sparse_export: to_write_sparse[t, :, 0] = waveform else: to_write_sparse[t, :, elec] = waveform numpy.save(os.path.join(output_path, 'templates'), to_write_sparse) if sparse_export: numpy.save(os.path.join(output_path, 'template_ind'), mapping_sparse) return N_tm def write_pcs(path, params, extension, N_tm, mode=0): spikes = numpy.load(os.path.join(output_path, 'spike_times.npy')) labels = numpy.load(os.path.join(output_path, 'spike_templates.npy')) max_loc_channel = get_max_loc_channel(params, extension) nb_features = params.getint('whitening', 'output_dim') sign_peaks = params.get('detection', 'peaks') nodes, edges = get_nodes_and_edges(params) N_total = params.getint('data', 'N_total') has_support = test_if_support(params, extension) if has_support: supports = io.load_data(params, 'supports', extension) else: inv_nodes = numpy.zeros(N_total, dtype=numpy.int32) inv_nodes[nodes] = numpy.arange(len(nodes)) if export_all: nb_templates = N_tm + N_e else: nb_templates = N_tm pc_features_ind = numpy.zeros((nb_templates, max_loc_channel), dtype=numpy.int32) best_elec = io.load_data(params, 'electrodes', extension) if export_all: best_elec = numpy.concatenate((best_elec, numpy.arange(N_e))) if has_support: for count, support in enumerate(supports): nb_loc = numpy.sum(support) pc_features_ind[count, numpy.arange(nb_loc)] = numpy.where( support == True)[0] else: for count, elec in enumerate(best_elec): nb_loc = len(edges[nodes[elec]]) pc_features_ind[count, numpy.arange(nb_loc)] = inv_nodes[edges[ nodes[elec]]] if sign_peaks in ['negative', 'both']: basis_proj, basis_rec = io.load_data(params, 'basis') elif sign_peaks in ['positive']: basis_proj, basis_rec = io.load_data(params, 'basis-pos') to_process = numpy.arange(comm.rank, nb_templates, comm.size) all_offsets = numpy.zeros(nb_templates, dtype=numpy.int32) for target in range(nb_templates): if mode == 0: all_offsets[target] = len(numpy.where(labels == target)[0]) elif mode == 1: all_offsets[target] = min( 500, len(numpy.where(labels == target)[0])) all_paddings = numpy.concatenate(([0], numpy.cumsum(all_offsets))) total_pcs = numpy.sum(all_offsets) pc_file = os.path.join(output_path, 'pc_features.npy') pc_file_ids = os.path.join(output_path, 'pc_feature_spike_ids.npy') from numpy.lib.format import open_memmap if comm.rank == 0: pc_features = open_memmap(pc_file, shape=(total_pcs, nb_features, max_loc_channel), dtype=numpy.float32, mode='w+') if mode == 1: pc_ids = open_memmap(pc_file_ids, shape=(total_pcs, ), dtype=numpy.int32, mode='w+') comm.Barrier() pc_features = open_memmap(pc_file, mode='r+') if mode == 1: pc_ids = open_memmap(pc_file_ids, mode='r+') to_explore = list(range(comm.rank, nb_templates, comm.size)) if comm.rank == 0: to_explore = get_tqdm_progressbar(params, to_explore) all_idx = numpy.zeros(0, dtype=numpy.int32) for gcount, target in enumerate(to_explore): count = all_paddings[target] if mode == 1: idx = numpy.random.permutation( numpy.where(labels == target)[0])[:500] pc_ids[count:count + len(idx)] = idx elif mode == 0: idx = numpy.where(labels == target)[0] elec = best_elec[target] if has_support: if target >= len(supports): indices = [target - N_tm] else: indices = numpy.where(supports[target])[0] else: indices = inv_nodes[edges[nodes[elec]]] labels_i = target * numpy.ones(len(idx)) times_i = numpy.take(spikes, idx).astype(numpy.int64) sub_data = io.get_stas(params, times_i, labels_i, elec, neighs=indices, nodes=nodes, auto_align=False) pcs = numpy.dot(sub_data, basis_proj) pcs = numpy.swapaxes(pcs, 1, 2) if mode == 0: pc_features[idx, :, :len(indices)] = pcs elif mode == 1: pc_features[count:count + len(idx), :, :len(indices)] = pcs comm.Barrier() if comm.rank == 0: numpy.save(os.path.join(output_path, 'pc_feature_ind'), pc_features_ind.astype( numpy.uint32)) # n_templates, n_loc_chan do_export = True if comm.rank == 0: if os.path.exists(output_path): if not erase_all: do_export = query_yes_no( Fore.WHITE + "Export already made! Do you want to erase everything?", default=None) if do_export: if os.path.exists(os.path.abspath('.phy')): shutil.rmtree(os.path.abspath('.phy')) shutil.rmtree(output_path) if do_export: comm.bcast(numpy.array([1], dtype=numpy.int32), root=0) else: comm.bcast(numpy.array([0], dtype=numpy.int32), root=0) else: do_export = bool( comm.bcast(numpy.array([0], dtype=numpy.int32), root=0)) comm.Barrier() if do_export: #apply_patch_for_similarities(params, extension) if comm.rank == 0: os.makedirs(output_path) print_and_log( ["Exporting data for the phy GUI with %d CPUs..." % nb_cpu], 'info', logger) if params.getboolean('whitening', 'spatial'): whitening_mat = io.load_data( params, 'spatial_whitening').astype(numpy.double) numpy.save(os.path.join(output_path, 'whitening_mat'), whitening_mat) numpy.save(os.path.join(output_path, 'whitening_mat_inv'), numpy.linalg.inv(whitening_mat)) else: numpy.save(os.path.join(output_path, 'whitening_mat'), numpy.eye(N_e)) positions, shanks = generate_mapping(probe) numpy.save(os.path.join(output_path, 'channel_positions'), positions.astype(numpy.double)) numpy.save(os.path.join(output_path, 'channel_shanks'), shanks.astype(numpy.double)) nodes, edges = get_nodes_and_edges(params) numpy.save(os.path.join(output_path, 'channel_map'), nodes.astype(numpy.int32)) write_results(output_path, params, extension) N_tm = write_templates(output_path, params, extension) template_file = h5py.File(file_out_suff + '.templates%s.hdf5' % extension, 'r', libver='earliest') similarities = template_file.get('maxoverlap')[:] template_file.close() norm = N_e * N_t if export_all: to_write = numpy.zeros((N_tm + N_e, N_tm + N_e), dtype=numpy.single) to_write[:N_tm, :N_tm] = (similarities[:N_tm, :N_tm] / norm).astype(numpy.single) else: to_write = (similarities[:N_tm, :N_tm] / norm).astype( numpy.single) numpy.save(os.path.join(output_path, 'similar_templates'), to_write) comm.bcast(numpy.array([N_tm], dtype=numpy.int32), root=0) else: N_tm = int(comm.bcast(numpy.array([0], dtype=numpy.int32), root=0)) comm.Barrier() make_pcs = 2 if comm.rank == 0: if export_pcs == 'prompt': key = '' while key not in ['a', 's', 'n']: print(( Fore.WHITE + "Do you want SpyKING CIRCUS to export PCs? (a)ll / (s)ome / (n)o" )) key = input('') else: key = export_pcs if key == 'a': make_pcs = 0 comm.bcast(numpy.array([0], dtype=numpy.int32), root=0) elif key == 's': make_pcs = 1 comm.bcast(numpy.array([1], dtype=numpy.int32), root=0) elif key == 'n': comm.bcast(numpy.array([2], dtype=numpy.int32), root=0) if os.path.exists(os.path.join(output_path, 'pc_features.npy')): os.remove(os.path.join(output_path, 'pc_features.npy')) if os.path.exists( os.path.join(output_path, 'pc_feature_ind.npy')): os.remove(os.path.join(output_path, 'pc_feature_ind.npy')) else: make_pcs = comm.bcast(numpy.array([0], dtype=numpy.int32), root=0) make_pcs = make_pcs[0] comm.Barrier() if make_pcs < 2: write_pcs(output_path, params, extension, N_tm, make_pcs) supported_by_phy = ['raw_binary', 'mcs_raw_binary', 'mda'] file_format = data_file.description gui_params = {} if file_format in supported_by_phy: if not params.getboolean('data', 'overwrite'): gui_params['dat_path'] = r"%s" % params.get( 'data', 'data_file_no_overwrite') else: if params.get('data', 'stream_mode') == 'multi-files': data_file = params.get_data_file(source=True, has_been_created=False) gui_params['dat_path'] = "[" for f in data_file.get_file_names(): gui_params['dat_path'] += 'r"%s", ' % f gui_params['dat_path'] += "]" else: gui_params['dat_path'] = 'r"%s"' % params.get( 'data', 'data_file') else: gui_params['dat_path'] = 'giverandomname.dat' gui_params['n_channels_dat'] = params.nb_channels gui_params['n_features_per_channel'] = 5 gui_params['dtype'] = data_file.data_dtype if 'data_offset' in list(data_file.params.keys()): gui_params['offset'] = data_file.data_offset gui_params['sample_rate'] = params.rate gui_params['dir_path'] = output_path gui_params['hp_filtered'] = True f = open(os.path.join(output_path, 'params.py'), 'w') for key, value in list(gui_params.items()): if key in ['dir_path', 'dtype']: f.write('%s = r"%s"\n' % (key, value)) else: f.write("%s = %s\n" % (key, value)) f.close()
def main(params, nb_cpu, nb_gpu, use_gpu): # Part 1: Whitening numpy.random.seed(420) # params = detect_memory(params) _ = init_logging(params.logfile) logger = logging.getLogger('circus.whitening') ################################################################# data_file = params.data_file N_e = params.getint('data', 'N_e') hdf5_compress = params.getboolean('data', 'hdf5_compress') N_total = params.nb_channels N_t = params.getint('detection', 'N_t') dist_peaks = params.getint('detection', 'dist_peaks') template_shift = params.getint('detection', 'template_shift') file_out_suff = params.get('data', 'file_out_suff') spike_thresh = params.getfloat('detection', 'spike_thresh') spike_width = params.getfloat('detection', 'spike_width') matched_filter = params.getboolean('detection', 'matched-filter') matched_thresh = params.getfloat('detection', 'matched_thresh') fudge = params.getfloat('whitening', 'fudge') sign_peaks = params.get('detection', 'peaks') do_temporal_whitening = params.getboolean('whitening', 'temporal') do_spatial_whitening = params.getboolean('whitening', 'spatial') ignore_spikes = params.getboolean('whitening', 'ignore_spikes') chunk_size = detect_memory(params, whitening=True) plot_path = os.path.join(params.get('data', 'file_out_suff'), 'plots') nodes, edges = get_nodes_and_edges(params) safety_time = params.getint('whitening', 'safety_time') safety_space = params.getboolean('whitening', 'safety_space') sort_waveforms = params.getboolean('whitening', 'sort_waveforms') nb_temp_white = min(max(20, comm.size), N_e) max_silence_1 = int(20 * params.rate // comm.size) max_silence_2 = 5000 inv_nodes = numpy.zeros(N_total, dtype=numpy.int32) inv_nodes[nodes] = numpy.arange(len(nodes)) jitter_range = params.getint('detection', 'jitter_range') template_shift_2 = template_shift + jitter_range use_hanning = params.getboolean('detection', 'hanning') rejection_threshold = params.getfloat('detection', 'rejection_threshold') noise_window = params.getint('detection', 'noise_time') data_file.open() ################################################################# if use_hanning: hanning_filter = numpy.hanning(N_t) if comm.rank == 0: print_and_log( ["Analyzing data to get whitening matrices and thresholds..."], 'default', logger) nodes_indices = {} for elec in numpy.arange(N_e): nodes_indices[elec] = inv_nodes[edges[nodes[elec]]] if use_gpu: import cudamat as cmt # # Need to properly handle multi GPU per MPI nodes? if nb_gpu > nb_cpu: gpu_id = int(comm.rank // nb_cpu) else: gpu_id = 0 cmt.cuda_set_device(gpu_id) cmt.init() cmt.cuda_sync_threads() nb_chunks, last_chunk_len = data_file.analyze(chunk_size) if nb_chunks < comm.size: res = io.data_stats(params, show=False) chunk_size = int(res * params.rate // comm.size) if comm.rank == 0: print_and_log( ["Too much cores, automatically resizing the data chunks"], 'debug', logger) nb_chunks, last_chunk_len = data_file.analyze(chunk_size) # I guess this is more relevant, to take signals from all over the recordings. if nb_chunks > comm.size: all_chunks = numpy.random.permutation( numpy.arange(nb_chunks - 1, dtype=numpy.int32)) else: all_chunks = numpy.random.permutation( numpy.arange(nb_chunks, dtype=numpy.int32)) all_electrodes = numpy.random.permutation(N_e) numpy.random.seed(comm.rank) for gidx in [all_chunks[comm.rank]]: # print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..." local_chunk, t_offset = data_file.get_data(gidx, chunk_size, nodes=nodes) local_shape = len(local_chunk) # print "Node", comm.rank, "computes the median absolute deviations in a random chunk" thresholds = numpy.zeros(N_e, dtype=numpy.float32) for i in range(N_e): u = numpy.median(local_chunk[:, i], 0) thresholds[i] = numpy.median(numpy.abs(local_chunk[:, i] - u), 0) gdata = gather_array(thresholds, comm) if comm.rank == 0: gdata = gdata.reshape((comm.size, N_e)) thresholds = numpy.mean(gdata, 0) bfile = h5py.File(file_out_suff + '.basis.hdf5', 'w', libver='earliest') io.write_datasets(bfile, ['thresholds'], {'thresholds': thresholds}, compression=hdf5_compress) bfile.close() comm.Barrier() thresholds = io.load_data(params, 'thresholds') local_borders = (template_shift, local_shape - template_shift) found_peaktimes = [] if ignore_spikes: # Extracting the peaks. local_peaktimes = [np.empty(0, dtype=numpy.uint32)] for i in range(N_e): peaktimes = scipy.signal.find_peaks(numpy.abs(local_chunk[:, i]), height=thresholds[i], width=spike_width, wlen=N_t)[0] peaktimes = peaktimes.astype(numpy.uint32) # print "Removing the useless borders..." idx = (peaktimes >= local_borders[0]) & (peaktimes < local_borders[1]) peaktimes = numpy.compress(idx, peaktimes) found_peaktimes.append(peaktimes) else: for i in range(N_e): found_peaktimes.append(numpy.zeros(0, dtype=numpy.uint32)) all_peaktimes = numpy.concatenate(found_peaktimes) local_peaktimes = numpy.unique(all_peaktimes) if len(local_peaktimes) > 0: diff_times = local_peaktimes[-1] - local_peaktimes[0] all_times = numpy.zeros((N_e, diff_times + 1), dtype=numpy.bool) padded_peaks = (local_peaktimes - local_peaktimes[0]).astype( numpy.int32) min_times = numpy.maximum(padded_peaks - safety_time, 0) max_times = numpy.minimum(padded_peaks + safety_time + 1, diff_times + 1) test_extremas = numpy.zeros((N_e, diff_times + 1), dtype=numpy.bool) for i in range(N_e): test_extremas[i, found_peaktimes[i] - local_peaktimes[0]] = True argmax_peak = numpy.random.permutation( numpy.arange(len(local_peaktimes))) all_idx = numpy.take(local_peaktimes, argmax_peak) # print "Selection of the peaks with spatio-temporal masks..." for idx, peak in zip(argmax_peak, all_idx): all_elecs = numpy.where(test_extremas[:, peak - local_peaktimes[0]])[0] data = local_chunk[peak, all_elecs] elec = all_elecs[numpy.argmax(numpy.abs(data))] indices = nodes_indices[elec] if safety_space: all_times[indices, min_times[idx]:max_times[idx]] = True else: all_times[elec, min_times[idx]:max_times[idx]] = True else: all_times = numpy.zeros((N_e, len(local_chunk)), dtype=numpy.bool) if do_temporal_whitening: local_res_temp = [] for elec in all_electrodes[numpy.arange(comm.rank, nb_temp_white, comm.size)]: res = numpy.zeros((0, N_t), dtype=numpy.float32) scount = 0 indices = nodes_indices[elec] all_times_elec = numpy.any(numpy.take(all_times, indices, axis=0), 0) esubset = numpy.where(all_times_elec == False)[0] bound = len(esubset) - N_t while (scount < bound) and (len(res) < max_silence_2): myslice = esubset[scount:scount + N_t] if numpy.all((myslice - esubset[scount]) == numpy.arange(N_t)): scount += N_t res = numpy.vstack((res, local_chunk[myslice, elec])) else: scount += 1 if len(res) > 5: local_res_temp += [numpy.cov(res.T)] nb_elecs = numpy.array([len(local_res_temp)], dtype=numpy.float32) local_res_temp = numpy.array(local_res_temp, dtype=numpy.float32) if len(local_res_temp) == 0: local_res_temp = numpy.zeros(0, dtype=numpy.float32) else: local_res_temp = numpy.sum(local_res_temp, 0) all_res_temp = gather_array(local_res_temp.ravel(), comm, 0, 1) all_elecs = gather_array(nb_elecs, comm, 0, 1) if do_spatial_whitening: local_res_spac = numpy.zeros((N_e, N_e), dtype=numpy.float32) local_silences = [] for elec in numpy.arange(comm.rank, N_e, comm.size): indices = nodes_indices[elec] all_times_elec = numpy.any(numpy.take(all_times, indices, axis=0), 0) esubset = numpy.where(all_times_elec == False)[0] local_data = local_chunk[esubset][:, indices] local_whitening = get_whitening_matrix( local_data, fudge=fudge).astype(numpy.float32) pos = numpy.where(elec == indices)[0] local_res_spac[elec, indices] = local_whitening[pos] local_silences += [len(esubset)] all_res_spac = gather_array(local_res_spac.ravel(), comm, 0, 1) all_silences = gather_array( numpy.array(local_silences, dtype=numpy.int32), comm, 0, 1, 'uint32') if comm.rank == 0: to_write = {} if do_temporal_whitening: try: nb_silences = numpy.sum(all_elecs > 0) all_res_temp = all_res_temp.reshape((nb_silences, N_t**2)) except Exception: print_and_log([ "No silent periods detected: something wrong with the parameters?" ], 'error', logger) all_res_temp = numpy.sum(all_res_temp, 0) all_res_temp = all_res_temp.reshape( (N_t, N_t)) / numpy.sum(all_elecs) temporal_whitening = get_whitening_matrix( all_res_temp.astype(numpy.double), fudge=1e-3)[template_shift].astype(numpy.float32) temporal_whitening /= temporal_whitening.sum() to_write['temporal'] = temporal_whitening have_nans = numpy.sum(numpy.isnan(temporal_whitening)) if have_nans > 0: temporal_whitening = numpy.zeros(N_t, dtype=numpy.float32) temporal_whitening[N_t // 2] = 1 to_write['temporal'] = temporal_whitening print_and_log( ["Disabling temporal whitening because of NaNs found"], 'info', logger) if do_spatial_whitening: all_res_spac = all_res_spac.reshape(comm.size, N_e, N_e) spatial_whitening = numpy.sum(all_res_spac, 0) to_write['spatial'] = spatial_whitening if ignore_spikes: print_and_log([ "Found %gs without spikes to compute the whitening matrix..." % (numpy.mean(all_silences) / params.rate) ], 'default', logger) else: print_and_log([ "Found %gs to compute the whitening matrix..." % (numpy.mean(all_silences) / params.rate) ], 'default', logger) have_nans = numpy.sum(numpy.isnan(spatial_whitening)) if have_nans > 0: spatial_whitening = numpy.eye(spatial_whitening.shape[0], dtype=numpy.float32) to_write['spatial'] = spatial_whitening print_and_log( ["Disabling spatial whitening because of NaNs found"], 'info', logger) bfile = h5py.File(file_out_suff + '.basis.hdf5', 'r+', libver='earliest') io.write_datasets(bfile, list(to_write.keys()), to_write, compression=hdf5_compress) bfile.close() comm.Barrier() if do_spatial_whitening or do_temporal_whitening: if comm.rank == 0: print_and_log( ["Because of whitening, need to recompute the thresholds..."], 'default', logger) if do_spatial_whitening: spatial_whitening = io.load_data(params, 'spatial_whitening') if use_gpu: spatial_whitening = cmt.CUDAMatrix(spatial_whitening, copy_on_host=False) if do_temporal_whitening: temporal_whitening = io.load_data(params, 'temporal_whitening') for gidx in [all_chunks[comm.rank]]: local_chunk, t_offset = data_file.get_data(gidx, chunk_size, nodes=nodes) local_shape = len(local_chunk) if do_spatial_whitening: if use_gpu: local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False) local_chunk = local_chunk.dot(spatial_whitening).asarray() else: local_chunk = numpy.dot(local_chunk, spatial_whitening) if do_temporal_whitening: local_chunk = scipy.ndimage.filters.convolve1d( local_chunk, temporal_whitening, axis=0, mode='constant') thresholds = numpy.zeros(N_e, dtype=numpy.float32) for i in range(N_e): u = numpy.median(local_chunk[:, i], 0) thresholds[i] = numpy.median(numpy.abs(local_chunk[:, i] - u), 0) gdata = gather_array(thresholds, comm) if comm.rank == 0: gdata = gdata.reshape((comm.size, N_e)) thresholds = numpy.mean(gdata, 0) bfile = h5py.File(file_out_suff + '.basis.hdf5', 'r+', libver='earliest') bfile.pop('thresholds') io.write_datasets(bfile, ['thresholds'], {'thresholds': thresholds}, compression=hdf5_compress) bfile.close() comm.Barrier() # if comm.rank == 0: # if not os.path.exists(plot_path): # os.makedirs(plot_path) # N_elec = min(int(numpy.sqrt(data_file.N_e)), 5) # plot.view_fit(filename, t_start=0, t_stop=1, fit_on=False, square=True, # n_elec=N_elec, save=[plot_path, 'electrodes']) # Part 2: Basis numpy.random.seed(422) SHARED_MEMORY = get_shared_memory_flag(params) ################################################################# file_out = params.get('data', 'file_out') alignment = params.getboolean('detection', 'alignment') over_factor = params.getint('detection', 'oversampling_factor') nb_jitter = params.getint('detection', 'nb_jitter') spike_thresh = params.getfloat('detection', 'spike_thresh') nodes, edges = get_nodes_and_edges(params) _, positions = get_nodes_and_positions(params) do_temporal_whitening = params.getboolean('whitening', 'temporal') do_spatial_whitening = params.getboolean('whitening', 'spatial') use_barycenter = params.getboolean('detection', 'use_barycenter') if matched_filter: chunk_size = detect_memory(params, whitening=True) else: chunk_size = detect_memory(params) safety_time = params.getint('whitening', 'safety_time') max_elts_elec = params.getint('whitening', 'max_elts') output_dim = params.getfloat('whitening', 'output_dim') inv_nodes = numpy.zeros(N_total, dtype=numpy.int32) inv_nodes[nodes] = numpy.arange(len(nodes)) smoothing_factor = params.getfloat('detection', 'smoothing_factor') if sign_peaks == 'both': max_elts_elec *= 2 nb_elts = int( params.getfloat('whitening', 'nb_elts') * N_e * max_elts_elec) weird_thresh = params.get('detection', 'weird_thresh') if weird_thresh != '': ignore_artefacts = True weird_thresh = io.load_data(params, 'weird-thresholds') else: ignore_artefacts = False ignore_dead_times = params.getboolean('triggers', 'ignore_times') if ignore_dead_times: if SHARED_MEMORY: all_dead_times, mpi_memory_3 = get_dead_times(params) else: all_dead_times = get_dead_times(params) data_file.open() ################################################################# if comm.rank == 0: print_and_log(["Searching spikes to construct the PCA basis..."], 'default', logger) nb_chunks, last_chunk_len = data_file.analyze(chunk_size) if nb_chunks < comm.size: res = io.data_stats(params, show=False) chunk_size = int(res * params.rate // comm.size) if comm.rank == 0: print_and_log( ["Too much cores, automatically resizing the data chunks"], 'debug', logger) nb_chunks, last_chunk_len = data_file.analyze(chunk_size) groups = {} for i in range(N_e): groups[i] = 0 # I guess this is more relevant, to take signals from all over the recordings all_chunks = numpy.random.permutation( numpy.arange(nb_chunks, dtype=numpy.int32)) max_elts_elec //= comm.size nb_elts //= comm.size elt_count_pos = 0 elt_count_neg = 0 if sign_peaks in ['positive', 'both']: times_pos = numpy.zeros(nb_elts, dtype=numpy.int32) electrodes_pos = numpy.zeros(nb_elts, dtype=numpy.int32) extremum_pos = numpy.zeros(nb_elts, dtype=numpy.float32) elts_pos = numpy.zeros((N_t, nb_elts), dtype=numpy.float32) if sign_peaks in ['negative', 'both']: times_neg = numpy.zeros(nb_elts, dtype=numpy.int32) electrodes_neg = numpy.zeros(nb_elts, dtype=numpy.int32) extremum_neg = numpy.zeros(nb_elts, dtype=numpy.float32) elts_neg = numpy.zeros((N_t, nb_elts), dtype=numpy.float32) thresholds = io.load_data(params, 'thresholds') mads = io.load_data(params, 'mads') stds = io.load_data(params, 'stds') if alignment: cdata = numpy.linspace(-jitter_range, +jitter_range, nb_jitter) xdata = numpy.arange(-template_shift_2, template_shift_2 + 1) xoff = len(cdata) / 2.0 snippet_duration = template_shift_2 m_size = 2 * template_shift_2 + 1 align_factor = m_size local_factors = align_factor * ((smoothing_factor * mads)**2) else: snippet_duration = template_shift xdata = numpy.arange(-template_shift, template_shift + 1) if rejection_threshold > 0: reject_noise = True noise_levels = stds * (2 * noise_window + 1) else: reject_noise = False to_explore = all_chunks[comm.rank::comm.size] upper_bounds = max_elts_elec if comm.rank == 0: to_explore = get_tqdm_progressbar(params, to_explore) for gcount, gidx in enumerate(to_explore): if (elt_count_pos + elt_count_neg) < nb_elts: # print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..." local_chunk, t_offset = data_file.get_data(gidx, chunk_size, nodes=nodes) local_shape = len(local_chunk) if do_spatial_whitening: if use_gpu: local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False) local_chunk = local_chunk.dot(spatial_whitening).asarray() else: local_chunk = numpy.dot(local_chunk, spatial_whitening) if do_temporal_whitening: local_chunk = scipy.ndimage.filters.convolve1d( local_chunk, temporal_whitening, axis=0, mode='constant') local_borders = (snippet_duration, local_shape - snippet_duration) if ignore_dead_times: dead_indices = numpy.searchsorted( all_dead_times, [t_offset, t_offset + local_shape]) # Extracting the peaks. all_peaktimes = [numpy.empty(0, dtype=numpy.uint32)] found_peaktimes = [] found_peak_amplitudes = [] for i in range(N_e): height = thresholds[i] if sign_peaks == 'negative': peaktimes = scipy.signal.find_peaks(-local_chunk[:, i], height=height, distance=dist_peaks)[0] elif sign_peaks == 'positive': peaktimes = scipy.signal.find_peaks(local_chunk[:, i], height=height, distance=dist_peaks)[0] elif sign_peaks == 'both': peaktimes = scipy.signal.find_peaks(numpy.abs( local_chunk[:, i]), height=height, distance=dist_peaks)[0] else: peaktimes = numpy.empty(0, dtype=numpy.uint32) if ignore_artefacts: artetimes = scipy.signal.find_peaks( numpy.abs(local_chunk[:, i]), height=weird_thresh[i])[0] to_keep = numpy.logical_not( numpy.in1d(peaktimes, artetimes)) peaktimes = peaktimes[to_keep] idx = (peaktimes >= local_borders[0]) & (peaktimes < local_borders[1]) peaktimes = peaktimes[idx] if ignore_dead_times: if dead_indices[0] != dead_indices[1]: is_included = numpy.in1d( peaktimes + t_offset, all_dead_times[dead_indices[0]:dead_indices[1]]) peaktimes = peaktimes[~is_included] peaktimes = peaktimes.astype(numpy.uint32) found_peaktimes.append(peaktimes) peak_amplitudes = local_chunk[peaktimes, i] found_peak_amplitudes.append(peak_amplitudes) all_peaktimes = numpy.concatenate( found_peaktimes) # i.e. concatenate once for efficiency all_peak_amplitudes = numpy.concatenate(found_peak_amplitudes) local_peaktimes, local_indices = numpy.unique(all_peaktimes, return_inverse=True) if len(local_peaktimes) > 0: diff_times = (local_peaktimes[-1] - local_peaktimes[0]) + 1 all_times = numpy.zeros((N_e, diff_times), dtype=numpy.bool) padded_peaks = (local_peaktimes - local_peaktimes[0]).astype( numpy.int32) min_times = numpy.maximum(padded_peaks - safety_time, 0) max_times = numpy.minimum(padded_peaks + safety_time + 1, diff_times + 1) test_extremas = numpy.zeros((N_e, diff_times + 1), dtype=numpy.bool) for i in range(N_e): test_extremas[i, found_peaktimes[i] - local_peaktimes[0]] = True # Consider the peaks by decreasing extremum. if sort_waveforms: order = numpy.argsort(-np.abs(all_peak_amplitudes)) all_idx = numpy.take(all_peaktimes, order) argmax_peak = local_indices[order] else: n_times = len(all_peaktimes) shuffling = numpy.random.permutation(numpy.arange(n_times)) all_idx = numpy.take(all_peaktimes, shuffling) argmax_peak = local_indices[shuffling] # print "Selection of the peaks with spatio-temporal masks..." for midx, peak in zip(argmax_peak, all_idx): if (elt_count_neg + elt_count_pos) == nb_elts: break all_elecs = numpy.where( test_extremas[:, peak - local_peaktimes[0]])[0] data = local_chunk[peak, all_elecs] #target_area = test_extremas[:, min_times[midx]:max_times[midx]].sum(1) #all_elecs = numpy.where(target_area)[0] #data = local_chunk[peak, all_elecs] if sign_peaks == 'negative': if N_e > 1: if use_barycenter: weighed_position = data[:, numpy. newaxis] * positions[ all_elecs] barycenter = weighed_position.sum( 0) / data.sum() elec = numpy.argmin( numpy.linalg.norm(barycenter - positions[all_elecs], axis=1)) else: elec = numpy.argmin(data) else: elec = 0 negative_peak = True elif sign_peaks == 'positive': if N_e > 1: if use_barycenter: weighed_position = data[:, numpy. newaxis] * positions[ all_elecs] barycenter = weighed_position.sum( 0) / data.sum() elec = numpy.argmax( numpy.linalg.norm(barycenter - positions[all_elecs], axis=1)) else: elec = numpy.argmax(data) else: elec = 0 negative_peak = False elif sign_peaks == 'both': if N_e == 1: if data < 0: negative_peak = True elif data > 0: negative_peak = False elec = 0 else: if numpy.abs(numpy.max(data)) > numpy.abs( numpy.min(data)): elec = numpy.argmax(data) negative_peak = False else: elec = numpy.argmin(data) negative_peak = True elec = all_elecs[elec] if groups[elec] < upper_bounds: indices = nodes_indices[elec] myslice = all_times[indices, min_times[midx]:max_times[midx]] if not myslice.any(): sub_mat = local_chunk[peak - snippet_duration:peak + snippet_duration + 1, elec] if reject_noise: slice_window = sub_mat[ snippet_duration - noise_window:snippet_duration + noise_window + 1] value = numpy.linalg.norm( slice_window) / noise_levels[elec] is_noise = value < rejection_threshold else: is_noise = False if not is_noise: extrema = sub_mat[snippet_duration] if alignment: smoothed = True try: f = scipy.interpolate.UnivariateSpline( xdata, sub_mat, s=local_factors[elec], k=3) except Exception: smoothed = False f = scipy.interpolate.UnivariateSpline( xdata, sub_mat, k=3, s=0) if negative_peak: rmin = (numpy.argmin(f(cdata)) - xoff) / over_factor else: rmin = (numpy.argmax(f(cdata)) - xoff) / over_factor ddata = numpy.linspace( rmin - template_shift, rmin + template_shift, N_t) if smoothed: f = scipy.interpolate.UnivariateSpline( xdata, sub_mat, s=local_factors[elec], k=3) else: f = scipy.interpolate.UnivariateSpline( xdata, sub_mat, s=0, k=3) sub_mat = f(ddata).astype(numpy.float32) if negative_peak: times_neg[elt_count_neg] = peak + t_offset electrodes_neg[elt_count_neg] = elec extremum_neg[elt_count_neg] = extrema elts_neg[:, elt_count_neg] = sub_mat elt_count_neg += 1 else: times_pos[elt_count_pos] = peak + t_offset electrodes_pos[elt_count_pos] = elec extremum_pos[elt_count_pos] = extrema elts_pos[:, elt_count_pos] = sub_mat elt_count_pos += 1 groups[elec] += 1 all_times[ indices, min_times[midx]:max_times[midx]] = True test_extremas[elec, peak - local_peaktimes[0]] = False sys.stderr.flush() print_and_log([ "Node %d has collected %d waveforms" % (comm.rank, elt_count_pos + elt_count_neg) ], 'debug', logger) if sign_peaks in ['negative', 'both']: times_neg = gather_array(times_neg[:elt_count_neg], comm, 0, 1, dtype='int32') electrodes_neg = gather_array(electrodes_neg[:elt_count_neg], comm, 0, 1, dtype='int32') extremum_neg = gather_array(extremum_neg[:elt_count_neg], comm, 0, 1) gdata_neg = gather_array(elts_neg[:, :elt_count_neg].T, comm, 0, 1) if sign_peaks in ['positive', 'both']: times_pos = gather_array(times_pos[:elt_count_pos], comm, 0, 1, dtype='int32') electrodes_pos = gather_array(electrodes_pos[:elt_count_pos], comm, 0, 1, dtype='int32') extremum_pos = gather_array(extremum_pos[:elt_count_pos], comm, 0, 1) gdata_pos = gather_array(elts_pos[:, :elt_count_pos].T, comm, 0, 1) nb_waveforms = 0 if comm.rank == 0: # DO PCA on elts and store the basis obtained. if sign_peaks in ['negative', 'both']: nb_waveforms += gdata_neg.shape[0] if sign_peaks in ['positive', 'both']: nb_waveforms += gdata_pos.shape[0] nb_waveforms = all_gather_array( numpy.array([nb_waveforms], dtype=numpy.float32), comm, 0)[0] if comm.rank == 0: print_and_log([ "Found %d waveforms over %d requested" % (nb_waveforms, int(nb_elts * comm.size)) ], 'default', logger) if nb_waveforms == 0: print_and_log( ['No waveforms found! Are the data properly loaded??'], 'error', logger) if nb_waveforms == 0: sys.exit(0) if comm.rank == 0: res = {} pca = None pca_pos = None pca_neg = None warning_n_t = False if sign_peaks in ['negative', 'both']: res['times'] = times_neg res['electrodes'] = electrodes_neg res['extremum'] = extremum_neg if len(gdata_neg) > 0: pca = PCA(output_dim) if use_hanning: pca.fit(gdata_neg * hanning_filter) else: pca.fit(gdata_neg) res['proj'] = pca.components_.T.astype(numpy.float32) pca_neg = numpy.sum(pca.explained_variance_ratio_) else: res['proj'] = numpy.identity(int(output_dim), dtype=numpy.float32) res['rec'] = res['proj'].T res['waveform'] = numpy.median(gdata_neg, 0) # dispersion = numpy.std(gdata_neg, 0) / numpy.median(stds) # ratio = numpy.sum(dispersion > 1.1) / float(len(dispersion)) # if ratio < 0.25: # print_and_log(["Time window N_t in [detection] seems too large!"], 'info', logger) # warning_n_t = True # elif ratio == 1: # print_and_log(["Time window N_t in [detection] seems too small!"], 'info', logger) # warning_n_t = True idx = numpy.random.permutation(numpy.arange( gdata_neg.shape[0]))[:2500] res['waveforms'] = gdata_neg[idx, :] if sign_peaks in ['positive', 'both']: res['times_pos'] = times_pos res['electrodes_pos'] = electrodes_pos res['extremum_pos'] = extremum_pos if len(gdata_pos) > 0: pca = PCA(output_dim) if use_hanning: pca.fit(gdata_pos * hanning_filter) else: pca.fit(gdata_pos) res['proj_pos'] = pca.components_.T.astype(numpy.float32) pca_pos = numpy.sum(pca.explained_variance_ratio_) else: res['proj_pos'] = numpy.identity(int(output_dim), dtype=numpy.float32) res['rec_pos'] = res['proj_pos'].T res['waveform_pos'] = numpy.median(gdata_pos, 0) # dispersion = numpy.std(gdata_pos, 0) / numpy.median(stds) # ratio = numpy.sum(dispersion > 1.1) / float(len(dispersion)) # if ratio < 0.25 and not warning_n_t: # print_and_log(["Time window N_t in [detection] seems too large!"], 'info', logger) # elif ratio == 1 and not warning_n_t: # print_and_log(["Time window N_t in [detection] seems too small!"], 'info', logger) idx = numpy.random.permutation(numpy.arange( gdata_pos.shape[0]))[:2500] res['waveforms_pos'] = gdata_pos[idx, :] bfile = h5py.File(file_out_suff + '.basis.hdf5', 'r+', libver='earliest') io.write_datasets(bfile, list(res.keys()), res, compression=hdf5_compress) if sign_peaks == 'positive': print_and_log([ "A basis with %s dimensions has been built" % res['proj_pos'].shape[1] ], 'info', logger) elif sign_peaks == 'negative': print_and_log([ "A basis with %s dimensions has been built" % res['proj'].shape[1] ], 'info', logger) elif sign_peaks == 'both': print_and_log([ "Two basis with %s dimensions has been built" % res['proj'].shape[1] ], 'debug', logger) if pca_pos is not None: print_and_log([ "The percentage of variance explained is %s for positive spikes" % pca_pos ], 'debug', logger) if pca_neg is not None: print_and_log([ "The percentage of variance explained is %s for negative spikes" % pca_neg ], 'debug', logger) bfile.close() comm.Barrier() if matched_filter: if comm.rank == 0: print_and_log([ "Because of matched filters, need to recompute the thresholds..." ], 'default', logger) if do_spatial_whitening: spatial_whitening = io.load_data(params, 'spatial_whitening') if use_gpu: spatial_whitening = cmt.CUDAMatrix(spatial_whitening, copy_on_host=False) if do_temporal_whitening: temporal_whitening = io.load_data(params, 'temporal_whitening') if sign_peaks in ['negative', 'both']: waveform_neg = io.load_data(params, 'waveform')[::-1] waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) * len(waveform_neg)) if sign_peaks in ['positive', 'both']: waveform_pos = io.load_data(params, 'waveform-pos')[::-1] waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) * len(waveform_pos)) for gidx in [all_chunks[comm.rank]]: local_chunk, t_offset = data_file.get_data(gidx, chunk_size, nodes=nodes) local_shape = len(local_chunk) if do_spatial_whitening: if use_gpu: local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False) local_chunk = local_chunk.dot(spatial_whitening).asarray() else: local_chunk = numpy.dot(local_chunk, spatial_whitening) if do_temporal_whitening: local_chunk = scipy.ndimage.filters.convolve1d( local_chunk, temporal_whitening, axis=0, mode='constant') local_chunk /= thresholds if sign_peaks in ['negative', 'both']: tmp_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_neg, axis=0, mode='constant') thresholds = numpy.zeros(N_e, dtype=numpy.float32) for i in range(N_e): u = numpy.median(tmp_chunk[:, i], 0) thresholds[i] = numpy.median( numpy.abs(tmp_chunk[:, i] - u), 0) gdata = gather_array(thresholds, comm) if comm.rank == 0: gdata = gdata.reshape((comm.size, N_e)) thresholds = numpy.mean(gdata, 0) bfile = h5py.File(file_out_suff + '.basis.hdf5', 'r+', libver='earliest') io.write_datasets(bfile, ['matched_thresholds'], {'matched_thresholds': thresholds}, compression=hdf5_compress) bfile.close() comm.Barrier() if sign_peaks in ['positive', 'both']: tmp_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_pos, axis=0, mode='constant') thresholds = numpy.zeros(N_e, dtype=numpy.float32) for i in range(N_e): u = numpy.median(tmp_chunk[:, i], 0) thresholds[i] = numpy.median( numpy.abs(tmp_chunk[:, i] - u), 0) gdata = gather_array(thresholds, comm) if comm.rank == 0: gdata = gdata.reshape((comm.size, N_e)) thresholds = numpy.mean(gdata, 0) bfile = h5py.File(file_out_suff + '.basis.hdf5', 'r+', libver='earliest') io.write_datasets(bfile, ['matched_thresholds_pos'], {'matched_thresholds_pos': thresholds}, compression=hdf5_compress) bfile.close() comm.Barrier() data_file.close() if SHARED_MEMORY and ignore_dead_times: mpi_memory_3.Free()
def main(params, nb_cpu, nb_gpu, use_gpu): logger = init_logging(params.logfile) logger = logging.getLogger('circus.filtering') ################################################################# do_filter = params.getboolean('filtering', 'filter') filter_done = params.getboolean('noedits', 'filter_done') artefacts_done = params.getboolean('noedits', 'artefacts_done') median_done = params.getboolean('noedits', 'median_done') clean_artefact = params.getboolean('triggers', 'clean_artefact') remove_median = params.getboolean('filtering', 'remove_median') nodes, edges = get_nodes_and_edges(params) ################################################################# def filter_file(data_file_in, data_file_out, do_filtering, do_remove_median): try: cut_off = params.getfloat('filtering', 'cut_off') cut_off = [cut_off, 0.95 * (params.rate / 2.)] except Exception: cut_off = params.get('filtering', 'cut_off') cut_off = cut_off.split(',') try: cut_off[0] = float(cut_off[0]) except Exception: if comm.rank == 0: print_and_log( ['First value of cut off must be a valid number'], 'error', logger) sys.exit(1) cut_off[1] = cut_off[1].replace(' ', '') if cut_off[1] == 'auto': cut_off[1] = 0.95 * (params.rate / 2.) else: try: cut_off[1] = float(cut_off[1]) except Exception: if comm.rank == 0: print_and_log([ 'Second value of cut off must either auto, or a valid a number' ], 'error', logger) sys.exit(1) chunk_size = params.getint('data', 'chunk_size') nb_chunks, _ = data_file_in.analyze(chunk_size) do_butter = False # mmyros do_lowess = False # otherwise wavelet if both lowess and butter are false if do_butter: b, a = signal.butter(3, np.array(cut_off) / (params.rate / 2.), 'pass') all_chunks = numpy.arange(nb_chunks, dtype=numpy.int64) to_process = all_chunks[comm.rank::comm.size] loc_nb_chunks = len(to_process) N_total = params.nb_channels if comm.rank == 0: to_write = [] if do_filtering: to_write += [ "Filtering the signal with a Butterworth filter in (%g, %g) Hz" % (cut_off[0], cut_off[1]) ] if do_remove_median: to_write += [ "Median over all channels is substracted to each channels" ] print_and_log(to_write, 'default', logger) to_explore = xrange(comm.rank, nb_chunks, comm.size) if comm.rank == 0: to_explore = get_tqdm_progressbar(to_explore) for count, gidx in enumerate(to_explore): local_chunk, t_offset = data_file_in.get_data(gidx, chunk_size) if do_filtering: for i in nodes: #try: if do_butter: local_chunk[:, i] = signal.filtfilt( b, a, local_chunk[:, i]) elif do_lowess: local_chunk[:, i] = lowess(local_chunk[:, i]) else: local_chunk[:, i] = WMLDR(local_chunk[:, i]) local_chunk[:, i] -= numpy.median(local_chunk[:, i]) if do_remove_median: if not numpy.all(nodes == numpy.arange(N_total)): global_median = numpy.median( numpy.take(local_chunk, nodes, axis=1), 1) else: global_median = numpy.median(local_chunk, 1) for i in nodes: local_chunk[:, i] -= global_median if data_file_in != data_file_out and data_file_in.is_first_chunk( gidx, nb_chunks): if data_file_in.is_stream: g_offset = t_offset - numpy.sum( data_file_in. _times[:data_file_in. _get_streams_index_by_time(t_offset) + 1]) else: g_offset = t_offset - data_file_in.t_start else: g_offset = t_offset data_file_out.set_data(g_offset, local_chunk) comm.Barrier() def lowess(data, frac=0.01): ''' Local regression (Lowess) filter for spike extraction dependency: pip install cylowess ''' #import statsmodels.api as sm import cylowess data = np.atleast_2d(data) nt = data.shape[1] # filter data in place, iterate over channels in rows: nchans = len(data) for chani in range(nchans): # lowess using vanilla statsmodels #fit=sm.nonparametric.lowess(data[chani],range(len(data[chani])))[:,1] # cython implementation delta = 67.61 #deltafrac*len(data[chani]) fit = cylowess.lowess(np.asarray(data[chani], dtype='float'), np.asarray(range(len(data[chani])), dtype='float'), frac=frac, it=0, delta=delta)[:, 1] data[chani] = data[chani] - fit return data def WMLDR(data, wname="db4", maxlevel=6, mode='sym'): """ Function by Martin Spacek from https://github.com/spyke/spyke Perform wavelet multi-level decomposition and reconstruction (WMLDR) on multichannel data. See Wiltschko2008. Default to Daubechies(4) wavelet. Modifies data in-place, at least for now. The effective cutoff frequency is: fc = (sampfreq / 2) / 2**maxlevel (Wiltschko2008) For sampfreq of 25 kHz and maxlevel of 6, the effective cutoff frequency is 195 Hz. For sampfreq of 30 kHz and maxlevel of 6, the effective cutoff frequency is 234 Hz. TODO: for now, this only returns highpass data. In the future, this probably should return both low and highpass data (and not modify it in-place). The Discussion in Wiltschko2008 suggests that this approach cannot be used to extract the LFP, but I don't see why you can't simply subtract the highpass data from the raw data to get the lowpass data. Signal extension modes (from PyWavelets docs): PyWavelets provides several methods of signal extrapolation that can be used to minimize edge effects. PyWavelet's default is 'sym': zpd - zero-padding - signal is extended by adding zero samples: ... 0 0 | x1 x2 ... xn | 0 0 ... cpd - constant-padding - border values are replicated: ... x1 x1 | x1 x2 ... xn | xn xn ... sym - symmetric-padding - signal is extended by mirroring samples: ... x2 x1 | x1 x2 ... xn | xn xn-1 ... ppd - periodic-padding - signal is treated as a periodic one: ... xn-1 xn | x1 x2 ... xn | x1 x2 ... sp1 - smooth-padding - signal is extended according to the first derivatives calculated on the edges (straight line) DWT performed for these extension modes is slightly redundant, but ensures perfect reconstruction. To receive the smallest possible number of coefficients, computations can be performed with the periodization mode: per - periodization - is like periodic-padding but gives the smallest possible number of decomposition coefficients. IDWT must be performed with the same mode. """ import pywt data = np.atleast_2d(data) nt = data.shape[1] # reconstructed signals always seem to have an even number of points. If the number of # input data points is odd, trim last data point from reconstructed signal: isodd = nt % 2 # filter data in place, iterate over channels in rows: nchans = len(data) for chani in range(nchans): # decompose the signal: cs = pywt.wavedec(data[chani], wname, mode=mode, level=maxlevel) # destroy the appropriate approximation coefficients to get highpass data: cs[0] = None # reconstruct the signal: recsignal = pywt.waverec(cs, wname, mode=mode) ntrec = len(recsignal) data[chani] = recsignal[:ntrec - isodd] return data def compute_artefacts(data_file): chunk_size = params.getint('data', 'chunk_size') trig_in_ms = params.getboolean('triggers', 'trig_in_ms') artefacts = numpy.loadtxt(params.get('triggers', 'trig_file')) windows = numpy.loadtxt(params.get('triggers', 'trig_windows')) make_plots = params.get('triggers', 'make_plots') plot_path = os.path.join(params.get('data', 'data_file_noext'), 'plots') if len(windows.shape) == 1: windows = windows.reshape(1, 2) if len(artefacts.shape) == 1: artefacts = artefacts.reshape(1, 2) if trig_in_ms: if comm.rank == 0: print_and_log(['Artefact times are read in ms'], 'debug', logger) artefacts[:, 1] *= numpy.int64(data_file.sampling_rate * 1e-3) windows[:, 1] *= numpy.int64(data_file.sampling_rate * 1e-3) else: if comm.rank == 0: print_and_log(['Artefact times are read in timesteps'], 'debug', logger) artefacts = artefacts.astype(numpy.int64) windows = windows.astype(numpy.int64) nb_stimuli = len(numpy.unique(artefacts[:, 0])) mytest = nb_stimuli == len(windows) if not mytest: if comm.rank == 0: print_and_log(['Error in the trigger files'], 'error', logger) sys.exit(1) all_labels = artefacts[:, 0].astype(numpy.int32) all_times = artefacts[:, 1].astype(numpy.int32) mask = (all_times >= 0) & (all_times + numpy.max(windows[:, 1]) < data_file.t_stop) all_times = numpy.compress(mask, all_times) all_labels = numpy.compress(mask, all_labels) local_labels = numpy.unique(all_labels)[comm.rank::comm.size] if comm.rank == 0: to_write = [ "Computing averaged artefacts from %d stimuli" % (nb_stimuli) ] print_and_log(to_write, 'default', logger) if not os.path.exists(plot_path): os.makedirs(plot_path) local_labels = get_tqdm_progressbar(local_labels) comm.Barrier() # First we need to get the average artefacts art_dict = {} for count, artefact in enumerate(local_labels): indices = numpy.where(all_labels == artefact)[0].astype( numpy.int32) tmp = numpy.where(windows[:, 0] == artefact)[0] tau = numpy.int64(windows[tmp, 1]) pspikes = all_times[indices] times = numpy.sort(numpy.random.permutation(pspikes)[:500]) if len(numpy.where(numpy.diff(times) < tau)[0]) > 0: if comm.rank == 0: print_and_log([ 'Stimulation times for artefact %d are too close!' % artefact ], 'error', logger) sys.exit(1) art_dict[artefact] = get_artefact(params, times, tau, nodes) if make_plots not in ['None', '']: save = [plot_path, '%d.%s' % (artefact, make_plots)] plot.view_artefact(art_dict[artefact], save=save) return art_dict def remove_artefacts(data_file, art_dict): chunk_size = params.getint('data', 'chunk_size') trig_in_ms = params.getboolean('triggers', 'trig_in_ms') artefacts = numpy.loadtxt(params.get('triggers', 'trig_file')).astype(numpy.int64) windows = numpy.loadtxt(params.get('triggers', 'trig_windows')).astype(numpy.int64) make_plots = params.get('triggers', 'make_plots') plot_path = os.path.join(params.get('data', 'data_file_noext'), 'plots') if len(windows.shape) == 1: windows = windows.reshape(1, 2) if len(artefacts.shape) == 1: artefacts = artefacts.reshape(1, 2) if trig_in_ms: if comm.rank == 0: print_and_log(['Artefact times are read in ms'], 'debug', logger) artefacts[:, 1] *= numpy.int64(data_file.sampling_rate * 1e-3) windows[:, 1] *= numpy.int64(data_file.sampling_rate * 1e-3) else: if comm.rank == 0: print_and_log(['Artefact times are read in timesteps'], 'debug', logger) artefacts = artefacts.astype(numpy.int64) windows = windows.astype(numpy.int64) nb_stimuli = len(numpy.unique(artefacts[:, 0])) mytest = nb_stimuli == len(windows) if not mytest: if comm.rank == 0: print_and_log(['Error in the trigger files'], 'error', logger) sys.exit(1) all_labels = artefacts[:, 0].astype(numpy.int32) all_times = artefacts[:, 1].astype(numpy.int32) local_labels = numpy.unique(all_labels)[comm.rank::comm.size] mask = numpy.in1d(all_labels, local_labels) all_times = numpy.compress(mask, all_times) all_labels = numpy.compress(mask, all_labels) mask = (all_times >= 0) & (all_times < data_file.t_stop) all_times = numpy.compress(mask, all_times) all_labels = numpy.compress(mask, all_labels) if comm.rank == 0: to_write = ["Removing artefacts from %d stimuli" % (nb_stimuli)] print_and_log(to_write, 'default', logger) all_times = get_tqdm_progressbar(all_times) comm.Barrier() for count, time in enumerate(all_times): label = all_labels[count] tmp = numpy.where(windows[:, 0] == label)[0][0] tau = numpy.int64(windows[tmp, 1]) if (data_file.t_stop - time) < tau: tau = max_offset - time local_chunk = data_file.get_snippet(time, tau) for idx, i in enumerate(nodes): local_chunk[:, i] -= art_dict[label][idx, :tau] data_file.set_data(time, local_chunk) comm.Barrier() if comm.rank == 0: print_and_log(['Initializing the filtering step...'], 'debug', logger) if params.getboolean('data', 'overwrite'): if comm.rank == 0: print_and_log(['Reading the input file...'], 'debug', logger) data_file_in = params.get_data_file() data_file_out = data_file_in else: if comm.rank == 0: print_and_log( ['Overwrite is set to False, so creating a new datafile...'], 'debug', logger) if comm.rank == 0: print_and_log(['Reading the input file...'], 'debug', logger) data_file_in = params.get_data_file(source=True, has_been_created=False) if comm.rank == 0: print_and_log( ['Reading the output file and allocating ressources...'], 'debug', logger) description = data_file_in.get_description() description['data_dtype'] = 'float32' description['dtype_offset'] = 0 description['data_offset'] = 0 data_file_out = params.get_data_file(is_empty=True, params=description) data_file_out.allocate(shape=data_file_in.shape) #<<<<<<< HEAD data_file_in._params = tmp_params if data_file_in.is_stream: for source in data_file_in._sources: source._params = tmp_params #======= #>>>>>>> upstream/master if clean_artefact: if not (os.path.exists(params.get('triggers', 'trig_file')) and os.path.exists(params.get('triggers', 'trig_windows'))): if comm.rank == 0: print_and_log( ['trig_file or trig_windows file can not be found'], 'error', logger) sys.exit(1) to_write = [] if do_filter and filter_done: do_filter = False to_write += ["Filtering has already been done"] if remove_median and median_done: remove_median = False to_write += [ "Median over all channels has already been substracted to each channels" ] if comm.rank == 0: print_and_log(to_write, 'debug', logger) if params.getboolean('data', 'overwrite'): data_file_in.open(mode='r+') else: data_file_in.open() data_file_out.open(mode='r+') if do_filter or remove_median: filter_file(data_file_in, data_file_out, do_filter, remove_median) if comm.rank == 0: if do_filter: params.write('noedits', 'filter_done', 'True') if remove_median: params.write('noedits', 'median_done', 'True') if clean_artefact and artefacts_done: clean_artefact = False if comm.rank == 0: print_and_log(['Artefacts have already been removed'], 'debug', logger) if clean_artefact: art_dict = compute_artefacts(data_file_in) remove_artefacts(data_file_out, art_dict) if comm.rank == 0: if clean_artefact: params.write('noedits', 'artefacts_done', 'True') data_file_in.close() if not params.getboolean('data', 'overwrite'): data_file_out.close() comm.Barrier()
def main(params, nb_cpu, nb_gpu, use_gpu): # Part 1: Whitening numpy.random.seed(420) #params = detect_memory(params) logger = init_logging(params.logfile) logger = logging.getLogger('circus.whitening') ################################################################# data_file = params.data_file N_e = params.getint('data', 'N_e') hdf5_compress = params.getboolean('data', 'hdf5_compress') N_total = params.nb_channels N_t = params.getint('detection', 'N_t') dist_peaks = params.getint('detection', 'dist_peaks') template_shift = params.getint('detection', 'template_shift') file_out_suff = params.get('data', 'file_out_suff') spike_thresh = params.getfloat('detection', 'spike_thresh') spike_width = params.getfloat('detection', 'spike_width') matched_filter = params.getboolean('detection', 'matched-filter') matched_thresh = params.getfloat('detection', 'matched_thresh') sign_peaks = params.get('detection', 'peaks') do_temporal_whitening = params.getboolean('whitening', 'temporal') do_spatial_whitening = params.getboolean('whitening', 'spatial') chunk_size = detect_memory(params, whitening=True) plot_path = os.path.join(params.get('data', 'file_out_suff'), 'plots') nodes, edges = get_nodes_and_edges(params) safety_time = params.getint('whitening', 'safety_time') safety_space = params.getboolean('whitening', 'safety_space') nb_temp_white = min(max(20, comm.size), N_e) max_silence_1 = int(20 * params.rate // comm.size) max_silence_2 = 5000 inv_nodes = numpy.zeros(N_total, dtype=numpy.int32) inv_nodes[nodes] = numpy.arange(len(nodes)) jitter_range = params.getint('detection', 'jitter_range') template_shift_2 = template_shift + jitter_range use_hanning = params.getboolean('detection', 'hanning') data_file.open() ################################################################# if use_hanning: hanning_filter = numpy.hanning(N_t) if comm.rank == 0: print_and_log( ["Analyzing data to get whitening matrices and thresholds..."], 'default', logger) if use_gpu: import cudamat as cmt ## Need to properly handle multi GPU per MPI nodes? if nb_gpu > nb_cpu: gpu_id = int(comm.rank // nb_cpu) else: gpu_id = 0 cmt.cuda_set_device(gpu_id) cmt.init() cmt.cuda_sync_threads() nb_chunks, last_chunk_len = data_file.analyze(chunk_size) if nb_chunks < comm.size: res = io.data_stats(params, show=False) chunk_size = int(res * params.rate // comm.size) if comm.rank == 0: print_and_log( ["Too much cores, automatically resizing the data chunks"], 'debug', logger) nb_chunks, last_chunk_len = data_file.analyze(chunk_size) # I guess this is more relevant, to take signals from all over the recordings all_chunks = numpy.random.permutation( numpy.arange(nb_chunks, dtype=numpy.int32)) all_electrodes = numpy.random.permutation(N_e) for gidx in [all_chunks[comm.rank]]: #print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..." local_chunk, t_offset = data_file.get_data(gidx, chunk_size, nodes=nodes) local_shape = len(local_chunk) #print "Node", comm.rank, "computes the median absolute deviations in a random chunk" thresholds = numpy.zeros(N_e, dtype=numpy.float32) for i in xrange(N_e): u = numpy.median(local_chunk[:, i], 0) thresholds[i] = numpy.median(numpy.abs(local_chunk[:, i] - u), 0) gdata = gather_array(thresholds, comm) if comm.rank == 0: gdata = gdata.reshape((comm.size, N_e)) thresholds = numpy.mean(gdata, 0) bfile = h5py.File(file_out_suff + '.basis.hdf5', 'w', libver='earliest') io.write_datasets(bfile, ['thresholds'], {'thresholds': thresholds}, compression=hdf5_compress) bfile.close() comm.Barrier() thresholds = io.load_data(params, 'thresholds') #print "Extracting the peaks..." local_peaktimes = numpy.zeros(0, dtype=numpy.uint32) for i in xrange(N_e): peaktimes = scipy.signal.find_peaks(numpy.abs(local_chunk[:, i]), height=thresholds[i], width=spike_width, wlen=N_t)[0] local_peaktimes = numpy.concatenate((local_peaktimes, peaktimes)) local_peaktimes = numpy.unique(local_peaktimes) #print "Removing the useless borders..." local_borders = (template_shift, local_shape - template_shift) idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes < local_borders[1]) local_peaktimes = numpy.compress(idx, local_peaktimes) if len(local_peaktimes) > 0: diff_times = local_peaktimes[-1] - local_peaktimes[0] all_times = numpy.zeros((N_e, diff_times + 1), dtype=numpy.bool) min_times = numpy.maximum( local_peaktimes - local_peaktimes[0] - safety_time, 0) max_times = numpy.minimum( local_peaktimes - local_peaktimes[0] + safety_time + 1, diff_times) argmax_peak = numpy.random.permutation( numpy.arange(len(local_peaktimes))) all_idx = numpy.take(local_peaktimes, argmax_peak) #print "Selection of the peaks with spatio-temporal masks..." for idx, peak in zip(argmax_peak, all_idx): elec = numpy.argmax(numpy.abs(local_chunk[peak])) indices = numpy.take(inv_nodes, edges[nodes[elec]]) if safety_space: all_times[indices, min_times[idx]:max_times[idx]] = True else: all_times[elec, min_times[idx]:max_times[idx]] = True else: all_times = numpy.zeros((N_e, len(local_chunk)), dtype=numpy.bool) if do_temporal_whitening: local_res_temp = [] for elec in all_electrodes[numpy.arange(comm.rank, nb_temp_white, comm.size)]: res = numpy.zeros((0, N_t), dtype=numpy.float32) scount = 0 indices = numpy.take(inv_nodes, edges[nodes[elec]]) all_times_elec = numpy.any(numpy.take(all_times, indices, axis=0), 0) esubset = numpy.where(all_times_elec == False)[0] bound = len(esubset) - N_t while (scount < bound) and (len(res) < max_silence_2): myslice = esubset[scount:scount + N_t] if numpy.all((myslice - esubset[scount]) == numpy.arange(N_t)): scount += N_t res = numpy.vstack((res, local_chunk[myslice, elec])) else: scount += 1 if len(res) > 5: local_res_temp += [numpy.cov(res.T)] nb_elecs = numpy.array([len(local_res_temp)], dtype=numpy.float32) local_res_temp = numpy.array(local_res_temp, dtype=numpy.float32) if len(local_res_temp) == 0: local_res_temp = numpy.zeros(0, dtype=numpy.float32) else: local_res_temp = numpy.sum(local_res_temp, 0) all_res_temp = gather_array(local_res_temp.ravel(), comm, 0, 1) all_elecs = gather_array(nb_elecs, comm, 0, 1) if do_spatial_whitening: local_res_spac = numpy.zeros((N_e, N_e), dtype=numpy.float32) local_silences = [] for elec in numpy.arange(comm.rank, N_e, comm.size): indices = numpy.take(inv_nodes, edges[nodes[elec]]) all_times_elec = numpy.any(numpy.take(all_times, indices, axis=0), 0) esubset = numpy.where(all_times_elec == False)[0] local_data = local_chunk[esubset][:, indices] local_whitening = get_whitening_matrix(local_data).astype( numpy.float32) pos = numpy.where(elec == indices)[0] local_res_spac[elec, indices] = local_whitening[pos] local_silences += [len(esubset)] all_res_spac = gather_array(local_res_spac.ravel(), comm, 0, 1) all_silences = gather_array( numpy.array(local_silences, dtype=numpy.int32), comm, 0, 1, 'uint32') if comm.rank == 0: to_write = {} if do_temporal_whitening: try: nb_silences = numpy.sum(all_elecs > 0) all_res_temp = all_res_temp.reshape((nb_silences, N_t**2)) except Exception: print_and_log([ "No silent periods detected: something wrong with the parameters?" ], 'error', logger) all_res_temp = numpy.sum(all_res_temp, 0) all_res_temp = all_res_temp.reshape( (N_t, N_t)) / numpy.sum(all_elecs) temporal_whitening = get_whitening_matrix( all_res_temp.astype(numpy.double), fudge=1e-3)[template_shift].astype(numpy.float32) temporal_whitening /= temporal_whitening.sum() to_write['temporal'] = temporal_whitening have_nans = numpy.sum(numpy.isnan(temporal_whitening)) if have_nans > 0: temporal_whitening = numpy.zeros(N_t, dtype=numpy.float32) temporal_whitening[N_t // 2] = 1 to_write['temporal'] = temporal_whitening print_and_log( ["Disabling temporal whitening because of NaNs found"], 'info', logger) if do_spatial_whitening: all_res_spac = all_res_spac.reshape(comm.size, N_e, N_e) spatial_whitening = numpy.sum(all_res_spac, 0) to_write['spatial'] = spatial_whitening print_and_log([ "Found %gs without spikes for whitening matrices..." % (numpy.mean(all_silences) / params.rate) ], 'default', logger) have_nans = numpy.sum(numpy.isnan(spatial_whitening)) if have_nans > 0: spatial_whitening = numpy.eye(spatial_whitening.shape[0], dtype=numpy.float32) to_write['spatial'] = spatial_whitening print_and_log( ["Disabling spatial whitening because of NaNs found"], 'info', logger) bfile = h5py.File(file_out_suff + '.basis.hdf5', 'r+', libver='earliest') io.write_datasets(bfile, to_write.keys(), to_write, compression=hdf5_compress) bfile.close() comm.Barrier() if do_spatial_whitening or do_temporal_whitening: if comm.rank == 0: print_and_log( ["Because of whitening, need to recompute the thresholds..."], 'default', logger) if do_spatial_whitening: spatial_whitening = io.load_data(params, 'spatial_whitening') if use_gpu: spatial_whitening = cmt.CUDAMatrix(spatial_whitening, copy_on_host=False) if do_temporal_whitening: temporal_whitening = io.load_data(params, 'temporal_whitening') for gidx in [all_chunks[comm.rank]]: local_chunk, t_offset = data_file.get_data(gidx, chunk_size, nodes=nodes) local_shape = len(local_chunk) if do_spatial_whitening: if use_gpu: local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False) local_chunk = local_chunk.dot(spatial_whitening).asarray() else: local_chunk = numpy.dot(local_chunk, spatial_whitening) if do_temporal_whitening: local_chunk = scipy.ndimage.filters.convolve1d( local_chunk, temporal_whitening, axis=0, mode='constant') thresholds = numpy.zeros(N_e, dtype=numpy.float32) for i in xrange(N_e): u = numpy.median(local_chunk[:, i], 0) thresholds[i] = numpy.median(numpy.abs(local_chunk[:, i] - u), 0) gdata = gather_array(thresholds, comm) if comm.rank == 0: gdata = gdata.reshape((comm.size, N_e)) thresholds = numpy.mean(gdata, 0) bfile = h5py.File(file_out_suff + '.basis.hdf5', 'r+', libver='earliest') bfile.pop('thresholds') io.write_datasets(bfile, ['thresholds'], {'thresholds': thresholds}, compression=hdf5_compress) bfile.close() comm.Barrier() #if comm.rank == 0: #if not os.path.exists(plot_path): # os.makedirs(plot_path) #N_elec = min(int(numpy.sqrt(data_file.N_e)), 5) #plot.view_fit(filename, t_start=0, t_stop=1, fit_on=False, square=True, # n_elec=N_elec, save=[plot_path, 'electrodes']) # Part 2: Basis numpy.random.seed(422) ################################################################# file_out = params.get('data', 'file_out') alignment = params.getboolean('detection', 'alignment') smoothing = params.getboolean('detection', 'smoothing') isolation = params.getboolean('detection', 'isolation') over_factor = params.getint('detection', 'oversampling_factor') spike_thresh = params.getfloat('detection', 'spike_thresh') nodes, edges = get_nodes_and_edges(params) do_temporal_whitening = params.getboolean('whitening', 'temporal') do_spatial_whitening = params.getboolean('whitening', 'spatial') if matched_filter: chunk_size = detect_memory(params, whitening=True) else: chunk_size = detect_memory(params) safety_time = params.getint('whitening', 'safety_time') max_elts_elec = params.getint('whitening', 'max_elts') output_dim = params.getfloat('whitening', 'output_dim') inv_nodes = numpy.zeros(N_total, dtype=numpy.int32) inv_nodes[nodes] = numpy.arange(len(nodes)) smoothing_factor = params.getfloat( 'detection', 'smoothing_factor') * (1. / spike_thresh)**2 if sign_peaks == 'both': max_elts_elec *= 2 nb_elts = int( params.getfloat('whitening', 'nb_elts') * N_e * max_elts_elec) ignore_dead_times = params.getboolean('triggers', 'ignore_times') if ignore_dead_times: all_dead_times = get_dead_times(params) data_file.open() ################################################################# if comm.rank == 0: print_and_log(["Searching spikes to construct the PCA basis..."], 'default', logger) nb_chunks, last_chunk_len = data_file.analyze(chunk_size) if nb_chunks < comm.size: res = io.data_stats(params, show=False) chunk_size = int(res * params.rate // comm.size) if comm.rank == 0: print_and_log( ["Too much cores, automatically resizing the data chunks"], 'debug', logger) nb_chunks, last_chunk_len = data_file.analyze(chunk_size) groups = {} for i in xrange(N_e): groups[i] = 0 # I guess this is more relevant, to take signals from all over the recordings all_chunks = numpy.random.permutation( numpy.arange(nb_chunks, dtype=numpy.int32)) max_elts_elec //= comm.size nb_elts //= comm.size elt_count_pos = 0 elt_count_neg = 0 if sign_peaks in ['positive', 'both']: elts_pos = numpy.zeros((N_t, nb_elts), dtype=numpy.float32) if sign_peaks in ['negative', 'both']: elts_neg = numpy.zeros((N_t, nb_elts), dtype=numpy.float32) chunks_to_load = all_chunks[comm.rank::comm.size] thresholds = io.load_data(params, 'thresholds') mads = io.load_data(params, 'mads') if alignment: cdata = numpy.linspace(-jitter_range, jitter_range, int(over_factor * 2 * jitter_range)) xdata = numpy.arange(-template_shift_2, template_shift_2 + 1) xoff = len(cdata) / 2. if isolation: yoff = numpy.array(range(0, N_t // 4) + range(3 * N_t // 4, N_t)) to_explore = xrange(comm.rank, nb_chunks, comm.size) if comm.rank == 0: to_explore = get_tqdm_progressbar(to_explore) for gcount, gidx in enumerate(to_explore): gidx = all_chunks[gidx] if ((elt_count_pos + elt_count_neg) < nb_elts): #print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..." local_chunk, t_offset = data_file.get_data(gidx, chunk_size, nodes=nodes) local_shape = len(local_chunk) if do_spatial_whitening: if use_gpu: local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False) local_chunk = local_chunk.dot(spatial_whitening).asarray() else: local_chunk = numpy.dot(local_chunk, spatial_whitening) if do_temporal_whitening: local_chunk = scipy.ndimage.filters.convolve1d( local_chunk, temporal_whitening, axis=0, mode='constant') #print "Extracting the peaks..." all_peaktimes = numpy.zeros(0, dtype=numpy.uint32) all_extremas = numpy.zeros(0, dtype=numpy.uint32) for i in xrange(N_e): if sign_peaks == 'negative': peaktimes = scipy.signal.find_peaks(-local_chunk[:, i], height=thresholds[i], width=spike_width, distance=dist_peaks, wlen=N_t)[0] elif sign_peaks == 'positive': peaktimes = scipy.signal.find_peaks(local_chunk[:, i], height=thresholds[i], width=spike_width, wlen=N_t)[0] elif sign_peaks == 'both': peaktimes = scipy.signal.find_peaks(numpy.abs( local_chunk[:, i]), height=thresholds[i], width=spike_width, wlen=N_t)[0] all_peaktimes = numpy.concatenate((all_peaktimes, peaktimes)) all_extremas = numpy.concatenate( (all_extremas, i * numpy.ones(len(peaktimes), dtype=numpy.uint32))) #print "Removing the useless borders..." if alignment: local_borders = (template_shift_2, local_shape - template_shift_2) else: local_borders = (template_shift, local_shape - template_shift) idx = (all_peaktimes >= local_borders[0]) & (all_peaktimes < local_borders[1]) all_peaktimes = numpy.compress(idx, all_peaktimes) all_extremas = numpy.compress(idx, all_extremas) local_peaktimes = numpy.unique(all_peaktimes) if ignore_dead_times: dead_indices = numpy.searchsorted( all_dead_times, [t_offset, t_offset + local_shape]) if dead_indices[0] != dead_indices[1]: is_included = numpy.in1d( local_peaktimes + t_offset, all_dead_times[dead_indices[0]:dead_indices[1]]) local_peaktimes = local_peaktimes[~is_included] local_peaktimes = numpy.sort(local_peaktimes) if len(local_peaktimes) > 0: diff_times = local_peaktimes[-1] - local_peaktimes[0] all_times = numpy.zeros((N_e, diff_times + 1), dtype=numpy.bool) min_times = numpy.maximum( local_peaktimes - local_peaktimes[0] - safety_time, 0) max_times = numpy.minimum( local_peaktimes - local_peaktimes[0] + safety_time + 1, diff_times) n_times = len(local_peaktimes) argmax_peak = numpy.random.permutation(numpy.arange(n_times)) all_idx = numpy.take(local_peaktimes, argmax_peak) #print "Selection of the peaks with spatio-temporal masks..." for midx, peak in zip(argmax_peak, all_idx): if (elt_count_neg + elt_count_pos) == nb_elts: break if sign_peaks == 'negative': elec = numpy.argmin(local_chunk[peak]) negative_peak = True elif sign_peaks == 'positive': elec = numpy.argmax(local_chunk[peak]) negative_peak = False elif sign_peaks == 'both': if N_e == 1: if local_chunk[peak] < 0: negative_peak = True elif local_chunk[peak] > 0: negative_peak = False elec = 0 else: if numpy.abs(numpy.max( local_chunk[peak])) > numpy.abs( numpy.min(local_chunk[peak])): elec = numpy.argmax(local_chunk[peak]) negative_peak = False else: elec = numpy.argmin(local_chunk[peak]) negative_peak = True indices = numpy.take(inv_nodes, edges[nodes[elec]]) myslice = all_times[indices, min_times[midx]:max_times[midx]] is_local_extrema = elec in all_extremas[all_peaktimes == peak] if is_local_extrema and not myslice.any(): upper_bounds = max_elts_elec if groups[elec] < upper_bounds: if not alignment: sub_mat = local_chunk[peak - template_shift:peak + template_shift + 1, elec] elif alignment: ydata = local_chunk[peak - template_shift_2:peak + template_shift_2 + 1, elec] if smoothing: factor = smoothing_factor * xdata.size f = scipy.interpolate.UnivariateSpline( xdata, ydata, s=factor, k=3) else: f = scipy.interpolate.UnivariateSpline( xdata, ydata, k=3, s=0) if negative_peak: rmin = (numpy.argmin(f(cdata)) - xoff) / over_factor else: rmin = (numpy.argmax(f(cdata)) - xoff) / over_factor ddata = numpy.linspace(rmin - template_shift, rmin + template_shift, N_t) sub_mat = f(ddata).astype(numpy.float32) if alignment: if negative_peak: if numpy.min(sub_mat) >= -thresholds[elec]: to_accept = False else: if numpy.max(sub_mat) <= thresholds[elec]: to_accept = False if isolation: to_accept = numpy.all( numpy.max(numpy.abs(sub_mat[yoff])) <= thresholds[elec]) else: to_accept = True if to_accept: if negative_peak: elts_neg[:, elt_count_neg] = sub_mat else: elts_pos[:, elt_count_pos] = sub_mat if negative_peak: elt_count_neg += 1 else: elt_count_pos += 1 groups[elec] += 1 all_times[indices, min_times[midx]:max_times[midx]] = True sys.stderr.flush() if isolation: print_and_log([ "Node %d has collected %d isolated waveforms" % (comm.rank, elt_count_pos + elt_count_neg) ], 'debug', logger) else: print_and_log([ "Node %d has collected %d waveforms" % (comm.rank, elt_count_pos + elt_count_neg) ], 'debug', logger) if sign_peaks in ['negative', 'both']: gdata_neg = gather_array(elts_neg[:, :elt_count_neg].T, comm, 0, 1) if sign_peaks in ['positive', 'both']: gdata_pos = gather_array(elts_pos[:, :elt_count_pos].T, comm, 0, 1) nb_waveforms = 0 if comm.rank == 0: #DO PCA on elts and store the basis obtained. if sign_peaks in ['negative', 'both']: nb_waveforms += gdata_neg.shape[0] if sign_peaks in ['positive', 'both']: nb_waveforms += gdata_pos.shape[0] nb_waveforms = all_gather_array( numpy.array([nb_waveforms], dtype=numpy.float32), comm, 0)[0] if comm.rank == 0: if isolation: print_and_log([ "Found %d isolated waveforms over %d requested" % (nb_waveforms, int(nb_elts * comm.size)) ], 'default', logger) else: print_and_log([ "Found %d waveforms over %d requested" % (nb_waveforms, int(nb_elts * comm.size)) ], 'default', logger) if nb_waveforms == 0: print_and_log( ['No waveforms found! Are the data properly loaded??'], 'error', logger) if nb_waveforms == 0: sys.exit(0) if comm.rank == 0: res = {} pca = None pca_pos = None pca_neg = None if sign_peaks in ['negative', 'both']: if len(gdata_neg) > 0: pca = PCA(output_dim) if use_hanning: pca.fit(gdata_neg * hanning_filter) else: pca.fit(gdata_neg) res['proj'] = pca.components_.T.astype(numpy.float32) pca_neg = numpy.sum(pca.explained_variance_ratio_) else: res['proj'] = numpy.identity(int(output_dim), dtype=numpy.float32) res['rec'] = res['proj'].T res['waveform'] = numpy.median(gdata_neg, 0) idx = numpy.random.permutation(numpy.arange( gdata_neg.shape[0]))[:1000] res['waveforms'] = gdata_neg[idx, :] if sign_peaks in ['positive', 'both']: if len(gdata_pos) > 0: pca = PCA(output_dim) if use_hanning: pca.fit(gdata_pos * hanning_filter) else: pca.fit(gdata_pos) res['proj_pos'] = pca.components_.T.astype(numpy.float32) pca_pos = numpy.sum(pca.explained_variance_ratio_) else: res['proj_pos'] = numpy.identity(int(output_dim), dtype=numpy.float32) res['rec_pos'] = res['proj_pos'].T res['waveform_pos'] = numpy.median(gdata_pos, 0) idx = numpy.random.permutation(numpy.arange( gdata_pos.shape[0]))[:1000] res['waveforms_pos'] = gdata_pos[idx, :] bfile = h5py.File(file_out_suff + '.basis.hdf5', 'r+', libver='earliest') io.write_datasets(bfile, res.keys(), res, compression=hdf5_compress) if sign_peaks == 'positive': print_and_log([ "A basis with %s dimensions has been built" % res['proj_pos'].shape[1] ], 'info', logger) elif sign_peaks == 'negative': print_and_log([ "A basis with %s dimensions has been built" % res['proj'].shape[1] ], 'info', logger) elif sign_peaks == 'both': print_and_log([ "Two basis with %s dimensions has been built" % res['proj'].shape[1] ], 'debug', logger) if pca_pos is not None: print_and_log([ "The percentage of variance explained is %s for positive spikes" % pca_pos ], 'debug', logger) if pca_neg is not None: print_and_log([ "The percentage of variance explained is %s for negative spikes" % pca_neg ], 'debug', logger) bfile.close() comm.Barrier() if matched_filter: if comm.rank == 0: print_and_log([ "Because of matched filters, need to recompute the thresholds..." ], 'default', logger) if do_spatial_whitening: spatial_whitening = io.load_data(params, 'spatial_whitening') if use_gpu: spatial_whitening = cmt.CUDAMatrix(spatial_whitening, copy_on_host=False) if do_temporal_whitening: temporal_whitening = io.load_data(params, 'temporal_whitening') if sign_peaks in ['negative', 'both']: waveform_neg = io.load_data(params, 'waveform')[::-1] waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) * len(waveform_neg)) if sign_peaks in ['positive', 'both']: waveform_pos = io.load_data(params, 'waveform-pos')[::-1] waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) * len(waveform_pos)) for gidx in [all_chunks[comm.rank]]: local_chunk, t_offset = data_file.get_data(gidx, chunk_size, nodes=nodes) local_shape = len(local_chunk) if do_spatial_whitening: if use_gpu: local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False) local_chunk = local_chunk.dot(spatial_whitening).asarray() else: local_chunk = numpy.dot(local_chunk, spatial_whitening) if do_temporal_whitening: local_chunk = scipy.ndimage.filters.convolve1d( local_chunk, temporal_whitening, axis=0, mode='constant') local_chunk /= thresholds if sign_peaks in ['negative', 'both']: tmp_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_neg, axis=0, mode='constant') thresholds = numpy.zeros(N_e, dtype=numpy.float32) for i in xrange(N_e): u = numpy.median(tmp_chunk[:, i], 0) thresholds[i] = numpy.median( numpy.abs(tmp_chunk[:, i] - u), 0) gdata = gather_array(thresholds, comm) if comm.rank == 0: gdata = gdata.reshape((comm.size, N_e)) thresholds = numpy.mean(gdata, 0) bfile = h5py.File(file_out_suff + '.basis.hdf5', 'r+', libver='earliest') io.write_datasets(bfile, ['matched_thresholds'], {'matched_thresholds': thresholds}, compression=hdf5_compress) bfile.close() comm.Barrier() if sign_peaks in ['positive', 'both']: tmp_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_pos, axis=0, mode='constant') thresholds = numpy.zeros(N_e, dtype=numpy.float32) for i in xrange(N_e): u = numpy.median(tmp_chunk[:, i], 0) thresholds[i] = numpy.median( numpy.abs(tmp_chunk[:, i] - u), 0) gdata = gather_array(thresholds, comm) if comm.rank == 0: gdata = gdata.reshape((comm.size, N_e)) thresholds = numpy.mean(gdata, 0) bfile = h5py.File(file_out_suff + '.basis.hdf5', 'r+', libver='earliest') io.write_datasets(bfile, ['matched_thresholds_pos'], {'matched_thresholds_pos': thresholds}, compression=hdf5_compress) bfile.close() comm.Barrier() data_file.close()
def main(params, nb_cpu, nb_gpu, use_gpu): ################################################################# # params = detect_memory(params) _ = init_logging(params.logfile) SHARED_MEMORY = get_shared_memory_flag(params) logger = logging.getLogger('circus.fitting') data_file = params.data_file n_e = params.getint('data', 'N_e') n_total = params.nb_channels n_t = params.getint('detection', 'N_t') template_shift = params.getint('detection', 'template_shift') # file_out = params.get('data', 'file_out') file_out_suff = params.get('data', 'file_out_suff') sign_peaks = params.get('detection', 'peaks') matched_filter = params.getboolean('detection', 'matched-filter') # spike_thresh = params.getfloat('detection', 'spike_thresh') ratio_thresh = params.getfloat('fitting', 'ratio_thresh') two_components = params.getboolean('fitting', 'two_components') sparse_threshold = params.getfloat('fitting', 'sparse_thresh') # spike_width = params.getfloat('detection', 'spike_width') # dist_peaks = params.getint('detection', 'dist_peaks') do_temporal_whitening = params.getboolean('whitening', 'temporal') do_spatial_whitening = params.getboolean('whitening', 'spatial') templates_normalization = params.getboolean('clustering', 'templates_normalization') # TODO test, switch, test! chunk_size = detect_memory(params, fitting=True) gpu_only = params.getboolean('fitting', 'gpu_only') nodes, edges = get_nodes_and_edges(params) tmp_limits = params.get('fitting', 'amp_limits').replace('(', '').replace(')', '').split(',') tmp_limits = [float(v) for v in tmp_limits] amp_auto = params.getboolean('fitting', 'amp_auto') max_chunk = params.getfloat('fitting', 'max_chunk') # noise_thr = params.getfloat('clustering', 'noise_thr') collect_all = params.getboolean('fitting', 'collect_all') min_second_component = params.getfloat('fitting', 'min_second_component') debug = params.getboolean('fitting', 'debug') ignore_dead_times = params.getboolean('triggers', 'ignore_times') inv_nodes = numpy.zeros(n_total, dtype=numpy.int32) inv_nodes[nodes] = numpy.arange(len(nodes)) data_file.open() supports = io.load_data(params, 'supports') low_channels_thr = params.getint('detection', 'low_channels_thr') median_channels = numpy.median(numpy.sum(supports, 1)) fixed_amplitudes = params.getboolean('clustering', 'fixed_amplitudes') weird_thresh = params.get('detection', 'weird_thresh') if weird_thresh != '': ignore_artefacts = True weird_thresh = io.load_data(params, 'weird-thresholds') else: ignore_artefacts = False if not fixed_amplitudes: nb_amp_bins = params.getint('clustering', 'nb_amp_bins') splits = np.linspace(0, params.data_file.duration, nb_amp_bins) interpolated_times = np.zeros(len(splits) - 1, dtype=numpy.float32) for count in range(0, len(splits) - 1): interpolated_times[count] = (splits[count] + splits[count + 1])/2 interpolated_times = numpy.concatenate(([0], interpolated_times, [params.data_file.duration])) nb_amp_times = len(splits) + 1 mse_error = params.getboolean('fitting', 'mse_error') if mse_error: stds = io.load_data(params, 'stds') stds_norm = numpy.linalg.norm(stds) # if median_channels < low_channels_thr: # normalization = False # if comm.rank == 0: # print_and_log(['Templates defined on few channels (%g), turning off normalization' %median_channels], 'debug', logger) ################################################################# if SHARED_MEMORY: templates, mpi_memory_1 = io.load_data_memshared(params, 'templates', normalize=templates_normalization, transpose=True, sparse_threshold=sparse_threshold) N_tm, x = templates.shape is_sparse = not isinstance(templates, numpy.ndarray) else: templates = io.load_data(params, 'templates') x, N_tm = templates.shape if N_tm > 0: sparsity = templates.nnz / (x * N_tm) is_sparse = sparsity < sparse_threshold else: is_sparse = True if not is_sparse: if comm.rank == 0: print_and_log(['Templates sparsity is low (%g): densified to speedup the algorithm' %sparsity], 'debug', logger) templates = templates.toarray() temp_2_shift = 2 * template_shift temp_3_shift = 3 * template_shift n_tm = N_tm // 2 n_scalar = n_e * n_t temp_window = numpy.arange(-template_shift, template_shift + 1) size_window = n_e * (2 * template_shift + 1) if not amp_auto: amp_limits = numpy.zeros((n_tm, 2)) amp_limits[:, 0] = tmp_limits[0] amp_limits[:, 1] = tmp_limits[1] else: amp_limits = io.load_data(params, 'limits') norm_templates = io.load_data(params, 'norm-templates') sub_norm_templates = n_scalar * norm_templates[:n_tm] if not templates_normalization: norm_templates_2 = (norm_templates ** 2.0) * n_scalar sub_norm_templates_2 = norm_templates_2[:n_tm] if not SHARED_MEMORY: # Normalize templates (if necessary). if templates_normalization: if is_sparse: for idx in range(templates.shape[1]): myslice = numpy.arange(templates.indptr[idx], templates.indptr[idx+1]) templates.data[myslice] /= norm_templates[idx] else: for idx in range(templates.shape[1]): templates[:, idx] /= norm_templates[idx] # Transpose templates. templates = templates.T maxoverlap = io.load_data(params, 'maxoverlap')/n_scalar similar = np.where(maxoverlap > 0.5) idx = similar[0] < similar[1] similar = similar[0][idx], similar[1][idx] nb_mixtures = len(similar[0]) waveform_neg = numpy.empty(0) # default assignment (for PyCharm code inspection) matched_thresholds_neg = None # default assignment (for PyCharm code inspection) waveform_pos = numpy.empty(0) # default assignment (for PyCharm code inspection) matched_thresholds_pos = None # default assignment (for PyCharm code inspection) if matched_filter: if sign_peaks in ['negative', 'both']: waveform_neg = io.load_data(params, 'waveform')[::-1] waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) * len(waveform_neg)) matched_thresholds_neg = io.load_data(params, 'matched-thresholds') if sign_peaks in ['positive', 'both']: waveform_pos = io.load_data(params, 'waveform-pos')[::-1] waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) * len(waveform_pos)) matched_thresholds_pos = io.load_data(params, 'matched-thresholds-pos') if ignore_dead_times: if SHARED_MEMORY: all_dead_times, mpi_memory_3 = get_dead_times(params) else: all_dead_times = get_dead_times(params) else: all_dead_times = None # default assignment (for PyCharm code inspection) thresholds = io.get_accurate_thresholds(params, ratio_thresh) neighbors = {} if collect_all: is_sparse = not isinstance(templates, numpy.ndarray) for i in range(0, n_tm): if is_sparse: tmp = templates[i, :].toarray().reshape(n_e, n_t) else: tmp = templates[i].reshape(n_e, n_t) if templates_normalization: tmp = tmp * norm_templates[i] neighbors[i] = numpy.where(numpy.sum(tmp, axis=1) != 0.0)[0] #N_tm, x = templates.shape #sparsity_factor = templates.nnz / (N_tm * x) #if sparsity_factor > sparse_threshold: # if comm.rank == 0: # print_and_log(['Templates are not sparse enough, we densify them for'], 'default', logger) # templates = templates.toarray() info_string = '' if comm.rank == 0: info_string = "using %d CPUs" % comm.size comm.Barrier() c_overlap = io.get_overlaps(params, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu) over_shape = c_overlap.get('over_shape')[:] n_over = int(numpy.sqrt(over_shape[0])) s_over = over_shape[1] s_center = s_over // 2 # # If the number of overlaps is different from templates, we need to recompute them. if n_over != N_tm: if comm.rank == 0: print_and_log(['Templates have been modified, recomputing the overlaps...'], 'default', logger) c_overlap = io.get_overlaps(params, erase=True, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu) over_shape = c_overlap.get('over_shape')[:] n_over = int(numpy.sqrt(over_shape[0])) s_over = over_shape[1] if SHARED_MEMORY: c_overs, mpi_memory_2 = io.load_data_memshared(params, 'overlaps') else: c_overs = io.load_data(params, 'overlaps') comm.Barrier() if n_tm == 0: if comm.rank == 0: print_and_log(["No templates present. Redo clustering?"], 'default', logger) sys.exit(0) if comm.rank == 0: print_and_log(["Here comes the SpyKING CIRCUS %s and %d templates..." % (info_string, n_tm)], 'default', logger) purge(file_out_suff, '.data') if do_spatial_whitening: spatial_whitening = io.load_data(params, 'spatial_whitening') else: spatial_whitening = None # default assignment (for PyCharm code inspection) if do_temporal_whitening: temporal_whitening = io.load_data(params, 'temporal_whitening') else: temporal_whitening = None # default assignment (for PyCharm code inspection) nb_chunks, last_chunk_len = data_file.analyze(chunk_size) processed_chunks = int(min(nb_chunks, max_chunk)) comm.Barrier() spiketimes_file = open(file_out_suff + '.spiketimes-%d.data' % comm.rank, 'wb') comm.Barrier() amplitudes_file = open(file_out_suff + '.amplitudes-%d.data' % comm.rank, 'wb') comm.Barrier() templates_file = open(file_out_suff + '.templates-%d.data' % comm.rank, 'wb') comm.Barrier() if ignore_artefacts: comm.Barrier() arte_spiketimes_file = open(file_out_suff + '.times-%d.sata' % comm.rank, 'wb') comm.Barrier() arte_electrodes_file = open(file_out_suff + '.elec-%d.sata' % comm.rank, 'wb') comm.Barrier() arte_amplitudes_file = open(file_out_suff + '.amp-%d.sata' % comm.rank, 'wb') comm.Barrier() if mse_error: mse_file = open(file_out_suff + '.mses-%d.data' % comm.rank, 'wb') comm.Barrier() if collect_all: garbage_times_file = open(file_out_suff + '.gspiketimes-%d.data' % comm.rank, 'wb') comm.Barrier() garbage_temp_file = open(file_out_suff + '.gtemplates-%d.data' % comm.rank, 'wb') comm.Barrier() else: garbage_times_file = None # default assignment (for PyCharm code inspection) garbage_temp_file = None # default assignment (for PyCharm code inspection) if debug: # Open debug files. chunk_nbs_debug_file = open(file_out_suff + '.chunk_nbs_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() iteration_nbs_debug_file = open(file_out_suff + '.iteration_nbs_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() peak_nbs_debug_file = open(file_out_suff + '.peak_nbs_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() peak_local_time_steps_debug_file = open( file_out_suff + '.peak_local_time_steps_debug_%d.data' % comm.rank, mode='wb' ) comm.Barrier() peak_time_steps_debug_file = open(file_out_suff + '.peak_time_steps_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() peak_scalar_products_debug_file = open( file_out_suff + '.peak_scalar_products_debug_%d.data' % comm.rank, mode='wb' ) comm.Barrier() peak_solved_flags_debug_file = open(file_out_suff + '.peak_solved_flags_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() template_nbs_debug_file = open(file_out_suff + '.template_nbs_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() success_flags_debug_file = open(file_out_suff + '.success_flags_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() else: chunk_nbs_debug_file = None # default assignment (for PyCharm code inspection) iteration_nbs_debug_file = None # default assignment (for PyCharm code inspection) peak_nbs_debug_file = None # default assignment (for PyCharm code inspection) peak_local_time_steps_debug_file = None # default assignment (for PyCharm code inspection) peak_time_steps_debug_file = None # default assignment (for PyCharm code inspection) peak_scalar_products_debug_file = None # default assignment (for PyCharm code inspection) peak_solved_flags_debug_file = None # default assignment (for PyCharm code inspection) template_nbs_debug_file = None # default assignment (for PyCharm code inspection) success_flags_debug_file = None # default assignment (for PyCharm code inspection) last_chunk_size = 0 slice_indices = numpy.zeros(0, dtype=numpy.int32) to_explore = list(range(comm.rank, processed_chunks, comm.size)) if comm.rank == 0: to_explore = get_tqdm_progressbar(params, to_explore) if fixed_amplitudes: min_scalar_products = amp_limits[:,0][:, numpy.newaxis] max_scalar_products = amp_limits[:,1][:, numpy.newaxis] if templates_normalization: min_sps = min_scalar_products * sub_norm_templates[:, numpy.newaxis] max_sps = max_scalar_products * sub_norm_templates[:, numpy.newaxis] else: min_sps = min_scalar_products * sub_norm_templates_2[:, numpy.newaxis] max_sps = max_scalar_products * sub_norm_templates_2[:, numpy.newaxis] for gcount, gidx in enumerate(to_explore): # print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..." # # We need to deal with the borders by taking chunks of size [0, chunck_size + template_shift]. is_first = data_file.is_first_chunk(gidx, nb_chunks) is_last = data_file.is_last_chunk(gidx, nb_chunks) if not (is_first and is_last): if is_last: padding = (-temp_3_shift, 0) elif is_first: padding = (0, temp_3_shift) else: padding = (-temp_3_shift, temp_3_shift) else: padding = (0, 0) result = { 'spiketimes': [], 'amplitudes': [], 'templates': [], } if mse_error: mse_fit = { 'spiketimes': [], 'amplitudes': [], 'templates': [], } result_debug = { 'chunk_nbs': [], 'iteration_nbs': [], 'peak_nbs': [], 'peak_local_time_steps': [], 'peak_time_steps': [], 'peak_scalar_products': [], 'peak_solved_flags': [], 'template_nbs': [], 'success_flags': [], } local_chunk, t_offset = data_file.get_data(gidx, chunk_size, padding, nodes=nodes) len_chunk = len(local_chunk) if is_last: my_chunk_size = last_chunk_size else: my_chunk_size = chunk_size if do_spatial_whitening: local_chunk = numpy.dot(local_chunk, spatial_whitening) if do_temporal_whitening: local_chunk = scipy.ndimage.filters.convolve1d(local_chunk, temporal_whitening, axis=0, mode='constant') # Extracting peaks. all_found_spikes = {} if collect_all: for i in range(n_e): all_found_spikes[i] = [] local_peaktimes = [numpy.empty(0, dtype=numpy.uint32)] if ignore_artefacts: artefacts_peaktimes = [numpy.zeros(0, dtype=numpy.uint32)] artefacts_elecs = [numpy.zeros(0, dtype=numpy.uint32)] artefacts_amps = [numpy.zeros(0, dtype=numpy.float32)] if matched_filter: if sign_peaks in ['positive', 'both']: filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_pos, axis=0, mode='constant') for i in range(n_e): peaktimes = scipy.signal.find_peaks(filter_chunk[:, i], height=matched_thresholds_pos[i])[0] if ignore_artefacts: artetimes = scipy.signal.find_peaks(numpy.abs(filter_chunk[:, i]), height=weird_thresh[i])[0] to_keep = numpy.logical_not(numpy.in1d(peaktimes, artetimes)) peaktimes = peaktimes[to_keep] artefacts_peaktimes.append(artetimes) artefacts_elecs.append(i*numpy.ones(len(artetimes), dtype='uint32')) artefacts_amps.append(local_chunk[artetimes, i]) local_peaktimes.append(peaktimes) if collect_all: all_found_spikes[i] += peaktimes.tolist() if sign_peaks in ['negative', 'both']: filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_neg, axis=0, mode='constant') for i in range(n_e): peaktimes = scipy.signal.find_peaks(filter_chunk[:, i], height=matched_thresholds_neg[i])[0] if ignore_artefacts: artetimes = scipy.signal.find_peaks(numpy.abs(filter_chunk[:, i]), height=weird_thresh[i])[0] to_keep = numpy.logical_not(numpy.in1d(peaktimes, artetimes)) peaktimes = peaktimes[to_keep] artefacts_peaktimes.append(artetimes) artefacts_elecs.append(i*numpy.ones(len(artetimes), dtype='uint32')) artefacts_amps.append(local_chunk[artetimes, i]) local_peaktimes.append(peaktimes) if collect_all: all_found_spikes[i] += peaktimes.tolist() local_peaktimes = numpy.concatenate(local_peaktimes) if ignore_artefacts: artefacts_peaktimes = numpy.concatenate(artefacts_peaktimes) artefacts_elecs = numpy.concatenate(artefacts_elecs) artefacts_amps = numpy.concatenate(artefacts_amps) else: for i in range(n_e): if sign_peaks == 'negative': peaktimes = scipy.signal.find_peaks(-local_chunk[:, i], height=thresholds[i])[0] elif sign_peaks == 'positive': peaktimes = scipy.signal.find_peaks(local_chunk[:, i], height=thresholds[i])[0] elif sign_peaks == 'both': peaktimes = scipy.signal.find_peaks(numpy.abs(local_chunk[:, i]), height=thresholds[i])[0] else: raise ValueError("Unexpected value %s" % sign_peaks) if ignore_artefacts: artetimes = scipy.signal.find_peaks(numpy.abs(local_chunk[:, i]), height=weird_thresh[i])[0] to_keep = numpy.logical_not(numpy.in1d(peaktimes, artetimes)) peaktimes = peaktimes[to_keep] artefacts_peaktimes.append(artetimes) artefacts_elecs.append(i*numpy.ones(len(artetimes), dtype='uint32')) artefacts_amps.append(local_chunk[artetimes, i]) local_peaktimes.append(peaktimes) if collect_all: all_found_spikes[i] += peaktimes.tolist() local_peaktimes = numpy.concatenate(local_peaktimes) if ignore_artefacts: artefacts_peaktimes = numpy.concatenate(artefacts_peaktimes) artefacts_elecs = numpy.concatenate(artefacts_elecs) artefacts_amps = numpy.concatenate(artefacts_amps) local_peaktimes = numpy.unique(local_peaktimes) g_offset = t_offset + padding[0] if ignore_dead_times: dead_indices = numpy.searchsorted(all_dead_times, [t_offset, t_offset + my_chunk_size]) if dead_indices[0] != dead_indices[1]: is_included = numpy.in1d(local_peaktimes + g_offset, all_dead_times[dead_indices[0]:dead_indices[1]]) local_peaktimes = local_peaktimes[~is_included] if ignore_artefacts: is_included = numpy.in1d(artefacts_peaktimes + g_offset, all_dead_times[dead_indices[0]:dead_indices[1]]) artefacts_peaktimes = artefacts_peaktimes[~is_included] artefacts_elecs = artefacts_elecs[~is_included] artefacts_amps = artefacts_amps[~is_included] local_peaktimes = numpy.sort(local_peaktimes) else: dead_indices = None # default assignment (for PyCharm code inspection) # print "Removing the useless borders..." local_borders = (template_shift, len_chunk - template_shift) idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes < local_borders[1]) local_peaktimes = numpy.compress(idx, local_peaktimes) if ignore_artefacts: artefacts_peaktimes = artefacts_peaktimes + g_offset idx = (artefacts_peaktimes >= t_offset) & (artefacts_peaktimes < t_offset + my_chunk_size) artefacts_peaktimes = numpy.compress(idx, artefacts_peaktimes) artefacts_elecs = numpy.compress(idx, artefacts_elecs) artefacts_amps = numpy.compress(idx, artefacts_amps) if collect_all: for i in range(n_e): all_found_spikes[i] = numpy.array(all_found_spikes[i], dtype=numpy.uint32) if ignore_dead_times: if dead_indices[0] != dead_indices[1]: is_included = numpy.in1d( all_found_spikes[i] + g_offset, all_dead_times[dead_indices[0]:dead_indices[1]] ) all_found_spikes[i] = all_found_spikes[i][~is_included] all_found_spikes[i] = numpy.sort(all_found_spikes[i]) idx = (all_found_spikes[i] >= local_borders[0]) & (all_found_spikes[i] < local_borders[1]) all_found_spikes[i] = numpy.compress(idx, all_found_spikes[i]) nb_local_peak_times = len(local_peaktimes) if nb_local_peak_times > 0: # print "Computing the b (should full_gpu by putting all chunks on GPU if possible?)..." if collect_all or mse_error: c_local_chunk = local_chunk.copy() else: c_local_chunk = None # default assignment (for PyCharm code inspection) sub_mat = local_chunk[local_peaktimes[:, None] + temp_window] sub_mat = sub_mat.transpose(2, 1, 0).reshape(size_window, nb_local_peak_times) del local_chunk b = templates.dot(sub_mat) del sub_mat local_restriction = (t_offset, t_offset + my_chunk_size) all_spikes = local_peaktimes + g_offset if collect_all: c_all_times = numpy.zeros((len_chunk, n_e), dtype=numpy.bool) c_min_times = numpy.maximum(numpy.arange(len_chunk) - template_shift, 0) c_max_times = numpy.minimum(numpy.arange(len_chunk) + template_shift + 1, len_chunk) for i in range(n_e): c_all_times[all_found_spikes[i], i] = True else: c_all_times = None # default assignment (for PyCharm code inspection) c_min_times = None # default assignment (for PyCharm code inspection) c_max_times = None # default assignment (for PyCharm code inspection) iteration_nb = 0 data = b[:n_tm, :] if not fixed_amplitudes: amp_index = numpy.searchsorted(splits, local_restriction[0], 'right') scaling = 1/(splits[amp_index] - splits[amp_index - 1]) min_scalar_products = amp_limits[:, amp_index, 0] + (amp_limits[:, amp_index, 0] - amp_limits[:, amp_index+1, 0])*scaling max_scalar_products = amp_limits[:, amp_index, 1] + (amp_limits[:, amp_index, 1] - amp_limits[:, amp_index+1, 0])*scaling min_scalar_products = min_scalar_products[:, numpy.newaxis] max_scalar_products = max_scalar_products[:, numpy.newaxis] if templates_normalization: min_sps = min_scalar_products * sub_norm_templates[:, numpy.newaxis] max_sps = max_scalar_products * sub_norm_templates[:, numpy.newaxis] else: min_sps = min_scalar_products * sub_norm_templates_2[:, numpy.newaxis] max_sps = max_scalar_products * sub_norm_templates_2[:, numpy.newaxis] while True: is_valid = (data > min_sps)*(data < max_sps) valid_indices = numpy.where(is_valid) if len(valid_indices[0]) == 0: break best_amplitude_idx = data[is_valid].argmax() best_template_index, peak_index = valid_indices[0][best_amplitude_idx], valid_indices[1][best_amplitude_idx] peak_scalar_product = data[is_valid][best_amplitude_idx] best_template2_index = best_template_index + n_tm if templates_normalization: best_amp = b[best_template_index, peak_index] / n_scalar best_amp_n = best_amp / norm_templates[best_template_index] if two_components: best_amp2 = b[best_template2_index, peak_index] / n_scalar best_amp2_n = best_amp2 / norm_templates[best_template2_index] else: best_amp2 = 0 best_amp2_n = 0 else: best_amp = b[best_template_index, peak_index] / norm_templates_2[best_template_index] best_amp_n = best_amp if two_components: best_amp2 = b[best_template2_index, peak_index] / norm_templates_2[best_template2_index] best_amp2_n = best_amp2 else: best_amp2 = 0 best_amp2_n = 0 peak_time_step = local_peaktimes[peak_index] peak_data = (local_peaktimes - peak_time_step).astype(np.int32) is_neighbor = np.abs(peak_data) <= temp_2_shift idx_neighbor = peak_data[is_neighbor] + temp_2_shift tmp1 = c_overs[best_template_index].multiply(-best_amp) if numpy.abs(best_amp2_n) > min_second_component: tmp1 += c_overs[best_template2_index].multiply(-best_amp2) to_add = tmp1.toarray()[:, idx_neighbor] b[:, is_neighbor] += to_add b[best_template_index, peak_index] = -numpy.inf # Add matching to the result. t_spike = all_spikes[peak_index] if (t_spike >= local_restriction[0]) and (t_spike < local_restriction[1]): result['spiketimes'] += [t_spike] result['amplitudes'] += [(best_amp_n, best_amp2_n)] result['templates'] += [best_template_index] elif mse_error: mse_fit['spiketimes'] += [t_spike] mse_fit['amplitudes'] += [(best_amp_n, best_amp2_n)] mse_fit['templates'] += [best_template_index] # Save debug data. if debug: result_debug['chunk_nbs'] += [gidx] result_debug['iteration_nbs'] += [iteration_nb] result_debug['peak_nbs'] += [peak_index] result_debug['peak_local_time_steps'] += [local_peaktimes[peak_index]] result_debug['peak_time_steps'] += [all_spikes[peak_index]] result_debug['peak_scalar_products'] += [peak_scalar_product] result_debug['peak_solved_flags'] += [b[best_template_index, peak_index]] result_debug['template_nbs'] += [best_template_index] result_debug['success_flags'] += [True] iteration_nb += 1 spikes_to_write = numpy.array(result['spiketimes'], dtype=numpy.uint32) amplitudes_to_write = numpy.array(result['amplitudes'], dtype=numpy.float32) templates_to_write = numpy.array(result['templates'], dtype=numpy.uint32) spiketimes_file.write(spikes_to_write.tostring()) amplitudes_file.write(amplitudes_to_write.tostring()) templates_file.write(templates_to_write.tostring()) if ignore_artefacts: arte_spiketimes_file.write(artefacts_peaktimes.astype(numpy.uint32).tostring()) arte_electrodes_file.write(artefacts_elecs.tostring()) arte_amplitudes_file.write(artefacts_amps.tostring()) if mse_error: curve = numpy.zeros((len_chunk, n_e), dtype=numpy.float32) for spike, temp_id, amplitude in zip(result['spiketimes'], result['templates'], result['amplitudes']): spike = spike - t_offset - padding[0] if is_sparse: tmp1 = templates[temp_id].toarray().reshape(n_e, n_t) tmp2 = templates[temp_id + n_tm].toarray().reshape(n_e, n_t) else: tmp1 = templates[temp_id].reshape(n_e, n_t) tmp2 = templates[temp_id + n_tm].reshape(n_e, n_t) curve[spike - template_shift:spike + template_shift + 1, :] += (amplitude[0] * tmp1 + amplitude[1] * tmp2).T for spike, temp_id, amplitude in zip(mse_fit['spiketimes'], mse_fit['templates'], mse_fit['amplitudes']): spike = spike - t_offset + padding[0] if is_sparse: tmp1 = templates[temp_id].toarray().reshape(n_e, n_t) tmp2 = templates[temp_id + n_tm].toarray().reshape(n_e, n_t) else: tmp1 = templates[temp_id].reshape(n_e, n_t) tmp2 = templates[temp_id + n_tm].reshape(n_e, n_t) try: curve[int(spike) - template_shift:int(spike) + template_shift + 1, :] += (amplitude[0] * tmp1 + amplitude[1] * tmp2).T except Exception: pass mse = numpy.linalg.norm((curve - c_local_chunk)[-padding[0]:-padding[1]]) nb_points = len(curve) - (padding[1] - padding[0]) mse_ratio = mse/(numpy.sqrt(nb_points)*stds_norm) mse_to_write = numpy.array([g_offset, mse_ratio], dtype=numpy.float32) mse_file.write(mse_to_write.tostring()) if collect_all: for temp, spike in zip(templates_to_write, spikes_to_write - g_offset): c_all_times[c_min_times[spike]:c_max_times[spike], neighbors[temp]] = False gspikes = numpy.where(numpy.sum(c_all_times, 1) > 0)[0] c_all_times = numpy.take(c_all_times, gspikes, axis=0) c_local_chunk = numpy.take(c_local_chunk, gspikes, axis=0) * c_all_times if sign_peaks == 'negative': bestlecs = numpy.argmin(c_local_chunk, 1) if matched_filter: threshs = -matched_thresholds_neg[bestlecs] else: threshs = -thresholds[bestlecs] idx = numpy.where(numpy.min(c_local_chunk, 1) < threshs)[0] elif sign_peaks == 'positive': bestlecs = numpy.argmax(c_local_chunk, 1) if matched_filter: threshs = matched_thresholds_pos[bestlecs] else: threshs = thresholds[bestlecs] idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0] elif sign_peaks == 'both': c_local_chunk = numpy.abs(c_local_chunk) bestlecs = numpy.argmax(c_local_chunk, 1) if matched_filter: threshs = numpy.minimum(matched_thresholds_neg[bestlecs], matched_thresholds_pos[bestlecs]) else: threshs = thresholds[bestlecs] idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0] else: raise ValueError("Unexpected value %s" % sign_peaks) gspikes = numpy.take(gspikes, idx) bestlecs = numpy.take(bestlecs, idx) gspikes_to_write = numpy.array(gspikes + g_offset, dtype=numpy.uint32) gtemplates_to_write = numpy.array(bestlecs, dtype=numpy.uint32) garbage_times_file.write(gspikes_to_write.tostring()) garbage_temp_file.write(gtemplates_to_write.tostring()) if debug: # Write debug data to debug files. for field_label, field_dtype, field_file in [ ('chunk_nbs', numpy.uint32, chunk_nbs_debug_file), ('iteration_nbs', numpy.uint32, iteration_nbs_debug_file), ('peak_nbs', numpy.uint32, peak_nbs_debug_file), ('peak_local_time_steps', numpy.uint32, peak_local_time_steps_debug_file), ('peak_time_steps', numpy.uint32, peak_time_steps_debug_file), ('peak_scalar_products', numpy.float32, peak_scalar_products_debug_file), ('peak_solved_flags', numpy.float32, peak_solved_flags_debug_file), ('template_nbs', numpy.uint32, template_nbs_debug_file), ('success_flags', numpy.bool, success_flags_debug_file), ]: field_to_write = numpy.array(result_debug[field_label], dtype=field_dtype) field_file.write(field_to_write.tostring()) sys.stderr.flush() spiketimes_file.flush() os.fsync(spiketimes_file.fileno()) spiketimes_file.close() amplitudes_file.flush() os.fsync(amplitudes_file.fileno()) amplitudes_file.close() templates_file.flush() os.fsync(templates_file.fileno()) templates_file.close() if collect_all: garbage_temp_file.flush() os.fsync(garbage_temp_file.fileno()) garbage_temp_file.close() garbage_times_file.flush() os.fsync(garbage_times_file.fileno()) garbage_times_file.close() if mse_error: mse_file.flush() os.fsync(mse_file.fileno()) mse_file.close() if ignore_artefacts: arte_spiketimes_file.flush() os.fsync(arte_spiketimes_file.fileno()) arte_spiketimes_file.close() arte_electrodes_file.flush() os.fsync(arte_electrodes_file.fileno()) arte_electrodes_file.close() arte_amplitudes_file.flush() os.fsync(arte_amplitudes_file.fileno()) arte_amplitudes_file.close() if debug: # Close debug files. for field_file in [ chunk_nbs_debug_file, iteration_nbs_debug_file, peak_nbs_debug_file, peak_local_time_steps_debug_file, peak_time_steps_debug_file, peak_scalar_products_debug_file, peak_solved_flags_debug_file, template_nbs_debug_file, success_flags_debug_file, ]: field_file.flush() os.fsync(field_file.fileno()) field_file.close() comm.Barrier() if SHARED_MEMORY: for memory in mpi_memory_1 + mpi_memory_2: memory.Free() if ignore_dead_times: mpi_memory_3.Free() if comm.rank == 0: io.collect_data(comm.size, params, erase=True) if ignore_artefacts: io.collect_artefacts(comm.size, params, erase=True) data_file.close()
def main(params, nb_cpu, nb_gpu, use_gpu, extension): _ = init_logging(params.logfile) logger = logging.getLogger('circus.merging') file_out_suff = params.get('data', 'file_out_suff') erase_all = params.getboolean('merging', 'erase_all') extension_in = extension extension_out = '-merged' # Erase previous results (if user agrees). if comm.rank == 0: existing_file_paths = [ file_path for file_path in [ file_out_suff + ".%s%s.hdf5" % (file_id, extension_out) for file_id in ['templates', 'clusters', 'result'] ] if os.path.isfile(file_path) ] existing_directory_path = [ directory_path for directory_path in [file_out_suff + "%s.GUI" % extension_out] if os.path.isdir(directory_path) ] if len(existing_file_paths) > 0 or len(existing_directory_path) > 0: if not erase_all: erase = query_yes_no( "Merging already done! Do you want to erase previous merging results?", default=None) else: erase = True if erase: for path in existing_file_paths: os.remove(path) if comm.rank == 0: print_and_log(["Removed file %s" % path], 'debug', logger) for path in existing_directory_path: shutil.rmtree(path) if comm.rank == 0: print_and_log(["Removed directory %s" % path], 'debug', logger) comm.Barrier() if params.getfloat('merging', 'auto_mode') == 0: from sys import platform if not platform == 'win32': if not ('DISPLAY' in os.environ and re.search(":\d", os.environ['DISPLAY']) != None): print_and_log( ['Preview mode can not be used, check DISPLAY variable'], 'error', logger) sys.exit(0) if comm.rank == 0: try: from PyQt5.QtWidgets import QApplication except ImportError: from matplotlib.backends import qt_compat use_pyside = qt_compat.QT_API == qt_compat.QT_API_PYSIDE if use_pyside: from PySide.QtGui import QApplication else: from PyQt4.QtGui import QApplication app = QApplication([]) try: pylab.style.use('ggplot') except Exception: pass else: app = None else: app = None if comm.rank == 0: print_and_log(['Launching the merging GUI...'], 'debug', logger) _ = gui.MergeWindow(params, app, extension_in, extension_out) sys.exit(app.exec_())
def main(params, nb_cpu, nb_gpu, use_gpu): _ = init_logging(params.logfile) logger = logging.getLogger('circus.gathering') io.collect_data(nb_cpu, params, erase=False)
def main(argv=None): if argv is None: argv = sys.argv[1:] header = get_colored_header() parser = argparse.ArgumentParser( description=header, formatter_class=argparse.RawTextHelpFormatter) parser.add_argument('datafile', help='data file') parser.add_argument('-e', '--extension', help='extension to consider for visualization', default='') if len(argv) == 0: parser.print_help() sys.exit() args = parser.parse_args(argv) filename = os.path.abspath(args.datafile) extension = args.extension params = CircusParser(filename) if os.path.exists(params.logfile): os.remove(params.logfile) logger = init_logging(params.logfile) logger = logging.getLogger(__name__) data_file = params.get_data_file() data_dtype = data_file.data_dtype gain = data_file.gain t_start = data_file.t_start file_format = data_file.description if file_format not in supported_by_matlab: print_and_log([ "File format %s is not supported by MATLAB. Waveforms disabled" % file_format ], 'info', logger) if numpy.iterable(gain): print_and_log( ['Multiple gains are not supported, using a default value of 1'], 'info', logger) gain = 1 file_out_suff = params.get('data', 'file_out_suff') if hasattr(data_file, 'data_offset'): data_offset = data_file.data_offset else: data_offset = 0 probe = params.probe if extension != '': extension = '-' + extension def generate_matlab_mapping(probe): p = {} positions = [] nodes = [] for key in probe['channel_groups'].keys(): p.update(probe['channel_groups'][key]['geometry']) nodes += probe['channel_groups'][key]['channels'] positions += [ p[channel] for channel in probe['channel_groups'][key]['channels'] ] idx = numpy.argsort(nodes) positions = numpy.array(positions)[idx] t = tempfile.NamedTemporaryFile().name + '.hdf5' cfile = h5py.File(t, 'w') to_write = { 'positions': positions / 10., 'permutation': numpy.sort(nodes), 'nb_total': numpy.array([probe['total_nb_channels']]) } write_datasets(cfile, to_write.keys(), to_write) cfile.close() return t mapping = generate_matlab_mapping(probe) if not params.getboolean('data', 'overwrite'): filename = params.get('data', 'data_file_no_overwrite') else: filename = params.get('data', 'data_file') apply_patch_for_similarities(params, extension) gui_file = pkg_resources.resource_filename( 'circus', os.path.join('matlab_GUI', 'SortingGUI.m')) # Change to the directory of the matlab file os.chdir(os.path.abspath(os.path.dirname(gui_file))) # Use quotation marks for string arguments if file_format not in supported_by_matlab: gui_params = [ params.rate, os.path.abspath(file_out_suff), '%s.mat' % extension, mapping, 2, t_start ] is_string = [False, True, True, True, False] else: gui_params = [ params.rate, os.path.abspath(file_out_suff), '%s.mat' % extension, mapping, 2, t_start, data_dtype, data_offset, gain, filename ] is_string = [ False, True, True, True, False, False, True, False, False, True ] arguments = ', '.join([ "'%s'" % arg if s else "%s" % arg for arg, s in zip(gui_params, is_string) ]) matlab_command = 'SortingGUI(%s)' % arguments print_and_log(["Launching the MATLAB GUI..."], 'info', logger) print_and_log([matlab_command], 'debug', logger) if params.getboolean('fitting', 'collect_all'): print_and_log([ 'You can not view the unfitted spikes with the MATLAB GUI', 'Please consider using phy if you really would like to see them' ], 'info', logger) try: sys.exit( subprocess.call( ['matlab', '-nodesktop', '-nosplash', '-r', matlab_command])) except Exception: print_and_log( ["Something wrong with MATLAB. Try circus-gui-python instead?"], 'error', logger) sys.exit(1)
def main(params, nb_cpu, nb_gpu, use_gpu): ################################################################# logger = init_logging(params.logfile) logger = logging.getLogger('circus.fitting') data_file = params.data_file data_file.open() N_e = params.getint('data', 'N_e') N_total = params.nb_channels N_t = params.getint('detection', 'N_t') template_shift = params.getint('detection', 'template_shift') file_out = params.get('data', 'file_out') file_out_suff = params.get('data', 'file_out_suff') sign_peaks = params.get('detection', 'peaks') matched_filter = params.getboolean('detection', 'matched-filter') spike_thresh = params.getfloat('detection', 'spike_thresh') do_temporal_whitening = params.getboolean('whitening', 'temporal') do_spatial_whitening = params.getboolean('whitening', 'spatial') chunk_size = params.getint('fitting', 'chunk_size') gpu_only = params.getboolean('fitting', 'gpu_only') nodes, edges = get_nodes_and_edges(params) tmp_limits = params.get('fitting', 'amp_limits').replace('(', '').replace(')', '').split(',') tmp_limits = map(float, tmp_limits) amp_auto = params.getboolean('fitting', 'amp_auto') space_explo = params.getfloat('fitting', 'space_explo') nb_chances = params.getint('fitting', 'nb_chances') max_chunk = params.getfloat('fitting', 'max_chunk') noise_thr = params.getfloat('clustering', 'noise_thr') collect_all = params.getboolean('fitting', 'collect_all') ignore_dead_times = params.getboolean('triggers', 'ignore_times') inv_nodes = numpy.zeros(N_total, dtype=numpy.int32) inv_nodes[nodes] = numpy.argsort(nodes) ################################################################# if use_gpu: import cudamat as cmt ## Need to properly handle multi GPU per MPI nodes? if nb_gpu > nb_cpu: gpu_id = int(comm.rank//nb_cpu) else: gpu_id = 0 cmt.cuda_set_device(gpu_id) cmt.init() cmt.cuda_sync_threads() if SHARED_MEMORY: templates = io.load_data_memshared(params, 'templates', normalize=True, transpose=True) N_tm, x = templates.shape else: templates = io.load_data(params, 'templates') x, N_tm = templates.shape temp_2_shift = 2*template_shift full_gpu = use_gpu and gpu_only n_tm = N_tm//2 n_scalar = N_e*N_t last_spikes = numpy.zeros((n_tm, 1), dtype=numpy.int32) temp_window = numpy.arange(-template_shift, template_shift+1) if not amp_auto: amp_limits = numpy.zeros((n_tm, 2)) amp_limits[:, 0] = tmp_limits[0] amp_limits[:, 1] = tmp_limits[1] else: amp_limits = io.load_data(params, 'limits') norm_templates = io.load_data(params, 'norm-templates') if not SHARED_MEMORY: for idx in xrange(templates.shape[1]): myslice = numpy.arange(templates.indptr[idx], templates.indptr[idx+1]) templates.data[myslice] /= norm_templates[idx] templates = templates.T if matched_filter: if sign_peaks in ['negative', 'both']: waveform_neg = io.load_data(params, 'waveform') waveform_neg /= (numpy.abs(numpy.sum(waveform_neg))* len(waveform_neg)) matched_tresholds_neg = io.load_data(params, 'matched-thresholds') if sign_peaks in ['positive', 'both']: waveform_pos = io.load_data(params, 'waveform-pos') waveform_pos /= (numpy.abs(numpy.sum(waveform_pos))* len(waveform_pos)) matched_tresholds_pos = io.load_data(params, 'matched-thresholds-pos') if ignore_dead_times: dead_times = numpy.loadtxt(params.get('triggers', 'dead_file')) if len(dead_times.shape) == 1: dead_times = dead_times.reshape(1, 2) dead_in_ms = params.getboolean('triggers', 'dead_in_ms') if dead_in_ms: dead_times *= numpy.int64(data_file.sampling_rate*1e-3) dead_times = dead_times.astype(numpy.int64) all_dead_times = [] for i in xrange(len(dead_times)): all_dead_times += range(dead_times[i, 0], dead_times[i, 1]) thresholds = io.load_data(params, 'thresholds') if collect_all: neighbors = {} for i in xrange(n_tm): tmp = templates[i, :].toarray().reshape(N_e, N_t) * norm_templates[i] neighbors[i] = numpy.where(numpy.sum(tmp, 1) != 0)[0] if use_gpu: templates = cmt.SparseCUDAMatrix(templates, copy_on_host=False) info_string = '' if comm.rank == 0: if use_gpu: info_string = "using %d GPUs" %(comm.size) else: info_string = "using %d CPUs" %(comm.size) comm.Barrier() c_overlap = io.get_overlaps(params, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu) over_shape = c_overlap.get('over_shape')[:] N_over = int(numpy.sqrt(over_shape[0])) S_over = over_shape[1] ## If the number of overlaps is different from templates, we need to recompute them if N_over != N_tm: if comm.rank == 0: print_and_log(['Templates have been modified, recomputing the overlaps...'], 'default', logger) c_overlap = io.get_overlaps(params, erase=True, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu) over_shape = c_overlap.get('over_shape')[:] N_over = int(numpy.sqrt(over_shape[0])) S_over = over_shape[1] if SHARED_MEMORY: c_overs = io.load_data_memshared(params, 'overlaps', nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu) else: c_overlap = io.get_overlaps(params, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu) over_x = c_overlap.get('over_x')[:] over_y = c_overlap.get('over_y')[:] over_data = c_overlap.get('over_data')[:] over_shape = c_overlap.get('over_shape')[:] c_overlap.close() # To be faster, we rearrange the overlaps into a dictionnary. This has a cost: twice the memory usage for # a short period of time c_overs = {} overlaps = scipy.sparse.csr_matrix((over_data, (over_x, over_y)), shape=(over_shape[0], over_shape[1])) del over_x, over_y, over_data for i in xrange(N_over): c_overs[i] = overlaps[i*N_over:(i+1)*N_over] del overlaps comm.Barrier() if comm.rank == 0: print_and_log(["Here comes the SpyKING CIRCUS %s and %d templates..." %(info_string, n_tm)], 'default', logger) purge(file_out_suff, '.data') if do_spatial_whitening: spatial_whitening = io.load_data(params, 'spatial_whitening') if do_temporal_whitening: temporal_whitening = io.load_data(params, 'temporal_whitening') if full_gpu: try: # If memory on the GPU is large enough, we load the overlaps onto it for i in xrange(N_over): c_overs[i] = cmt.SparseCUDAMatrix(c_overs[i], copy_on_host=False) except Exception: if comm.rank == 0: print_and_log(["Not enough memory on GPUs: GPUs are used for projection only"], 'info', logger) for i in xrange(N_over): if c_overs.has_key(i): del c_overs[i] full_gpu = False nb_chunks, last_chunk_len = data_file.analyze(chunk_size) processed_chunks = int(min(nb_chunks, max_chunk)) comm.Barrier() spiketimes_file = open(file_out_suff + '.spiketimes-%d.data' %comm.rank, 'wb') comm.Barrier() amplitudes_file = open(file_out_suff + '.amplitudes-%d.data' %comm.rank, 'wb') comm.Barrier() templates_file = open(file_out_suff + '.templates-%d.data' %comm.rank, 'wb') comm.Barrier() if collect_all: garbage_times_file = open(file_out_suff + '.gspiketimes-%d.data' %comm.rank, 'wb') comm.Barrier() garbage_temp_file = open(file_out_suff + '.gtemplates-%d.data' %comm.rank, 'wb') comm.Barrier() if use_gpu and do_spatial_whitening: spatial_whitening = cmt.CUDAMatrix(spatial_whitening, copy_on_host=False) last_chunk_size = 0 to_explore = xrange(comm.rank, processed_chunks, comm.size) if comm.rank == 0: to_explore = get_tqdm_progressbar(to_explore) for gcount, gidx in enumerate(to_explore): #print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..." ## We need to deal with the borders by taking chunks of size [0, chunck_size+template_shift] is_first = data_file.is_first_chunk(gidx, nb_chunks) is_last = data_file.is_last_chunk(gidx, nb_chunks) if is_last: padding = (-2*template_shift, 0) elif is_first: padding = (0, 2*template_shift) else: padding = (-2*template_shift, 2*template_shift) result = {'spiketimes' : [], 'amplitudes' : [], 'templates' : []} local_chunk, t_offset = data_file.get_data(gidx, chunk_size, padding, nodes=nodes) len_chunk = len(local_chunk) if do_spatial_whitening: if use_gpu: local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False) local_chunk = local_chunk.dot(spatial_whitening).asarray() else: local_chunk = numpy.dot(local_chunk, spatial_whitening) if do_temporal_whitening: local_chunk = scipy.ndimage.filters.convolve1d(local_chunk, temporal_whitening, axis=0, mode='constant') #print "Extracting the peaks..." if collect_all: all_found_spikes = {} for i in xrange(N_e): all_found_spikes[i] = [] local_peaktimes = numpy.zeros(0, dtype=numpy.int32) if matched_filter: if sign_peaks in ['positive', 'both']: filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_pos, axis=0, mode='constant') for i in xrange(N_e): peaktimes = algo.detect_peaks(filter_chunk[:, i], matched_tresholds_pos[i]) local_peaktimes = numpy.concatenate((local_peaktimes, peaktimes)) if collect_all: all_found_spikes[i] += peaktimes.tolist() if sign_peaks in ['negative', 'both']: filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_neg, axis=0, mode='constant') for i in xrange(N_e): peaktimes = algo.detect_peaks(filter_chunk[:, i], matched_tresholds_neg[i]) local_peaktimes = numpy.concatenate((local_peaktimes, peaktimes)) if collect_all: all_found_spikes[i] += peaktimes.tolist() else: for i in xrange(N_e): if sign_peaks == 'negative': peaktimes = algo.detect_peaks(local_chunk[:, i], thresholds[i], valley=True) elif sign_peaks == 'positive': peaktimes = algo.detect_peaks(local_chunk[:, i], thresholds[i], valley=False) elif sign_peaks == 'both': peaktimes = algo.detect_peaks(numpy.abs(local_chunk[:, i]), thresholds[i], valley=False) local_peaktimes = numpy.concatenate((local_peaktimes, peaktimes)) if collect_all: all_found_spikes[i] += peaktimes.tolist() local_peaktimes = numpy.unique(local_peaktimes) if ignore_dead_times: local_peaktimes = numpy.array(list(set(local_peaktimes + t_offset).difference(all_dead_times)), dtype=numpy.int32) - t_offset local_peaktimes = numpy.sort(local_peaktimes) #print "Removing the useless borders..." local_borders = (template_shift, len_chunk - template_shift) idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes < local_borders[1]) local_peaktimes = numpy.compress(idx, local_peaktimes) if collect_all: for i in xrange(N_e): all_found_spikes[i] = numpy.array(all_found_spikes[i], dtype=numpy.int32) if ignore_dead_times: all_found_spikes[i] = numpy.array(list(set(all_found_spikes[i] + t_offset).difference(all_dead_times)), dtype=numpy.int32) - t_offset all_found_spikes[i] = numpy.sort(all_found_spikes[i]) idx = (all_found_spikes[i] >= local_borders[0]) & (all_found_spikes[i] < local_borders[1]) all_found_spikes[i] = numpy.compress(idx, all_found_spikes[i]) n_t = len(local_peaktimes) all_indices = numpy.arange(n_t) if full_gpu: # all_indices = cmt.CUDAMatrix(all_indices) tmp_gpu = cmt.CUDAMatrix(local_peaktimes.reshape((1, n_t)), copy_on_host=False) if n_t > 0: #print "Computing the b (should full_gpu by putting all chunks on GPU if possible?)..." if collect_all: c_local_chunk = local_chunk.copy() local_chunk = local_chunk.T.ravel() sub_mat = numpy.zeros((N_e*(2*template_shift+1), n_t), dtype=numpy.float32) if len_chunk != last_chunk_size: slice_indices = numpy.zeros(0, dtype=numpy.int32) for idx in xrange(N_e): slice_indices = numpy.concatenate((slice_indices, len_chunk*idx + temp_window)) last_chunk_size = len_chunk for count, idx in enumerate(local_peaktimes): sub_mat[:, count] = numpy.take(local_chunk, slice_indices + idx) del local_chunk if use_gpu: sub_mat = cmt.CUDAMatrix(sub_mat, copy_on_host=False) b = cmt.sparse_dot(templates, sub_mat) else: b = templates.dot(sub_mat) del sub_mat local_offset = padding[0] + t_offset local_bounds = (temp_2_shift, len_chunk - temp_2_shift) all_spikes = local_peaktimes + local_offset # Because for GPU, slicing by columns is more efficient, we need to transpose b #b = b.transpose() if use_gpu and not full_gpu: b = b.asarray() failure = numpy.zeros(n_t, dtype=numpy.int32) if full_gpu: mask = numpy.zeros((2*n_tm, n_t), dtype=numpy.float32) mask[:n_tm, :] = 1 data = cmt.empty(mask.shape) patch_gpu= b.shape[1] == 1 else: mask = numpy.ones((n_tm, n_t), dtype=numpy.float32) sub_b = b[:n_tm, :] min_time = local_peaktimes.min() max_time = local_peaktimes.max() local_len = max_time - min_time + 1 min_times = numpy.maximum(local_peaktimes - min_time - temp_2_shift, 0) max_times = numpy.minimum(local_peaktimes - min_time + temp_2_shift + 1, max_time - min_time) max_n_t = int(space_explo*(max_time-min_time+1)//(2*temp_2_shift + 1)) if collect_all: c_all_times = numpy.zeros((len_chunk, N_e), dtype=numpy.bool) c_min_times = numpy.maximum(numpy.arange(len_chunk) - template_shift, 0) c_max_times = numpy.minimum(numpy.arange(len_chunk) + template_shift + 1, len_chunk) for i in xrange(N_e): c_all_times[all_found_spikes[i], i] = True while (numpy.mean(failure) < nb_chances): if full_gpu: gpu_mask = cmt.CUDAMatrix(mask, copy_on_host=False) b.mult(gpu_mask, data) tmp_mat = data.max(0) argmax_bi = numpy.argsort(tmp_mat.asarray()[0, :])[::-1] del tmp_mat else: data = sub_b * mask argmax_bi = numpy.argsort(numpy.max(data, 0))[::-1] while (len(argmax_bi) > 0): subset = [] indices = [] all_times = numpy.zeros(local_len, dtype=numpy.bool) for count, idx in enumerate(argmax_bi): myslice = all_times[min_times[idx]:max_times[idx]] if not myslice.any(): subset += [idx] indices += [count] all_times[min_times[idx]:max_times[idx]] = True if len(subset) > max_n_t: break subset = numpy.array(subset, dtype=numpy.int32) argmax_bi = numpy.delete(argmax_bi, indices) if full_gpu: b_array = b.asarray() sub_b = b_array[:n_tm, :] inds_t, inds_temp = subset, numpy.argmax(numpy.take(sub_b, subset, axis=1), 0) if full_gpu: best_amp = sub_b[inds_temp, inds_t]/n_scalar best_amp2 = b_array[inds_temp + n_tm, inds_t]/n_scalar else: best_amp = sub_b[inds_temp, inds_t]/n_scalar best_amp2 = b[inds_temp + n_tm, inds_t]/n_scalar mask[inds_temp, inds_t] = 0 best_amp_n = best_amp/numpy.take(norm_templates, inds_temp) best_amp2_n = best_amp2/numpy.take(norm_templates, inds_temp + n_tm) all_idx = ((best_amp_n >= amp_limits[inds_temp, 0]) & (best_amp_n <= amp_limits[inds_temp, 1])) to_keep = numpy.where(all_idx == True)[0] to_reject = numpy.where(all_idx == False)[0] ts = numpy.take(local_peaktimes, inds_t[to_keep]) good = (ts >= local_bounds[0]) & (ts < local_bounds[1]) # We reduce to only the good times that will be kept #to_keep = to_keep[good] #ts = ts[good] if len(ts) > 0: if full_gpu: tmp = cmt.CUDAMatrix(numpy.ones((len(ts), 1)), copy_on_host=False) tmp3 = cmt.CUDAMatrix(-ts.reshape((len(ts), 1)), copy_on_host=False) tmp = tmp.dot(tmp_gpu) tmp.add_col_vec(tmp3) condition = cmt.empty(tmp.shape) cmt.abs(tmp, condition).less_than(temp_2_shift + 1) condition = condition.asarray().astype(numpy.bool) tmp = tmp.asarray().astype(numpy.int32) else: tmp = numpy.dot(numpy.ones((len(ts), 1), dtype=numpy.int32), local_peaktimes.reshape((1, n_t))) tmp -= ts.reshape((len(ts), 1)) condition = numpy.abs(tmp) <= temp_2_shift for count, keep in enumerate(to_keep): idx_b = numpy.compress(condition[count, :], all_indices) ytmp = tmp[count, condition[count, :]] + temp_2_shift indices = numpy.zeros((S_over, len(ytmp)), dtype=numpy.float32) indices[ytmp, numpy.arange(len(ytmp))] = 1 if full_gpu: indices = cmt.CUDAMatrix(indices, copy_on_host=False) if patch_gpu: b_lines = b.get_col_slice(0, b.shape[0]) else: b_lines = b.get_col_slice(idx_b[0], idx_b[-1]+1) tmp1 = cmt.sparse_dot(c_overs[inds_temp[keep]], indices, mult=-best_amp[keep]) tmp2 = cmt.sparse_dot(c_overs[inds_temp[keep] + n_tm], indices, mult=-best_amp2[keep]) b_lines.add(tmp1.add(tmp2)) del tmp1, tmp2 else: tmp1 = c_overs[inds_temp[keep]].multiply(-best_amp[keep]).dot(indices) tmp2 = c_overs[inds_temp[keep] + n_tm].multiply(-best_amp2[keep]).dot(indices) b[:, idx_b] += tmp1 + tmp2 if good[count]: t_spike = ts[count] + local_offset result['spiketimes'] += [t_spike] result['amplitudes'] += [(best_amp_n[keep], best_amp2_n[keep])] result['templates'] += [inds_temp[keep]] myslice = numpy.take(inds_t, to_reject) failure[myslice] += 1 sub_idx = (numpy.take(failure, myslice) >= nb_chances) mask[:, numpy.compress(sub_idx, myslice)] = 0 spikes_to_write = numpy.array(result['spiketimes'], dtype=numpy.uint32) amplitudes_to_write = numpy.array(result['amplitudes'], dtype=numpy.float32) templates_to_write = numpy.array(result['templates'], dtype=numpy.int32) spiketimes_file.write(spikes_to_write.tostring()) amplitudes_file.write(amplitudes_to_write.tostring()) templates_file.write(templates_to_write.tostring()) if collect_all: for temp, spike in zip(templates_to_write, spikes_to_write - local_offset): c_all_times[c_min_times[spike]:c_max_times[spike], neighbors[temp]] = False gspikes = numpy.where(numpy.sum(c_all_times, 1) > 0)[0] c_all_times = numpy.take(c_all_times, gspikes, axis=0) c_local_chunk = numpy.take(c_local_chunk, gspikes, axis=0) * c_all_times if sign_peaks == 'negative': bestlecs = numpy.argmin(c_local_chunk, 1) if matched_filter: threshs = -matched_tresholds_neg[bestlecs] else: threshs = -thresholds[bestlecs] idx = numpy.where(numpy.min(c_local_chunk, 1) < threshs)[0] elif sign_peaks == 'positive': bestlecs = numpy.argmax(c_local_chunk, 1) if matched_filter: threshs = matched_tresholds_pos[bestlecs] else: threshs = thresholds[bestlecs] idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0] elif sign_peaks == 'both': c_local_chunk = numpy.abs(c_local_chunk) bestlecs = numpy.argmax(c_local_chunk, 1) if matched_filter: threshs = numpy.minimum(matched_tresholds_neg[bestlecs], matched_tresholds_pos[bestlecs]) else: threshs = thresholds[bestlecs] idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0] gspikes = numpy.take(gspikes, idx) bestlecs = numpy.take(bestlecs, idx) gspikes_to_write = numpy.array(gspikes + local_offset, dtype=numpy.uint32) gtemplates_to_write = numpy.array(bestlecs, dtype=numpy.int32) garbage_times_file.write(gspikes_to_write.tostring()) garbage_temp_file.write(gtemplates_to_write.tostring()) if full_gpu: del gpu_mask, b, data spiketimes_file.flush() os.fsync(spiketimes_file.fileno()) spiketimes_file.close() amplitudes_file.flush() os.fsync(amplitudes_file.fileno()) amplitudes_file.close() templates_file.flush() os.fsync(templates_file.fileno()) templates_file.close() if collect_all: garbage_temp_file.flush() os.fsync(garbage_temp_file.fileno()) garbage_temp_file.close() garbage_times_file.flush() os.fsync(garbage_times_file.fileno()) garbage_times_file.close() comm.Barrier() if comm.rank == 0: io.collect_data(comm.size, params, erase=True) data_file.close()
def main(params, nb_cpu, nb_gpu, use_gpu): ################################################################# # params = detect_memory(params) _ = init_logging(params.logfile) SHARED_MEMORY = get_shared_memory_flag(params) logger = logging.getLogger('circus.fitting') data_file = params.data_file n_e = params.getint('data', 'N_e') n_total = params.nb_channels n_t = params.getint('detection', 'N_t') template_shift = params.getint('detection', 'template_shift') # file_out = params.get('data', 'file_out') file_out_suff = params.get('data', 'file_out_suff') sign_peaks = params.get('detection', 'peaks') matched_filter = params.getboolean('detection', 'matched-filter') # spike_thresh = params.getfloat('detection', 'spike_thresh') ratio_thresh = params.getfloat('fitting', 'ratio_thresh') two_components = params.getboolean('fitting', 'two_components') # spike_width = params.getfloat('detection', 'spike_width') # dist_peaks = params.getint('detection', 'dist_peaks') do_temporal_whitening = params.getboolean('whitening', 'temporal') do_spatial_whitening = params.getboolean('whitening', 'spatial') templates_normalization = params.getboolean('clustering', 'templates_normalization') # TODO test, switch, test! chunk_size = detect_memory(params, fitting=True) gpu_only = params.getboolean('fitting', 'gpu_only') nodes, edges = get_nodes_and_edges(params) tmp_limits = params.get('fitting', 'amp_limits').replace('(', '').replace(')', '').split(',') tmp_limits = [float(v) for v in tmp_limits] amp_auto = params.getboolean('fitting', 'amp_auto') auto_nb_chances = params.getboolean('fitting', 'auto_nb_chances') if auto_nb_chances: nb_chances = io.load_data(params, 'nb_chances') max_nb_chances = params.getint('fitting', 'max_nb_chances') percent_nb_chances = params.getfloat('fitting', 'percent_nb_chances') total_nb_chances = max(1, numpy.nanpercentile(nb_chances, percent_nb_chances)) total_nb_chances = min(total_nb_chances, max_nb_chances) if comm.rank == 0: print_and_log(['nb_chances set automatically to %g' %total_nb_chances], 'debug', logger) else: total_nb_chances = params.getfloat('fitting', 'nb_chances') max_chunk = params.getfloat('fitting', 'max_chunk') # noise_thr = params.getfloat('clustering', 'noise_thr') collect_all = params.getboolean('fitting', 'collect_all') min_second_component = params.getfloat('fitting', 'min_second_component') debug = params.getboolean('fitting', 'debug') ignore_dead_times = params.getboolean('triggers', 'ignore_times') inv_nodes = numpy.zeros(n_total, dtype=numpy.int32) inv_nodes[nodes] = numpy.arange(len(nodes)) data_file.open() ################################################################# if use_gpu: import cudamat as cmt # # Need to properly handle multi GPU per MPI nodes? if nb_gpu > nb_cpu: gpu_id = int(comm.rank // nb_cpu) else: gpu_id = 0 cmt.cuda_set_device(gpu_id) cmt.init() cmt.cuda_sync_threads() if SHARED_MEMORY: templates, _ = io.load_data_memshared(params, 'templates', normalize=templates_normalization, transpose=True) N_tm, x = templates.shape else: templates = io.load_data(params, 'templates') x, N_tm = templates.shape temp_2_shift = 2 * template_shift temp_3_shift = 3 * template_shift full_gpu = use_gpu and gpu_only n_tm = N_tm // 2 n_scalar = n_e * n_t temp_window = numpy.arange(-template_shift, template_shift + 1) size_window = n_e * (2 * template_shift + 1) if not amp_auto: amp_limits = numpy.zeros((n_tm, 2)) amp_limits[:, 0] = tmp_limits[0] amp_limits[:, 1] = tmp_limits[1] else: amp_limits = io.load_data(params, 'limits') norm_templates = io.load_data(params, 'norm-templates') if not templates_normalization: norm_templates_2 = (norm_templates ** 2.0) * n_scalar if not SHARED_MEMORY: # Normalize templates (if necessary). if templates_normalization: for idx in range(templates.shape[1]): myslice = numpy.arange(templates.indptr[idx], templates.indptr[idx+1]) templates.data[myslice] /= norm_templates[idx] # Transpose templates. templates = templates.T waveform_neg = numpy.empty(0) # default assignment (for PyCharm code inspection) matched_thresholds_neg = None # default assignment (for PyCharm code inspection) waveform_pos = numpy.empty(0) # default assignment (for PyCharm code inspection) matched_thresholds_pos = None # default assignment (for PyCharm code inspection) if matched_filter: if sign_peaks in ['negative', 'both']: waveform_neg = io.load_data(params, 'waveform')[::-1] waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) * len(waveform_neg)) matched_thresholds_neg = io.load_data(params, 'matched-thresholds') if sign_peaks in ['positive', 'both']: waveform_pos = io.load_data(params, 'waveform-pos')[::-1] waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) * len(waveform_pos)) matched_thresholds_pos = io.load_data(params, 'matched-thresholds-pos') if ignore_dead_times: all_dead_times = get_dead_times(params) else: all_dead_times = None # default assignment (for PyCharm code inspection) thresholds = io.get_accurate_thresholds(params, ratio_thresh) neighbors = {} if collect_all: for i in range(0, n_tm): tmp = templates[i, :].toarray().reshape(n_e, n_t) if templates_normalization: tmp = tmp * norm_templates[i] neighbors[i] = numpy.where(numpy.sum(tmp, axis=1) != 0.0)[0] if use_gpu: templates = cmt.SparseCUDAMatrix(templates, copy_on_host=False) info_string = '' if comm.rank == 0: if use_gpu: info_string = "using %d GPUs" % comm.size else: info_string = "using %d CPUs" % comm.size comm.Barrier() c_overlap = io.get_overlaps(params, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu) over_shape = c_overlap.get('over_shape')[:] n_over = int(numpy.sqrt(over_shape[0])) s_over = over_shape[1] # # If the number of overlaps is different from templates, we need to recompute them. if n_over != N_tm: if comm.rank == 0: print_and_log(['Templates have been modified, recomputing the overlaps...'], 'default', logger) c_overlap = io.get_overlaps(params, erase=True, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu) over_shape = c_overlap.get('over_shape')[:] n_over = int(numpy.sqrt(over_shape[0])) s_over = over_shape[1] if SHARED_MEMORY: c_overs, _ = io.load_data_memshared(params, 'overlaps') else: c_overs = io.load_data(params, 'overlaps') comm.Barrier() if n_tm == 0: if comm.rank == 0: print_and_log(["No templates present. Redo clustering?"], 'default', logger) sys.exit(0) if comm.rank == 0: print_and_log(["Here comes the SpyKING CIRCUS %s and %d templates..." % (info_string, n_tm)], 'default', logger) purge(file_out_suff, '.data') if do_spatial_whitening: spatial_whitening = io.load_data(params, 'spatial_whitening') else: spatial_whitening = None # default assignment (for PyCharm code inspection) if do_temporal_whitening: temporal_whitening = io.load_data(params, 'temporal_whitening') else: temporal_whitening = None # default assignment (for PyCharm code inspection) if full_gpu: try: # If memory on the GPU is large enough, we load the overlaps onto it for i in range(n_over): c_overs[i] = cmt.SparseCUDAMatrix(c_overs[i], copy_on_host=False) except Exception: if comm.rank == 0: print_and_log(["Not enough memory on GPUs: GPUs are used for projection only"], 'info', logger) for i in range(n_over): if i in c_overs: del c_overs[i] full_gpu = False nb_chunks, last_chunk_len = data_file.analyze(chunk_size) processed_chunks = int(min(nb_chunks, max_chunk)) comm.Barrier() spiketimes_file = open(file_out_suff + '.spiketimes-%d.data' % comm.rank, 'wb') comm.Barrier() amplitudes_file = open(file_out_suff + '.amplitudes-%d.data' % comm.rank, 'wb') comm.Barrier() templates_file = open(file_out_suff + '.templates-%d.data' % comm.rank, 'wb') comm.Barrier() if collect_all: garbage_times_file = open(file_out_suff + '.gspiketimes-%d.data' % comm.rank, 'wb') comm.Barrier() garbage_temp_file = open(file_out_suff + '.gtemplates-%d.data' % comm.rank, 'wb') comm.Barrier() else: garbage_times_file = None # default assignment (for PyCharm code inspection) garbage_temp_file = None # default assignment (for PyCharm code inspection) if debug: # Open debug files. chunk_nbs_debug_file = open(file_out_suff + '.chunk_nbs_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() iteration_nbs_debug_file = open(file_out_suff + '.iteration_nbs_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() peak_nbs_debug_file = open(file_out_suff + '.peak_nbs_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() peak_local_time_steps_debug_file = open( file_out_suff + '.peak_local_time_steps_debug_%d.data' % comm.rank, mode='wb' ) comm.Barrier() peak_time_steps_debug_file = open(file_out_suff + '.peak_time_steps_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() peak_scalar_products_debug_file = open( file_out_suff + '.peak_scalar_products_debug_%d.data' % comm.rank, mode='wb' ) comm.Barrier() peak_solved_flags_debug_file = open(file_out_suff + '.peak_solved_flags_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() template_nbs_debug_file = open(file_out_suff + '.template_nbs_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() success_flags_debug_file = open(file_out_suff + '.success_flags_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() else: chunk_nbs_debug_file = None # default assignment (for PyCharm code inspection) iteration_nbs_debug_file = None # default assignment (for PyCharm code inspection) peak_nbs_debug_file = None # default assignment (for PyCharm code inspection) peak_local_time_steps_debug_file = None # default assignment (for PyCharm code inspection) peak_time_steps_debug_file = None # default assignment (for PyCharm code inspection) peak_scalar_products_debug_file = None # default assignment (for PyCharm code inspection) peak_solved_flags_debug_file = None # default assignment (for PyCharm code inspection) template_nbs_debug_file = None # default assignment (for PyCharm code inspection) success_flags_debug_file = None # default assignment (for PyCharm code inspection) if use_gpu and do_spatial_whitening: spatial_whitening = cmt.CUDAMatrix(spatial_whitening, copy_on_host=False) last_chunk_size = 0 slice_indices = numpy.zeros(0, dtype=numpy.int32) to_explore = range(comm.rank, processed_chunks, comm.size) if comm.rank == 0: to_explore = get_tqdm_progressbar(params, to_explore) for gcount, gidx in enumerate(to_explore): # print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..." # # We need to deal with the borders by taking chunks of size [0, chunck_size + template_shift]. is_first = data_file.is_first_chunk(gidx, nb_chunks) is_last = data_file.is_last_chunk(gidx, nb_chunks) if not (is_first and is_last): if is_last: padding = (-temp_3_shift, 0) elif is_first: padding = (0, temp_3_shift) else: padding = (-temp_3_shift, temp_3_shift) else: padding = (0, 0) result = { 'spiketimes': [], 'amplitudes': [], 'templates': [], } result_debug = { 'chunk_nbs': [], 'iteration_nbs': [], 'peak_nbs': [], 'peak_local_time_steps': [], 'peak_time_steps': [], 'peak_scalar_products': [], 'peak_solved_flags': [], 'template_nbs': [], 'success_flags': [], } local_chunk, t_offset = data_file.get_data(gidx, chunk_size, padding, nodes=nodes) len_chunk = len(local_chunk) if do_spatial_whitening: if use_gpu: local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False) local_chunk = local_chunk.dot(spatial_whitening).asarray() else: local_chunk = numpy.dot(local_chunk, spatial_whitening) if do_temporal_whitening: local_chunk = scipy.ndimage.filters.convolve1d(local_chunk, temporal_whitening, axis=0, mode='constant') # Extracting peaks. all_found_spikes = {} if collect_all: for i in range(n_e): all_found_spikes[i] = [] local_peaktimes = [numpy.empty(0, dtype=numpy.uint32)] if matched_filter: if sign_peaks in ['positive', 'both']: filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_pos, axis=0, mode='constant') for i in range(n_e): peaktimes = scipy.signal.find_peaks(filter_chunk[:, i], height=matched_thresholds_pos[i])[0] local_peaktimes.append(peaktimes) if collect_all: all_found_spikes[i] += peaktimes.tolist() if sign_peaks in ['negative', 'both']: filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_neg, axis=0, mode='constant') for i in range(n_e): peaktimes = scipy.signal.find_peaks(filter_chunk[:, i], height=matched_thresholds_neg[i])[0] local_peaktimes.append(peaktimes) if collect_all: all_found_spikes[i] += peaktimes.tolist() local_peaktimes = numpy.concatenate(local_peaktimes) else: for i in range(n_e): if sign_peaks == 'negative': peaktimes = scipy.signal.find_peaks(-local_chunk[:, i], height=thresholds[i])[0] elif sign_peaks == 'positive': peaktimes = scipy.signal.find_peaks(local_chunk[:, i], height=thresholds[i])[0] elif sign_peaks == 'both': peaktimes = scipy.signal.find_peaks(numpy.abs(local_chunk[:, i]), height=thresholds[i])[0] else: raise ValueError("Unexpected value %s" % sign_peaks) local_peaktimes.append(peaktimes) if collect_all: all_found_spikes[i] += peaktimes.tolist() local_peaktimes = numpy.concatenate(local_peaktimes) local_peaktimes = numpy.unique(local_peaktimes) g_offset = t_offset + padding[0] if ignore_dead_times: dead_indices = numpy.searchsorted(all_dead_times, [t_offset, t_offset + chunk_size]) if dead_indices[0] != dead_indices[1]: is_included = numpy.in1d(local_peaktimes + g_offset, all_dead_times[dead_indices[0]:dead_indices[1]]) local_peaktimes = local_peaktimes[~is_included] local_peaktimes = numpy.sort(local_peaktimes) else: dead_indices = None # default assignment (for PyCharm code inspection) # print "Removing the useless borders..." local_borders = (template_shift, len_chunk - template_shift) idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes < local_borders[1]) local_peaktimes = numpy.compress(idx, local_peaktimes) if collect_all: for i in range(n_e): all_found_spikes[i] = numpy.array(all_found_spikes[i], dtype=numpy.uint32) if ignore_dead_times: if dead_indices[0] != dead_indices[1]: is_included = numpy.in1d( all_found_spikes[i] + g_offset, all_dead_times[dead_indices[0]:dead_indices[1]] ) all_found_spikes[i] = all_found_spikes[i][~is_included] all_found_spikes[i] = numpy.sort(all_found_spikes[i]) idx = (all_found_spikes[i] >= local_borders[0]) & (all_found_spikes[i] < local_borders[1]) all_found_spikes[i] = numpy.compress(idx, all_found_spikes[i]) nb_local_peak_times = len(local_peaktimes) if full_gpu: # all_indices = cmt.CUDAMatrix(all_indices) # tmp_gpu = cmt.CUDAMatrix(local_peaktimes.reshape((1, nb_local_peak_times)), copy_on_host=False) _ = cmt.CUDAMatrix(local_peaktimes.reshape((1, nb_local_peak_times)), copy_on_host=False) if nb_local_peak_times > 0: # print "Computing the b (should full_gpu by putting all chunks on GPU if possible?)..." if collect_all: c_local_chunk = local_chunk.copy() else: c_local_chunk = None # default assignment (for PyCharm code inspection) sub_mat = local_chunk[local_peaktimes[:, None] + temp_window] sub_mat = sub_mat.transpose(2, 1, 0).reshape(size_window, nb_local_peak_times) del local_chunk if use_gpu: sub_mat = cmt.CUDAMatrix(sub_mat, copy_on_host=False) b = cmt.sparse_dot(templates, sub_mat) else: b = templates.dot(sub_mat) del sub_mat local_restriction = (t_offset, t_offset + chunk_size) all_spikes = local_peaktimes + g_offset # Because for GPU, slicing by columns is more efficient, we need to transpose b # b = b.transpose() if use_gpu and not full_gpu: b = b.asarray() failure = numpy.zeros(nb_local_peak_times, dtype=numpy.int32) if full_gpu: mask = numpy.zeros((2 * n_tm, nb_local_peak_times), dtype=numpy.float32) mask[:n_tm, :] = 1 # data = cmt.empty(mask.shape) _ = cmt.empty(mask.shape) patch_gpu = b.shape[1] == 1 else: patch_gpu = None if collect_all: c_all_times = numpy.zeros((len_chunk, n_e), dtype=numpy.bool) c_min_times = numpy.maximum(numpy.arange(len_chunk) - template_shift, 0) c_max_times = numpy.minimum(numpy.arange(len_chunk) + template_shift + 1, len_chunk) for i in range(n_e): c_all_times[all_found_spikes[i], i] = True else: c_all_times = None # default assignment (for PyCharm code inspection) c_min_times = None # default assignment (for PyCharm code inspection) c_max_times = None # default assignment (for PyCharm code inspection) iteration_nb = 0 local_max = 0 numerous_argmax = False nb_argmax = n_tm best_indices = numpy.zeros(0, dtype=numpy.int32) data = b[:n_tm, :] flatten_data = data.ravel() while numpy.mean(failure) < total_nb_chances: # Is there a way to update sub_b * mask at the same time? if full_gpu: b_array = b.asarray() else: b_array = None if numerous_argmax: if len(best_indices) == 0: best_indices = largest_indices(flatten_data, nb_argmax) best_template_index, peak_index = numpy.unravel_index(best_indices[0], data.shape) else: best_template_index, peak_index = numpy.unravel_index(data.argmax(), data.shape) peak_scalar_product = data[best_template_index, peak_index] best_template2_index = best_template_index + n_tm if templates_normalization: if full_gpu: best_amp = b_array[best_template_index, peak_index] / n_scalar best_amp2 = b_array[best_template2_index, peak_index] / n_scalar else: best_amp = b[best_template_index, peak_index] / n_scalar if two_components: best_amp2 = b[best_template2_index, peak_index] / n_scalar else: best_amp2 = 0.0 best_amp_n = best_amp / norm_templates[best_template_index] best_amp2_n = best_amp2 / norm_templates[best_template2_index] else: if full_gpu: best_amp = b_array[best_template_index, peak_index] best_amp = best_amp / norm_templates_2[best_template_index] # TODO is `best_amp` value correct? best_amp2 = b_array[best_template2_index, peak_index] best_amp2 = best_amp2 / norm_templates_2[best_template2_index] # TODO is `best_amp2` value correct? else: best_amp = b[best_template_index, peak_index] best_amp = best_amp / norm_templates_2[best_template_index] # TODO is `best_amp` value correct? if two_components: best_amp2 = b[best_template2_index, peak_index] best_amp2 = best_amp2 / norm_templates_2[best_template2_index] # TODO is `best_amp2` value correct? else: best_amp2 = 0.0 best_amp_n = best_amp best_amp2_n = best_amp2 # Verify amplitude constraint. a_min, a_max = amp_limits[best_template_index, :] if (a_min <= best_amp_n) & (best_amp_n <= a_max): # Keep the matching. peak_time_step = local_peaktimes[peak_index] peak_data = (local_peaktimes - peak_time_step).astype(np.int32) is_neighbor = np.where(np.abs(peak_data) <= temp_2_shift)[0] idx_neighbor = peak_data[is_neighbor] + temp_2_shift nb_neighbors = len(is_neighbor) indices = np.zeros((s_over, nb_neighbors), dtype=np.int32) indices[idx_neighbor, np.arange(nb_neighbors)] = 1 if full_gpu: indices = cmt.CUDAMatrix(indices, copy_on_host=False) if patch_gpu: b_lines = b.get_col_slice(0, b.shape[0]) else: b_lines = b.get_col_slice(is_neighbor[0], is_neighbor[-1]+1) tmp1 = cmt.sparse_dot(c_overs[best_template_index], indices, mult=-best_amp) tmp2 = cmt.sparse_dot(c_overs[best_template2_index], indices, mult=-best_amp2) b_lines.add(tmp1.add(tmp2)) del tmp1, tmp2 else: tmp1 = c_overs[best_template_index].multiply(-best_amp) if numpy.abs(best_amp2) > min_second_component: tmp1 += c_overs[best_template2_index].multiply(-best_amp2) b[:, is_neighbor] += tmp1.dot(indices) numerous_argmax = False # Add matching to the result. t_spike = all_spikes[peak_index] if (t_spike >= local_restriction[0]) and (t_spike < local_restriction[1]): result['spiketimes'] += [t_spike] result['amplitudes'] += [(best_amp_n, best_amp2_n)] result['templates'] += [best_template_index] # Mark current matching as tried. b[best_template_index, peak_index] = -numpy.inf # Save debug data. if debug: result_debug['chunk_nbs'] += [gidx] result_debug['iteration_nbs'] += [iteration_nb] result_debug['peak_nbs'] += [peak_index] result_debug['peak_local_time_steps'] += [local_peaktimes[peak_index]] result_debug['peak_time_steps'] += [all_spikes[peak_index]] result_debug['peak_scalar_products'] += [peak_scalar_product] result_debug['peak_solved_flags'] += [b[best_template_index, peak_index]] result_debug['template_nbs'] += [best_template_index] result_debug['success_flags'] += [True] else: # Reject the matching. numerous_argmax = True # Update failure counter of the peak. failure[peak_index] += 1 # If the maximal number of failures is reached then mark peak as solved (i.e. not fitted). if failure[peak_index] >= total_nb_chances: # Mark all the matching associated to the current peak as tried. b[:, peak_index] = -numpy.inf index = numpy.arange(n_tm) * nb_local_peak_times + peak_index else: # Mark current matching as tried. b[best_template_index, peak_index] = -numpy.inf index = best_template_index * nb_local_peak_times + peak_index if numerous_argmax: best_indices = best_indices[~numpy.in1d(best_indices, index)] # Save debug data. if debug: result_debug['chunk_nbs'] += [gidx] result_debug['iteration_nbs'] += [iteration_nb] result_debug['peak_nbs'] += [peak_index] result_debug['peak_local_time_steps'] += [local_peaktimes[peak_index]] result_debug['peak_time_steps'] += [all_spikes[peak_index]] result_debug['peak_scalar_products'] += [peak_scalar_product] result_debug['peak_solved_flags'] += [b[best_template_index, peak_index]] result_debug['template_nbs'] += [best_template_index] result_debug['success_flags'] += [False] iteration_nb += 1 spikes_to_write = numpy.array(result['spiketimes'], dtype=numpy.uint32) amplitudes_to_write = numpy.array(result['amplitudes'], dtype=numpy.float32) templates_to_write = numpy.array(result['templates'], dtype=numpy.uint32) spiketimes_file.write(spikes_to_write.tostring()) amplitudes_file.write(amplitudes_to_write.tostring()) templates_file.write(templates_to_write.tostring()) if collect_all: for temp, spike in zip(templates_to_write, spikes_to_write - g_offset): c_all_times[c_min_times[spike]:c_max_times[spike], neighbors[temp]] = False gspikes = numpy.where(numpy.sum(c_all_times, 1) > 0)[0] c_all_times = numpy.take(c_all_times, gspikes, axis=0) c_local_chunk = numpy.take(c_local_chunk, gspikes, axis=0) * c_all_times if sign_peaks == 'negative': bestlecs = numpy.argmin(c_local_chunk, 1) if matched_filter: threshs = -matched_thresholds_neg[bestlecs] else: threshs = -thresholds[bestlecs] idx = numpy.where(numpy.min(c_local_chunk, 1) < threshs)[0] elif sign_peaks == 'positive': bestlecs = numpy.argmax(c_local_chunk, 1) if matched_filter: threshs = matched_thresholds_pos[bestlecs] else: threshs = thresholds[bestlecs] idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0] elif sign_peaks == 'both': c_local_chunk = numpy.abs(c_local_chunk) bestlecs = numpy.argmax(c_local_chunk, 1) if matched_filter: threshs = numpy.minimum(matched_thresholds_neg[bestlecs], matched_thresholds_pos[bestlecs]) else: threshs = thresholds[bestlecs] idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0] else: raise ValueError("Unexpected value %s" % sign_peaks) gspikes = numpy.take(gspikes, idx) bestlecs = numpy.take(bestlecs, idx) gspikes_to_write = numpy.array(gspikes + g_offset, dtype=numpy.uint32) gtemplates_to_write = numpy.array(bestlecs, dtype=numpy.uint32) garbage_times_file.write(gspikes_to_write.tostring()) garbage_temp_file.write(gtemplates_to_write.tostring()) if debug: # Write debug data to debug files. for field_label, field_dtype, field_file in [ ('chunk_nbs', numpy.uint32, chunk_nbs_debug_file), ('iteration_nbs', numpy.uint32, iteration_nbs_debug_file), ('peak_nbs', numpy.uint32, peak_nbs_debug_file), ('peak_local_time_steps', numpy.uint32, peak_local_time_steps_debug_file), ('peak_time_steps', numpy.uint32, peak_time_steps_debug_file), ('peak_scalar_products', numpy.float32, peak_scalar_products_debug_file), ('peak_solved_flags', numpy.float32, peak_solved_flags_debug_file), ('template_nbs', numpy.uint32, template_nbs_debug_file), ('success_flags', numpy.bool, success_flags_debug_file), ]: field_to_write = numpy.array(result_debug[field_label], dtype=field_dtype) field_file.write(field_to_write.tostring()) if full_gpu: del b, data sys.stderr.flush() spiketimes_file.flush() os.fsync(spiketimes_file.fileno()) spiketimes_file.close() amplitudes_file.flush() os.fsync(amplitudes_file.fileno()) amplitudes_file.close() templates_file.flush() os.fsync(templates_file.fileno()) templates_file.close() if collect_all: garbage_temp_file.flush() os.fsync(garbage_temp_file.fileno()) garbage_temp_file.close() garbage_times_file.flush() os.fsync(garbage_times_file.fileno()) garbage_times_file.close() if debug: # Close debug files. for field_file in [ chunk_nbs_debug_file, iteration_nbs_debug_file, peak_nbs_debug_file, peak_local_time_steps_debug_file, peak_time_steps_debug_file, peak_scalar_products_debug_file, peak_solved_flags_debug_file, template_nbs_debug_file, success_flags_debug_file, ]: field_file.flush() os.fsync(field_file.fileno()) field_file.close() comm.Barrier() if comm.rank == 0: io.collect_data(comm.size, params, erase=True) data_file.close()
def main(params, nb_cpu, nb_gpu, use_gpu): numpy.random.seed(426236) #params = detect_memory(params) parallel_hdf5 = get_parallel_hdf5_flag(params) logger = init_logging(params.logfile) logger = logging.getLogger('circus.extracting') ################################################################# data_file = params.data_file N_e = params.getint('data', 'N_e') N_t = params.getint('detecton', 'N_t') N_total = params.nb_channels template_shift = params.getint('detection', 'template_shift') chunk_size = params.getint('data', 'chunk_size') file_out = params.get('data', 'file_out') file_out_suff = params.get('data', 'file_out_suff') do_temporal_whitening = params.getboolean('whitening', 'temporal') do_spatial_whitening = params.getboolean('whitening', 'spatial') nodes, edges = get_nodes_and_edges(params) safety_time = params.getint('extracting', 'safety_time') max_elts_temp = params.getint('extracting', 'max_elts') output_dim = params.getfloat('extracting', 'output_dim') noise_thr = params.getfloat('extracting', 'noise_thr') hdf5_compress = params.getboolean('data', 'hdf5_compress') blosc_compress = params.getboolean('data', 'blosc_compress') tmp_limits = params.get('fitting', 'amp_limits').replace('(', '').replace(')', '').split(',') amp_limits = map(float, tmp_limits) elt_count = 0 inv_nodes = numpy.zeros(N_total, dtype=numpy.int32) inv_nodes[nodes] = numpy.argsort(nodes) ################################################################# if comm.rank == 0: print_and_log(["Extracting templates from already found clusters..."], 'default', logger) thresholds = io.load_data(params, 'thresholds') basis_proj, basis_rec = io.load_data(params, 'basis') clusters, spiketimes, N_clusters = io.load_data(params, 'spike-cluster') inv_clusters = numpy.zeros(clusters.max() + 1, dtype=numpy.int32) inv_clusters[numpy.unique(clusters)] = numpy.argsort( numpy.unique(clusters)) if use_gpu: import cudamat as cmt ## Need to properly handle multi GPU per MPI nodes? if nb_gpu > nb_cpu: gpu_id = int(comm.rank // nb_cpu) else: gpu_id = 0 cmt.cuda_set_device(gpu_id) cmt.init() cmt.cuda_sync_threads() if do_spatial_whitening: spatial_whitening = io.load_data(params, 'spatial_whitening') if do_temporal_whitening: temporal_whitening = io.load_data(params, 'temporal_whitening') if use_gpu and do_spatial_whitening: spatial_whitening = cmt.CUDAMatrix(spatial_whitening, copy_on_host=False) result = {} for i in xrange(N_clusters): result['data_tmp_' + str(i)] = numpy.zeros( (0, N_e * basis_proj.shape[1]), dtype=numpy.float32) result['times_' + str(i)] = numpy.zeros(0, dtype=numpy.int32) nb_chunks, last_chunk_len = data_file.analyze(chunk_size) # I guess this is more relevant, to take signals from all over the recordings all_chunks = numpy.random.permutation(numpy.arange(nb_chunks)) nb_templates = numpy.sum( comm.rank == numpy.mod(numpy.arange(N_clusters), comm.size)) nb_elts = max_elts_temp * nb_templates to_explore = all_chunks if comm.rank == 0: to_explore = get_tqdm_progressbar(to_explore) for gidx in all_chunks: if (elt_count < nb_elts): #print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..." local_chunk, t_offset = data_file.get_data(gidx, chunk_size, nodes=nodes) local_shape = len(local_chunk) if do_spatial_whitening: if use_gpu: local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False) local_chunk = local_chunk.dot(spatial_whitening).asarray() else: local_chunk = numpy.dot(local_chunk, spatial_whitening) if do_temporal_whitening: local_chunk = scipy.ndimage.filters.convolve1d( local_chunk, temporal_whitening, axis=0, mode='constant') #print "Extracting the peaks..." idx = numpy.where((spiketimes >= gidx * chunk_size) & (spiketimes < (gidx + 1) * chunk_size))[0] local_offset = t_offset local_peaktimes = spiketimes[idx] - local_offset #print "Removing the useless borders..." local_borders = (template_shift, chunk_size - template_shift) idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes < local_borders[1]) local_peaktimes = local_peaktimes[idx] local_clusters = inv_clusters[clusters[idx]] if len(local_peaktimes) > 0: all_times = numpy.zeros( (N_e, local_peaktimes[-1] - local_peaktimes[0] + 1), dtype=numpy.bool) min_times = numpy.maximum( local_peaktimes - local_peaktimes[0] - safety_time, 0) max_times = numpy.minimum( local_peaktimes - local_peaktimes[0] + safety_time + 1, local_peaktimes[-1] - local_peaktimes[0]) n_times = len(local_peaktimes) argmax_peak = numpy.random.permutation(numpy.arange(n_times)) clusters_id = local_clusters[argmax_peak] local_peaktimes = local_peaktimes[argmax_peak] #print "Selection of the peaks with spatio-temporal masks..." for idx in xrange(len(local_peaktimes)): if elt_count == nb_elts: break temp = clusters_id[idx] if numpy.mod(temp, comm.size) == comm.rank: elec = numpy.argmin(local_chunk[local_peaktimes[idx]]) indices = inv_nodes[edges[nodes[elec]]] myslice = all_times[indices, min_times[idx]:max_times[idx]] peak = local_peaktimes[idx] if not myslice.any(): if (len(result['data_tmp_' + str(temp)]) < max_elts_temp): elt_count += 1 sub_mat = local_chunk[peak - template_shift:peak + template_shift + 1, :] sub_mat = numpy.dot(basis_rec, sub_mat) nx, ny = sub_mat.shape sub_mat = sub_mat.reshape((1, nx * ny)) result['data_tmp_' + str(temp)] = numpy.vstack( (result['data_tmp_' + str(temp)], sub_mat)) to_add = numpy.array([peak + local_offset], dtype=numpy.int32) result['times_' + str(temp)] = numpy.concatenate( (result['times_' + str(temp)], to_add)) all_times[indices, min_times[idx]:max_times[idx]] = True total_nb_elts = 0 for temp in xrange(N_clusters): total_nb_elts += len(result['data_tmp_' + str(temp)]) gdata = gather_array(numpy.array([total_nb_elts], dtype=numpy.float32), comm, 0) if comm.rank == 0: print_and_log([ "Found %d spikes over %d requested" % (int(numpy.sum(gdata)), int(nb_elts)) ], 'default', logger) #print "Spikes extracted in", time.time() - t_start, "s" comm.Barrier() local_nb_clusters = 0 for temp in xrange(comm.rank, N_clusters, comm.size): if len(result['data_tmp_' + str(temp)]) > 0: local_nb_clusters += 1 #print total_nb_clusters, "found in", time.time() - t_start, "s" gdata3 = gather_array( numpy.array([local_nb_clusters], dtype=numpy.float32), comm, 0) comm.Barrier() if comm.rank == 0: print_and_log(["Extracting the templates..."], 'default', logger) total_nb_clusters = int( comm.bcast(numpy.array([int(numpy.sum(gdata3))], dtype=numpy.int32), root=0)[0]) offsets = numpy.zeros(comm.size, dtype=numpy.int32) for i in xrange(comm.size - 1): offsets[i + 1] = comm.bcast(numpy.array([local_nb_clusters], dtype=numpy.int32), root=i) if parallel_hdf5: node_pad = numpy.sum(offsets[:comm.rank + 1]) hfile = h5py.File(file_out_suff + '.templates.hdf5', 'w', driver='mpio', comm=comm, libver='earliest') norms = hfile.create_dataset('norms', shape=(2 * total_nb_clusters, ), dtype=numpy.float32, chunks=True) electrodes = hfile.create_dataset('electrodes', shape=(total_nb_clusters, ), dtype=numpy.int32, chunks=True) amps_lims = hfile.create_dataset('limits', shape=(total_nb_clusters, 2), dtype=numpy.float32, chunks=True) g_count = node_pad g_offset = total_nb_clusters else: node_pad = 0 hfile = h5py.File(file_out_suff + '.templates-%d.hdf5' % comm.rank, 'w', libver='earliest') electrodes = hfile.create_dataset('electrodes', shape=(local_nb_clusters, ), dtype=numpy.int32, chunks=True) norms = hfile.create_dataset('norms', shape=(2 * local_nb_clusters, ), dtype=numpy.float32, chunks=True) amps_lims = hfile.create_dataset('limits', shape=(local_nb_clusters, 2), dtype=numpy.float32, chunks=True) g_count = 0 g_offset = local_nb_clusters cfile = h5py.File(file_out_suff + '.clusters-%d.hdf5' % comm.rank, 'w', libver='earliest') count_templates = node_pad temp_x = numpy.zeros(0, dtype=numpy.int32) temp_y = numpy.zeros(0, dtype=numpy.int32) temp_data = numpy.zeros(0, dtype=numpy.float32) to_explore = xrange(comm.rank, N_clusters, comm.size) if comm.rank == 0: to_explore = get_tqdm_progressbar(to_explore) for temp in to_explore: n_data = len(result['data_tmp_' + str(temp)]) if n_data > 0: data = result['data_tmp_' + str(temp)].reshape( n_data, basis_proj.shape[1], N_e) first_component = numpy.median(data, axis=0) tmp_templates = numpy.dot(first_component.T, basis_rec) electrodes[g_count] = indices[tmpidx[0][0]] indices = inv_nodes[edges[nodes[electrodes[-1]]]] templates = numpy.zeros((N_e, N_t), dtype=numpy.float32) if shift > 0: templates[indices, shift:] = tmp_templates[:, :-shift] elif shift < 0: templates[indices, :shift] = tmp_templates[:, -shift:] else: templates[indices, :] = tmp_templates templates = templates.flatten() dx = templates.nonzero()[0].astype(numpy.int32) temp_x = numpy.concatenate((temp_x, dx)) temp_y = numpy.concatenate( (temp_y, count_templates * numpy.ones(len(dx), dtype=numpy.int32))) temp_data = numpy.concatenate((temp_data, templates[dx])) norms[g_count] = numpy.sqrt( numpy.sum(templates.flatten()**2) / (N_e * N_t)) x, y, z = data.shape data_flat = data.reshape(x, y * z) first_flat = first_component.reshape(y * z, 1) amplitudes = numpy.dot(data_flat, first_flat) amplitudes /= numpy.sum(first_flat**2) for i in xrange(x): data_flat[i, :] -= amplitudes[i] * first_flat[:, 0] variations = 10 * numpy.median( numpy.abs(amplitudes - numpy.median(amplitudes))) physical_limit = noise_thr * ( -thresholds[indices[tmpidx[0][0]]]) / tmp_templates.min() amp_min = max(physical_limit, numpy.median(amplitudes) - variations) amp_max = min(amp_limits[1], numpy.median(amplitudes) + variations) amps_lims[g_count] = [amp_min, amp_max] if len(data_flat) > 1: pca = PCA(1) res_pca = pca.fit_transform(data_flat.astype(numpy.double)) second_component = pca.components_.T.astype( numpy.float32).reshape(y, z) else: second_component = data_flat.reshape(y, z) / numpy.sum( data_flat**2) tmp_templates = numpy.dot(second_component.T, basis_rec) offset = total_nb_clusters + count_templates sub_templates = numpy.zeros((N_e, N_t), dtype=numpy.float32) if shift > 0: sub_templates[indices, shift:] = tmp_templates[:, :-shift] elif shift < 0: sub_templates[indices, :shift] = tmp_templates[:, -shift:] else: sub_templates[indices, :] = tmp_templates sub_templates = sub_templates.flatten() dx = sub_templates.nonzero()[0].astype(numpy.int32) temp_x = numpy.concatenate((temp_x, dx)) temp_y = numpy.concatenate( (temp_y, offset * numpy.ones(len(dx), dtype=numpy.int32))) temp_data = numpy.concatenate((temp_data, sub_templates[dx])) norms[g_count + g_offset] = numpy.sqrt( numpy.sum(sub_templates.flatten()**2) / (N_e * N_t)) count_templates += 1 g_count += 1 io.write_datasets(cfile, to_write, result, ielec, compress=hdf5_compress) #At the end we should have a templates variable to store. cfile.close() del result, templates, amps_lims comm.Barrier() #We need to gather the sparse arrays temp_x = gather_array(temp_x, comm, dtype='int32', compress=blosc_compress) temp_y = gather_array(temp_y, comm, dtype='int32', compress=blosc_compress) temp_data = gather_array(temp_data, comm, compress=blosc_compress) if parallel_hdf5: if comm.rank == 0: rs = [ h5py.File(file_out_suff + '.clusters-%d.hdf5' % i, 'r', libver='earliest') for i in xrange(comm.size) ] cfile = h5py.File(file_out_suff + '.clusters.hdf5', 'w', libver='earliest') io.write_datasets(cfile, ['electrodes'], {'electrodes': electrodes[:]}, compress=hdf5_compress) for i in xrange(comm.size): for j in range(i, N_e, comm.size): io.write_datasets(cfile, to_write, rs[i], j, compress=hdf5_compress) rs[i].close() os.remove(file_out_suff + '.clusters-%d.hdf5' % i) cfile.close() hfile.close() else: hfile.close() if comm.rank == 0: ts = [ h5py.File(file_out_suff + '.templates-%d.hdf5' % i, 'r', libver='earliest') for i in xrange(comm.size) ] rs = [ h5py.File(file_out_suff + '.clusters-%d.hdf5' % i, 'r', libver='earliest') for i in xrange(comm.size) ] result = {} hfile = h5py.File(file_out_suff + '.templates.hdf5', 'w', libver='earliest') cfile = h5py.File(file_out_suff + '.clusters.hdf5', 'w', libver='earliest') electrodes = hfile.create_dataset('electrodes', shape=(total_nb_clusters, ), dtype=numpy.int32, chunks=True) norms = hfile.create_dataset('norms', shape=(2 * total_nb_clusters, ), dtype=numpy.float32, chunks=True) amplitudes = hfile.create_dataset('limits', shape=(total_nb_clusters, 2), dtype=numpy.float32, chunks=True) count = 0 for i in xrange(comm.size): loc_temp = ts[i].get('templates') middle = loc_temp.shape[2] // 2 norms[count:count + middle] = loc_norms[:middle] norms[n_clusters + count:n_clusters + count + middle] = loc_norms[middle:] electrodes[count:count + middle] = ts[i].get('electrodes') amplitudes[count:count + middle] = ts[i].get('limits') count += middle for j in range(i, N_e, comm.size): io.write_datasets(cfile, to_write, rs[i], j, compress=hdf5_compress) ts[i].close() rs[i].close() os.remove(file_out_suff + '.templates-%d.hdf5' % i) os.remove(file_out_suff + '.clusters-%d.hdf5' % i) io.write_datasets(cfile, ['electrodes'], {'electrodes': electrodes[:]}, compress=hdf5_compress) hfile.close() cfile.close() if comm.rank == 0: hfile = h5py.File(file_out_suff + '.templates.hdf5', 'r+', libver='earliest') hfile.create_dataset('temp_x', data=temp_x) hfile.create_dataset('temp_y', data=temp_y) hfile.create_dataset('temp_data', data=temp_data) hfile.create_dataset('temp_shape', data=numpy.array( [N_e, N_t, 2 * total_nb_clusters], dtype=numpy.int32)) hfile.close() comm.Barrier() if comm.rank == 0: print_and_log(["Merging similar templates..."], 'default', logger) merged1 = algo.merging_cc(params, parallel_hdf5) comm.Barrier() if remove_mixture: if comm.rank == 0: print_and_log(["Removing mixtures..."], 'default', logger) merged2 = algo.delete_mixtures(params, parallel_hdf5) else: merged2 = [0, 0] if comm.rank == 0: print_and_log([ "Number of global merges : %d" % merged1[1], "Number of mixtures removed : %d" % merged2[1] ], 'info', logger) comm.Barrier() io.get_overlaps(params, erase=True, parallel_hdf5=parallel_hdf5) data_file.close()
def main(params, nb_cpu, nb_gpu, use_gpu): logger = init_logging(params.logfile) logger = logging.getLogger('circus.filtering') ################################################################# do_filter = params.getboolean('filtering', 'filter') filter_done = params.getboolean('noedits', 'filter_done') artefacts_done = params.getboolean('noedits', 'artefacts_done') median_done = params.getboolean('noedits', 'median_done') clean_artefact = params.getboolean('triggers', 'clean_artefact') remove_median = params.getboolean('filtering', 'remove_median') nodes, edges = get_nodes_and_edges(params) ################################################################# def filter_file(data_file_in, data_file_out, do_filtering, do_remove_median): try: cut_off = params.getfloat('filtering', 'cut_off') cut_off = [cut_off, 0.95 * (params.rate / 2.)] except Exception: cut_off = params.get('filtering', 'cut_off') cut_off = cut_off.split(',') try: cut_off[0] = float(cut_off[0]) except Exception: if comm.rank == 0: print_and_log( ['First value of cut off must be a valid number'], 'error', logger) sys.exit(0) cut_off[1] = cut_off[1].replace(' ', '') if cut_off[1] == 'auto': cut_off[1] = 0.95 * (params.rate / 2.) else: try: cut_off[1] = float(cut_off[1]) except Exception: if comm.rank == 0: print_and_log([ 'Second value of cut off must either auto, or a valid a number' ], 'error', logger) sys.exit(0) chunk_size = params.getint('data', 'chunk_size') nb_chunks, _ = data_file_in.analyze(chunk_size) b, a = signal.butter(3, np.array(cut_off) / (params.rate / 2.), 'pass') all_chunks = numpy.arange(nb_chunks, dtype=numpy.int64) to_process = all_chunks[comm.rank::comm.size] loc_nb_chunks = len(to_process) N_total = params.nb_channels if comm.rank == 0: to_write = [] if do_filtering: to_write += [ "Filtering the signal with a Butterworth filter in (%g, %g) Hz" % (cut_off[0], cut_off[1]) ] if do_remove_median: to_write += [ "Median over all channels is substracted to each channels" ] print_and_log(to_write, 'default', logger) to_explore = xrange(comm.rank, nb_chunks, comm.size) if comm.rank == 0: to_explore = get_tqdm_progressbar(to_explore) for count, gidx in enumerate(to_explore): local_chunk, t_offset = data_file_in.get_data(gidx, chunk_size) if do_filtering: for i in nodes: try: local_chunk[:, i] = signal.filtfilt( b, a, local_chunk[:, i]) except Exception: pass local_chunk[:, i] -= numpy.median(local_chunk[:, i]) if do_remove_median: if not numpy.all(nodes == numpy.arange(N_total)): global_median = numpy.median( numpy.take(local_chunk, nodes, axis=1), 1) else: global_median = numpy.median(local_chunk, 1) for i in nodes: local_chunk[:, i] -= global_median if data_file_in != data_file_out and data_file_in.is_first_chunk( gidx, nb_chunks): if data_file_in.is_stream: g_offset = t_offset - numpy.sum( data_file_in. _times[:data_file_in. _get_streams_index_by_time(t_offset) + 1]) else: g_offset = t_offset - data_file_in.t_start else: g_offset = t_offset data_file_out.set_data(g_offset, local_chunk) comm.Barrier() def compute_artefacts(data_file): chunk_size = params.getint('data', 'chunk_size') trig_in_ms = params.getboolean('triggers', 'trig_in_ms') artefacts = numpy.loadtxt(params.get('triggers', 'trig_file')) windows = numpy.loadtxt(params.get('triggers', 'trig_windows')) make_plots = params.get('triggers', 'make_plots') plot_path = os.path.join(params.get('data', 'data_file_noext'), 'plots') if len(windows.shape) == 1: windows = windows.reshape(1, 2) if len(artefacts.shape) == 1: artefacts = artefacts.reshape(1, 2) if trig_in_ms: if comm.rank == 0: print_and_log(['Artefact times are read in ms'], 'debug', logger) artefacts[:, 1] *= numpy.int64(data_file.sampling_rate * 1e-3) windows[:, 1] *= numpy.int64(data_file.sampling_rate * 1e-3) else: if comm.rank == 0: print_and_log(['Artefact times are read in timesteps'], 'debug', logger) artefacts = artefacts.astype(numpy.int64) windows = windows.astype(numpy.int64) nb_stimuli = len(numpy.unique(artefacts[:, 0])) mytest = nb_stimuli == len(windows) if not mytest: if comm.rank == 0: print_and_log(['Error in the trigger files'], 'error', logger) sys.exit(0) all_labels = artefacts[:, 0].astype(numpy.int32) all_times = artefacts[:, 1].astype(numpy.int32) mask = (all_times >= 0) & (all_times + numpy.max(windows[:, 1]) < data_file.t_stop) all_times = numpy.compress(mask, all_times) all_labels = numpy.compress(mask, all_labels) local_labels = numpy.unique(all_labels)[comm.rank::comm.size] if comm.rank == 0: to_write = [ "Computing averaged artefacts from %d stimuli" % (nb_stimuli) ] print_and_log(to_write, 'default', logger) if not os.path.exists(plot_path): os.makedirs(plot_path) local_labels = get_tqdm_progressbar(local_labels) comm.Barrier() # First we need to get the average artefacts art_dict = {} for count, artefact in enumerate(local_labels): indices = numpy.where(all_labels == artefact)[0].astype( numpy.int32) tmp = numpy.where(windows[:, 0] == artefact)[0] tau = numpy.int64(windows[tmp, 1]) pspikes = all_times[indices] times = numpy.sort(numpy.random.permutation(pspikes)[:500]) if len(numpy.where(numpy.diff(times) < tau)[0]) > 0: if comm.rank == 0: print_and_log([ 'Stimulation times for artefact %d are too close!' % artefact ], 'error', logger) sys.exit(0) art_dict[artefact] = get_artefact(params, times, tau, nodes) if make_plots not in ['None', '']: save = [plot_path, '%d.%s' % (artefact, make_plots)] plot.view_artefact(art_dict[artefact], save=save) return art_dict def remove_artefacts(data_file, art_dict): chunk_size = params.getint('data', 'chunk_size') trig_in_ms = params.getboolean('triggers', 'trig_in_ms') artefacts = numpy.loadtxt(params.get('triggers', 'trig_file')).astype(numpy.int64) windows = numpy.loadtxt(params.get('triggers', 'trig_windows')).astype(numpy.int64) make_plots = params.get('triggers', 'make_plots') plot_path = os.path.join(params.get('data', 'data_file_noext'), 'plots') if len(windows.shape) == 1: windows = windows.reshape(1, 2) if len(artefacts.shape) == 1: artefacts = artefacts.reshape(1, 2) if trig_in_ms: if comm.rank == 0: print_and_log(['Artefact times are read in ms'], 'debug', logger) artefacts[:, 1] *= numpy.int64(data_file.sampling_rate * 1e-3) windows[:, 1] *= numpy.int64(data_file.sampling_rate * 1e-3) else: if comm.rank == 0: print_and_log(['Artefact times are read in timesteps'], 'debug', logger) artefacts = artefacts.astype(numpy.int64) windows = windows.astype(numpy.int64) nb_stimuli = len(numpy.unique(artefacts[:, 0])) mytest = nb_stimuli == len(windows) if not mytest: if comm.rank == 0: print_and_log(['Error in the trigger files'], 'error', logger) sys.exit(0) all_labels = artefacts[:, 0].astype(numpy.int32) all_times = artefacts[:, 1].astype(numpy.int32) local_labels = numpy.unique(all_labels)[comm.rank::comm.size] mask = numpy.in1d(all_labels, local_labels) all_times = numpy.compress(mask, all_times) all_labels = numpy.compress(mask, all_labels) mask = (all_times >= 0) & (all_times < data_file.t_stop) all_times = numpy.compress(mask, all_times) all_labels = numpy.compress(mask, all_labels) if comm.rank == 0: to_write = ["Removing artefacts from %d stimuli" % (nb_stimuli)] print_and_log(to_write, 'default', logger) all_times = get_tqdm_progressbar(all_times) comm.Barrier() for count, time in enumerate(all_times): label = all_labels[count] tmp = numpy.where(windows[:, 0] == label)[0][0] tau = numpy.int64(windows[tmp, 1]) if (data_file.t_stop - time) < tau: tau = max_offset - time local_chunk = data_file.get_snippet(time, tau) for idx, i in enumerate(nodes): local_chunk[:, i] -= art_dict[label][idx, :tau] data_file.set_data(time, local_chunk) comm.Barrier() if comm.rank == 0: print_and_log(['Initializing the filtering step...'], 'debug', logger) if params.getboolean('data', 'overwrite'): if comm.rank == 0: print_and_log(['Reading the input file...'], 'debug', logger) data_file_in = params.get_data_file() data_file_out = data_file_in else: if comm.rank == 0: print_and_log( ['Overwrite is set to False, so creating a new datafile...'], 'debug', logger) if comm.rank == 0: print_and_log(['Reading the input file...'], 'debug', logger) if os.path.exists(params.get('data', 'data_file_no_overwrite')): has_been_created = True else: has_been_created = False if not has_been_created and (filter_done or median_done or artefacts_done): if comm.rank == 0: print_and_log([ 'The filtering is done but file not present. See no_edits section' ], 'error', logger) sys.exit(0) if not has_been_created: data_file_in = params.get_data_file( source=True, has_been_created=has_been_created) else: data_file_in = params.get_data_file( source=False, has_been_created=has_been_created) if comm.rank == 0: print_and_log( ['Reading the output file and allocating ressources...'], 'debug', logger) description = data_file_in.get_description() description['data_dtype'] = 'float32' description['dtype_offset'] = 0 description['data_offset'] = 0 data_file_out = params.get_data_file(is_empty=not has_been_created, params=description) if not has_been_created: data_file_out.allocate(shape=data_file_in.shape) comm.Barrier() if clean_artefact: if not (os.path.exists(params.get('triggers', 'trig_file')) and os.path.exists(params.get('triggers', 'trig_windows'))): if comm.rank == 0: print_and_log( ['trig_file or trig_windows file can not be found'], 'error', logger) sys.exit(0) to_write = [] if do_filter and filter_done: do_filter = False to_write += ["Filtering has already been done"] if remove_median and median_done: remove_median = False to_write += [ "Median over all channels has already been substracted to each channels" ] if comm.rank == 0 and len(to_write) > 0: print_and_log(to_write, 'info', logger) if params.getboolean('data', 'overwrite'): data_file_in.open(mode='r+') else: data_file_in.open() data_file_out.open(mode='r+') if do_filter or remove_median: filter_file(data_file_in, data_file_out, do_filter, remove_median) if comm.rank == 0: if do_filter: params.write('noedits', 'filter_done', 'True') if remove_median: params.write('noedits', 'median_done', 'True') if clean_artefact and artefacts_done: clean_artefact = False if comm.rank == 0: print_and_log(['Artefacts have already been removed'], 'debug', logger) if clean_artefact: art_dict = compute_artefacts(data_file_in) remove_artefacts(data_file_out, art_dict) if comm.rank == 0: if clean_artefact: params.write('noedits', 'artefacts_done', 'True') data_file_in.close() if not params.getboolean('data', 'overwrite'): data_file_out.close() comm.Barrier()
def main(argv=None): if argv is None: argv = sys.argv[1:] header = get_colored_header() header += '''Utility to group files within several folders into a single virtual folder, such that they can be processed together with the multi-files mode. If you want to also process .dead or .trig files in order to later on concatenate artefacts, please use the -d or -t options ''' parser = argparse.ArgumentParser(description=header, formatter_class=argparse.RawTextHelpFormatter) parser.add_argument('folders', help='text file with the list of folders to consider') parser.add_argument('extension', help='file extension to consider within folders') parser.add_argument('-o', '--output', help='name of the output folder [default is output]', default='output') parser.add_argument('-d', '--dead', help='Search for all .dead files', action='store_true') parser.add_argument('-t', '--trig', help='Search for all .trig files', action='store_true') if len(argv) == 0: parser.print_help() sys.exit() args = parser.parse_args(argv) folders_file = os.path.abspath(args.folders) output = os.path.abspath(args.output) extension = args.extension filename, ext = os.path.splitext(os.path.basename(folders_file)) logger = init_logging(filename + '.log') logger = logging.getLogger(__name__) if not os.path.exists(folders_file): print_and_log(['The folder file %s does not exists!' %folders_file], 'error', logger) sys.exit(0) try: folders = [] myfile = open(folders_file, 'r') lines = myfile.readlines() myfile.close() for l in lines: folders += [os.path.abspath(l.strip())] except Exception: print_and_log(['Check the syntax of the folder file'], 'error', logger) sys.exit(0) do_folders = True if os.path.exists(output): do_folders = query_yes_no(Fore.WHITE + "Folder %s already exists! Do you want to erase everything?" %output, default=None) if not do_folders: sys.exit(0) else: shutil.rmtree(output) os.makedirs(output) for count, folder in enumerate(folders): files = os.listdir(folder) for file in files: _, ext = os.path.splitext(file) ext = ext.strip('.') if (ext.lower() == extension.lower()) or (args.dead and ext.lower() == 'dead') or (args.trig and ext.lower()== 'trig'): original_file = os.path.join(folder, file) linked_file = os.path.join(output, 'sc_{c}_{f}'.format(c=count, f=os.path.basename(original_file))) if not os.path.exists(linked_file): os.symlink(original_file, linked_file) else: os.symlink(original_file, linked_file)
def main(argv=None): if argv is None: argv = sys.argv[1:] header = get_colored_header() parser = argparse.ArgumentParser( description=header, formatter_class=argparse.RawTextHelpFormatter) parser.add_argument('datafile', help='data file') parser.add_argument('-e', '--extension', help='extension to consider for visualization', default='') if len(argv) == 0: parser.print_help() sys.exit() args = parser.parse_args(argv) filename = os.path.abspath(args.datafile) extension = args.extension params = CircusParser(filename) if os.path.exists(params.logfile): os.remove(params.logfile) logger = init_logging(params.logfile) logger = logging.getLogger(__name__) mytest = StrictVersion(phycontrib.__version__) >= StrictVersion("1.0.12") if not mytest: print_and_log( ['You need to update phy-contrib to the latest git version'], 'error', logger) sys.exit(1) data_file = params.get_data_file() data_dtype = data_file.data_dtype if hasattr(data_file, 'data_offset'): data_offset = data_file.data_offset else: data_offset = 0 file_format = data_file.description file_out_suff = params.get('data', 'file_out_suff') if file_format not in supported_by_phy: print_and_log([ "File format %s is not supported by phy. TraceView disabled" % file_format ], 'info', logger) if numpy.iterable(data_file.gain): print_and_log( ['Multiple gains are not supported, using a default value of 1'], 'info', logger) gain = 1 else: if data_file.gain != 1: print_and_log([ "Gain of %g is not supported by phy. Expecting a scaling mismatch" % gain ], 'info', logger) gain = data_file.gain probe = params.probe if extension != '': extension = '-' + extension output_path = params.get('data', 'file_out_suff') + extension + '.GUI' if not os.path.exists(output_path): print_and_log( ['Data should be first exported with the converting method!'], 'error', logger) else: print_and_log(["Launching the phy GUI..."], 'info', logger) gui_params = {} if file_format in supported_by_phy: gui_params['dat_path'] = params.get('data', 'data_file') else: gui_params['dat_path'] = '' gui_params['n_channels_dat'] = params.nb_channels gui_params['n_features_per_channel'] = 5 gui_params['dtype'] = data_dtype gui_params['offset'] = data_offset gui_params['sample_rate'] = params.rate gui_params['hp_filtered'] = True os.chdir(output_path) create_app() controller = TemplateController(**gui_params) gui = controller.create_gui() gui.show() run_app() gui.close() del gui
def main(argv=None): if argv is None: argv = sys.argv[1:] header = get_colored_header() header += '''Utility to concatenate artefacts/dead times before using stream mode. Code will look for .dead and .trig files, and concatenate them automatically taking care of file offsets ''' parser = argparse.ArgumentParser( description=header, formatter_class=argparse.RawTextHelpFormatter) parser.add_argument('datafile', help='data file') # parser.add_argument('-w', '--window', help='text file with artefact window files', # default=None) if len(argv) == 0: parser.print_help() sys.exit() args = parser.parse_args(argv) # if args.window is None: # window_file = None # else: # window_file = os.path.abspath(args.window) filename = os.path.abspath(args.datafile) params = CircusParser(filename) dead_in_ms = params.getboolean('triggers', 'dead_in_ms') trig_in_ms = params.getboolean('triggers', 'trig_in_ms') if os.path.exists(params.logfile): os.remove(params.logfile) _ = init_logging(params.logfile) logger = logging.getLogger(__name__) if params.get('data', 'stream_mode') == 'multi-files': data_file = params.get_data_file(source=True, has_been_created=False) all_times_dead = numpy.zeros((0, 2), dtype=numpy.int64) all_times_trig = numpy.zeros((0, 2), dtype=numpy.int64) for f in data_file._sources: name, ext = os.path.splitext(f.file_name) dead_file = f.file_name.replace(ext, '.dead') trig_file = f.file_name.replace(ext, '.trig') if os.path.exists(dead_file): print_and_log(['Found file %s' % dead_file], 'default', logger) times = get_dead_times(dead_file, data_file.sampling_rate, dead_in_ms) if times.max() > f.duration or times.min() < 0: print_and_log([ 'Dead zones larger than duration for file %s' % f.file_name, '-> Clipping automatically' ], 'error', logger) times = numpy.minimum(times, f.duration) times = numpy.maximum(times, 0) times += f.t_start all_times_dead = numpy.vstack((all_times_dead, times)) if os.path.exists(trig_file): print_and_log(['Found file %s' % trig_file], 'default', logger) times = get_trig_times(trig_file, data_file.sampling_rate, trig_in_ms) if times[:, 1].max() > f.duration or times[:, 1].min() < 0: print_and_log([ 'Triggers larger than duration for file %s' % f.file_name ], 'error', logger) sys.exit(0) times[:, 1] += f.t_start all_times_trig = numpy.vstack((all_times_trig, times)) if len(all_times_dead) > 0: output_file = os.path.join(os.path.dirname(filename), 'dead_zones.txt') print_and_log(['Saving global artefact file in %s' % output_file], 'default', logger) if dead_in_ms: all_times_dead = all_times_dead.astype( numpy.float32) / data_file.sampling_rate numpy.savetxt(output_file, all_times_dead) if len(all_times_trig) > 0: output_file = os.path.join(os.path.dirname(filename), 'triggers.txt') print_and_log(['Saving global artefact file in %s' % output_file], 'default', logger) if trig_in_ms: all_times_trig = all_times_trig.astype( numpy.float32) / data_file.sampling_rate numpy.savetxt(output_file, all_times_trig) elif params.get('data', 'stream_mode') == 'single-file': print_and_log(['Not implemented'], 'error', logger) sys.exit(0) else: print_and_log( ['You should select a valid stream_mode such as multi-files'], 'error', logger) sys.exit(0)
def main(params, nb_cpu, nb_gpu, use_gpu): # Part 1: Whitening numpy.random.seed(420) logger = init_logging(params.logfile) logger = logging.getLogger('circus.whitening') ################################################################# data_file = params.data_file data_file.open() N_e = params.getint('data', 'N_e') N_total = params.nb_channels N_t = params.getint('detection', 'N_t') dist_peaks = params.getint('detection', 'dist_peaks') template_shift = params.getint('detection', 'template_shift') file_out_suff = params.get('data', 'file_out_suff') file_out = params.get('data', 'file_out') spike_thresh = params.getfloat('detection', 'spike_thresh') matched_filter = params.getboolean('detection', 'matched-filter') matched_thresh = params.getfloat('detection', 'matched_thresh') sign_peaks = params.get('detection', 'peaks') do_temporal_whitening = params.getboolean('whitening', 'temporal') do_spatial_whitening = params.getboolean('whitening', 'spatial') chunk_size = params.getint('whitening', 'chunk_size') plot_path = os.path.join(params.get('data', 'data_file_noext'), 'plots') nodes, edges = get_nodes_and_edges(params) safety_time = params.getint('whitening', 'safety_time') nb_temp_white = min(max(20, comm.size), N_e) max_silence_1 = int(20 * params.rate // comm.size) max_silence_2 = 5000 inv_nodes = numpy.zeros(N_total, dtype=numpy.int32) inv_nodes[nodes] = numpy.argsort(nodes) ################################################################# if comm.rank == 0: print_and_log( ["Analyzing data to get whitening matrices and thresholds..."], 'default', logger) if use_gpu: import cudamat as cmt ## Need to properly handle multi GPU per MPI nodes? if nb_gpu > nb_cpu: gpu_id = int(comm.rank // nb_cpu) else: gpu_id = 0 cmt.cuda_set_device(gpu_id) cmt.init() cmt.cuda_sync_threads() nb_chunks, last_chunk_len = data_file.analyze(chunk_size) if nb_chunks < comm.size: res = io.data_stats(params, show=False) chunk_size = int(res * params.rate // comm.size) if comm.rank == 0: print_and_log( ["Too much cores, automatically resizing the data chunks"], 'debug', logger) nb_chunks, last_chunk_len = data_file.analyze(chunk_size) # I guess this is more relevant, to take signals from all over the recordings all_chunks = numpy.random.permutation( numpy.arange(nb_chunks, dtype=numpy.int32)) all_electrodes = numpy.random.permutation(N_e) for gidx in [all_chunks[comm.rank]]: #print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..." local_chunk, t_offset = data_file.get_data(gidx, chunk_size, nodes=nodes) local_shape = len(local_chunk) #print "Node", comm.rank, "computes the median absolute deviations in a random chunk" thresholds = numpy.zeros(N_e, dtype=numpy.float32) for i in xrange(N_e): u = numpy.median(local_chunk[:, i], 0) thresholds[i] = numpy.median(numpy.abs(local_chunk[:, i] - u), 0) gdata = gather_array(thresholds, comm) if comm.rank == 0: gdata = gdata.reshape((comm.size, N_e)) thresholds = numpy.mean(gdata, 0) bfile = h5py.File(file_out_suff + '.basis.hdf5', 'w', libver='latest') io.write_datasets(bfile, ['thresholds'], {'thresholds': thresholds}) bfile.close() comm.Barrier() thresholds = io.load_data(params, 'thresholds') #print "Extracting the peaks..." local_peaktimes = numpy.zeros(0, dtype=numpy.int32) for i in xrange(N_e): peaktimes = algo.detect_peaks(numpy.abs(local_chunk[:, i]), thresholds[i], valley=False, mpd=dist_peaks) local_peaktimes = numpy.concatenate((local_peaktimes, peaktimes)) local_peaktimes = numpy.unique(local_peaktimes) #print "Removing the useless borders..." local_borders = (template_shift, local_shape - template_shift) idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes < local_borders[1]) local_peaktimes = numpy.compress(idx, local_peaktimes) if len(local_peaktimes) > 0: diff_times = local_peaktimes[-1] - local_peaktimes[0] all_times = numpy.zeros((N_e, diff_times + 1), dtype=numpy.bool) min_times = numpy.maximum( local_peaktimes - local_peaktimes[0] - safety_time, 0) max_times = numpy.minimum( local_peaktimes - local_peaktimes[0] + safety_time + 1, diff_times) argmax_peak = numpy.random.permutation( numpy.arange(len(local_peaktimes))) all_idx = numpy.take(local_peaktimes, argmax_peak) #print "Selection of the peaks with spatio-temporal masks..." for idx, peak in zip(argmax_peak, all_idx): elec = numpy.argmax(numpy.abs(local_chunk[peak])) indices = numpy.take(inv_nodes, edges[nodes[elec]]) all_times[indices, min_times[idx]:max_times[idx]] = True else: all_times = numpy.zeros((N_e, len(local_chunk)), dtype=numpy.bool) all_times_Ne = numpy.any(all_times, 0) subset = numpy.where(all_times_Ne == False)[0] all_silences = [] if do_spatial_whitening: local_silences = numpy.take(local_chunk, subset, axis=0)[:max_silence_1] all_silences = gather_array(local_silences, comm, 0, 1) local_res = [] if do_temporal_whitening: for elec in all_electrodes[numpy.arange(comm.rank, nb_temp_white, comm.size)]: res = numpy.zeros((0, N_t), dtype=numpy.float32) scount = 0 indices = numpy.take(inv_nodes, edges[nodes[elec]]) all_times_elec = numpy.any(numpy.take(all_times, indices, axis=0), 0) esubset = numpy.where(all_times_elec == False)[0] bound = len(esubset) - N_t while (scount < bound) and (len(res) < max_silence_2): myslice = esubset[scount:scount + N_t] if numpy.all((myslice - esubset[scount]) == numpy.arange(N_t)): scount += N_t res = numpy.vstack((res, local_chunk[myslice, elec])) else: scount += 1 if len(res) > 5: local_res += [numpy.cov(res.T)] nb_elecs = numpy.array([len(local_res)], dtype=numpy.float32) local_res = numpy.array(local_res, dtype=numpy.float32) if len(local_res) == 0: local_res = numpy.zeros(0, dtype=numpy.float32) else: local_res = numpy.sum(local_res, 0) all_res = gather_array(local_res.ravel(), comm, 0, 1) all_elecs = gather_array(nb_elecs, comm, 0, 1) if comm.rank == 0: to_write = {} if do_temporal_whitening: try: nb_silences = numpy.sum(all_elecs > 0) all_res = all_res.reshape((nb_silences, N_t**2)) except Exception: print_and_log([ "No silent periods detected: something wrong with the parameters?" ], 'error', logger) all_res = numpy.sum(all_res, 0) all_res = all_res.reshape((N_t, N_t)) / numpy.sum(all_elecs) temporal_whitening = get_whitening_matrix( all_res.astype(numpy.double), fudge=1e-3)[template_shift].astype(numpy.float32) temporal_whitening /= temporal_whitening.sum() to_write['temporal'] = temporal_whitening have_nans = numpy.sum(numpy.isnan(temporal_whitening)) if have_nans > 0: temporal_whitening = numpy.zeros(N_t, dtype=numpy.float32) temporal_whitening[N_t // 2] = 1 to_write['temporal'] = temporal_whitening print_and_log( ["Disabling temporal whitening because of NaNs found"], 'info', logger) if do_spatial_whitening: if len(all_silences) / params.rate == 0: print_and_log([ "No silent periods detected: something wrong with the parameters?" ], 'error', logger) spatial_whitening = get_whitening_matrix( all_silences.astype(numpy.double)).astype(numpy.float32) to_write['spatial'] = spatial_whitening print_and_log([ "Found %gs without spikes for whitening matrices..." % (len(all_silences) / params.rate) ], 'default', logger) have_nans = numpy.sum(numpy.isnan(spatial_whitening)) if have_nans > 0: spatial_whitening = numpy.eye(spatial_whitening.shape[0], dtype=numpy.float32) to_write['spatial'] = spatial_whitening print_and_log( ["Disabling spatial whitening because of NaNs found"], 'info', logger) bfile = h5py.File(file_out_suff + '.basis.hdf5', 'r+', libver='latest') io.write_datasets(bfile, to_write.keys(), to_write) bfile.close() del all_silences comm.Barrier() if do_spatial_whitening or do_temporal_whitening: if comm.rank == 0: print_and_log( ["Because of whitening, need to recompute the thresholds..."], 'default', logger) if do_spatial_whitening: spatial_whitening = io.load_data(params, 'spatial_whitening') if use_gpu: spatial_whitening = cmt.CUDAMatrix(spatial_whitening, copy_on_host=False) if do_temporal_whitening: temporal_whitening = io.load_data(params, 'temporal_whitening') for gidx in [all_chunks[comm.rank]]: local_chunk, t_offset = data_file.get_data(gidx, chunk_size, nodes=nodes) local_shape = len(local_chunk) if do_spatial_whitening: if use_gpu: local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False) local_chunk = local_chunk.dot(spatial_whitening).asarray() else: local_chunk = numpy.dot(local_chunk, spatial_whitening) if do_temporal_whitening: local_chunk = scipy.ndimage.filters.convolve1d( local_chunk, temporal_whitening, axis=0, mode='constant') thresholds = numpy.zeros(N_e, dtype=numpy.float32) for i in xrange(N_e): u = numpy.median(local_chunk[:, i], 0) thresholds[i] = numpy.median(numpy.abs(local_chunk[:, i] - u), 0) gdata = gather_array(thresholds, comm) if comm.rank == 0: gdata = gdata.reshape((comm.size, N_e)) thresholds = numpy.mean(gdata, 0) bfile = h5py.File(file_out_suff + '.basis.hdf5', 'r+', libver='latest') bfile.pop('thresholds') io.write_datasets(bfile, ['thresholds'], {'thresholds': thresholds}) bfile.close() comm.Barrier() #if comm.rank == 0: #if not os.path.exists(plot_path): # os.makedirs(plot_path) #N_elec = min(int(numpy.sqrt(data_file.N_e)), 5) #plot.view_fit(filename, t_start=0, t_stop=1, fit_on=False, square=True, # n_elec=N_elec, save=[plot_path, 'electrodes']) # Part 2: Basis numpy.random.seed(422) ################################################################# file_out = params.get('data', 'file_out') alignment = params.getboolean('detection', 'alignment') spike_thresh = params.getfloat('detection', 'spike_thresh') nodes, edges = get_nodes_and_edges(params) do_temporal_whitening = params.getboolean('whitening', 'temporal') do_spatial_whitening = params.getboolean('whitening', 'spatial') chunk_size = params.getint('data', 'chunk_size') safety_time = params.getint('whitening', 'safety_time') max_elts_elec = params.getint('whitening', 'max_elts') nb_elts = int( params.getfloat('whitening', 'nb_elts') * N_e * max_elts_elec) output_dim = params.getfloat('whitening', 'output_dim') inv_nodes = numpy.zeros(N_total, dtype=numpy.int32) inv_nodes[nodes] = numpy.argsort(nodes) if sign_peaks == 'both': max_elts_elec *= 2 ################################################################# if comm.rank == 0: print_and_log(["Searching spikes to construct the PCA basis..."], 'default', logger) nb_chunks, last_chunk_len = data_file.analyze(chunk_size) if nb_chunks < comm.size: res = io.data_stats(params, show=False) chunk_size = int(res * params.rate // comm.size) if comm.rank == 0: print_and_log( ["Too much cores, automatically resizing the data chunks"], 'debug', logger) nb_chunks, last_chunk_len = data_file.analyze(chunk_size) groups = {} for i in xrange(N_e): groups[i] = 0 # I guess this is more relevant, to take signals from all over the recordings all_chunks = numpy.random.permutation( numpy.arange(nb_chunks, dtype=numpy.int32)) max_elts_elec //= comm.size nb_elts //= comm.size elt_count_pos = 0 elt_count_neg = 0 if sign_peaks in ['positive', 'both']: elts_pos = numpy.zeros((N_t, nb_elts), dtype=numpy.float32) if sign_peaks in ['negative', 'both']: elts_neg = numpy.zeros((N_t, nb_elts), dtype=numpy.float32) chunks_to_load = all_chunks[comm.rank::comm.size] thresholds = io.load_data(params, 'thresholds') if comm.rank == 0: pbar = get_progressbar(nb_elts) if alignment: cdata = numpy.linspace(-template_shift, template_shift, 5 * N_t) xdata = numpy.arange(-2 * template_shift, 2 * template_shift + 1) for gcount, gidx in enumerate(chunks_to_load): if ((elt_count_pos + elt_count_neg) < nb_elts): #print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..." local_chunk, t_offset = data_file.get_data(gidx, chunk_size, nodes=nodes) local_shape = len(local_chunk) if do_spatial_whitening: if use_gpu: local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False) local_chunk = local_chunk.dot(spatial_whitening).asarray() else: local_chunk = numpy.dot(local_chunk, spatial_whitening) if do_temporal_whitening: local_chunk = scipy.ndimage.filters.convolve1d( local_chunk, temporal_whitening, axis=0, mode='constant') #print "Extracting the peaks..." all_peaktimes = numpy.zeros(0, dtype=numpy.int32) all_extremas = numpy.zeros(0, dtype=numpy.int32) for i in xrange(N_e): if sign_peaks == 'negative': peaktimes = algo.detect_peaks(local_chunk[:, i], thresholds[i], valley=True, mpd=dist_peaks) elif sign_peaks == 'positive': peaktimes = algo.detect_peaks(local_chunk[:, i], thresholds[i], valley=False, mpd=dist_peaks) elif sign_peaks == 'both': peaktimes = algo.detect_peaks(numpy.abs(local_chunk[:, i]), thresholds[i], valley=False, mpd=dist_peaks) all_peaktimes = numpy.concatenate((all_peaktimes, peaktimes)) all_extremas = numpy.concatenate( (all_extremas, i * numpy.ones(len(peaktimes), dtype=numpy.int32))) #print "Removing the useless borders..." if alignment: local_borders = (2 * template_shift, local_shape - 2 * template_shift) else: local_borders = (template_shift, local_shape - template_shift) idx = (all_peaktimes >= local_borders[0]) & (all_peaktimes < local_borders[1]) all_peaktimes = numpy.compress(idx, all_peaktimes) all_extremas = numpy.compress(idx, all_extremas) local_peaktimes = numpy.unique(all_peaktimes) if len(local_peaktimes) > 0: diff_times = local_peaktimes[-1] - local_peaktimes[0] all_times = numpy.zeros((N_e, diff_times + 1), dtype=numpy.bool) min_times = numpy.maximum( local_peaktimes - local_peaktimes[0] - safety_time, 0) max_times = numpy.minimum( local_peaktimes - local_peaktimes[0] + safety_time + 1, diff_times) n_times = len(local_peaktimes) argmax_peak = numpy.random.permutation(numpy.arange(n_times)) all_idx = numpy.take(local_peaktimes, argmax_peak) #print "Selection of the peaks with spatio-temporal masks..." for midx, peak in zip(argmax_peak, all_idx): if (elt_count_neg + elt_count_pos) == nb_elts: break if sign_peaks == 'negative': elec = numpy.argmin(local_chunk[peak]) negative_peak = True elif sign_peaks == 'positive': elec = numpy.argmax(local_chunk[peak]) negative_peak = False elif sign_peaks == 'both': if numpy.abs(numpy.max(local_chunk[peak])) > numpy.abs( numpy.min(local_chunk[peak])): elec = numpy.argmax(local_chunk[peak]) negative_peak = False else: elec = numpy.argmin(local_chunk[peak]) negative_peak = True indices = numpy.take(inv_nodes, edges[nodes[elec]]) myslice = all_times[indices, min_times[midx]:max_times[midx]] is_local_extrema = elec in all_extremas[all_peaktimes == peak] if is_local_extrema and not myslice.any(): upper_bounds = max_elts_elec if groups[elec] < upper_bounds: if negative_peak: elts_neg[:, elt_count_neg] = local_chunk[ peak - template_shift:peak + template_shift + 1, elec] else: elts_pos[:, elt_count_pos] = local_chunk[ peak - template_shift:peak + template_shift + 1, elec] if alignment: ydata = local_chunk[peak - 2 * template_shift:peak + 2 * template_shift + 1, elec] f = scipy.interpolate.UnivariateSpline(xdata, ydata, s=0) if negative_peak: rmin = (numpy.argmin(f(cdata)) - len(cdata) / 2.) / 5. else: rmin = (numpy.argmax(f(cdata)) - len(cdata) / 2.) / 5. ddata = numpy.linspace(rmin - template_shift, rmin + template_shift, N_t) if negative_peak: elts_neg[:, elt_count_neg] = f(ddata).astype( numpy.float32) else: elts_pos[:, elt_count_pos] = f(ddata).astype( numpy.float32) if negative_peak: elt_count_neg += 1 else: elt_count_pos += 1 groups[elec] += 1 all_times[indices, min_times[midx]:max_times[midx]] = True if comm.rank == 0: pbar.update(elt_count_pos + elt_count_neg) if comm.rank == 0: if (elt_count_pos + elt_count_neg < (gcount + 1) * max_elts_elec // len(chunks_to_load)): pbar.update( (gcount + 1) * max_elts_elec // len(chunks_to_load)) if comm.rank == 0: pbar.finish() print_and_log([ "Node %d has collected %d waveforms" % (comm.rank, elt_count_pos + elt_count_neg) ], 'debug', logger) if sign_peaks in ['negative', 'both']: gdata_neg = gather_array(elts_neg[:, :elt_count_neg].T, comm, 0, 1) if sign_peaks in ['positive', 'both']: gdata_pos = gather_array(elts_pos[:, :elt_count_pos].T, comm, 0, 1) if comm.rank == 0: #DO PCA on elts and store the basis obtained. nb_waveforms = 0 if sign_peaks in ['negative', 'both']: nb_waveforms += gdata_neg.shape[0] if sign_peaks in ['positive', 'both']: nb_waveforms += gdata_pos.shape[0] print_and_log([ "Found %d waveforms over %d requested" % (nb_waveforms, int(nb_elts * comm.size)) ], 'default', logger) pca = PCA(output_dim, copy=False) res = {} if sign_peaks in ['negative', 'both']: if len(gdata_neg) > 0: res_pca = pca.fit_transform(gdata_neg.astype( numpy.double)).astype(numpy.float32) res['proj'] = pca.components_.T.astype(numpy.float32) else: res['proj'] = numpy.identity(N_t, dtype=numpy.float32) res['rec'] = res['proj'].T res['waveform'] = numpy.median(gdata_neg, 0) idx = numpy.random.permutation(numpy.arange( gdata_neg.shape[0]))[:1000] res['waveforms'] = gdata_neg[idx, :] if sign_peaks in ['positive', 'both']: if len(gdata_pos) > 0: res_pca = pca.fit_transform(gdata_pos.astype( numpy.double)).astype(numpy.float32) res['proj_pos'] = pca.components_.T.astype(numpy.float32) else: res['proj_pos'] = numpy.identity(N_t, dtype=numpy.float32) res['rec_pos'] = res['proj_pos'].T res['waveform_pos'] = numpy.median(gdata_pos, 0) idx = numpy.random.permutation(numpy.arange( gdata_pos.shape[0]))[:1000] res['waveforms_pos'] = gdata_pos[idx, :] bfile = h5py.File(file_out_suff + '.basis.hdf5', 'r+', libver='latest') io.write_datasets(bfile, res.keys(), res) if sign_peaks == 'positive': print_and_log([ "A basis with %s dimensions has been built" % res['proj_pos'].shape[1] ], 'info', logger) elif sign_peaks == 'negative': print_and_log([ "A basis with %s dimensions has been built" % res['proj'].shape[1] ], 'info', logger) elif sign_peaks == 'both': print_and_log([ "Two basis with %s dimensions has been built" % res['proj'].shape[1] ], 'info', logger) bfile.close() comm.Barrier() if matched_filter: if comm.rank == 0: print_and_log([ "Because of matched filters, need to recompute the thresholds..." ], 'default', logger) if do_spatial_whitening: spatial_whitening = io.load_data(params, 'spatial_whitening') if use_gpu: spatial_whitening = cmt.CUDAMatrix(spatial_whitening, copy_on_host=False) if do_temporal_whitening: temporal_whitening = io.load_data(params, 'temporal_whitening') if sign_peaks in ['negative', 'both']: waveform_neg = io.load_data(params, 'waveform') waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) * len(waveform_neg)) if sign_peaks in ['positive', 'both']: waveform_pos = io.load_data(params, 'waveform-pos') waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) * len(waveform_pos)) for gidx in [all_chunks[comm.rank]]: local_chunk, t_offset = data_file.get_data(gidx, chunk_size, nodes=nodes) local_shape = len(local_chunk) if do_spatial_whitening: if use_gpu: local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False) local_chunk = local_chunk.dot(spatial_whitening).asarray() else: local_chunk = numpy.dot(local_chunk, spatial_whitening) if do_temporal_whitening: local_chunk = scipy.ndimage.filters.convolve1d( local_chunk, temporal_whitening, axis=0, mode='constant') if sign_peaks in ['negative', 'both']: tmp_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_neg, axis=0, mode='constant') thresholds = numpy.zeros(N_e, dtype=numpy.float32) for i in xrange(N_e): u = numpy.median(tmp_chunk[:, i], 0) thresholds[i] = numpy.median( numpy.abs(tmp_chunk[:, i] - u), 0) gdata = gather_array(thresholds, comm) if comm.rank == 0: gdata = gdata.reshape((comm.size, N_e)) thresholds = numpy.mean(gdata, 0) bfile = h5py.File(file_out + '.basis.hdf5', 'r+', libver='latest') io.write_datasets(bfile, ['matched_thresholds'], {'matched_thresholds': thresholds}) bfile.close() comm.Barrier() if sign_peaks in ['positive', 'both']: tmp_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_pos, axis=0, mode='constant') thresholds = numpy.zeros(N_e, dtype=numpy.float32) for i in xrange(N_e): u = numpy.median(tmp_chunk[:, i], 0) thresholds[i] = numpy.median( numpy.abs(tmp_chunk[:, i] - u), 0) gdata = gather_array(thresholds, comm) if comm.rank == 0: gdata = gdata.reshape((comm.size, N_e)) thresholds = numpy.mean(gdata, 0) bfile = h5py.File(file_out + '.basis.hdf5', 'r+', libver='latest') io.write_datasets(bfile, ['matched_thresholds_pos'], {'matched_thresholds_pos': thresholds}) bfile.close() comm.Barrier() data_file.close()
def main(argv=None): if argv is None: argv = sys.argv[1:] parallel_hdf5 = h5py.get_config().mpi user_path = pjoin(os.path.expanduser('~'), 'spyking-circus') tasks_list = None if not os.path.exists(user_path): os.makedirs(user_path) try: import cudamat as cmt cmt.init() HAVE_CUDA = True except Exception: HAVE_CUDA = False all_steps = [ 'whitening', 'clustering', 'fitting', 'gathering', 'extracting', 'filtering', 'converting', 'deconverting', 'benchmarking', 'merging', 'validating', 'thresholding' ] config_file = os.path.abspath(pkg_resources.resource_filename('circus', 'config.params')) header = get_colored_header() header += Fore.GREEN + 'Local CPUs : ' + Fore.CYAN + str(psutil.cpu_count()) + '\n' # header += Fore.GREEN + 'GPU detected : ' + Fore.CYAN + str(HAVE_CUDA) + '\n' header += Fore.GREEN + 'Parallel HDF5 : ' + Fore.CYAN + str(parallel_hdf5) + '\n' do_upgrade = '' if not SHARED_MEMORY: do_upgrade = Fore.WHITE + ' [please consider upgrading MPI]' header += Fore.GREEN + 'Shared memory : ' + Fore.CYAN + str(SHARED_MEMORY) + do_upgrade + '\n' header += '\n' header += Fore.GREEN + "##################################################################" header += Fore.RESET method_help = '''by default, all steps are performed, but a subset x,y can be done. Steps are: - filtering - whitening - clustering - fitting - merging [with or without a GUI for meta merging] - (extra) converting [export results to phy format] - (extra) thresholding [to get MUA activity only] - (extra) deconverting [import results from phy format] - (extra) gathering [force collection of results] - (extra) extracting [get templates from spike times] - (extra) benchmarking [with -o and -t] - (extra) validating [to compare performance with GT neurons]''' parser = argparse.ArgumentParser(description=header, formatter_class=argparse.RawTextHelpFormatter) parser.add_argument('datafile', help='data file (or a list of commands if batch mode)') parser.add_argument('-i', '--info', help='list the file formats supported by SpyKING CIRCUS', action='store_true') parser.add_argument('-m', '--method', default='filtering,whitening,clustering,fitting,merging', help=method_help) parser.add_argument('-c', '--cpu', type=int, default=max(1, int(psutil.cpu_count()/2)), help='number of CPU') # parser.add_argument('-g', '--gpu', type=int, default=0, help='number of GPU') parser.add_argument('-H', '--hostfile', help='hostfile for MPI', default=pjoin(user_path, 'circus.hosts')) parser.add_argument('-b', '--batch', help='datafile is a list of commands to launch, in a batch mode', action='store_true') parser.add_argument('-p', '--preview', help='GUI to display the first second filtered with thresholds', action='store_true') parser.add_argument('-r', '--result', help='GUI to display the results on top of raw data', action='store_true') parser.add_argument('-s', '--second', type=int, default=0, help='If preview mode, begining of the preview [in s]') parser.add_argument('-e', '--extension', help='extension to consider for merging, converting and deconverting', default='None') parser.add_argument('-o', '--output', help='output file [for generation of synthetic benchmarks]') parser.add_argument('-t', '--type', help='benchmark type', choices=['fitting', 'clustering', 'synchrony']) if len(argv) == 0: parser.print_help() sys.exit(0) args = parser.parse_args(argv) steps = args.method.split(',') for step in steps: if step not in all_steps: print_error(['The method "%s" is not recognized' % step]) sys.exit(0) # To save some typing later nb_gpu = 0 (nb_cpu, hostfile, batch, preview, result, extension, output, benchmark, info, second) = \ (args.cpu, args.hostfile, args.batch, args.preview, args.result, args.extension, args.output, args.type, args.info, args.second) filename = os.path.abspath(args.datafile) real_file = filename f_next, extens = os.path.splitext(filename) if info: if args.datafile.lower() in __supported_data_files__: filename = 'tmp' if len(__supported_data_files__[args.datafile.lower()].extension) > 0: filename += __supported_data_files__[args.datafile.lower()].extension[0] __supported_data_files__[args.datafile.lower()](filename, {}, is_empty=True)._display_requirements_() else: print_and_log([ '', 'To get info on any particular file format, do:', '>> spyking-circus file_format -i', '' ], 'default') print_and_log(list_all_file_format()) sys.exit(0) if extens == '.params': print_error(['You should launch the code on the data file!']) sys.exit(0) file_params = f_next + '.params' if not os.path.exists(file_params) and not batch: print(Fore.RED + 'The parameter file %s is not present!' % file_params) create_params = query_yes_no(Fore.WHITE + "Do you want SpyKING CIRCUS to create a parameter file?") if create_params: print(Fore.WHITE + "Creating %s" % file_params) print(Fore.WHITE + "Fill it properly before launching the code! (see documentation)") print_info(['Keep in mind that filtering is performed on site, so please', 'be sure to keep a copy of your data elsewhere']) shutil.copyfile(config_file, file_params) sys.exit(0) elif batch: tasks_list = filename if not batch: file_params = f_next + '.params' if not os.path.exists(file_params): print_and_log(["%s does not exist" % file_params], 'error') sys.exit(0) import ConfigParser as configparser parser = configparser.ConfigParser() myfile = open(file_params, 'r') lines = myfile.readlines() myfile.close() myfile = open(file_params, 'w') for l in lines: myfile.write(l.replace('\t', '')) myfile.close() parser.read(file_params) for section in CircusParser.__all_sections__: if parser.has_section(section): for (key, value) in parser.items(section): parser.set(section, key, value.split('#')[0].rstrip()) else: parser.add_section(section) try: use_output_dir = parser.get('data', 'output_dir') != '' except Exception: use_output_dir = False if use_output_dir: path = os.path.abspath(os.path.expanduser(parser.get('data', 'output_dir'))) file_out = os.path.join(path, os.path.basename(f_next)) if not os.path.exists(file_out): os.makedirs(file_out) else: file_out = f_next logfile = file_out + '.log' if os.path.exists(logfile): os.remove(logfile) logger = init_logging(logfile) params = CircusParser(filename) data_file = params.get_data_file(source=True, has_been_created=False) overwrite = params.getboolean('data', 'overwrite') file_format = params.get('data', 'file_format') if overwrite: support_parallel_write = data_file.parallel_write is_writable = data_file.is_writable else: support_parallel_write = __supported_data_files__['raw_binary'].parallel_write is_writable = __supported_data_files__['raw_binary'].is_writable if preview: print_and_log(['Preview mode, showing only seconds [%d-%d] of the recording' % (second, second+1)], 'info', logger) tmp_path_loc = os.path.join(os.path.abspath(params.get('data', 'file_out')), 'tmp') if not os.path.exists(tmp_path_loc): os.makedirs(tmp_path_loc) filename = os.path.join(tmp_path_loc, 'preview.dat') f_next, extens = os.path.splitext(filename) preview_params = f_next + '.params' shutil.copyfile(file_params, preview_params) steps = ['filtering', 'whitening'] chunk_size = int(params.rate) data_file.open() nb_chunks, _ = data_file.analyze(chunk_size) if nb_chunks <= (second + 1): print_and_log(['Recording is too short to display seconds [%d-%d]' % (second, second+1)]) sys.exit(0) local_chunk = data_file.get_snippet(int(second*params.rate), int(1.2*chunk_size)) description = data_file.get_description() data_file.close() new_params = CircusParser(filename, create_folders=False) new_params.write('data', 'chunk_size', '1') new_params.write('data', 'file_format', 'raw_binary') new_params.write('data', 'data_dtype', 'float32') new_params.write('data', 'data_offset', '0') new_params.write('data', 'dtype_offset', '0') new_params.write('data', 'stream_mode', 'None') new_params.write('data', 'overwrite', 'True') new_params.write('triggers', 'ignore_times', 'False') new_params.write('data', 'sampling_rate', str(params.rate)) new_params.write('whitening', 'safety_time', '0') new_params.write('clustering', 'safety_time', '0') new_params.write('whitening', 'chunk_size', '1') new_params.write('data', 'preview_path', params.file_params) new_params.write('data', 'output_dir', '') description['data_dtype'] = 'float32' description['dtype_offset'] = 0 description['data_offset'] = 0 description['gain'] = 1. new_params = CircusParser(filename) data_file_out = new_params.get_data_file(is_empty=True, params=description) support_parallel_write = data_file_out.parallel_write is_writable = data_file_out.is_writable data_file_out.allocate(shape=local_chunk.shape, data_dtype=numpy.float32) data_file_out.open('r+') data_file_out.set_data(0, local_chunk) data_file_out.close() if tasks_list is not None: with open(tasks_list, 'r') as f: for line in f: if len(line) > 0: subprocess.check_call(['spyking-circus'] + line.replace('\n', '').split(" ")) else: print_and_log(['Config file: %s' % (f_next + '.params')], 'debug', logger) print_and_log(['Data file : %s' % filename], 'debug', logger) print(get_colored_header()) print(Fore.GREEN + "File : " + Fore.CYAN + real_file) if preview: print(Fore.GREEN + "Steps : " + Fore.CYAN + "preview mode") elif result: print(Fore.GREEN + "Steps : " + Fore.CYAN + "result mode") else: print(Fore.GREEN + "Steps : " + Fore.CYAN + ", ".join(steps)) # print Fore.GREEN + "GPU detected : ", Fore.CYAN + str(HAVE_CUDA) print(Fore.GREEN + "Number of CPU : " + Fore.CYAN + str(nb_cpu) + "/" + str(psutil.cpu_count())) # if HAVE_CUDA: # print Fore.GREEN + "Number of GPU : ", Fore.CYAN + str(nb_gpu) print(Fore.GREEN + "Parallel HDF5 : " + Fore.CYAN + str(parallel_hdf5)) do_upgrade = '' use_shared_memory = get_shared_memory_flag(params) if not SHARED_MEMORY: do_upgrade = Fore.WHITE + ' [please consider upgrading MPI]' print(Fore.GREEN + "Shared memory : " + Fore.CYAN + str(use_shared_memory) + do_upgrade) print(Fore.GREEN + "Hostfile : " + Fore.CYAN + hostfile) print("") print(Fore.GREEN + "##################################################################") print("") print(Fore.RESET) # Launch the subtasks subtasks = [('filtering', 'mpirun'), ('whitening', 'mpirun'), ('clustering', 'mpirun'), ('fitting', 'mpirun'), ('extracting', 'mpirun'), ('gathering', 'python'), ('converting', 'mpirun'), ('deconverting', 'mpirun'), ('benchmarking', 'mpirun'), ('merging', 'mpirun'), ('validating', 'mpirun'), ('thresholding', 'mpirun')] # if HAVE_CUDA and nb_gpu > 0: # use_gpu = 'True' # else: use_gpu = 'False' time = data_stats(params) / 60.0 if preview: params = new_params if nb_cpu < psutil.cpu_count(): if use_gpu != 'True' and not result: print_and_log(['Using only %d out of %d local CPUs available (-c to change)' % (nb_cpu, psutil.cpu_count())], 'info', logger) if params.getboolean('detection', 'matched-filter') and not params.getboolean('clustering', 'smart_search'): print_and_log(['Smart Search should be activated for matched filtering'], 'info', logger) if time > 30 and not params.getboolean('clustering', 'smart_search'): print_and_log(['Smart Search should be activated for long recordings'], 'info', logger) n_edges = get_averaged_n_edges(params) if n_edges > 100 and not params.getboolean('clustering', 'compress'): print_and_log(['Template compression is highly recommended based on parameters'], 'info', logger) if not result: for subtask, command in subtasks: if subtask in steps: if command == 'python': # Directly call the launcher try: circus.launch(subtask, filename, nb_cpu, nb_gpu, use_gpu) except: print_and_log(['Step "%s" failed!' % subtask], 'error', logger) sys.exit(0) elif command == 'mpirun': # Use mpirun to make the call mpi_args = gather_mpi_arguments(hostfile, params) one_cpu = False if subtask in ['filtering', 'benchmarking'] and not is_writable: if not preview and overwrite: print_and_log(['The file format %s is read only!' % file_format, 'You should set overwite to False, to create a copy of the data.', 'However, note that if you have streams, informations on times', 'will be discarded'], 'info', logger) sys.exit(0) if subtask in ['filtering'] and not support_parallel_write and (args.cpu > 1): print_and_log(['No parallel writes for %s: only 1 node used for %s' %(file_format, subtask)], 'info', logger) nb_tasks = str(1) one_cpu = True else: if subtask != 'fitting': nb_tasks = str(args.cpu) else: # if use_gpu == 'True': # nb_tasks = str(args.gpu) # else: nb_tasks = str(args.cpu) if subtask == 'benchmarking': if (output is None) or (benchmark is None): print_and_log(["To generate synthetic datasets, you must provide output and type"], 'error', logger) sys.exit(0) mpi_args += [ '-np', nb_tasks, 'spyking-circus-subtask', subtask, filename, str(nb_cpu), str(nb_gpu), use_gpu, output, benchmark ] elif subtask in ['merging', 'converting']: mpi_args += [ '-np', nb_tasks, 'spyking-circus-subtask', subtask, filename, str(nb_cpu), str(nb_gpu), use_gpu, extension ] elif subtask in ['deconverting']: nb_tasks = str(1) nb_cpu = 1 mpi_args += [ '-np', nb_tasks, 'spyking-circus-subtask', subtask, filename, str(nb_cpu), str(nb_gpu), use_gpu, extension ] else: mpi_args += [ '-np', nb_tasks, 'spyking-circus-subtask', subtask, filename, str(nb_cpu), str(nb_gpu), use_gpu, str(one_cpu) ] print_and_log(['Launching task %s' % subtask], 'debug', logger) print_and_log(['Command: %s' % str(mpi_args)], 'debug', logger) try: subprocess.check_call(mpi_args) except subprocess.CalledProcessError as e: print_and_log(['Step "%s" failed for reason %s!' % (subtask, e)], 'error', logger) sys.exit(0) if preview or result: from circus.shared import gui import pylab try: from PyQt5.QtWidgets import QApplication except ImportError: from matplotlib.backends import qt_compat use_pyside = qt_compat.QT_API == qt_compat.QT_API_PYSIDE if use_pyside: from PySide.QtGui import QApplication else: from PyQt4.QtGui import QApplication app = QApplication([]) try: pylab.style.use('ggplot') except Exception: pass if preview: print_and_log(['Launching the preview GUI...'], 'debug', logger) mygui = gui.PreviewGUI(new_params) shutil.rmtree(tmp_path_loc) elif result: data_file = params.get_data_file() print_and_log(['Launching the result GUI...'], 'debug', logger) mygui = gui.PreviewGUI(params, show_fit=True) sys.exit(app.exec_())
def main(params, nb_cpu, nb_gpu, use_gpu): logger = init_logging(params.logfile) logger = logging.getLogger('circus.filtering') ################################################################# do_filter = params.getboolean('filtering', 'filter') filter_done = check_if_done(params, 'filter_done', logger) artefacts_done = check_if_done(params, 'artefacts_done', logger) median_done = check_if_done(params, 'median_done', logger) ground_done = check_if_done(params, 'ground_done', logger) file_out_suff = params.get('data', 'file_out_suff') clean_artefact = params.getboolean('triggers', 'clean_artefact') remove_median = params.getboolean('filtering', 'remove_median') sat_value = params.get('filtering', 'sat_value') sat_threshold = params.get('filtering', 'sat_threshold') if sat_value != '': flag_saturation = True sat_value = float(sat_value) else: flag_saturation = False common_ground = params.common_ground remove_ground = len(common_ground) > 0 nodes, edges = get_nodes_and_edges(params) N_total = params.nb_channels inv_nodes = numpy.zeros(N_total, dtype=numpy.int32) inv_nodes[nodes] = numpy.arange(len(nodes)) ################################################################# def filter_file(data_file_in, data_file_out, do_filtering, do_remove_median, do_remove_ground): """ Performs a high-pass and low-pass Butterworth filter on the data file. Parameters ---------- data_file_in : data_file_out : do_filtering : bool do_remove_median : bool do_remove_median : bool """ try: cut_off = params.getfloat('filtering', 'cut_off', check=False) cut_off = [cut_off, 0.95 * (params.rate / 2.0)] # Nyquist except Exception: cut_off = params.get('filtering', 'cut_off', check=False) cut_off = cut_off.split(',') try: cut_off[0] = float(cut_off[0]) except Exception: if comm.rank == 0: print_and_log(['First value of cut off must be a valid number'], 'error', logger) sys.exit(0) cut_off[1] = cut_off[1].replace(' ', '') if cut_off[1] == 'auto': cut_off[1] = 0.95 * (params.rate / 2.0) else: try: cut_off[1] = float(cut_off[1]) except Exception: if comm.rank == 0: print_and_log(['Second value of cut off must either auto, or a valid a number'], 'error', logger) sys.exit(0) chunk_size = detect_memory(params, filtering=True) butter_order = params.getint('filtering', 'butter_order') nb_chunks, _ = data_file_in.analyze(chunk_size) b, a = signal.butter(butter_order, np.array(cut_off)/(params.rate/2.), 'pass') all_chunks = numpy.arange(nb_chunks, dtype=numpy.int64) to_process = all_chunks[comm.rank::comm.size] loc_nb_chunks = len(to_process) N_total = params.nb_channels nb_shanks = len(params.probe['channel_groups']) if nb_shanks > 1: shank_channels = {} for i in list(params.probe['channel_groups'].keys()): shank_channels[i] = numpy.array(params.probe['channel_groups'][i]['channels'], dtype=numpy.int32) else: channel_group = list(params.probe['channel_groups'].keys())[0] process_all_channels = numpy.all(nodes == numpy.arange(N_total)) duration = int(0.1*params.rate) if comm.rank == 0: to_write = [] if do_filtering: to_write += ["Filtering with a Butterworth filter (order %d) in [%g, %g] Hz" % (butter_order, cut_off[0], cut_off[1])] if do_remove_median: to_write += ["Median over all channels is subtracted to each channels"] if do_remove_ground: to_write += ["Channels %s are used as reference channels in respective shanks" % common_ground] print_and_log(to_write, 'default', logger) to_explore = list(range(comm.rank, nb_chunks, comm.size)) if comm.rank == 0: to_explore = get_tqdm_progressbar(params, to_explore) if data_file_in == data_file_out: data_file_in.open(mode='r+') else: data_file_in.open(mode='r') data_file_out.open(mode='r+') if flag_saturation: comm.Barrier() saturation_times = open(file_out_suff + '.times-%d.data' % comm.rank, 'wb') saturation_channels = open(file_out_suff + '.channels-%d.data' % comm.rank, 'wb') saturation_values = open(file_out_suff + '.values-%d.data' % comm.rank, 'wb') if data_file_in.data_dtype in ['float32', numpy.float32, 'float64', numpy.float64]: max_value = numpy.finfo(data_file_in.data_dtype).max else: max_value = numpy.iinfo(data_file_in.data_dtype).max - data_file_in.dtype_offset saturation = sat_value * max_value for count, gidx in enumerate(to_explore): is_first = data_file_in.is_first_chunk(gidx, nb_chunks) is_last = data_file_in.is_last_chunk(gidx, nb_chunks) if not (is_first and is_last): if is_first: padding = (0, duration) elif is_last: padding = (-duration, 0) else: padding = (-duration, duration) else: padding = (0, 0) local_chunk, t_offset = data_file_in.get_data(gidx, chunk_size, padding) len_chunk = len(local_chunk) if flag_saturation: if sat_threshold == 'negative': indices = numpy.where(local_chunk <= saturation * data_file_in.gain) elif sat_threshold == 'positive': indices = numpy.where(local_chunk >= saturation * data_file_in.gain) elif sat_threshold == 'both': indices = numpy.where(numpy.abs(local_chunk) >= saturation * data_file_in.gain) if not process_all_channels: to_keep = numpy.in1d(indices[1], nodes) channels = inv_nodes[indices[1][to_keep]] times = indices[0][to_keep] else: channels = indices[1] times = indices[0] to_keep = (times >= padding[0]) & (times < (len_chunk-numpy.abs(padding[1]))) sub_times = times[to_keep] sub_channels = channels[to_keep] saturation_times.write((sub_times + t_offset).astype(numpy.uint32).tostring()) saturation_channels.write(sub_channels.astype(numpy.uint32).tostring()) saturation_values.write(local_chunk[sub_times, sub_channels].tostring()) if do_filtering: if not process_all_channels: local_chunk[:, nodes] = signal.filtfilt(b, a, local_chunk[:, nodes], axis=0) local_chunk[:, nodes] -= numpy.median(local_chunk[:, nodes], 0) else: local_chunk = signal.filtfilt(b, a, local_chunk, axis=0) local_chunk -= numpy.median(local_chunk, 0) if flag_saturation: local_chunk[times, channels] = 0 local_chunk = local_chunk[numpy.abs(padding[0]):len_chunk-numpy.abs(padding[1])] if do_remove_median: if nb_shanks == 1: if not process_all_channels: global_median = numpy.median(numpy.take(local_chunk, nodes, axis=1), 1) else: global_median = numpy.median(local_chunk, 1) local_chunk -= global_median[:, numpy.newaxis] else: for i in list(params.probe['channel_groups'].keys()): global_median = numpy.median(numpy.take(local_chunk, shank_channels[i], axis=1), 1) local_chunk[:, shank_channels[i]] -= global_median[:, numpy.newaxis] if do_remove_ground: if nb_shanks == 1: ground = local_chunk[:, common_ground[channel_group]] local_chunk -= ground[:, numpy.newaxis] else: for i in list(params.probe['channel_groups'].keys()): ground = local_chunk[:, common_ground[i]] local_chunk[:, shank_channels[i]] -= ground[:, numpy.newaxis] if data_file_in != data_file_out: if data_file_in.is_stream: g_offset = t_offset else: g_offset = t_offset - data_file_in.t_start else: g_offset = t_offset data_file_out.set_data(g_offset, local_chunk) sys.stderr.flush() comm.Barrier() if flag_saturation: saturation_times.flush() os.fsync(saturation_times.fileno()) saturation_times.close() saturation_values.flush() os.fsync(saturation_values.fileno()) saturation_values.close() saturation_channels.flush() os.fsync(saturation_channels.fileno()) saturation_channels.close() # We need to gather the results into a single file if comm.rank == 0: collect_saturation(comm.size, params, erase=True) sys.stderr.flush() comm.Barrier() def compute_artefacts(data_file): """ Compute artefact locations based on the [triggers] section of the params file. Parameters ---------- data_file : Return ------ dict A dictionary with the location of the artefacts """ trig_in_ms = params.getboolean('triggers', 'trig_in_ms') artefacts = numpy.loadtxt(params.get('triggers', 'trig_file'), comments=['#', '//']) windows = numpy.loadtxt(params.get('triggers', 'trig_windows'), comments=['#', '//']) make_plots = params.get('triggers', 'make_plots') plot_path = os.path.join(params.get('data', 'file_out_suff'), 'plots') if len(windows.shape) == 1: windows = windows.reshape(1, 2) if len(artefacts.shape) == 1: artefacts = artefacts.reshape(1, 2) if trig_in_ms: if comm.rank == 0: print_and_log(['Artefact times are read in ms'], 'debug', logger) artefacts[:, 1] *= numpy.int64(data_file.sampling_rate*1e-3) windows[:, 1] *= numpy.int64(data_file.sampling_rate*1e-3) else: if comm.rank == 0: print_and_log(['Artefact times are read in timesteps'], 'debug', logger) artefacts = artefacts.astype(numpy.int64) windows = windows.astype(numpy.int64) nb_stimuli = len(numpy.unique(artefacts[:, 0])) mytest = numpy.all(numpy.in1d(numpy.unique(artefacts[:, 0]), numpy.unique(windows[:, 0]))) if not mytest: if comm.rank == 0: print_and_log(['Error in the trigger file: not all artefacts are defined'], 'error', logger) sys.exit(0) all_labels = artefacts[:, 0] all_times = artefacts[:, 1] mask = (all_times >= 0) & (all_times + numpy.max(windows[:, 1]) < data_file.t_stop) all_times = numpy.compress(mask, all_times) all_labels = numpy.compress(mask, all_labels) local_labels = numpy.unique(all_labels)[comm.rank::comm.size] if comm.rank == 0: to_write = ["Computing averaged artefacts from %d stimuli" % nb_stimuli] print_and_log(to_write, 'default', logger) if not os.path.exists(plot_path): os.makedirs(plot_path) local_labels = get_tqdm_progressbar(params, local_labels) comm.Barrier() # First we need to get the average artefacts art_dict = {} for count, artefact in enumerate(local_labels): indices = numpy.where(all_labels == artefact)[0].astype(numpy.uint32) tmp = numpy.where(windows[:, 0] == artefact)[0] tau = numpy.int64(windows[tmp, 1]) pspikes = all_times[indices] times = numpy.sort(numpy.random.permutation(pspikes)[:500]) if len(numpy.where(numpy.diff(times) < tau)[0]) > 0: if comm.rank == 0: print_and_log(['Stimulation times for artefact %d are too close!' % artefact], 'error', logger) sys.exit(0) art_dict[artefact] = get_artefact(params, times, tau, nodes) if make_plots not in ['None', '']: save = [plot_path, '%d.%s' % (artefact, make_plots)] plot.view_artefact(art_dict[artefact], save=save) sys.stderr.flush() return art_dict def remove_artefacts(data_file, art_dict): """ Remove artefact times based on the [triggers] section of the params file. Parameters ---------- data_file : art_dict : dict a dictionary with the artefact times. """ trig_in_ms = params.getboolean('triggers', 'trig_in_ms') artefacts = numpy.loadtxt(params.get('triggers', 'trig_file'), comments=['#', '//']) windows = numpy.loadtxt(params.get('triggers', 'trig_windows'), comments=['#', '//']) make_plots = params.get('triggers', 'make_plots') plot_path = os.path.join(params.get('data', 'file_out_suff'), 'plots') if len(windows.shape) == 1: windows = windows.reshape(1, 2) if len(artefacts.shape) == 1: artefacts = artefacts.reshape(1, 2) if trig_in_ms: if comm.rank == 0: print_and_log(['Artefact times are read in ms'], 'debug', logger) artefacts[:, 1] *= numpy.int64(data_file.sampling_rate * 1e-3) windows[:, 1] *= numpy.int64(data_file.sampling_rate * 1e-3) else: if comm.rank == 0: print_and_log(['Artefact times are read in timesteps'], 'debug', logger) artefacts = artefacts.astype(numpy.int64) windows = windows.astype(numpy.int64) nb_stimuli = len(numpy.unique(artefacts[:, 0])) mytest = numpy.all(numpy.in1d(numpy.unique(artefacts[:, 0]), numpy.unique(windows[:, 0]))) if not mytest: if comm.rank == 0: print_and_log(['Error in the trigger files: not all artefacts are defined'], 'error', logger) sys.exit(0) all_labels = artefacts[:, 0] all_times = artefacts[:, 1] local_labels = numpy.unique(all_labels)[comm.rank::comm.size] mask = numpy.in1d(all_labels, local_labels) all_times = numpy.compress(mask, all_times) all_labels = numpy.compress(mask, all_labels) mask = (all_times >= 0) & (all_times < data_file.t_stop) all_times = numpy.compress(mask, all_times) all_labels = numpy.compress(mask, all_labels) if comm.rank == 0: to_write = ["Removing artefacts from %d stimuli" % nb_stimuli] print_and_log(to_write, 'default', logger) all_times = get_tqdm_progressbar(params, all_times) comm.Barrier() for count, time in enumerate(all_times): label = all_labels[count] tmp = numpy.where(windows[:, 0] == label)[0][0] tau = numpy.int64(windows[tmp, 1]) if (data_file.t_stop - time) < tau: tau = max_offset - time local_chunk = data_file.get_snippet(time, tau) for idx, i in enumerate(nodes): local_chunk[:, i] -= art_dict[label][idx, :tau] data_file.set_data(time, local_chunk) comm.Barrier() sys.stderr.flush() if comm.rank == 0: print_and_log(['Initializing the filtering step...'], 'debug', logger) if params.getboolean('data', 'overwrite'): if comm.rank == 0: print_and_log(['Reading the input file...'], 'debug', logger) data_file_in = params.get_data_file() data_file_out = data_file_in else: if comm.rank == 0: print_and_log(['Overwrite is set to False, so creating a new datafile...'], 'debug', logger) if comm.rank == 0: print_and_log(['Reading the input file...'], 'debug', logger) if os.path.exists(params.get('data', 'data_file_no_overwrite')): has_been_created = True else: has_been_created = False if not has_been_created and (filter_done or median_done or artefacts_done): if comm.rank == 0: print_and_log(['The filtering is done but file not present. See no_edits section'], 'error', logger) sys.exit(0) if not has_been_created: data_file_in = params.get_data_file(source=True, has_been_created=has_been_created) else: data_file_in = params.get_data_file(source=False, has_been_created=has_been_created) if comm.rank == 0: print_and_log(['Reading the output file and allocating ressources...'], 'debug', logger) description = data_file_in.get_description() description['data_dtype'] = 'float32' description['dtype_offset'] = 0 description['data_offset'] = 0 comm.Barrier() data_file_out = params.get_data_file(is_empty=not has_been_created, params=description) if comm.rank == 0: print_and_log(['Allocating space for filtered files...'], 'debug', logger) if not has_been_created: data_file_out.allocate(shape=data_file_in.shape) comm.Barrier() if clean_artefact: if not (os.path.exists(params.get('triggers', 'trig_file')) and os.path.exists(params.get('triggers', 'trig_windows'))): if comm.rank == 0: print_and_log(['trig_file or trig_windows file can not be found'], 'error', logger) sys.exit(0) to_write = [] if do_filter and filter_done: do_filter = False to_write += ["Filtering has already been done"] if remove_median and median_done: remove_median = False to_write += ["Median over all channels has already been removed"] if remove_ground and ground_done: remove_ground = False to_write += ["Common ground %s has alread been subtracted" % common_ground] if comm.rank == 0 and len(to_write) > 0: print_and_log(to_write, 'info', logger) if params.getboolean('data', 'overwrite'): data_file_in.open(mode='r+') else: data_file_in.open(mode='r') data_file_out.open(mode='r+') if do_filter or remove_median or remove_ground: if comm.rank == 0: if do_filter: params.write('noedits', 'filter_done', 'Started') if remove_median: params.write('noedits', 'median_done', 'Started') if remove_ground: params.write('noedits', 'ground_done', 'Started') filter_file(data_file_in, data_file_out, do_filter, remove_median, remove_ground) if comm.rank == 0: if do_filter: params.write('noedits', 'filter_done', 'True') if remove_median: params.write('noedits', 'median_done', 'True') if remove_ground: params.write('noedits', 'ground_done', 'True') if clean_artefact and artefacts_done: clean_artefact = False if comm.rank == 0: print_and_log(['Artefacts have already been removed'], 'debug', logger) if clean_artefact: art_dict = compute_artefacts(data_file_in) if comm.rank == 0: params.write('noedits', 'artefacts_done', 'Started') remove_artefacts(data_file_out, art_dict) if comm.rank == 0: if clean_artefact: params.write('noedits', 'artefacts_done', 'True') data_file_in.close() if not params.getboolean('data', 'overwrite'): data_file_out.close() comm.Barrier()
def main(argv=None): if argv is None: argv = sys.argv[1:] header = get_colored_header() parser = argparse.ArgumentParser( description=header, formatter_class=argparse.RawTextHelpFormatter) parser.add_argument('datafile', help='data file') parser.add_argument('-e', '--extension', help='extension to consider for visualization', default='') if len(argv) == 0: parser.print_help() sys.exit() args = parser.parse_args(argv) filename = os.path.abspath(args.datafile) extension = args.extension params = CircusParser(filename) if os.path.exists(params.logfile): os.remove(params.logfile) logger = init_logging(params.logfile) logger = logging.getLogger(__name__) mytest = StrictVersion(phycontrib.__version__) >= StrictVersion("1.0.12") if not mytest: print_and_log( ['You need to update phy-contrib to the latest git version'], 'error', logger) sys.exit(1) if not test_patch_for_similarities(params, extension): print_and_log( ['You should re-export the data because of a fix in 0.6'], 'error', logger) continue_anyway = query_yes_no( Fore.WHITE + "Continue anyway (results may not be fully correct)?", default=None) if not continue_anyway: sys.exit(1) data_file = params.get_data_file() data_dtype = data_file.data_dtype if data_file.params.has_key('data_offset'): data_offset = data_file.data_offset else: data_offset = 0 file_format = data_file.description file_out_suff = params.get('data', 'file_out_suff') if file_format not in supported_by_phy: print_and_log([ "File format %s is not supported by phy. TraceView disabled" % file_format ], 'info', logger) if numpy.iterable(data_file.gain): print_and_log( ['Multiple gains are not supported, using a default value of 1'], 'info', logger) gain = 1 else: if data_file.gain != 1: print_and_log([ "Gain of %g is not supported by phy. Expecting a scaling mismatch" % data_file.gain ], 'info', logger) gain = data_file.gain probe = params.probe if extension != '': extension = '-' + extension output_path = params.get('data', 'file_out_suff') + extension + '.GUI' if not os.path.exists(output_path): print_and_log( ['Data should be first exported with the converting method!'], 'error', logger) else: print_and_log(["Launching the phy GUI..."], 'info', logger) gui_params = {} if file_format in supported_by_phy: if not params.getboolean('data', 'overwrite'): gui_params['dat_path'] = params.get('data', 'data_file_no_overwrite') else: if params.get('data', 'stream_mode') == 'multi-files': data_file = params.get_data_file(source=True, has_been_created=False) gui_params['dat_path'] = ' '.join( data_file.get_file_names()) else: gui_params['dat_path'] = params.get('data', 'data_file') else: gui_params['dat_path'] = 'giverandomname.dat' gui_params['n_channels_dat'] = params.nb_channels gui_params['n_features_per_channel'] = 5 gui_params['dtype'] = data_dtype gui_params['offset'] = data_offset gui_params['sample_rate'] = params.rate gui_params['dir_path'] = output_path gui_params['hp_filtered'] = True f = open(os.path.join(output_path, 'params.py'), 'w') for key, value in gui_params.items(): if key in ['dir_path', 'dat_path', 'dtype']: f.write('%s = "%s"\n' % (key, value)) else: f.write("%s = %s\n" % (key, value)) f.close() os.chdir(output_path) create_app() controller = TemplateController(**gui_params) gui = controller.create_gui() gui.show() run_app() gui.close() del gui