def main(params, nb_cpu, nb_gpu, use_gpu): ################################################################# # params = detect_memory(params) _ = init_logging(params.logfile) SHARED_MEMORY = get_shared_memory_flag(params) logger = logging.getLogger('circus.fitting') data_file = params.data_file n_e = params.getint('data', 'N_e') n_total = params.nb_channels n_t = params.getint('detection', 'N_t') template_shift = params.getint('detection', 'template_shift') # file_out = params.get('data', 'file_out') file_out_suff = params.get('data', 'file_out_suff') sign_peaks = params.get('detection', 'peaks') matched_filter = params.getboolean('detection', 'matched-filter') # spike_thresh = params.getfloat('detection', 'spike_thresh') ratio_thresh = params.getfloat('fitting', 'ratio_thresh') two_components = params.getboolean('fitting', 'two_components') # spike_width = params.getfloat('detection', 'spike_width') # dist_peaks = params.getint('detection', 'dist_peaks') do_temporal_whitening = params.getboolean('whitening', 'temporal') do_spatial_whitening = params.getboolean('whitening', 'spatial') templates_normalization = params.getboolean('clustering', 'templates_normalization') # TODO test, switch, test! chunk_size = detect_memory(params, fitting=True) gpu_only = params.getboolean('fitting', 'gpu_only') nodes, edges = get_nodes_and_edges(params) tmp_limits = params.get('fitting', 'amp_limits').replace('(', '').replace(')', '').split(',') tmp_limits = [float(v) for v in tmp_limits] amp_auto = params.getboolean('fitting', 'amp_auto') auto_nb_chances = params.getboolean('fitting', 'auto_nb_chances') if auto_nb_chances: nb_chances = io.load_data(params, 'nb_chances') max_nb_chances = params.getint('fitting', 'max_nb_chances') percent_nb_chances = params.getfloat('fitting', 'percent_nb_chances') total_nb_chances = max(1, numpy.nanpercentile(nb_chances, percent_nb_chances)) total_nb_chances = min(total_nb_chances, max_nb_chances) if comm.rank == 0: print_and_log(['nb_chances set automatically to %g' %total_nb_chances], 'debug', logger) else: total_nb_chances = params.getfloat('fitting', 'nb_chances') max_chunk = params.getfloat('fitting', 'max_chunk') # noise_thr = params.getfloat('clustering', 'noise_thr') collect_all = params.getboolean('fitting', 'collect_all') min_second_component = params.getfloat('fitting', 'min_second_component') debug = params.getboolean('fitting', 'debug') ignore_dead_times = params.getboolean('triggers', 'ignore_times') inv_nodes = numpy.zeros(n_total, dtype=numpy.int32) inv_nodes[nodes] = numpy.arange(len(nodes)) data_file.open() ################################################################# if use_gpu: import cudamat as cmt # # Need to properly handle multi GPU per MPI nodes? if nb_gpu > nb_cpu: gpu_id = int(comm.rank // nb_cpu) else: gpu_id = 0 cmt.cuda_set_device(gpu_id) cmt.init() cmt.cuda_sync_threads() if SHARED_MEMORY: templates, _ = io.load_data_memshared(params, 'templates', normalize=templates_normalization, transpose=True) N_tm, x = templates.shape else: templates = io.load_data(params, 'templates') x, N_tm = templates.shape temp_2_shift = 2 * template_shift temp_3_shift = 3 * template_shift full_gpu = use_gpu and gpu_only n_tm = N_tm // 2 n_scalar = n_e * n_t temp_window = numpy.arange(-template_shift, template_shift + 1) size_window = n_e * (2 * template_shift + 1) if not amp_auto: amp_limits = numpy.zeros((n_tm, 2)) amp_limits[:, 0] = tmp_limits[0] amp_limits[:, 1] = tmp_limits[1] else: amp_limits = io.load_data(params, 'limits') norm_templates = io.load_data(params, 'norm-templates') if not templates_normalization: norm_templates_2 = (norm_templates ** 2.0) * n_scalar if not SHARED_MEMORY: # Normalize templates (if necessary). if templates_normalization: for idx in range(templates.shape[1]): myslice = numpy.arange(templates.indptr[idx], templates.indptr[idx+1]) templates.data[myslice] /= norm_templates[idx] # Transpose templates. templates = templates.T waveform_neg = numpy.empty(0) # default assignment (for PyCharm code inspection) matched_thresholds_neg = None # default assignment (for PyCharm code inspection) waveform_pos = numpy.empty(0) # default assignment (for PyCharm code inspection) matched_thresholds_pos = None # default assignment (for PyCharm code inspection) if matched_filter: if sign_peaks in ['negative', 'both']: waveform_neg = io.load_data(params, 'waveform')[::-1] waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) * len(waveform_neg)) matched_thresholds_neg = io.load_data(params, 'matched-thresholds') if sign_peaks in ['positive', 'both']: waveform_pos = io.load_data(params, 'waveform-pos')[::-1] waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) * len(waveform_pos)) matched_thresholds_pos = io.load_data(params, 'matched-thresholds-pos') if ignore_dead_times: all_dead_times = get_dead_times(params) else: all_dead_times = None # default assignment (for PyCharm code inspection) thresholds = io.get_accurate_thresholds(params, ratio_thresh) neighbors = {} if collect_all: for i in range(0, n_tm): tmp = templates[i, :].toarray().reshape(n_e, n_t) if templates_normalization: tmp = tmp * norm_templates[i] neighbors[i] = numpy.where(numpy.sum(tmp, axis=1) != 0.0)[0] if use_gpu: templates = cmt.SparseCUDAMatrix(templates, copy_on_host=False) info_string = '' if comm.rank == 0: if use_gpu: info_string = "using %d GPUs" % comm.size else: info_string = "using %d CPUs" % comm.size comm.Barrier() c_overlap = io.get_overlaps(params, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu) over_shape = c_overlap.get('over_shape')[:] n_over = int(numpy.sqrt(over_shape[0])) s_over = over_shape[1] # # If the number of overlaps is different from templates, we need to recompute them. if n_over != N_tm: if comm.rank == 0: print_and_log(['Templates have been modified, recomputing the overlaps...'], 'default', logger) c_overlap = io.get_overlaps(params, erase=True, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu) over_shape = c_overlap.get('over_shape')[:] n_over = int(numpy.sqrt(over_shape[0])) s_over = over_shape[1] if SHARED_MEMORY: c_overs, _ = io.load_data_memshared(params, 'overlaps') else: c_overs = io.load_data(params, 'overlaps') comm.Barrier() if n_tm == 0: if comm.rank == 0: print_and_log(["No templates present. Redo clustering?"], 'default', logger) sys.exit(0) if comm.rank == 0: print_and_log(["Here comes the SpyKING CIRCUS %s and %d templates..." % (info_string, n_tm)], 'default', logger) purge(file_out_suff, '.data') if do_spatial_whitening: spatial_whitening = io.load_data(params, 'spatial_whitening') else: spatial_whitening = None # default assignment (for PyCharm code inspection) if do_temporal_whitening: temporal_whitening = io.load_data(params, 'temporal_whitening') else: temporal_whitening = None # default assignment (for PyCharm code inspection) if full_gpu: try: # If memory on the GPU is large enough, we load the overlaps onto it for i in range(n_over): c_overs[i] = cmt.SparseCUDAMatrix(c_overs[i], copy_on_host=False) except Exception: if comm.rank == 0: print_and_log(["Not enough memory on GPUs: GPUs are used for projection only"], 'info', logger) for i in range(n_over): if i in c_overs: del c_overs[i] full_gpu = False nb_chunks, last_chunk_len = data_file.analyze(chunk_size) processed_chunks = int(min(nb_chunks, max_chunk)) comm.Barrier() spiketimes_file = open(file_out_suff + '.spiketimes-%d.data' % comm.rank, 'wb') comm.Barrier() amplitudes_file = open(file_out_suff + '.amplitudes-%d.data' % comm.rank, 'wb') comm.Barrier() templates_file = open(file_out_suff + '.templates-%d.data' % comm.rank, 'wb') comm.Barrier() if collect_all: garbage_times_file = open(file_out_suff + '.gspiketimes-%d.data' % comm.rank, 'wb') comm.Barrier() garbage_temp_file = open(file_out_suff + '.gtemplates-%d.data' % comm.rank, 'wb') comm.Barrier() else: garbage_times_file = None # default assignment (for PyCharm code inspection) garbage_temp_file = None # default assignment (for PyCharm code inspection) if debug: # Open debug files. chunk_nbs_debug_file = open(file_out_suff + '.chunk_nbs_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() iteration_nbs_debug_file = open(file_out_suff + '.iteration_nbs_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() peak_nbs_debug_file = open(file_out_suff + '.peak_nbs_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() peak_local_time_steps_debug_file = open( file_out_suff + '.peak_local_time_steps_debug_%d.data' % comm.rank, mode='wb' ) comm.Barrier() peak_time_steps_debug_file = open(file_out_suff + '.peak_time_steps_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() peak_scalar_products_debug_file = open( file_out_suff + '.peak_scalar_products_debug_%d.data' % comm.rank, mode='wb' ) comm.Barrier() peak_solved_flags_debug_file = open(file_out_suff + '.peak_solved_flags_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() template_nbs_debug_file = open(file_out_suff + '.template_nbs_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() success_flags_debug_file = open(file_out_suff + '.success_flags_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() else: chunk_nbs_debug_file = None # default assignment (for PyCharm code inspection) iteration_nbs_debug_file = None # default assignment (for PyCharm code inspection) peak_nbs_debug_file = None # default assignment (for PyCharm code inspection) peak_local_time_steps_debug_file = None # default assignment (for PyCharm code inspection) peak_time_steps_debug_file = None # default assignment (for PyCharm code inspection) peak_scalar_products_debug_file = None # default assignment (for PyCharm code inspection) peak_solved_flags_debug_file = None # default assignment (for PyCharm code inspection) template_nbs_debug_file = None # default assignment (for PyCharm code inspection) success_flags_debug_file = None # default assignment (for PyCharm code inspection) if use_gpu and do_spatial_whitening: spatial_whitening = cmt.CUDAMatrix(spatial_whitening, copy_on_host=False) last_chunk_size = 0 slice_indices = numpy.zeros(0, dtype=numpy.int32) to_explore = range(comm.rank, processed_chunks, comm.size) if comm.rank == 0: to_explore = get_tqdm_progressbar(params, to_explore) for gcount, gidx in enumerate(to_explore): # print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..." # # We need to deal with the borders by taking chunks of size [0, chunck_size + template_shift]. is_first = data_file.is_first_chunk(gidx, nb_chunks) is_last = data_file.is_last_chunk(gidx, nb_chunks) if not (is_first and is_last): if is_last: padding = (-temp_3_shift, 0) elif is_first: padding = (0, temp_3_shift) else: padding = (-temp_3_shift, temp_3_shift) else: padding = (0, 0) result = { 'spiketimes': [], 'amplitudes': [], 'templates': [], } result_debug = { 'chunk_nbs': [], 'iteration_nbs': [], 'peak_nbs': [], 'peak_local_time_steps': [], 'peak_time_steps': [], 'peak_scalar_products': [], 'peak_solved_flags': [], 'template_nbs': [], 'success_flags': [], } local_chunk, t_offset = data_file.get_data(gidx, chunk_size, padding, nodes=nodes) len_chunk = len(local_chunk) if do_spatial_whitening: if use_gpu: local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False) local_chunk = local_chunk.dot(spatial_whitening).asarray() else: local_chunk = numpy.dot(local_chunk, spatial_whitening) if do_temporal_whitening: local_chunk = scipy.ndimage.filters.convolve1d(local_chunk, temporal_whitening, axis=0, mode='constant') # Extracting peaks. all_found_spikes = {} if collect_all: for i in range(n_e): all_found_spikes[i] = [] local_peaktimes = [numpy.empty(0, dtype=numpy.uint32)] if matched_filter: if sign_peaks in ['positive', 'both']: filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_pos, axis=0, mode='constant') for i in range(n_e): peaktimes = scipy.signal.find_peaks(filter_chunk[:, i], height=matched_thresholds_pos[i])[0] local_peaktimes.append(peaktimes) if collect_all: all_found_spikes[i] += peaktimes.tolist() if sign_peaks in ['negative', 'both']: filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_neg, axis=0, mode='constant') for i in range(n_e): peaktimes = scipy.signal.find_peaks(filter_chunk[:, i], height=matched_thresholds_neg[i])[0] local_peaktimes.append(peaktimes) if collect_all: all_found_spikes[i] += peaktimes.tolist() local_peaktimes = numpy.concatenate(local_peaktimes) else: for i in range(n_e): if sign_peaks == 'negative': peaktimes = scipy.signal.find_peaks(-local_chunk[:, i], height=thresholds[i])[0] elif sign_peaks == 'positive': peaktimes = scipy.signal.find_peaks(local_chunk[:, i], height=thresholds[i])[0] elif sign_peaks == 'both': peaktimes = scipy.signal.find_peaks(numpy.abs(local_chunk[:, i]), height=thresholds[i])[0] else: raise ValueError("Unexpected value %s" % sign_peaks) local_peaktimes.append(peaktimes) if collect_all: all_found_spikes[i] += peaktimes.tolist() local_peaktimes = numpy.concatenate(local_peaktimes) local_peaktimes = numpy.unique(local_peaktimes) g_offset = t_offset + padding[0] if ignore_dead_times: dead_indices = numpy.searchsorted(all_dead_times, [t_offset, t_offset + chunk_size]) if dead_indices[0] != dead_indices[1]: is_included = numpy.in1d(local_peaktimes + g_offset, all_dead_times[dead_indices[0]:dead_indices[1]]) local_peaktimes = local_peaktimes[~is_included] local_peaktimes = numpy.sort(local_peaktimes) else: dead_indices = None # default assignment (for PyCharm code inspection) # print "Removing the useless borders..." local_borders = (template_shift, len_chunk - template_shift) idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes < local_borders[1]) local_peaktimes = numpy.compress(idx, local_peaktimes) if collect_all: for i in range(n_e): all_found_spikes[i] = numpy.array(all_found_spikes[i], dtype=numpy.uint32) if ignore_dead_times: if dead_indices[0] != dead_indices[1]: is_included = numpy.in1d( all_found_spikes[i] + g_offset, all_dead_times[dead_indices[0]:dead_indices[1]] ) all_found_spikes[i] = all_found_spikes[i][~is_included] all_found_spikes[i] = numpy.sort(all_found_spikes[i]) idx = (all_found_spikes[i] >= local_borders[0]) & (all_found_spikes[i] < local_borders[1]) all_found_spikes[i] = numpy.compress(idx, all_found_spikes[i]) nb_local_peak_times = len(local_peaktimes) if full_gpu: # all_indices = cmt.CUDAMatrix(all_indices) # tmp_gpu = cmt.CUDAMatrix(local_peaktimes.reshape((1, nb_local_peak_times)), copy_on_host=False) _ = cmt.CUDAMatrix(local_peaktimes.reshape((1, nb_local_peak_times)), copy_on_host=False) if nb_local_peak_times > 0: # print "Computing the b (should full_gpu by putting all chunks on GPU if possible?)..." if collect_all: c_local_chunk = local_chunk.copy() else: c_local_chunk = None # default assignment (for PyCharm code inspection) sub_mat = local_chunk[local_peaktimes[:, None] + temp_window] sub_mat = sub_mat.transpose(2, 1, 0).reshape(size_window, nb_local_peak_times) del local_chunk if use_gpu: sub_mat = cmt.CUDAMatrix(sub_mat, copy_on_host=False) b = cmt.sparse_dot(templates, sub_mat) else: b = templates.dot(sub_mat) del sub_mat local_restriction = (t_offset, t_offset + chunk_size) all_spikes = local_peaktimes + g_offset # Because for GPU, slicing by columns is more efficient, we need to transpose b # b = b.transpose() if use_gpu and not full_gpu: b = b.asarray() failure = numpy.zeros(nb_local_peak_times, dtype=numpy.int32) if full_gpu: mask = numpy.zeros((2 * n_tm, nb_local_peak_times), dtype=numpy.float32) mask[:n_tm, :] = 1 # data = cmt.empty(mask.shape) _ = cmt.empty(mask.shape) patch_gpu = b.shape[1] == 1 else: patch_gpu = None if collect_all: c_all_times = numpy.zeros((len_chunk, n_e), dtype=numpy.bool) c_min_times = numpy.maximum(numpy.arange(len_chunk) - template_shift, 0) c_max_times = numpy.minimum(numpy.arange(len_chunk) + template_shift + 1, len_chunk) for i in range(n_e): c_all_times[all_found_spikes[i], i] = True else: c_all_times = None # default assignment (for PyCharm code inspection) c_min_times = None # default assignment (for PyCharm code inspection) c_max_times = None # default assignment (for PyCharm code inspection) iteration_nb = 0 local_max = 0 numerous_argmax = False nb_argmax = n_tm best_indices = numpy.zeros(0, dtype=numpy.int32) data = b[:n_tm, :] flatten_data = data.ravel() while numpy.mean(failure) < total_nb_chances: # Is there a way to update sub_b * mask at the same time? if full_gpu: b_array = b.asarray() else: b_array = None if numerous_argmax: if len(best_indices) == 0: best_indices = largest_indices(flatten_data, nb_argmax) best_template_index, peak_index = numpy.unravel_index(best_indices[0], data.shape) else: best_template_index, peak_index = numpy.unravel_index(data.argmax(), data.shape) peak_scalar_product = data[best_template_index, peak_index] best_template2_index = best_template_index + n_tm if templates_normalization: if full_gpu: best_amp = b_array[best_template_index, peak_index] / n_scalar best_amp2 = b_array[best_template2_index, peak_index] / n_scalar else: best_amp = b[best_template_index, peak_index] / n_scalar if two_components: best_amp2 = b[best_template2_index, peak_index] / n_scalar else: best_amp2 = 0.0 best_amp_n = best_amp / norm_templates[best_template_index] best_amp2_n = best_amp2 / norm_templates[best_template2_index] else: if full_gpu: best_amp = b_array[best_template_index, peak_index] best_amp = best_amp / norm_templates_2[best_template_index] # TODO is `best_amp` value correct? best_amp2 = b_array[best_template2_index, peak_index] best_amp2 = best_amp2 / norm_templates_2[best_template2_index] # TODO is `best_amp2` value correct? else: best_amp = b[best_template_index, peak_index] best_amp = best_amp / norm_templates_2[best_template_index] # TODO is `best_amp` value correct? if two_components: best_amp2 = b[best_template2_index, peak_index] best_amp2 = best_amp2 / norm_templates_2[best_template2_index] # TODO is `best_amp2` value correct? else: best_amp2 = 0.0 best_amp_n = best_amp best_amp2_n = best_amp2 # Verify amplitude constraint. a_min, a_max = amp_limits[best_template_index, :] if (a_min <= best_amp_n) & (best_amp_n <= a_max): # Keep the matching. peak_time_step = local_peaktimes[peak_index] peak_data = (local_peaktimes - peak_time_step).astype(np.int32) is_neighbor = np.where(np.abs(peak_data) <= temp_2_shift)[0] idx_neighbor = peak_data[is_neighbor] + temp_2_shift nb_neighbors = len(is_neighbor) indices = np.zeros((s_over, nb_neighbors), dtype=np.int32) indices[idx_neighbor, np.arange(nb_neighbors)] = 1 if full_gpu: indices = cmt.CUDAMatrix(indices, copy_on_host=False) if patch_gpu: b_lines = b.get_col_slice(0, b.shape[0]) else: b_lines = b.get_col_slice(is_neighbor[0], is_neighbor[-1]+1) tmp1 = cmt.sparse_dot(c_overs[best_template_index], indices, mult=-best_amp) tmp2 = cmt.sparse_dot(c_overs[best_template2_index], indices, mult=-best_amp2) b_lines.add(tmp1.add(tmp2)) del tmp1, tmp2 else: tmp1 = c_overs[best_template_index].multiply(-best_amp) if numpy.abs(best_amp2) > min_second_component: tmp1 += c_overs[best_template2_index].multiply(-best_amp2) b[:, is_neighbor] += tmp1.dot(indices) numerous_argmax = False # Add matching to the result. t_spike = all_spikes[peak_index] if (t_spike >= local_restriction[0]) and (t_spike < local_restriction[1]): result['spiketimes'] += [t_spike] result['amplitudes'] += [(best_amp_n, best_amp2_n)] result['templates'] += [best_template_index] # Mark current matching as tried. b[best_template_index, peak_index] = -numpy.inf # Save debug data. if debug: result_debug['chunk_nbs'] += [gidx] result_debug['iteration_nbs'] += [iteration_nb] result_debug['peak_nbs'] += [peak_index] result_debug['peak_local_time_steps'] += [local_peaktimes[peak_index]] result_debug['peak_time_steps'] += [all_spikes[peak_index]] result_debug['peak_scalar_products'] += [peak_scalar_product] result_debug['peak_solved_flags'] += [b[best_template_index, peak_index]] result_debug['template_nbs'] += [best_template_index] result_debug['success_flags'] += [True] else: # Reject the matching. numerous_argmax = True # Update failure counter of the peak. failure[peak_index] += 1 # If the maximal number of failures is reached then mark peak as solved (i.e. not fitted). if failure[peak_index] >= total_nb_chances: # Mark all the matching associated to the current peak as tried. b[:, peak_index] = -numpy.inf index = numpy.arange(n_tm) * nb_local_peak_times + peak_index else: # Mark current matching as tried. b[best_template_index, peak_index] = -numpy.inf index = best_template_index * nb_local_peak_times + peak_index if numerous_argmax: best_indices = best_indices[~numpy.in1d(best_indices, index)] # Save debug data. if debug: result_debug['chunk_nbs'] += [gidx] result_debug['iteration_nbs'] += [iteration_nb] result_debug['peak_nbs'] += [peak_index] result_debug['peak_local_time_steps'] += [local_peaktimes[peak_index]] result_debug['peak_time_steps'] += [all_spikes[peak_index]] result_debug['peak_scalar_products'] += [peak_scalar_product] result_debug['peak_solved_flags'] += [b[best_template_index, peak_index]] result_debug['template_nbs'] += [best_template_index] result_debug['success_flags'] += [False] iteration_nb += 1 spikes_to_write = numpy.array(result['spiketimes'], dtype=numpy.uint32) amplitudes_to_write = numpy.array(result['amplitudes'], dtype=numpy.float32) templates_to_write = numpy.array(result['templates'], dtype=numpy.uint32) spiketimes_file.write(spikes_to_write.tostring()) amplitudes_file.write(amplitudes_to_write.tostring()) templates_file.write(templates_to_write.tostring()) if collect_all: for temp, spike in zip(templates_to_write, spikes_to_write - g_offset): c_all_times[c_min_times[spike]:c_max_times[spike], neighbors[temp]] = False gspikes = numpy.where(numpy.sum(c_all_times, 1) > 0)[0] c_all_times = numpy.take(c_all_times, gspikes, axis=0) c_local_chunk = numpy.take(c_local_chunk, gspikes, axis=0) * c_all_times if sign_peaks == 'negative': bestlecs = numpy.argmin(c_local_chunk, 1) if matched_filter: threshs = -matched_thresholds_neg[bestlecs] else: threshs = -thresholds[bestlecs] idx = numpy.where(numpy.min(c_local_chunk, 1) < threshs)[0] elif sign_peaks == 'positive': bestlecs = numpy.argmax(c_local_chunk, 1) if matched_filter: threshs = matched_thresholds_pos[bestlecs] else: threshs = thresholds[bestlecs] idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0] elif sign_peaks == 'both': c_local_chunk = numpy.abs(c_local_chunk) bestlecs = numpy.argmax(c_local_chunk, 1) if matched_filter: threshs = numpy.minimum(matched_thresholds_neg[bestlecs], matched_thresholds_pos[bestlecs]) else: threshs = thresholds[bestlecs] idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0] else: raise ValueError("Unexpected value %s" % sign_peaks) gspikes = numpy.take(gspikes, idx) bestlecs = numpy.take(bestlecs, idx) gspikes_to_write = numpy.array(gspikes + g_offset, dtype=numpy.uint32) gtemplates_to_write = numpy.array(bestlecs, dtype=numpy.uint32) garbage_times_file.write(gspikes_to_write.tostring()) garbage_temp_file.write(gtemplates_to_write.tostring()) if debug: # Write debug data to debug files. for field_label, field_dtype, field_file in [ ('chunk_nbs', numpy.uint32, chunk_nbs_debug_file), ('iteration_nbs', numpy.uint32, iteration_nbs_debug_file), ('peak_nbs', numpy.uint32, peak_nbs_debug_file), ('peak_local_time_steps', numpy.uint32, peak_local_time_steps_debug_file), ('peak_time_steps', numpy.uint32, peak_time_steps_debug_file), ('peak_scalar_products', numpy.float32, peak_scalar_products_debug_file), ('peak_solved_flags', numpy.float32, peak_solved_flags_debug_file), ('template_nbs', numpy.uint32, template_nbs_debug_file), ('success_flags', numpy.bool, success_flags_debug_file), ]: field_to_write = numpy.array(result_debug[field_label], dtype=field_dtype) field_file.write(field_to_write.tostring()) if full_gpu: del b, data sys.stderr.flush() spiketimes_file.flush() os.fsync(spiketimes_file.fileno()) spiketimes_file.close() amplitudes_file.flush() os.fsync(amplitudes_file.fileno()) amplitudes_file.close() templates_file.flush() os.fsync(templates_file.fileno()) templates_file.close() if collect_all: garbage_temp_file.flush() os.fsync(garbage_temp_file.fileno()) garbage_temp_file.close() garbage_times_file.flush() os.fsync(garbage_times_file.fileno()) garbage_times_file.close() if debug: # Close debug files. for field_file in [ chunk_nbs_debug_file, iteration_nbs_debug_file, peak_nbs_debug_file, peak_local_time_steps_debug_file, peak_time_steps_debug_file, peak_scalar_products_debug_file, peak_solved_flags_debug_file, template_nbs_debug_file, success_flags_debug_file, ]: field_file.flush() os.fsync(field_file.fileno()) field_file.close() comm.Barrier() if comm.rank == 0: io.collect_data(comm.size, params, erase=True) data_file.close()
def main(params, nb_cpu, nb_gpu, use_gpu): # Part 1: Whitening numpy.random.seed(420) # params = detect_memory(params) _ = init_logging(params.logfile) logger = logging.getLogger('circus.whitening') ################################################################# data_file = params.data_file N_e = params.getint('data', 'N_e') hdf5_compress = params.getboolean('data', 'hdf5_compress') N_total = params.nb_channels N_t = params.getint('detection', 'N_t') dist_peaks = params.getint('detection', 'dist_peaks') template_shift = params.getint('detection', 'template_shift') file_out_suff = params.get('data', 'file_out_suff') spike_thresh = params.getfloat('detection', 'spike_thresh') spike_width = params.getfloat('detection', 'spike_width') matched_filter = params.getboolean('detection', 'matched-filter') matched_thresh = params.getfloat('detection', 'matched_thresh') fudge = params.getfloat('whitening', 'fudge') sign_peaks = params.get('detection', 'peaks') do_temporal_whitening = params.getboolean('whitening', 'temporal') do_spatial_whitening = params.getboolean('whitening', 'spatial') ignore_spikes = params.getboolean('whitening', 'ignore_spikes') chunk_size = detect_memory(params, whitening=True) plot_path = os.path.join(params.get('data', 'file_out_suff'), 'plots') nodes, edges = get_nodes_and_edges(params) safety_time = params.getint('whitening', 'safety_time') safety_space = params.getboolean('whitening', 'safety_space') sort_waveforms = params.getboolean('whitening', 'sort_waveforms') nb_temp_white = min(max(20, comm.size), N_e) max_silence_1 = int(20 * params.rate // comm.size) max_silence_2 = 5000 inv_nodes = numpy.zeros(N_total, dtype=numpy.int32) inv_nodes[nodes] = numpy.arange(len(nodes)) jitter_range = params.getint('detection', 'jitter_range') template_shift_2 = template_shift + jitter_range use_hanning = params.getboolean('detection', 'hanning') rejection_threshold = params.getfloat('detection', 'rejection_threshold') noise_window = params.getint('detection', 'noise_time') data_file.open() ################################################################# if use_hanning: hanning_filter = numpy.hanning(N_t) if comm.rank == 0: print_and_log( ["Analyzing data to get whitening matrices and thresholds..."], 'default', logger) nodes_indices = {} for elec in numpy.arange(N_e): nodes_indices[elec] = inv_nodes[edges[nodes[elec]]] if use_gpu: import cudamat as cmt # # Need to properly handle multi GPU per MPI nodes? if nb_gpu > nb_cpu: gpu_id = int(comm.rank // nb_cpu) else: gpu_id = 0 cmt.cuda_set_device(gpu_id) cmt.init() cmt.cuda_sync_threads() nb_chunks, last_chunk_len = data_file.analyze(chunk_size) if nb_chunks < comm.size: res = io.data_stats(params, show=False) chunk_size = int(res * params.rate // comm.size) if comm.rank == 0: print_and_log( ["Too much cores, automatically resizing the data chunks"], 'debug', logger) nb_chunks, last_chunk_len = data_file.analyze(chunk_size) # I guess this is more relevant, to take signals from all over the recordings. if nb_chunks > comm.size: all_chunks = numpy.random.permutation( numpy.arange(nb_chunks - 1, dtype=numpy.int32)) else: all_chunks = numpy.random.permutation( numpy.arange(nb_chunks, dtype=numpy.int32)) all_electrodes = numpy.random.permutation(N_e) numpy.random.seed(comm.rank) for gidx in [all_chunks[comm.rank]]: # print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..." local_chunk, t_offset = data_file.get_data(gidx, chunk_size, nodes=nodes) local_shape = len(local_chunk) # print "Node", comm.rank, "computes the median absolute deviations in a random chunk" thresholds = numpy.zeros(N_e, dtype=numpy.float32) for i in range(N_e): u = numpy.median(local_chunk[:, i], 0) thresholds[i] = numpy.median(numpy.abs(local_chunk[:, i] - u), 0) gdata = gather_array(thresholds, comm) if comm.rank == 0: gdata = gdata.reshape((comm.size, N_e)) thresholds = numpy.mean(gdata, 0) bfile = h5py.File(file_out_suff + '.basis.hdf5', 'w', libver='earliest') io.write_datasets(bfile, ['thresholds'], {'thresholds': thresholds}, compression=hdf5_compress) bfile.close() comm.Barrier() thresholds = io.load_data(params, 'thresholds') local_borders = (template_shift, local_shape - template_shift) found_peaktimes = [] if ignore_spikes: # Extracting the peaks. local_peaktimes = [np.empty(0, dtype=numpy.uint32)] for i in range(N_e): peaktimes = scipy.signal.find_peaks(numpy.abs(local_chunk[:, i]), height=thresholds[i], width=spike_width, wlen=N_t)[0] peaktimes = peaktimes.astype(numpy.uint32) # print "Removing the useless borders..." idx = (peaktimes >= local_borders[0]) & (peaktimes < local_borders[1]) peaktimes = numpy.compress(idx, peaktimes) found_peaktimes.append(peaktimes) else: for i in range(N_e): found_peaktimes.append(numpy.zeros(0, dtype=numpy.uint32)) all_peaktimes = numpy.concatenate(found_peaktimes) local_peaktimes = numpy.unique(all_peaktimes) if len(local_peaktimes) > 0: diff_times = local_peaktimes[-1] - local_peaktimes[0] all_times = numpy.zeros((N_e, diff_times + 1), dtype=numpy.bool) padded_peaks = (local_peaktimes - local_peaktimes[0]).astype( numpy.int32) min_times = numpy.maximum(padded_peaks - safety_time, 0) max_times = numpy.minimum(padded_peaks + safety_time + 1, diff_times + 1) test_extremas = numpy.zeros((N_e, diff_times + 1), dtype=numpy.bool) for i in range(N_e): test_extremas[i, found_peaktimes[i] - local_peaktimes[0]] = True argmax_peak = numpy.random.permutation( numpy.arange(len(local_peaktimes))) all_idx = numpy.take(local_peaktimes, argmax_peak) # print "Selection of the peaks with spatio-temporal masks..." for idx, peak in zip(argmax_peak, all_idx): all_elecs = numpy.where(test_extremas[:, peak - local_peaktimes[0]])[0] data = local_chunk[peak, all_elecs] elec = all_elecs[numpy.argmax(numpy.abs(data))] indices = nodes_indices[elec] if safety_space: all_times[indices, min_times[idx]:max_times[idx]] = True else: all_times[elec, min_times[idx]:max_times[idx]] = True else: all_times = numpy.zeros((N_e, len(local_chunk)), dtype=numpy.bool) if do_temporal_whitening: local_res_temp = [] for elec in all_electrodes[numpy.arange(comm.rank, nb_temp_white, comm.size)]: res = numpy.zeros((0, N_t), dtype=numpy.float32) scount = 0 indices = nodes_indices[elec] all_times_elec = numpy.any(numpy.take(all_times, indices, axis=0), 0) esubset = numpy.where(all_times_elec == False)[0] bound = len(esubset) - N_t while (scount < bound) and (len(res) < max_silence_2): myslice = esubset[scount:scount + N_t] if numpy.all((myslice - esubset[scount]) == numpy.arange(N_t)): scount += N_t res = numpy.vstack((res, local_chunk[myslice, elec])) else: scount += 1 if len(res) > 5: local_res_temp += [numpy.cov(res.T)] nb_elecs = numpy.array([len(local_res_temp)], dtype=numpy.float32) local_res_temp = numpy.array(local_res_temp, dtype=numpy.float32) if len(local_res_temp) == 0: local_res_temp = numpy.zeros(0, dtype=numpy.float32) else: local_res_temp = numpy.sum(local_res_temp, 0) all_res_temp = gather_array(local_res_temp.ravel(), comm, 0, 1) all_elecs = gather_array(nb_elecs, comm, 0, 1) if do_spatial_whitening: local_res_spac = numpy.zeros((N_e, N_e), dtype=numpy.float32) local_silences = [] for elec in numpy.arange(comm.rank, N_e, comm.size): indices = nodes_indices[elec] all_times_elec = numpy.any(numpy.take(all_times, indices, axis=0), 0) esubset = numpy.where(all_times_elec == False)[0] local_data = local_chunk[esubset][:, indices] local_whitening = get_whitening_matrix( local_data, fudge=fudge).astype(numpy.float32) pos = numpy.where(elec == indices)[0] local_res_spac[elec, indices] = local_whitening[pos] local_silences += [len(esubset)] all_res_spac = gather_array(local_res_spac.ravel(), comm, 0, 1) all_silences = gather_array( numpy.array(local_silences, dtype=numpy.int32), comm, 0, 1, 'uint32') if comm.rank == 0: to_write = {} if do_temporal_whitening: try: nb_silences = numpy.sum(all_elecs > 0) all_res_temp = all_res_temp.reshape((nb_silences, N_t**2)) except Exception: print_and_log([ "No silent periods detected: something wrong with the parameters?" ], 'error', logger) all_res_temp = numpy.sum(all_res_temp, 0) all_res_temp = all_res_temp.reshape( (N_t, N_t)) / numpy.sum(all_elecs) temporal_whitening = get_whitening_matrix( all_res_temp.astype(numpy.double), fudge=1e-3)[template_shift].astype(numpy.float32) temporal_whitening /= temporal_whitening.sum() to_write['temporal'] = temporal_whitening have_nans = numpy.sum(numpy.isnan(temporal_whitening)) if have_nans > 0: temporal_whitening = numpy.zeros(N_t, dtype=numpy.float32) temporal_whitening[N_t // 2] = 1 to_write['temporal'] = temporal_whitening print_and_log( ["Disabling temporal whitening because of NaNs found"], 'info', logger) if do_spatial_whitening: all_res_spac = all_res_spac.reshape(comm.size, N_e, N_e) spatial_whitening = numpy.sum(all_res_spac, 0) to_write['spatial'] = spatial_whitening if ignore_spikes: print_and_log([ "Found %gs without spikes to compute the whitening matrix..." % (numpy.mean(all_silences) / params.rate) ], 'default', logger) else: print_and_log([ "Found %gs to compute the whitening matrix..." % (numpy.mean(all_silences) / params.rate) ], 'default', logger) have_nans = numpy.sum(numpy.isnan(spatial_whitening)) if have_nans > 0: spatial_whitening = numpy.eye(spatial_whitening.shape[0], dtype=numpy.float32) to_write['spatial'] = spatial_whitening print_and_log( ["Disabling spatial whitening because of NaNs found"], 'info', logger) bfile = h5py.File(file_out_suff + '.basis.hdf5', 'r+', libver='earliest') io.write_datasets(bfile, list(to_write.keys()), to_write, compression=hdf5_compress) bfile.close() comm.Barrier() if do_spatial_whitening or do_temporal_whitening: if comm.rank == 0: print_and_log( ["Because of whitening, need to recompute the thresholds..."], 'default', logger) if do_spatial_whitening: spatial_whitening = io.load_data(params, 'spatial_whitening') if use_gpu: spatial_whitening = cmt.CUDAMatrix(spatial_whitening, copy_on_host=False) if do_temporal_whitening: temporal_whitening = io.load_data(params, 'temporal_whitening') for gidx in [all_chunks[comm.rank]]: local_chunk, t_offset = data_file.get_data(gidx, chunk_size, nodes=nodes) local_shape = len(local_chunk) if do_spatial_whitening: if use_gpu: local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False) local_chunk = local_chunk.dot(spatial_whitening).asarray() else: local_chunk = numpy.dot(local_chunk, spatial_whitening) if do_temporal_whitening: local_chunk = scipy.ndimage.filters.convolve1d( local_chunk, temporal_whitening, axis=0, mode='constant') thresholds = numpy.zeros(N_e, dtype=numpy.float32) for i in range(N_e): u = numpy.median(local_chunk[:, i], 0) thresholds[i] = numpy.median(numpy.abs(local_chunk[:, i] - u), 0) gdata = gather_array(thresholds, comm) if comm.rank == 0: gdata = gdata.reshape((comm.size, N_e)) thresholds = numpy.mean(gdata, 0) bfile = h5py.File(file_out_suff + '.basis.hdf5', 'r+', libver='earliest') bfile.pop('thresholds') io.write_datasets(bfile, ['thresholds'], {'thresholds': thresholds}, compression=hdf5_compress) bfile.close() comm.Barrier() # if comm.rank == 0: # if not os.path.exists(plot_path): # os.makedirs(plot_path) # N_elec = min(int(numpy.sqrt(data_file.N_e)), 5) # plot.view_fit(filename, t_start=0, t_stop=1, fit_on=False, square=True, # n_elec=N_elec, save=[plot_path, 'electrodes']) # Part 2: Basis numpy.random.seed(422) SHARED_MEMORY = get_shared_memory_flag(params) ################################################################# file_out = params.get('data', 'file_out') alignment = params.getboolean('detection', 'alignment') over_factor = params.getint('detection', 'oversampling_factor') nb_jitter = params.getint('detection', 'nb_jitter') spike_thresh = params.getfloat('detection', 'spike_thresh') nodes, edges = get_nodes_and_edges(params) _, positions = get_nodes_and_positions(params) do_temporal_whitening = params.getboolean('whitening', 'temporal') do_spatial_whitening = params.getboolean('whitening', 'spatial') use_barycenter = params.getboolean('detection', 'use_barycenter') if matched_filter: chunk_size = detect_memory(params, whitening=True) else: chunk_size = detect_memory(params) safety_time = params.getint('whitening', 'safety_time') max_elts_elec = params.getint('whitening', 'max_elts') output_dim = params.getfloat('whitening', 'output_dim') inv_nodes = numpy.zeros(N_total, dtype=numpy.int32) inv_nodes[nodes] = numpy.arange(len(nodes)) smoothing_factor = params.getfloat('detection', 'smoothing_factor') if sign_peaks == 'both': max_elts_elec *= 2 nb_elts = int( params.getfloat('whitening', 'nb_elts') * N_e * max_elts_elec) weird_thresh = params.get('detection', 'weird_thresh') if weird_thresh != '': ignore_artefacts = True weird_thresh = io.load_data(params, 'weird-thresholds') else: ignore_artefacts = False ignore_dead_times = params.getboolean('triggers', 'ignore_times') if ignore_dead_times: if SHARED_MEMORY: all_dead_times, mpi_memory_3 = get_dead_times(params) else: all_dead_times = get_dead_times(params) data_file.open() ################################################################# if comm.rank == 0: print_and_log(["Searching spikes to construct the PCA basis..."], 'default', logger) nb_chunks, last_chunk_len = data_file.analyze(chunk_size) if nb_chunks < comm.size: res = io.data_stats(params, show=False) chunk_size = int(res * params.rate // comm.size) if comm.rank == 0: print_and_log( ["Too much cores, automatically resizing the data chunks"], 'debug', logger) nb_chunks, last_chunk_len = data_file.analyze(chunk_size) groups = {} for i in range(N_e): groups[i] = 0 # I guess this is more relevant, to take signals from all over the recordings all_chunks = numpy.random.permutation( numpy.arange(nb_chunks, dtype=numpy.int32)) max_elts_elec //= comm.size nb_elts //= comm.size elt_count_pos = 0 elt_count_neg = 0 if sign_peaks in ['positive', 'both']: times_pos = numpy.zeros(nb_elts, dtype=numpy.int32) electrodes_pos = numpy.zeros(nb_elts, dtype=numpy.int32) extremum_pos = numpy.zeros(nb_elts, dtype=numpy.float32) elts_pos = numpy.zeros((N_t, nb_elts), dtype=numpy.float32) if sign_peaks in ['negative', 'both']: times_neg = numpy.zeros(nb_elts, dtype=numpy.int32) electrodes_neg = numpy.zeros(nb_elts, dtype=numpy.int32) extremum_neg = numpy.zeros(nb_elts, dtype=numpy.float32) elts_neg = numpy.zeros((N_t, nb_elts), dtype=numpy.float32) thresholds = io.load_data(params, 'thresholds') mads = io.load_data(params, 'mads') stds = io.load_data(params, 'stds') if alignment: cdata = numpy.linspace(-jitter_range, +jitter_range, nb_jitter) xdata = numpy.arange(-template_shift_2, template_shift_2 + 1) xoff = len(cdata) / 2.0 snippet_duration = template_shift_2 m_size = 2 * template_shift_2 + 1 align_factor = m_size local_factors = align_factor * ((smoothing_factor * mads)**2) else: snippet_duration = template_shift xdata = numpy.arange(-template_shift, template_shift + 1) if rejection_threshold > 0: reject_noise = True noise_levels = stds * (2 * noise_window + 1) else: reject_noise = False to_explore = all_chunks[comm.rank::comm.size] upper_bounds = max_elts_elec if comm.rank == 0: to_explore = get_tqdm_progressbar(params, to_explore) for gcount, gidx in enumerate(to_explore): if (elt_count_pos + elt_count_neg) < nb_elts: # print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..." local_chunk, t_offset = data_file.get_data(gidx, chunk_size, nodes=nodes) local_shape = len(local_chunk) if do_spatial_whitening: if use_gpu: local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False) local_chunk = local_chunk.dot(spatial_whitening).asarray() else: local_chunk = numpy.dot(local_chunk, spatial_whitening) if do_temporal_whitening: local_chunk = scipy.ndimage.filters.convolve1d( local_chunk, temporal_whitening, axis=0, mode='constant') local_borders = (snippet_duration, local_shape - snippet_duration) if ignore_dead_times: dead_indices = numpy.searchsorted( all_dead_times, [t_offset, t_offset + local_shape]) # Extracting the peaks. all_peaktimes = [numpy.empty(0, dtype=numpy.uint32)] found_peaktimes = [] found_peak_amplitudes = [] for i in range(N_e): height = thresholds[i] if sign_peaks == 'negative': peaktimes = scipy.signal.find_peaks(-local_chunk[:, i], height=height, distance=dist_peaks)[0] elif sign_peaks == 'positive': peaktimes = scipy.signal.find_peaks(local_chunk[:, i], height=height, distance=dist_peaks)[0] elif sign_peaks == 'both': peaktimes = scipy.signal.find_peaks(numpy.abs( local_chunk[:, i]), height=height, distance=dist_peaks)[0] else: peaktimes = numpy.empty(0, dtype=numpy.uint32) if ignore_artefacts: artetimes = scipy.signal.find_peaks( numpy.abs(local_chunk[:, i]), height=weird_thresh[i])[0] to_keep = numpy.logical_not( numpy.in1d(peaktimes, artetimes)) peaktimes = peaktimes[to_keep] idx = (peaktimes >= local_borders[0]) & (peaktimes < local_borders[1]) peaktimes = peaktimes[idx] if ignore_dead_times: if dead_indices[0] != dead_indices[1]: is_included = numpy.in1d( peaktimes + t_offset, all_dead_times[dead_indices[0]:dead_indices[1]]) peaktimes = peaktimes[~is_included] peaktimes = peaktimes.astype(numpy.uint32) found_peaktimes.append(peaktimes) peak_amplitudes = local_chunk[peaktimes, i] found_peak_amplitudes.append(peak_amplitudes) all_peaktimes = numpy.concatenate( found_peaktimes) # i.e. concatenate once for efficiency all_peak_amplitudes = numpy.concatenate(found_peak_amplitudes) local_peaktimes, local_indices = numpy.unique(all_peaktimes, return_inverse=True) if len(local_peaktimes) > 0: diff_times = (local_peaktimes[-1] - local_peaktimes[0]) + 1 all_times = numpy.zeros((N_e, diff_times), dtype=numpy.bool) padded_peaks = (local_peaktimes - local_peaktimes[0]).astype( numpy.int32) min_times = numpy.maximum(padded_peaks - safety_time, 0) max_times = numpy.minimum(padded_peaks + safety_time + 1, diff_times + 1) test_extremas = numpy.zeros((N_e, diff_times + 1), dtype=numpy.bool) for i in range(N_e): test_extremas[i, found_peaktimes[i] - local_peaktimes[0]] = True # Consider the peaks by decreasing extremum. if sort_waveforms: order = numpy.argsort(-np.abs(all_peak_amplitudes)) all_idx = numpy.take(all_peaktimes, order) argmax_peak = local_indices[order] else: n_times = len(all_peaktimes) shuffling = numpy.random.permutation(numpy.arange(n_times)) all_idx = numpy.take(all_peaktimes, shuffling) argmax_peak = local_indices[shuffling] # print "Selection of the peaks with spatio-temporal masks..." for midx, peak in zip(argmax_peak, all_idx): if (elt_count_neg + elt_count_pos) == nb_elts: break all_elecs = numpy.where( test_extremas[:, peak - local_peaktimes[0]])[0] data = local_chunk[peak, all_elecs] #target_area = test_extremas[:, min_times[midx]:max_times[midx]].sum(1) #all_elecs = numpy.where(target_area)[0] #data = local_chunk[peak, all_elecs] if sign_peaks == 'negative': if N_e > 1: if use_barycenter: weighed_position = data[:, numpy. newaxis] * positions[ all_elecs] barycenter = weighed_position.sum( 0) / data.sum() elec = numpy.argmin( numpy.linalg.norm(barycenter - positions[all_elecs], axis=1)) else: elec = numpy.argmin(data) else: elec = 0 negative_peak = True elif sign_peaks == 'positive': if N_e > 1: if use_barycenter: weighed_position = data[:, numpy. newaxis] * positions[ all_elecs] barycenter = weighed_position.sum( 0) / data.sum() elec = numpy.argmax( numpy.linalg.norm(barycenter - positions[all_elecs], axis=1)) else: elec = numpy.argmax(data) else: elec = 0 negative_peak = False elif sign_peaks == 'both': if N_e == 1: if data < 0: negative_peak = True elif data > 0: negative_peak = False elec = 0 else: if numpy.abs(numpy.max(data)) > numpy.abs( numpy.min(data)): elec = numpy.argmax(data) negative_peak = False else: elec = numpy.argmin(data) negative_peak = True elec = all_elecs[elec] if groups[elec] < upper_bounds: indices = nodes_indices[elec] myslice = all_times[indices, min_times[midx]:max_times[midx]] if not myslice.any(): sub_mat = local_chunk[peak - snippet_duration:peak + snippet_duration + 1, elec] if reject_noise: slice_window = sub_mat[ snippet_duration - noise_window:snippet_duration + noise_window + 1] value = numpy.linalg.norm( slice_window) / noise_levels[elec] is_noise = value < rejection_threshold else: is_noise = False if not is_noise: extrema = sub_mat[snippet_duration] if alignment: smoothed = True try: f = scipy.interpolate.UnivariateSpline( xdata, sub_mat, s=local_factors[elec], k=3) except Exception: smoothed = False f = scipy.interpolate.UnivariateSpline( xdata, sub_mat, k=3, s=0) if negative_peak: rmin = (numpy.argmin(f(cdata)) - xoff) / over_factor else: rmin = (numpy.argmax(f(cdata)) - xoff) / over_factor ddata = numpy.linspace( rmin - template_shift, rmin + template_shift, N_t) if smoothed: f = scipy.interpolate.UnivariateSpline( xdata, sub_mat, s=local_factors[elec], k=3) else: f = scipy.interpolate.UnivariateSpline( xdata, sub_mat, s=0, k=3) sub_mat = f(ddata).astype(numpy.float32) if negative_peak: times_neg[elt_count_neg] = peak + t_offset electrodes_neg[elt_count_neg] = elec extremum_neg[elt_count_neg] = extrema elts_neg[:, elt_count_neg] = sub_mat elt_count_neg += 1 else: times_pos[elt_count_pos] = peak + t_offset electrodes_pos[elt_count_pos] = elec extremum_pos[elt_count_pos] = extrema elts_pos[:, elt_count_pos] = sub_mat elt_count_pos += 1 groups[elec] += 1 all_times[ indices, min_times[midx]:max_times[midx]] = True test_extremas[elec, peak - local_peaktimes[0]] = False sys.stderr.flush() print_and_log([ "Node %d has collected %d waveforms" % (comm.rank, elt_count_pos + elt_count_neg) ], 'debug', logger) if sign_peaks in ['negative', 'both']: times_neg = gather_array(times_neg[:elt_count_neg], comm, 0, 1, dtype='int32') electrodes_neg = gather_array(electrodes_neg[:elt_count_neg], comm, 0, 1, dtype='int32') extremum_neg = gather_array(extremum_neg[:elt_count_neg], comm, 0, 1) gdata_neg = gather_array(elts_neg[:, :elt_count_neg].T, comm, 0, 1) if sign_peaks in ['positive', 'both']: times_pos = gather_array(times_pos[:elt_count_pos], comm, 0, 1, dtype='int32') electrodes_pos = gather_array(electrodes_pos[:elt_count_pos], comm, 0, 1, dtype='int32') extremum_pos = gather_array(extremum_pos[:elt_count_pos], comm, 0, 1) gdata_pos = gather_array(elts_pos[:, :elt_count_pos].T, comm, 0, 1) nb_waveforms = 0 if comm.rank == 0: # DO PCA on elts and store the basis obtained. if sign_peaks in ['negative', 'both']: nb_waveforms += gdata_neg.shape[0] if sign_peaks in ['positive', 'both']: nb_waveforms += gdata_pos.shape[0] nb_waveforms = all_gather_array( numpy.array([nb_waveforms], dtype=numpy.float32), comm, 0)[0] if comm.rank == 0: print_and_log([ "Found %d waveforms over %d requested" % (nb_waveforms, int(nb_elts * comm.size)) ], 'default', logger) if nb_waveforms == 0: print_and_log( ['No waveforms found! Are the data properly loaded??'], 'error', logger) if nb_waveforms == 0: sys.exit(0) if comm.rank == 0: res = {} pca = None pca_pos = None pca_neg = None warning_n_t = False if sign_peaks in ['negative', 'both']: res['times'] = times_neg res['electrodes'] = electrodes_neg res['extremum'] = extremum_neg if len(gdata_neg) > 0: pca = PCA(output_dim) if use_hanning: pca.fit(gdata_neg * hanning_filter) else: pca.fit(gdata_neg) res['proj'] = pca.components_.T.astype(numpy.float32) pca_neg = numpy.sum(pca.explained_variance_ratio_) else: res['proj'] = numpy.identity(int(output_dim), dtype=numpy.float32) res['rec'] = res['proj'].T res['waveform'] = numpy.median(gdata_neg, 0) # dispersion = numpy.std(gdata_neg, 0) / numpy.median(stds) # ratio = numpy.sum(dispersion > 1.1) / float(len(dispersion)) # if ratio < 0.25: # print_and_log(["Time window N_t in [detection] seems too large!"], 'info', logger) # warning_n_t = True # elif ratio == 1: # print_and_log(["Time window N_t in [detection] seems too small!"], 'info', logger) # warning_n_t = True idx = numpy.random.permutation(numpy.arange( gdata_neg.shape[0]))[:2500] res['waveforms'] = gdata_neg[idx, :] if sign_peaks in ['positive', 'both']: res['times_pos'] = times_pos res['electrodes_pos'] = electrodes_pos res['extremum_pos'] = extremum_pos if len(gdata_pos) > 0: pca = PCA(output_dim) if use_hanning: pca.fit(gdata_pos * hanning_filter) else: pca.fit(gdata_pos) res['proj_pos'] = pca.components_.T.astype(numpy.float32) pca_pos = numpy.sum(pca.explained_variance_ratio_) else: res['proj_pos'] = numpy.identity(int(output_dim), dtype=numpy.float32) res['rec_pos'] = res['proj_pos'].T res['waveform_pos'] = numpy.median(gdata_pos, 0) # dispersion = numpy.std(gdata_pos, 0) / numpy.median(stds) # ratio = numpy.sum(dispersion > 1.1) / float(len(dispersion)) # if ratio < 0.25 and not warning_n_t: # print_and_log(["Time window N_t in [detection] seems too large!"], 'info', logger) # elif ratio == 1 and not warning_n_t: # print_and_log(["Time window N_t in [detection] seems too small!"], 'info', logger) idx = numpy.random.permutation(numpy.arange( gdata_pos.shape[0]))[:2500] res['waveforms_pos'] = gdata_pos[idx, :] bfile = h5py.File(file_out_suff + '.basis.hdf5', 'r+', libver='earliest') io.write_datasets(bfile, list(res.keys()), res, compression=hdf5_compress) if sign_peaks == 'positive': print_and_log([ "A basis with %s dimensions has been built" % res['proj_pos'].shape[1] ], 'info', logger) elif sign_peaks == 'negative': print_and_log([ "A basis with %s dimensions has been built" % res['proj'].shape[1] ], 'info', logger) elif sign_peaks == 'both': print_and_log([ "Two basis with %s dimensions has been built" % res['proj'].shape[1] ], 'debug', logger) if pca_pos is not None: print_and_log([ "The percentage of variance explained is %s for positive spikes" % pca_pos ], 'debug', logger) if pca_neg is not None: print_and_log([ "The percentage of variance explained is %s for negative spikes" % pca_neg ], 'debug', logger) bfile.close() comm.Barrier() if matched_filter: if comm.rank == 0: print_and_log([ "Because of matched filters, need to recompute the thresholds..." ], 'default', logger) if do_spatial_whitening: spatial_whitening = io.load_data(params, 'spatial_whitening') if use_gpu: spatial_whitening = cmt.CUDAMatrix(spatial_whitening, copy_on_host=False) if do_temporal_whitening: temporal_whitening = io.load_data(params, 'temporal_whitening') if sign_peaks in ['negative', 'both']: waveform_neg = io.load_data(params, 'waveform')[::-1] waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) * len(waveform_neg)) if sign_peaks in ['positive', 'both']: waveform_pos = io.load_data(params, 'waveform-pos')[::-1] waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) * len(waveform_pos)) for gidx in [all_chunks[comm.rank]]: local_chunk, t_offset = data_file.get_data(gidx, chunk_size, nodes=nodes) local_shape = len(local_chunk) if do_spatial_whitening: if use_gpu: local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False) local_chunk = local_chunk.dot(spatial_whitening).asarray() else: local_chunk = numpy.dot(local_chunk, spatial_whitening) if do_temporal_whitening: local_chunk = scipy.ndimage.filters.convolve1d( local_chunk, temporal_whitening, axis=0, mode='constant') local_chunk /= thresholds if sign_peaks in ['negative', 'both']: tmp_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_neg, axis=0, mode='constant') thresholds = numpy.zeros(N_e, dtype=numpy.float32) for i in range(N_e): u = numpy.median(tmp_chunk[:, i], 0) thresholds[i] = numpy.median( numpy.abs(tmp_chunk[:, i] - u), 0) gdata = gather_array(thresholds, comm) if comm.rank == 0: gdata = gdata.reshape((comm.size, N_e)) thresholds = numpy.mean(gdata, 0) bfile = h5py.File(file_out_suff + '.basis.hdf5', 'r+', libver='earliest') io.write_datasets(bfile, ['matched_thresholds'], {'matched_thresholds': thresholds}, compression=hdf5_compress) bfile.close() comm.Barrier() if sign_peaks in ['positive', 'both']: tmp_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_pos, axis=0, mode='constant') thresholds = numpy.zeros(N_e, dtype=numpy.float32) for i in range(N_e): u = numpy.median(tmp_chunk[:, i], 0) thresholds[i] = numpy.median( numpy.abs(tmp_chunk[:, i] - u), 0) gdata = gather_array(thresholds, comm) if comm.rank == 0: gdata = gdata.reshape((comm.size, N_e)) thresholds = numpy.mean(gdata, 0) bfile = h5py.File(file_out_suff + '.basis.hdf5', 'r+', libver='earliest') io.write_datasets(bfile, ['matched_thresholds_pos'], {'matched_thresholds_pos': thresholds}, compression=hdf5_compress) bfile.close() comm.Barrier() data_file.close() if SHARED_MEMORY and ignore_dead_times: mpi_memory_3.Free()
def main(params, nb_cpu, nb_gpu, use_gpu): ################################################################# # params = detect_memory(params) _ = init_logging(params.logfile) SHARED_MEMORY = get_shared_memory_flag(params) logger = logging.getLogger('circus.fitting') data_file = params.data_file N_e = params.getint('data', 'N_e') N_total = params.nb_channels N_t = params.getint('detection', 'N_t') template_shift = params.getint('detection', 'template_shift') file_out = params.get('data', 'file_out') file_out_suff = params.get('data', 'file_out_suff') sign_peaks = params.get('detection', 'peaks') dist_peaks = params.getint('detection', 'dist_peaks') matched_filter = params.getboolean('detection', 'matched-filter') spike_thresh = params.getfloat('detection', 'spike_thresh') spike_width = params.getfloat('detection', 'spike_width') do_temporal_whitening = params.getboolean('whitening', 'temporal') do_spatial_whitening = params.getboolean('whitening', 'spatial') chunk_size = detect_memory(params) gpu_only = params.getboolean('fitting', 'gpu_only') nodes, edges = get_nodes_and_edges(params) tmp_limits = params.get('fitting', 'amp_limits').replace('(', '').replace(')', '').split(',') tmp_limits = map(float, tmp_limits) amp_auto = params.getboolean('fitting', 'amp_auto') nb_chances = params.getint('fitting', 'nb_chances') max_chunk = params.getfloat('fitting', 'max_chunk') noise_thr = params.getfloat('clustering', 'noise_thr') collect_all = params.getboolean('fitting', 'collect_all') ignore_dead_times = params.getboolean('triggers', 'ignore_times') inv_nodes = numpy.zeros(N_total, dtype=numpy.int32) inv_nodes[nodes] = numpy.arange(len(nodes)) data_file.open() ################################################################# if use_gpu: import cudamat as cmt # # Need to properly handle multi GPU per MPI nodes? if nb_gpu > nb_cpu: gpu_id = int(comm.rank // nb_cpu) else: gpu_id = 0 cmt.cuda_set_device(gpu_id) cmt.init() cmt.cuda_sync_threads() if matched_filter: if sign_peaks in ['negative', 'both']: waveform_neg = io.load_data(params, 'waveform')[::-1] waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) * len(waveform_neg)) matched_tresholds_neg = io.load_data(params, 'matched-thresholds') if sign_peaks in ['positive', 'both']: waveform_pos = io.load_data(params, 'waveform-pos')[::-1] waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) * len(waveform_pos)) matched_tresholds_pos = io.load_data(params, 'matched-thresholds-pos') if ignore_dead_times: all_dead_times = get_dead_times(params) thresholds = io.load_data(params, 'thresholds') comm.Barrier() if comm.rank == 0: print_and_log(["Extracting MUA activity..."], 'default', logger) purge(file_out_suff, '.data') if do_spatial_whitening: spatial_whitening = io.load_data(params, 'spatial_whitening') else: spatial_whitening = None # default assignment (PyCharm code inspection) if do_temporal_whitening: temporal_whitening = io.load_data(params, 'temporal_whitening') else: temporal_whitening = None # default assignment (PyCharm code inspection) nb_chunks, last_chunk_len = data_file.analyze(chunk_size) processed_chunks = int(min(nb_chunks, max_chunk)) comm.Barrier() spiketimes_file = open(file_out_suff + '.mua-%d.data' % comm.rank, 'wb') comm.Barrier() electrodes_file = open(file_out_suff + '.elec-%d.data' % comm.rank, 'wb') comm.Barrier() amplitudes_file = open(file_out_suff + '.amp-%d.data' % comm.rank, 'wb') comm.Barrier() if use_gpu and do_spatial_whitening: spatial_whitening = cmt.CUDAMatrix(spatial_whitening, copy_on_host=False) to_explore = range(comm.rank, processed_chunks, comm.size) if comm.rank == 0: to_explore = get_tqdm_progressbar(to_explore) for gcount, gidx in enumerate(to_explore): # print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..." # # We need to deal with the borders by taking chunks of size [0, chunck_size + template_shift]. is_first = data_file.is_first_chunk(gidx, nb_chunks) is_last = data_file.is_last_chunk(gidx, nb_chunks) if is_last: padding = (-dist_peaks, 0) elif is_first: padding = (0, dist_peaks) else: padding = (-dist_peaks, dist_peaks) result = {'spiketimes': [], 'amplitudes': [], 'templates': []} local_chunk, t_offset = data_file.get_data(gidx, chunk_size, padding, nodes=nodes) len_chunk = len(local_chunk) if do_spatial_whitening: if use_gpu: local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False) local_chunk = local_chunk.dot(spatial_whitening).asarray() else: local_chunk = numpy.dot(local_chunk, spatial_whitening) if do_temporal_whitening: local_chunk = scipy.ndimage.filters.convolve1d(local_chunk, temporal_whitening, axis=0, mode='constant') # print "Extracting the peaks..." local_peaktimes = [numpy.zeros(0, dtype=numpy.uint32)] local_elecs = [numpy.zeros(0, dtype=numpy.uint32)] local_amps = [numpy.zeros(0, dtype=numpy.float32)] if matched_filter: if sign_peaks in ['positive', 'both']: filter_chunk = scipy.ndimage.filters.convolve1d( local_chunk, waveform_pos, axis=0, mode='constant') for i in range(N_e): peaktimes = scipy.signal.find_peaks( filter_chunk[:, i], height=matched_tresholds_pos[i], width=spike_width, distance=dist_peaks, wlen=N_t)[0] local_peaktimes.append(peaktimes) local_elecs.append( i * numpy.ones(len(peaktimes), dtype='uint32')) local_amps.append(filter_chunk[peaktimes, i]) if sign_peaks in ['negative', 'both']: filter_chunk = scipy.ndimage.filters.convolve1d( local_chunk, waveform_neg, axis=0, mode='constant') for i in range(N_e): peaktimes = scipy.signal.find_peaks( filter_chunk[:, i], height=matched_tresholds_neg[i], width=spike_width, distance=dist_peaks, wlen=N_t)[0] local_peaktimes.append(peaktimes) local_elecs.append( i * numpy.ones(len(peaktimes), dtype='uint32')) local_amps.append(filter_chunk[peaktimes, i]) else: for i in range(N_e): if sign_peaks == 'negative': peaktimes = scipy.signal.find_peaks(-local_chunk[:, i], height=thresholds[i], width=spike_width, distance=dist_peaks, wlen=N_t)[0] elif sign_peaks == 'positive': peaktimes = scipy.signal.find_peaks(local_chunk[:, i], height=thresholds[i], width=spike_width, distance=dist_peaks, wlen=N_t)[0] elif sign_peaks == 'both': peaktimes = scipy.signal.find_peaks(numpy.abs( local_chunk[:, i]), height=thresholds[i], width=spike_width, distance=dist_peaks, wlen=N_t)[0] local_peaktimes.append(peaktimes) local_elecs.append(i * numpy.ones(len(peaktimes), dtype='uint32')) local_amps.append(local_chunk[peaktimes, i]) local_peaktimes = numpy.concatenate(local_peaktimes) local_elecs = numpy.concatenate(local_elecs) local_amps = numpy.concatenate(local_amps) g_offset = t_offset + padding[0] if ignore_dead_times: dead_indices = numpy.searchsorted( all_dead_times, [t_offset, t_offset + chunk_size]) if dead_indices[0] != dead_indices[1]: is_included = numpy.in1d( local_peaktimes + g_offset, all_dead_times[dead_indices[0]:dead_indices[1]]) local_peaktimes = local_peaktimes[~is_included] local_elecs = local_elecs[~is_included] local_amps = local_amps[~is_included] # print "Removing the useless borders..." local_borders = (dist_peaks, len_chunk - dist_peaks) idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes < local_borders[1]) local_peaktimes = numpy.compress(idx, local_peaktimes) + g_offset local_elecs = numpy.compress(idx, local_elecs) local_amps = numpy.compress(idx, local_amps) spiketimes_file.write(local_peaktimes.astype(numpy.uint32).tostring()) electrodes_file.write(local_elecs.tostring()) amplitudes_file.write(local_amps.tostring()) sys.stderr.flush() spiketimes_file.flush() os.fsync(spiketimes_file.fileno()) spiketimes_file.close() electrodes_file.flush() os.fsync(electrodes_file.fileno()) electrodes_file.close() amplitudes_file.flush() os.fsync(amplitudes_file.fileno()) amplitudes_file.close() comm.Barrier() if comm.rank == 0: io.collect_mua(comm.size, params, erase=True) data_file.close()
def main(params, nb_cpu, nb_gpu, use_gpu): ################################################################# # params = detect_memory(params) _ = init_logging(params.logfile) SHARED_MEMORY = get_shared_memory_flag(params) logger = logging.getLogger('circus.fitting') data_file = params.data_file n_e = params.getint('data', 'N_e') n_total = params.nb_channels n_t = params.getint('detection', 'N_t') template_shift = params.getint('detection', 'template_shift') # file_out = params.get('data', 'file_out') file_out_suff = params.get('data', 'file_out_suff') sign_peaks = params.get('detection', 'peaks') matched_filter = params.getboolean('detection', 'matched-filter') # spike_thresh = params.getfloat('detection', 'spike_thresh') ratio_thresh = params.getfloat('fitting', 'ratio_thresh') two_components = params.getboolean('fitting', 'two_components') sparse_threshold = params.getfloat('fitting', 'sparse_thresh') # spike_width = params.getfloat('detection', 'spike_width') # dist_peaks = params.getint('detection', 'dist_peaks') do_temporal_whitening = params.getboolean('whitening', 'temporal') do_spatial_whitening = params.getboolean('whitening', 'spatial') templates_normalization = params.getboolean('clustering', 'templates_normalization') # TODO test, switch, test! chunk_size = detect_memory(params, fitting=True) gpu_only = params.getboolean('fitting', 'gpu_only') nodes, edges = get_nodes_and_edges(params) tmp_limits = params.get('fitting', 'amp_limits').replace('(', '').replace(')', '').split(',') tmp_limits = [float(v) for v in tmp_limits] amp_auto = params.getboolean('fitting', 'amp_auto') max_chunk = params.getfloat('fitting', 'max_chunk') # noise_thr = params.getfloat('clustering', 'noise_thr') collect_all = params.getboolean('fitting', 'collect_all') min_second_component = params.getfloat('fitting', 'min_second_component') debug = params.getboolean('fitting', 'debug') ignore_dead_times = params.getboolean('triggers', 'ignore_times') inv_nodes = numpy.zeros(n_total, dtype=numpy.int32) inv_nodes[nodes] = numpy.arange(len(nodes)) data_file.open() supports = io.load_data(params, 'supports') low_channels_thr = params.getint('detection', 'low_channels_thr') median_channels = numpy.median(numpy.sum(supports, 1)) fixed_amplitudes = params.getboolean('clustering', 'fixed_amplitudes') weird_thresh = params.get('detection', 'weird_thresh') if weird_thresh != '': ignore_artefacts = True weird_thresh = io.load_data(params, 'weird-thresholds') else: ignore_artefacts = False if not fixed_amplitudes: nb_amp_bins = params.getint('clustering', 'nb_amp_bins') splits = np.linspace(0, params.data_file.duration, nb_amp_bins) interpolated_times = np.zeros(len(splits) - 1, dtype=numpy.float32) for count in range(0, len(splits) - 1): interpolated_times[count] = (splits[count] + splits[count + 1])/2 interpolated_times = numpy.concatenate(([0], interpolated_times, [params.data_file.duration])) nb_amp_times = len(splits) + 1 mse_error = params.getboolean('fitting', 'mse_error') if mse_error: stds = io.load_data(params, 'stds') stds_norm = numpy.linalg.norm(stds) # if median_channels < low_channels_thr: # normalization = False # if comm.rank == 0: # print_and_log(['Templates defined on few channels (%g), turning off normalization' %median_channels], 'debug', logger) ################################################################# if SHARED_MEMORY: templates, mpi_memory_1 = io.load_data_memshared(params, 'templates', normalize=templates_normalization, transpose=True, sparse_threshold=sparse_threshold) N_tm, x = templates.shape is_sparse = not isinstance(templates, numpy.ndarray) else: templates = io.load_data(params, 'templates') x, N_tm = templates.shape if N_tm > 0: sparsity = templates.nnz / (x * N_tm) is_sparse = sparsity < sparse_threshold else: is_sparse = True if not is_sparse: if comm.rank == 0: print_and_log(['Templates sparsity is low (%g): densified to speedup the algorithm' %sparsity], 'debug', logger) templates = templates.toarray() temp_2_shift = 2 * template_shift temp_3_shift = 3 * template_shift n_tm = N_tm // 2 n_scalar = n_e * n_t temp_window = numpy.arange(-template_shift, template_shift + 1) size_window = n_e * (2 * template_shift + 1) if not amp_auto: amp_limits = numpy.zeros((n_tm, 2)) amp_limits[:, 0] = tmp_limits[0] amp_limits[:, 1] = tmp_limits[1] else: amp_limits = io.load_data(params, 'limits') norm_templates = io.load_data(params, 'norm-templates') sub_norm_templates = n_scalar * norm_templates[:n_tm] if not templates_normalization: norm_templates_2 = (norm_templates ** 2.0) * n_scalar sub_norm_templates_2 = norm_templates_2[:n_tm] if not SHARED_MEMORY: # Normalize templates (if necessary). if templates_normalization: if is_sparse: for idx in range(templates.shape[1]): myslice = numpy.arange(templates.indptr[idx], templates.indptr[idx+1]) templates.data[myslice] /= norm_templates[idx] else: for idx in range(templates.shape[1]): templates[:, idx] /= norm_templates[idx] # Transpose templates. templates = templates.T maxoverlap = io.load_data(params, 'maxoverlap')/n_scalar similar = np.where(maxoverlap > 0.5) idx = similar[0] < similar[1] similar = similar[0][idx], similar[1][idx] nb_mixtures = len(similar[0]) waveform_neg = numpy.empty(0) # default assignment (for PyCharm code inspection) matched_thresholds_neg = None # default assignment (for PyCharm code inspection) waveform_pos = numpy.empty(0) # default assignment (for PyCharm code inspection) matched_thresholds_pos = None # default assignment (for PyCharm code inspection) if matched_filter: if sign_peaks in ['negative', 'both']: waveform_neg = io.load_data(params, 'waveform')[::-1] waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) * len(waveform_neg)) matched_thresholds_neg = io.load_data(params, 'matched-thresholds') if sign_peaks in ['positive', 'both']: waveform_pos = io.load_data(params, 'waveform-pos')[::-1] waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) * len(waveform_pos)) matched_thresholds_pos = io.load_data(params, 'matched-thresholds-pos') if ignore_dead_times: if SHARED_MEMORY: all_dead_times, mpi_memory_3 = get_dead_times(params) else: all_dead_times = get_dead_times(params) else: all_dead_times = None # default assignment (for PyCharm code inspection) thresholds = io.get_accurate_thresholds(params, ratio_thresh) neighbors = {} if collect_all: is_sparse = not isinstance(templates, numpy.ndarray) for i in range(0, n_tm): if is_sparse: tmp = templates[i, :].toarray().reshape(n_e, n_t) else: tmp = templates[i].reshape(n_e, n_t) if templates_normalization: tmp = tmp * norm_templates[i] neighbors[i] = numpy.where(numpy.sum(tmp, axis=1) != 0.0)[0] #N_tm, x = templates.shape #sparsity_factor = templates.nnz / (N_tm * x) #if sparsity_factor > sparse_threshold: # if comm.rank == 0: # print_and_log(['Templates are not sparse enough, we densify them for'], 'default', logger) # templates = templates.toarray() info_string = '' if comm.rank == 0: info_string = "using %d CPUs" % comm.size comm.Barrier() c_overlap = io.get_overlaps(params, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu) over_shape = c_overlap.get('over_shape')[:] n_over = int(numpy.sqrt(over_shape[0])) s_over = over_shape[1] s_center = s_over // 2 # # If the number of overlaps is different from templates, we need to recompute them. if n_over != N_tm: if comm.rank == 0: print_and_log(['Templates have been modified, recomputing the overlaps...'], 'default', logger) c_overlap = io.get_overlaps(params, erase=True, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu) over_shape = c_overlap.get('over_shape')[:] n_over = int(numpy.sqrt(over_shape[0])) s_over = over_shape[1] if SHARED_MEMORY: c_overs, mpi_memory_2 = io.load_data_memshared(params, 'overlaps') else: c_overs = io.load_data(params, 'overlaps') comm.Barrier() if n_tm == 0: if comm.rank == 0: print_and_log(["No templates present. Redo clustering?"], 'default', logger) sys.exit(0) if comm.rank == 0: print_and_log(["Here comes the SpyKING CIRCUS %s and %d templates..." % (info_string, n_tm)], 'default', logger) purge(file_out_suff, '.data') if do_spatial_whitening: spatial_whitening = io.load_data(params, 'spatial_whitening') else: spatial_whitening = None # default assignment (for PyCharm code inspection) if do_temporal_whitening: temporal_whitening = io.load_data(params, 'temporal_whitening') else: temporal_whitening = None # default assignment (for PyCharm code inspection) nb_chunks, last_chunk_len = data_file.analyze(chunk_size) processed_chunks = int(min(nb_chunks, max_chunk)) comm.Barrier() spiketimes_file = open(file_out_suff + '.spiketimes-%d.data' % comm.rank, 'wb') comm.Barrier() amplitudes_file = open(file_out_suff + '.amplitudes-%d.data' % comm.rank, 'wb') comm.Barrier() templates_file = open(file_out_suff + '.templates-%d.data' % comm.rank, 'wb') comm.Barrier() if ignore_artefacts: comm.Barrier() arte_spiketimes_file = open(file_out_suff + '.times-%d.sata' % comm.rank, 'wb') comm.Barrier() arte_electrodes_file = open(file_out_suff + '.elec-%d.sata' % comm.rank, 'wb') comm.Barrier() arte_amplitudes_file = open(file_out_suff + '.amp-%d.sata' % comm.rank, 'wb') comm.Barrier() if mse_error: mse_file = open(file_out_suff + '.mses-%d.data' % comm.rank, 'wb') comm.Barrier() if collect_all: garbage_times_file = open(file_out_suff + '.gspiketimes-%d.data' % comm.rank, 'wb') comm.Barrier() garbage_temp_file = open(file_out_suff + '.gtemplates-%d.data' % comm.rank, 'wb') comm.Barrier() else: garbage_times_file = None # default assignment (for PyCharm code inspection) garbage_temp_file = None # default assignment (for PyCharm code inspection) if debug: # Open debug files. chunk_nbs_debug_file = open(file_out_suff + '.chunk_nbs_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() iteration_nbs_debug_file = open(file_out_suff + '.iteration_nbs_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() peak_nbs_debug_file = open(file_out_suff + '.peak_nbs_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() peak_local_time_steps_debug_file = open( file_out_suff + '.peak_local_time_steps_debug_%d.data' % comm.rank, mode='wb' ) comm.Barrier() peak_time_steps_debug_file = open(file_out_suff + '.peak_time_steps_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() peak_scalar_products_debug_file = open( file_out_suff + '.peak_scalar_products_debug_%d.data' % comm.rank, mode='wb' ) comm.Barrier() peak_solved_flags_debug_file = open(file_out_suff + '.peak_solved_flags_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() template_nbs_debug_file = open(file_out_suff + '.template_nbs_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() success_flags_debug_file = open(file_out_suff + '.success_flags_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() else: chunk_nbs_debug_file = None # default assignment (for PyCharm code inspection) iteration_nbs_debug_file = None # default assignment (for PyCharm code inspection) peak_nbs_debug_file = None # default assignment (for PyCharm code inspection) peak_local_time_steps_debug_file = None # default assignment (for PyCharm code inspection) peak_time_steps_debug_file = None # default assignment (for PyCharm code inspection) peak_scalar_products_debug_file = None # default assignment (for PyCharm code inspection) peak_solved_flags_debug_file = None # default assignment (for PyCharm code inspection) template_nbs_debug_file = None # default assignment (for PyCharm code inspection) success_flags_debug_file = None # default assignment (for PyCharm code inspection) last_chunk_size = 0 slice_indices = numpy.zeros(0, dtype=numpy.int32) to_explore = list(range(comm.rank, processed_chunks, comm.size)) if comm.rank == 0: to_explore = get_tqdm_progressbar(params, to_explore) if fixed_amplitudes: min_scalar_products = amp_limits[:,0][:, numpy.newaxis] max_scalar_products = amp_limits[:,1][:, numpy.newaxis] if templates_normalization: min_sps = min_scalar_products * sub_norm_templates[:, numpy.newaxis] max_sps = max_scalar_products * sub_norm_templates[:, numpy.newaxis] else: min_sps = min_scalar_products * sub_norm_templates_2[:, numpy.newaxis] max_sps = max_scalar_products * sub_norm_templates_2[:, numpy.newaxis] for gcount, gidx in enumerate(to_explore): # print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..." # # We need to deal with the borders by taking chunks of size [0, chunck_size + template_shift]. is_first = data_file.is_first_chunk(gidx, nb_chunks) is_last = data_file.is_last_chunk(gidx, nb_chunks) if not (is_first and is_last): if is_last: padding = (-temp_3_shift, 0) elif is_first: padding = (0, temp_3_shift) else: padding = (-temp_3_shift, temp_3_shift) else: padding = (0, 0) result = { 'spiketimes': [], 'amplitudes': [], 'templates': [], } if mse_error: mse_fit = { 'spiketimes': [], 'amplitudes': [], 'templates': [], } result_debug = { 'chunk_nbs': [], 'iteration_nbs': [], 'peak_nbs': [], 'peak_local_time_steps': [], 'peak_time_steps': [], 'peak_scalar_products': [], 'peak_solved_flags': [], 'template_nbs': [], 'success_flags': [], } local_chunk, t_offset = data_file.get_data(gidx, chunk_size, padding, nodes=nodes) len_chunk = len(local_chunk) if is_last: my_chunk_size = last_chunk_size else: my_chunk_size = chunk_size if do_spatial_whitening: local_chunk = numpy.dot(local_chunk, spatial_whitening) if do_temporal_whitening: local_chunk = scipy.ndimage.filters.convolve1d(local_chunk, temporal_whitening, axis=0, mode='constant') # Extracting peaks. all_found_spikes = {} if collect_all: for i in range(n_e): all_found_spikes[i] = [] local_peaktimes = [numpy.empty(0, dtype=numpy.uint32)] if ignore_artefacts: artefacts_peaktimes = [numpy.zeros(0, dtype=numpy.uint32)] artefacts_elecs = [numpy.zeros(0, dtype=numpy.uint32)] artefacts_amps = [numpy.zeros(0, dtype=numpy.float32)] if matched_filter: if sign_peaks in ['positive', 'both']: filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_pos, axis=0, mode='constant') for i in range(n_e): peaktimes = scipy.signal.find_peaks(filter_chunk[:, i], height=matched_thresholds_pos[i])[0] if ignore_artefacts: artetimes = scipy.signal.find_peaks(numpy.abs(filter_chunk[:, i]), height=weird_thresh[i])[0] to_keep = numpy.logical_not(numpy.in1d(peaktimes, artetimes)) peaktimes = peaktimes[to_keep] artefacts_peaktimes.append(artetimes) artefacts_elecs.append(i*numpy.ones(len(artetimes), dtype='uint32')) artefacts_amps.append(local_chunk[artetimes, i]) local_peaktimes.append(peaktimes) if collect_all: all_found_spikes[i] += peaktimes.tolist() if sign_peaks in ['negative', 'both']: filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_neg, axis=0, mode='constant') for i in range(n_e): peaktimes = scipy.signal.find_peaks(filter_chunk[:, i], height=matched_thresholds_neg[i])[0] if ignore_artefacts: artetimes = scipy.signal.find_peaks(numpy.abs(filter_chunk[:, i]), height=weird_thresh[i])[0] to_keep = numpy.logical_not(numpy.in1d(peaktimes, artetimes)) peaktimes = peaktimes[to_keep] artefacts_peaktimes.append(artetimes) artefacts_elecs.append(i*numpy.ones(len(artetimes), dtype='uint32')) artefacts_amps.append(local_chunk[artetimes, i]) local_peaktimes.append(peaktimes) if collect_all: all_found_spikes[i] += peaktimes.tolist() local_peaktimes = numpy.concatenate(local_peaktimes) if ignore_artefacts: artefacts_peaktimes = numpy.concatenate(artefacts_peaktimes) artefacts_elecs = numpy.concatenate(artefacts_elecs) artefacts_amps = numpy.concatenate(artefacts_amps) else: for i in range(n_e): if sign_peaks == 'negative': peaktimes = scipy.signal.find_peaks(-local_chunk[:, i], height=thresholds[i])[0] elif sign_peaks == 'positive': peaktimes = scipy.signal.find_peaks(local_chunk[:, i], height=thresholds[i])[0] elif sign_peaks == 'both': peaktimes = scipy.signal.find_peaks(numpy.abs(local_chunk[:, i]), height=thresholds[i])[0] else: raise ValueError("Unexpected value %s" % sign_peaks) if ignore_artefacts: artetimes = scipy.signal.find_peaks(numpy.abs(local_chunk[:, i]), height=weird_thresh[i])[0] to_keep = numpy.logical_not(numpy.in1d(peaktimes, artetimes)) peaktimes = peaktimes[to_keep] artefacts_peaktimes.append(artetimes) artefacts_elecs.append(i*numpy.ones(len(artetimes), dtype='uint32')) artefacts_amps.append(local_chunk[artetimes, i]) local_peaktimes.append(peaktimes) if collect_all: all_found_spikes[i] += peaktimes.tolist() local_peaktimes = numpy.concatenate(local_peaktimes) if ignore_artefacts: artefacts_peaktimes = numpy.concatenate(artefacts_peaktimes) artefacts_elecs = numpy.concatenate(artefacts_elecs) artefacts_amps = numpy.concatenate(artefacts_amps) local_peaktimes = numpy.unique(local_peaktimes) g_offset = t_offset + padding[0] if ignore_dead_times: dead_indices = numpy.searchsorted(all_dead_times, [t_offset, t_offset + my_chunk_size]) if dead_indices[0] != dead_indices[1]: is_included = numpy.in1d(local_peaktimes + g_offset, all_dead_times[dead_indices[0]:dead_indices[1]]) local_peaktimes = local_peaktimes[~is_included] if ignore_artefacts: is_included = numpy.in1d(artefacts_peaktimes + g_offset, all_dead_times[dead_indices[0]:dead_indices[1]]) artefacts_peaktimes = artefacts_peaktimes[~is_included] artefacts_elecs = artefacts_elecs[~is_included] artefacts_amps = artefacts_amps[~is_included] local_peaktimes = numpy.sort(local_peaktimes) else: dead_indices = None # default assignment (for PyCharm code inspection) # print "Removing the useless borders..." local_borders = (template_shift, len_chunk - template_shift) idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes < local_borders[1]) local_peaktimes = numpy.compress(idx, local_peaktimes) if ignore_artefacts: artefacts_peaktimes = artefacts_peaktimes + g_offset idx = (artefacts_peaktimes >= t_offset) & (artefacts_peaktimes < t_offset + my_chunk_size) artefacts_peaktimes = numpy.compress(idx, artefacts_peaktimes) artefacts_elecs = numpy.compress(idx, artefacts_elecs) artefacts_amps = numpy.compress(idx, artefacts_amps) if collect_all: for i in range(n_e): all_found_spikes[i] = numpy.array(all_found_spikes[i], dtype=numpy.uint32) if ignore_dead_times: if dead_indices[0] != dead_indices[1]: is_included = numpy.in1d( all_found_spikes[i] + g_offset, all_dead_times[dead_indices[0]:dead_indices[1]] ) all_found_spikes[i] = all_found_spikes[i][~is_included] all_found_spikes[i] = numpy.sort(all_found_spikes[i]) idx = (all_found_spikes[i] >= local_borders[0]) & (all_found_spikes[i] < local_borders[1]) all_found_spikes[i] = numpy.compress(idx, all_found_spikes[i]) nb_local_peak_times = len(local_peaktimes) if nb_local_peak_times > 0: # print "Computing the b (should full_gpu by putting all chunks on GPU if possible?)..." if collect_all or mse_error: c_local_chunk = local_chunk.copy() else: c_local_chunk = None # default assignment (for PyCharm code inspection) sub_mat = local_chunk[local_peaktimes[:, None] + temp_window] sub_mat = sub_mat.transpose(2, 1, 0).reshape(size_window, nb_local_peak_times) del local_chunk b = templates.dot(sub_mat) del sub_mat local_restriction = (t_offset, t_offset + my_chunk_size) all_spikes = local_peaktimes + g_offset if collect_all: c_all_times = numpy.zeros((len_chunk, n_e), dtype=numpy.bool) c_min_times = numpy.maximum(numpy.arange(len_chunk) - template_shift, 0) c_max_times = numpy.minimum(numpy.arange(len_chunk) + template_shift + 1, len_chunk) for i in range(n_e): c_all_times[all_found_spikes[i], i] = True else: c_all_times = None # default assignment (for PyCharm code inspection) c_min_times = None # default assignment (for PyCharm code inspection) c_max_times = None # default assignment (for PyCharm code inspection) iteration_nb = 0 data = b[:n_tm, :] if not fixed_amplitudes: amp_index = numpy.searchsorted(splits, local_restriction[0], 'right') scaling = 1/(splits[amp_index] - splits[amp_index - 1]) min_scalar_products = amp_limits[:, amp_index, 0] + (amp_limits[:, amp_index, 0] - amp_limits[:, amp_index+1, 0])*scaling max_scalar_products = amp_limits[:, amp_index, 1] + (amp_limits[:, amp_index, 1] - amp_limits[:, amp_index+1, 0])*scaling min_scalar_products = min_scalar_products[:, numpy.newaxis] max_scalar_products = max_scalar_products[:, numpy.newaxis] if templates_normalization: min_sps = min_scalar_products * sub_norm_templates[:, numpy.newaxis] max_sps = max_scalar_products * sub_norm_templates[:, numpy.newaxis] else: min_sps = min_scalar_products * sub_norm_templates_2[:, numpy.newaxis] max_sps = max_scalar_products * sub_norm_templates_2[:, numpy.newaxis] while True: is_valid = (data > min_sps)*(data < max_sps) valid_indices = numpy.where(is_valid) if len(valid_indices[0]) == 0: break best_amplitude_idx = data[is_valid].argmax() best_template_index, peak_index = valid_indices[0][best_amplitude_idx], valid_indices[1][best_amplitude_idx] peak_scalar_product = data[is_valid][best_amplitude_idx] best_template2_index = best_template_index + n_tm if templates_normalization: best_amp = b[best_template_index, peak_index] / n_scalar best_amp_n = best_amp / norm_templates[best_template_index] if two_components: best_amp2 = b[best_template2_index, peak_index] / n_scalar best_amp2_n = best_amp2 / norm_templates[best_template2_index] else: best_amp2 = 0 best_amp2_n = 0 else: best_amp = b[best_template_index, peak_index] / norm_templates_2[best_template_index] best_amp_n = best_amp if two_components: best_amp2 = b[best_template2_index, peak_index] / norm_templates_2[best_template2_index] best_amp2_n = best_amp2 else: best_amp2 = 0 best_amp2_n = 0 peak_time_step = local_peaktimes[peak_index] peak_data = (local_peaktimes - peak_time_step).astype(np.int32) is_neighbor = np.abs(peak_data) <= temp_2_shift idx_neighbor = peak_data[is_neighbor] + temp_2_shift tmp1 = c_overs[best_template_index].multiply(-best_amp) if numpy.abs(best_amp2_n) > min_second_component: tmp1 += c_overs[best_template2_index].multiply(-best_amp2) to_add = tmp1.toarray()[:, idx_neighbor] b[:, is_neighbor] += to_add b[best_template_index, peak_index] = -numpy.inf # Add matching to the result. t_spike = all_spikes[peak_index] if (t_spike >= local_restriction[0]) and (t_spike < local_restriction[1]): result['spiketimes'] += [t_spike] result['amplitudes'] += [(best_amp_n, best_amp2_n)] result['templates'] += [best_template_index] elif mse_error: mse_fit['spiketimes'] += [t_spike] mse_fit['amplitudes'] += [(best_amp_n, best_amp2_n)] mse_fit['templates'] += [best_template_index] # Save debug data. if debug: result_debug['chunk_nbs'] += [gidx] result_debug['iteration_nbs'] += [iteration_nb] result_debug['peak_nbs'] += [peak_index] result_debug['peak_local_time_steps'] += [local_peaktimes[peak_index]] result_debug['peak_time_steps'] += [all_spikes[peak_index]] result_debug['peak_scalar_products'] += [peak_scalar_product] result_debug['peak_solved_flags'] += [b[best_template_index, peak_index]] result_debug['template_nbs'] += [best_template_index] result_debug['success_flags'] += [True] iteration_nb += 1 spikes_to_write = numpy.array(result['spiketimes'], dtype=numpy.uint32) amplitudes_to_write = numpy.array(result['amplitudes'], dtype=numpy.float32) templates_to_write = numpy.array(result['templates'], dtype=numpy.uint32) spiketimes_file.write(spikes_to_write.tostring()) amplitudes_file.write(amplitudes_to_write.tostring()) templates_file.write(templates_to_write.tostring()) if ignore_artefacts: arte_spiketimes_file.write(artefacts_peaktimes.astype(numpy.uint32).tostring()) arte_electrodes_file.write(artefacts_elecs.tostring()) arte_amplitudes_file.write(artefacts_amps.tostring()) if mse_error: curve = numpy.zeros((len_chunk, n_e), dtype=numpy.float32) for spike, temp_id, amplitude in zip(result['spiketimes'], result['templates'], result['amplitudes']): spike = spike - t_offset - padding[0] if is_sparse: tmp1 = templates[temp_id].toarray().reshape(n_e, n_t) tmp2 = templates[temp_id + n_tm].toarray().reshape(n_e, n_t) else: tmp1 = templates[temp_id].reshape(n_e, n_t) tmp2 = templates[temp_id + n_tm].reshape(n_e, n_t) curve[spike - template_shift:spike + template_shift + 1, :] += (amplitude[0] * tmp1 + amplitude[1] * tmp2).T for spike, temp_id, amplitude in zip(mse_fit['spiketimes'], mse_fit['templates'], mse_fit['amplitudes']): spike = spike - t_offset + padding[0] if is_sparse: tmp1 = templates[temp_id].toarray().reshape(n_e, n_t) tmp2 = templates[temp_id + n_tm].toarray().reshape(n_e, n_t) else: tmp1 = templates[temp_id].reshape(n_e, n_t) tmp2 = templates[temp_id + n_tm].reshape(n_e, n_t) try: curve[int(spike) - template_shift:int(spike) + template_shift + 1, :] += (amplitude[0] * tmp1 + amplitude[1] * tmp2).T except Exception: pass mse = numpy.linalg.norm((curve - c_local_chunk)[-padding[0]:-padding[1]]) nb_points = len(curve) - (padding[1] - padding[0]) mse_ratio = mse/(numpy.sqrt(nb_points)*stds_norm) mse_to_write = numpy.array([g_offset, mse_ratio], dtype=numpy.float32) mse_file.write(mse_to_write.tostring()) if collect_all: for temp, spike in zip(templates_to_write, spikes_to_write - g_offset): c_all_times[c_min_times[spike]:c_max_times[spike], neighbors[temp]] = False gspikes = numpy.where(numpy.sum(c_all_times, 1) > 0)[0] c_all_times = numpy.take(c_all_times, gspikes, axis=0) c_local_chunk = numpy.take(c_local_chunk, gspikes, axis=0) * c_all_times if sign_peaks == 'negative': bestlecs = numpy.argmin(c_local_chunk, 1) if matched_filter: threshs = -matched_thresholds_neg[bestlecs] else: threshs = -thresholds[bestlecs] idx = numpy.where(numpy.min(c_local_chunk, 1) < threshs)[0] elif sign_peaks == 'positive': bestlecs = numpy.argmax(c_local_chunk, 1) if matched_filter: threshs = matched_thresholds_pos[bestlecs] else: threshs = thresholds[bestlecs] idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0] elif sign_peaks == 'both': c_local_chunk = numpy.abs(c_local_chunk) bestlecs = numpy.argmax(c_local_chunk, 1) if matched_filter: threshs = numpy.minimum(matched_thresholds_neg[bestlecs], matched_thresholds_pos[bestlecs]) else: threshs = thresholds[bestlecs] idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0] else: raise ValueError("Unexpected value %s" % sign_peaks) gspikes = numpy.take(gspikes, idx) bestlecs = numpy.take(bestlecs, idx) gspikes_to_write = numpy.array(gspikes + g_offset, dtype=numpy.uint32) gtemplates_to_write = numpy.array(bestlecs, dtype=numpy.uint32) garbage_times_file.write(gspikes_to_write.tostring()) garbage_temp_file.write(gtemplates_to_write.tostring()) if debug: # Write debug data to debug files. for field_label, field_dtype, field_file in [ ('chunk_nbs', numpy.uint32, chunk_nbs_debug_file), ('iteration_nbs', numpy.uint32, iteration_nbs_debug_file), ('peak_nbs', numpy.uint32, peak_nbs_debug_file), ('peak_local_time_steps', numpy.uint32, peak_local_time_steps_debug_file), ('peak_time_steps', numpy.uint32, peak_time_steps_debug_file), ('peak_scalar_products', numpy.float32, peak_scalar_products_debug_file), ('peak_solved_flags', numpy.float32, peak_solved_flags_debug_file), ('template_nbs', numpy.uint32, template_nbs_debug_file), ('success_flags', numpy.bool, success_flags_debug_file), ]: field_to_write = numpy.array(result_debug[field_label], dtype=field_dtype) field_file.write(field_to_write.tostring()) sys.stderr.flush() spiketimes_file.flush() os.fsync(spiketimes_file.fileno()) spiketimes_file.close() amplitudes_file.flush() os.fsync(amplitudes_file.fileno()) amplitudes_file.close() templates_file.flush() os.fsync(templates_file.fileno()) templates_file.close() if collect_all: garbage_temp_file.flush() os.fsync(garbage_temp_file.fileno()) garbage_temp_file.close() garbage_times_file.flush() os.fsync(garbage_times_file.fileno()) garbage_times_file.close() if mse_error: mse_file.flush() os.fsync(mse_file.fileno()) mse_file.close() if ignore_artefacts: arte_spiketimes_file.flush() os.fsync(arte_spiketimes_file.fileno()) arte_spiketimes_file.close() arte_electrodes_file.flush() os.fsync(arte_electrodes_file.fileno()) arte_electrodes_file.close() arte_amplitudes_file.flush() os.fsync(arte_amplitudes_file.fileno()) arte_amplitudes_file.close() if debug: # Close debug files. for field_file in [ chunk_nbs_debug_file, iteration_nbs_debug_file, peak_nbs_debug_file, peak_local_time_steps_debug_file, peak_time_steps_debug_file, peak_scalar_products_debug_file, peak_solved_flags_debug_file, template_nbs_debug_file, success_flags_debug_file, ]: field_file.flush() os.fsync(field_file.fileno()) field_file.close() comm.Barrier() if SHARED_MEMORY: for memory in mpi_memory_1 + mpi_memory_2: memory.Free() if ignore_dead_times: mpi_memory_3.Free() if comm.rank == 0: io.collect_data(comm.size, params, erase=True) if ignore_artefacts: io.collect_artefacts(comm.size, params, erase=True) data_file.close()