def main(params, nb_cpu, nb_gpu, use_gpu): ################################################################# #params = detect_memory(params) logger = init_logging(params.logfile) SHARED_MEMORY = get_shared_memory_flag(params) logger = logging.getLogger('circus.fitting') data_file = params.data_file data_file.open() N_e = params.getint('data', 'N_e') N_total = params.nb_channels N_t = params.getint('detection', 'N_t') template_shift = params.getint('detection', 'template_shift') file_out = params.get('data', 'file_out') file_out_suff = params.get('data', 'file_out_suff') sign_peaks = params.get('detection', 'peaks') matched_filter = params.getboolean('detection', 'matched-filter') spike_thresh = params.getfloat('detection', 'spike_thresh') do_temporal_whitening = params.getboolean('whitening', 'temporal') do_spatial_whitening = params.getboolean('whitening', 'spatial') chunk_size = params.getint('fitting', 'chunk_size') gpu_only = params.getboolean('fitting', 'gpu_only') nodes, edges = get_nodes_and_edges(params) tmp_limits = params.get('fitting', 'amp_limits').replace('(', '').replace(')', '').split(',') tmp_limits = map(float, tmp_limits) amp_auto = params.getboolean('fitting', 'amp_auto') nb_chances = params.getint('fitting', 'nb_chances') max_chunk = params.getfloat('fitting', 'max_chunk') noise_thr = params.getfloat('clustering', 'noise_thr') collect_all = params.getboolean('fitting', 'collect_all') ignore_dead_times = params.getboolean('triggers', 'ignore_times') inv_nodes = numpy.zeros(N_total, dtype=numpy.int32) inv_nodes[nodes] = numpy.argsort(nodes) ################################################################# if use_gpu: import cudamat as cmt ## Need to properly handle multi GPU per MPI nodes? if nb_gpu > nb_cpu: gpu_id = int(comm.rank // nb_cpu) else: gpu_id = 0 cmt.cuda_set_device(gpu_id) cmt.init() cmt.cuda_sync_threads() if SHARED_MEMORY: templates = io.load_data_memshared(params, 'templates', normalize=True, transpose=True) N_tm, x = templates.shape else: templates = io.load_data(params, 'templates') x, N_tm = templates.shape temp_2_shift = 2 * template_shift full_gpu = use_gpu and gpu_only n_tm = N_tm // 2 n_scalar = N_e * N_t temp_window = numpy.arange(-template_shift, template_shift + 1) size_window = N_e * (2 * template_shift + 1) if not amp_auto: amp_limits = numpy.zeros((n_tm, 2)) amp_limits[:, 0] = tmp_limits[0] amp_limits[:, 1] = tmp_limits[1] else: amp_limits = io.load_data(params, 'limits') norm_templates = io.load_data(params, 'norm-templates') if not SHARED_MEMORY: for idx in xrange(templates.shape[1]): myslice = numpy.arange(templates.indptr[idx], templates.indptr[idx + 1]) templates.data[myslice] /= norm_templates[idx] templates = templates.T if matched_filter: if sign_peaks in ['negative', 'both']: waveform_neg = io.load_data(params, 'waveform') waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) * len(waveform_neg)) matched_tresholds_neg = io.load_data(params, 'matched-thresholds') if sign_peaks in ['positive', 'both']: waveform_pos = io.load_data(params, 'waveform-pos') waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) * len(waveform_pos)) matched_tresholds_pos = io.load_data(params, 'matched-thresholds-pos') if ignore_dead_times: all_dead_times = get_dead_times(params) thresholds = io.load_data(params, 'thresholds') if collect_all: neighbors = {} for i in xrange(n_tm): tmp = templates[i, :].toarray().reshape(N_e, N_t) * norm_templates[i] neighbors[i] = numpy.where(numpy.sum(tmp, 1) != 0)[0] if use_gpu: templates = cmt.SparseCUDAMatrix(templates, copy_on_host=False) info_string = '' if comm.rank == 0: if use_gpu: info_string = "using %d GPUs" % (comm.size) else: info_string = "using %d CPUs" % (comm.size) comm.Barrier() c_overlap = io.get_overlaps(params, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu) over_shape = c_overlap.get('over_shape')[:] N_over = int(numpy.sqrt(over_shape[0])) S_over = over_shape[1] ## If the number of overlaps is different from templates, we need to recompute them if N_over != N_tm: if comm.rank == 0: print_and_log( ['Templates have been modified, recomputing the overlaps...'], 'default', logger) c_overlap = io.get_overlaps(params, erase=True, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu) over_shape = c_overlap.get('over_shape')[:] N_over = int(numpy.sqrt(over_shape[0])) S_over = over_shape[1] if SHARED_MEMORY: c_overs = io.load_data_memshared(params, 'overlaps') else: c_overs = io.load_data(params, 'overlaps') comm.Barrier() if n_tm == 0: if comm.rank == 0: print_and_log(["No templates present. Redo clustering?"], 'default', logger) sys.exit(0) if comm.rank == 0: print_and_log([ "Here comes the SpyKING CIRCUS %s and %d templates..." % (info_string, n_tm) ], 'default', logger) purge(file_out_suff, '.data') if do_spatial_whitening: spatial_whitening = io.load_data(params, 'spatial_whitening') if do_temporal_whitening: temporal_whitening = io.load_data(params, 'temporal_whitening') if full_gpu: try: # If memory on the GPU is large enough, we load the overlaps onto it for i in xrange(N_over): c_overs[i] = cmt.SparseCUDAMatrix(c_overs[i], copy_on_host=False) except Exception: if comm.rank == 0: print_and_log([ "Not enough memory on GPUs: GPUs are used for projection only" ], 'info', logger) for i in xrange(N_over): if c_overs.has_key(i): del c_overs[i] full_gpu = False nb_chunks, last_chunk_len = data_file.analyze(chunk_size) processed_chunks = int(min(nb_chunks, max_chunk)) comm.Barrier() spiketimes_file = open(file_out_suff + '.spiketimes-%d.data' % comm.rank, 'wb') comm.Barrier() amplitudes_file = open(file_out_suff + '.amplitudes-%d.data' % comm.rank, 'wb') comm.Barrier() templates_file = open(file_out_suff + '.templates-%d.data' % comm.rank, 'wb') comm.Barrier() if collect_all: garbage_times_file = open( file_out_suff + '.gspiketimes-%d.data' % comm.rank, 'wb') comm.Barrier() garbage_temp_file = open( file_out_suff + '.gtemplates-%d.data' % comm.rank, 'wb') comm.Barrier() if use_gpu and do_spatial_whitening: spatial_whitening = cmt.CUDAMatrix(spatial_whitening, copy_on_host=False) last_chunk_size = 0 to_explore = xrange(comm.rank, processed_chunks, comm.size) if comm.rank == 0: to_explore = get_tqdm_progressbar(to_explore) for gcount, gidx in enumerate(to_explore): #print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..." ## We need to deal with the borders by taking chunks of size [0, chunck_size+template_shift] is_first = data_file.is_first_chunk(gidx, nb_chunks) is_last = data_file.is_last_chunk(gidx, nb_chunks) if is_last: padding = (-temp_2_shift, 0) elif is_first: padding = (0, temp_2_shift) else: padding = (-temp_2_shift, temp_2_shift) result = {'spiketimes': [], 'amplitudes': [], 'templates': []} local_chunk, t_offset = data_file.get_data(gidx, chunk_size, padding, nodes=nodes) len_chunk = len(local_chunk) if do_spatial_whitening: if use_gpu: local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False) local_chunk = local_chunk.dot(spatial_whitening).asarray() else: local_chunk = numpy.dot(local_chunk, spatial_whitening) if do_temporal_whitening: local_chunk = scipy.ndimage.filters.convolve1d(local_chunk, temporal_whitening, axis=0, mode='constant') #print "Extracting the peaks..." if collect_all: all_found_spikes = {} for i in xrange(N_e): all_found_spikes[i] = [] local_peaktimes = numpy.zeros(0, dtype=numpy.uint32) if matched_filter: if sign_peaks in ['positive', 'both']: filter_chunk = scipy.ndimage.filters.convolve1d( local_chunk, waveform_pos, axis=0, mode='constant') for i in xrange(N_e): peaktimes = algo.detect_peaks(filter_chunk[:, i], matched_tresholds_pos[i]) local_peaktimes = numpy.concatenate( (local_peaktimes, peaktimes)) if collect_all: all_found_spikes[i] += peaktimes.tolist() if sign_peaks in ['negative', 'both']: filter_chunk = scipy.ndimage.filters.convolve1d( local_chunk, waveform_neg, axis=0, mode='constant') for i in xrange(N_e): peaktimes = algo.detect_peaks(filter_chunk[:, i], matched_tresholds_neg[i]) local_peaktimes = numpy.concatenate( (local_peaktimes, peaktimes)) if collect_all: all_found_spikes[i] += peaktimes.tolist() else: for i in xrange(N_e): if sign_peaks == 'negative': peaktimes = algo.detect_peaks(local_chunk[:, i], thresholds[i], valley=True) elif sign_peaks == 'positive': peaktimes = algo.detect_peaks(local_chunk[:, i], thresholds[i], valley=False) elif sign_peaks == 'both': peaktimes = algo.detect_peaks(numpy.abs(local_chunk[:, i]), thresholds[i], valley=False) local_peaktimes = numpy.concatenate( (local_peaktimes, peaktimes)) if collect_all: all_found_spikes[i] += peaktimes.tolist() local_peaktimes = numpy.unique(local_peaktimes) g_offset = t_offset + padding[0] if ignore_dead_times: dead_indices = numpy.searchsorted( all_dead_times, [t_offset, t_offset + chunk_size]) if dead_indices[0] != dead_indices[1]: local_peaktimes = numpy.array(list( set(local_peaktimes + g_offset).difference( all_dead_times[dead_indices[0]:dead_indices[1]])), dtype=numpy.uint32) - g_offset local_peaktimes = numpy.sort(local_peaktimes) #print "Removing the useless borders..." local_borders = (template_shift, len_chunk - template_shift) idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes < local_borders[1]) local_peaktimes = numpy.compress(idx, local_peaktimes) if collect_all: for i in xrange(N_e): all_found_spikes[i] = numpy.array(all_found_spikes[i], dtype=numpy.uint32) if ignore_dead_times: if dead_indices[0] != dead_indices[1]: all_found_spikes[i] = numpy.array( list( set(all_found_spikes[i] + g_offset).difference( all_dead_times[ dead_indices[0]:dead_indices[1]])), dtype=numpy.uint32) - g_offset all_found_spikes[i] = numpy.sort(all_found_spikes[i]) idx = (all_found_spikes[i] >= local_borders[0]) & ( all_found_spikes[i] < local_borders[1]) all_found_spikes[i] = numpy.compress(idx, all_found_spikes[i]) n_t = len(local_peaktimes) if full_gpu: # all_indices = cmt.CUDAMatrix(all_indices) tmp_gpu = cmt.CUDAMatrix(local_peaktimes.reshape((1, n_t)), copy_on_host=False) if n_t > 0: #print "Computing the b (should full_gpu by putting all chunks on GPU if possible?)..." if collect_all: c_local_chunk = local_chunk.copy() local_chunk = local_chunk.T.ravel() sub_mat = numpy.zeros((size_window, n_t), dtype=numpy.float32) if len_chunk != last_chunk_size: slice_indices = numpy.zeros(0, dtype=numpy.int32) for idx in xrange(N_e): slice_indices = numpy.concatenate( (slice_indices, len_chunk * idx + temp_window)) last_chunk_size = len_chunk for count, idx in enumerate(local_peaktimes): sub_mat[:, count] = numpy.take(local_chunk, slice_indices + idx) del local_chunk if use_gpu: sub_mat = cmt.CUDAMatrix(sub_mat, copy_on_host=False) b = cmt.sparse_dot(templates, sub_mat) else: b = templates.dot(sub_mat) del sub_mat local_restriction = (t_offset, t_offset + chunk_size) all_spikes = local_peaktimes + g_offset # Because for GPU, slicing by columns is more efficient, we need to transpose b #b = b.transpose() if use_gpu and not full_gpu: b = b.asarray() failure = numpy.zeros(n_t, dtype=numpy.int32) if full_gpu: mask = numpy.zeros((2 * n_tm, n_t), dtype=numpy.float32) mask[:n_tm, :] = 1 data = cmt.empty(mask.shape) patch_gpu = b.shape[1] == 1 else: mask = numpy.ones((n_tm, n_t), dtype=numpy.float32) sub_b = b[:n_tm, :] if collect_all: c_all_times = numpy.zeros((len_chunk, N_e), dtype=numpy.bool) c_min_times = numpy.maximum( numpy.arange(len_chunk) - template_shift, 0) c_max_times = numpy.minimum( numpy.arange(len_chunk) + template_shift + 1, len_chunk) for i in xrange(N_e): c_all_times[all_found_spikes[i], i] = True while (numpy.mean(failure) < nb_chances): if full_gpu: gpu_mask = cmt.CUDAMatrix(mask, copy_on_host=False) b.mult(gpu_mask, data) tmp_mat = data.max(0) argmax_bi = numpy.argsort(tmp_mat.asarray()[0, :])[::-1] del tmp_mat else: data = sub_b * mask argmax_bi = numpy.argsort(numpy.max(data, 0))[::-1] for peak_index in argmax_bi: if full_gpu: b_array = b.asarray() sub_b = b_array[:n_tm, :] peak_scalar_products = np.take(sub_b, peak_index, axis=1) best_template_index = np.argmax(peak_scalar_products, axis=0) best_template2_index = best_template_index + n_tm if full_gpu: best_amp = sub_b[best_template_index, peak_index] / n_scalar best_amp2 = b_array[best_template_index, peak_index] / n_scalar else: best_amp = sub_b[best_template_index, peak_index] / n_scalar best_amp2 = b[best_template2_index, peak_index] / n_scalar best_amp_n = best_amp / norm_templates[best_template_index] best_amp2_n = best_amp2 / norm_templates[ best_template2_index] # Verify amplitude constraint. a_min = amp_limits[best_template_index, 0] a_max = amp_limits[best_template_index, 1] if (a_min <= best_amp_n) & (best_amp_n <= a_max): # Keep the matching. peak_time_step = local_peaktimes[peak_index] data = (local_peaktimes - peak_time_step).astype( np.int32) is_neighbor = np.where(np.abs(data) <= temp_2_shift)[0] idx_neighbor = data[is_neighbor] + temp_2_shift nb_neighbors = len(is_neighbor) indices = np.zeros((S_over, nb_neighbors), dtype=np.int32) indices[idx_neighbor, np.arange(nb_neighbors)] = 1 if full_gpu: indices = cmt.CUDAMatrix(indices, copy_on_host=False) if patch_gpu: b_lines = b.get_col_slice(0, b.shape[0]) else: b_lines = b.get_col_slice( is_neighbor[0], is_neighbor[-1] + 1) tmp1 = cmt.sparse_dot(c_overs[best_template_index], indices, mult=-best_amp[keep]) tmp2 = cmt.sparse_dot( c_overs[best_template2_index], indices, mult=-best_amp2[keep]) b_lines.add(tmp1.add(tmp2)) del tmp1, tmp2 else: tmp1 = c_overs[best_template_index].multiply( -best_amp) tmp2 = c_overs[best_template2_index].multiply( -best_amp2) b[:, is_neighbor] += (tmp1 + tmp2).dot(indices) # Add matching to the result. t_spike = all_spikes[peak_index] if (t_spike >= local_restriction[0]) and ( t_spike < local_restriction[1]): #print "Accept spikes", t_spike, local_restriction, type(t_spike), t_spike > local_restriction[0], t_spike < local_restriction[1] result['spiketimes'] += [t_spike] result['amplitudes'] += [(best_amp_n, best_amp2_n)] result['templates'] += [best_template_index] # Mark current matching as tried. mask[best_template_index, peak_index] = 0 else: # Reject the matching. # Update failure counter of the peak. failure[peak_index] += 1 # If the maximal number of failures is reached then mark peak as solved (i.e. not fitted). if failure[peak_index] == nb_chances: mask[:, peak_index] = 0 else: mask[best_template_index, peak_index] = 0 spikes_to_write = numpy.array(result['spiketimes'], dtype=numpy.uint32) amplitudes_to_write = numpy.array(result['amplitudes'], dtype=numpy.float32) templates_to_write = numpy.array(result['templates'], dtype=numpy.uint32) spiketimes_file.write(spikes_to_write.tostring()) amplitudes_file.write(amplitudes_to_write.tostring()) templates_file.write(templates_to_write.tostring()) if collect_all: for temp, spike in zip(templates_to_write, spikes_to_write - g_offset): c_all_times[c_min_times[spike]:c_max_times[spike], neighbors[temp]] = False gspikes = numpy.where(numpy.sum(c_all_times, 1) > 0)[0] c_all_times = numpy.take(c_all_times, gspikes, axis=0) c_local_chunk = numpy.take(c_local_chunk, gspikes, axis=0) * c_all_times if sign_peaks == 'negative': bestlecs = numpy.argmin(c_local_chunk, 1) if matched_filter: threshs = -matched_tresholds_neg[bestlecs] else: threshs = -thresholds[bestlecs] idx = numpy.where(numpy.min(c_local_chunk, 1) < threshs)[0] elif sign_peaks == 'positive': bestlecs = numpy.argmax(c_local_chunk, 1) if matched_filter: threshs = matched_tresholds_pos[bestlecs] else: threshs = thresholds[bestlecs] idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0] elif sign_peaks == 'both': c_local_chunk = numpy.abs(c_local_chunk) bestlecs = numpy.argmax(c_local_chunk, 1) if matched_filter: threshs = numpy.minimum( matched_tresholds_neg[bestlecs], matched_tresholds_pos[bestlecs]) else: threshs = thresholds[bestlecs] idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0] gspikes = numpy.take(gspikes, idx) bestlecs = numpy.take(bestlecs, idx) gspikes_to_write = numpy.array(gspikes + g_offset, dtype=numpy.uint32) gtemplates_to_write = numpy.array(bestlecs, dtype=numpy.uint32) garbage_times_file.write(gspikes_to_write.tostring()) garbage_temp_file.write(gtemplates_to_write.tostring()) if full_gpu: del gpu_mask, b, data sys.stderr.flush() spiketimes_file.flush() os.fsync(spiketimes_file.fileno()) spiketimes_file.close() amplitudes_file.flush() os.fsync(amplitudes_file.fileno()) amplitudes_file.close() templates_file.flush() os.fsync(templates_file.fileno()) templates_file.close() if collect_all: garbage_temp_file.flush() os.fsync(garbage_temp_file.fileno()) garbage_temp_file.close() garbage_times_file.flush() os.fsync(garbage_times_file.fileno()) garbage_times_file.close() comm.Barrier() if comm.rank == 0: io.collect_data(comm.size, params, erase=True) data_file.close()
def main(filename, params, nb_cpu, nb_gpu, use_gpu): try: SHARED_MEMORY = True MPI.Win.Allocate_shared(1, 1, MPI.INFO_NULL, MPI.COMM_SELF).Free() except NotImplementedError: SHARED_MEMORY = False ################################################################# sampling_rate = params.getint('data', 'sampling_rate') N_e = params.getint('data', 'N_e') N_t = params.getint('data', 'N_t') N_total = params.getint('data', 'N_total') template_shift = params.getint('data', 'template_shift') file_out = params.get('data', 'file_out') file_out_suff = params.get('data', 'file_out_suff') sign_peaks = params.get('detection', 'peaks') matched_filter = params.getboolean('detection', 'matched-filter') spike_thresh = params.getfloat('detection', 'spike_thresh') do_temporal_whitening = params.getboolean('whitening', 'temporal') do_spatial_whitening = params.getboolean('whitening', 'spatial') chunk_size = int(params.getfloat('fitting', 'chunk') * sampling_rate) gpu_only = params.getboolean('fitting', 'gpu_only') nodes, edges = io.get_nodes_and_edges(params) tmp_limits = params.get('fitting', 'amp_limits').replace('(', '').replace(')', '').split(',') tmp_limits = map(float, tmp_limits) amp_auto = params.getboolean('fitting', 'amp_auto') space_explo = params.getfloat('fitting', 'space_explo') nb_chances = params.getint('fitting', 'nb_chances') max_chunk = params.getfloat('fitting', 'max_chunk') inv_nodes = numpy.zeros(N_total, dtype=numpy.int32) inv_nodes[nodes] = numpy.argsort(nodes) ################################################################# if use_gpu: import cudamat as cmt ## Need to properly handle multi GPU per MPI nodes? if nb_gpu > nb_cpu: gpu_id = int(comm.rank // nb_cpu) else: gpu_id = 0 cmt.cuda_set_device(gpu_id) cmt.init() cmt.cuda_sync_threads() if SHARED_MEMORY: templates = io.load_data_memshared(params, comm, 'templates', normalize=True, transpose=True) N_tm, x = templates.shape else: templates = io.load_data(params, 'templates') x, N_tm = templates.shape N_e = params.getint('data', 'N_e') N_t = params.getint('data', 'N_t') template_shift = int((N_t - 1) // 2) temp_2_shift = 2 * template_shift full_gpu = use_gpu and gpu_only n_tm = N_tm // 2 n_scalar = N_e * N_t last_spikes = numpy.zeros((n_tm, 1), dtype=numpy.int32) temp_window = numpy.arange(-template_shift, template_shift + 1) if not amp_auto: amp_limits = numpy.zeros((n_tm, 2)) amp_limits[:, 0] = tmp_limits[0] amp_limits[:, 1] = tmp_limits[1] else: amp_limits = io.load_data(params, 'limits') norm_templates = io.load_data(params, 'norm-templates') if not SHARED_MEMORY: for idx in xrange(templates.shape[1]): myslice = numpy.arange(templates.indptr[idx], templates.indptr[idx + 1]) templates.data[myslice] /= norm_templates[idx] templates = templates.T if use_gpu: templates = cmt.SparseCUDAMatrix(templates) info_string = '' if matched_filter: if sign_peaks in ['negative', 'both']: waveform_neg = io.load_data(params, 'waveform') waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) * len(waveform_neg)) matched_tresholds_neg = io.load_data(params, 'matched-thresholds') if sign_peaks in ['positive', 'both']: waveform_pos = io.load_data(params, 'waveform-pos') waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) * len(waveform_pos)) matched_tresholds_pos = io.load_data(params, 'matched-thresholds-pos') if comm.rank == 0: if use_gpu: info_string = "using %d GPUs" % (comm.size) else: info_string = "using %d CPUs" % (comm.size) comm.Barrier() thresholds = io.load_data(params, 'thresholds') if SHARED_MEMORY: c_overs = io.load_data_memshared(params, comm, 'overlaps', nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu) c_overlap = io.get_overlaps(comm, params, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu) over_shape = c_overlap.get('over_shape')[:] N_over = int(numpy.sqrt(over_shape[0])) S_over = over_shape[1] else: c_overlap = io.get_overlaps(comm, params, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu) over_x = c_overlap.get('over_x')[:] over_y = c_overlap.get('over_y')[:] over_data = c_overlap.get('over_data')[:] over_shape = c_overlap.get('over_shape')[:] N_over = int(numpy.sqrt(over_shape[0])) S_over = over_shape[1] c_overlap.close() # To be faster, we rearrange the overlaps into a dictionnary. This has a cost: twice the memory usage for # a short period of time c_overs = {} overlaps = scipy.sparse.csr_matrix( (over_data, (over_x, over_y)), shape=(over_shape[0], over_shape[1])) del over_x, over_y, over_data for i in xrange(N_over): c_overs[i] = overlaps[i * N_over:(i + 1) * N_over] del overlaps comm.Barrier() if comm.rank == 0: io.print_and_log([ "Here comes the SpyKING CIRCUS %s and %d templates..." % (info_string, n_tm) ], 'default', params) io.purge(file_out_suff, '.data') if do_spatial_whitening: spatial_whitening = io.load_data(params, 'spatial_whitening') if do_temporal_whitening: temporal_whitening = io.load_data(params, 'temporal_whitening') if full_gpu: try: # If memory on the GPU is large enough, we load the overlaps onto it for i in xrange(N_over): c_overs[i] = cmt.SparseCUDAMatrix(c_overs[i]) except Exception: if comm.rank == 0: io.print_and_log([ "Not enough memory on GPUs: GPUs are used for projection only" ], 'info', params) for i in xrange(N_over): if c_overs.has_key(i): del c_overs[i] full_gpu = False borders, nb_chunks, chunk_len, last_chunk_len = io.analyze_data( params, chunk_size) nb_chunks = int(min(nb_chunks, max_chunk)) if comm.rank == 0: pbar = get_progressbar(int(nb_chunks // comm.size)) spiketimes_file = open(file_out_suff + '.spiketimes-%d.data' % comm.rank, 'wb') amplitudes_file = open(file_out_suff + '.amplitudes-%d.data' % comm.rank, 'wb') templates_file = open(file_out_suff + '.templates-%d.data' % comm.rank, 'wb') comm.Barrier() if use_gpu and do_spatial_whitening: spatial_whitening = cmt.CUDAMatrix(spatial_whitening, copy_on_host=False) last_chunk_size = 0 for gcount, gidx in enumerate(xrange(comm.rank, nb_chunks, comm.size)): #print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..." ## We need to deal with the borders by taking chunks of size [0, chunck_size+template_shift] if gidx == (nb_chunks - 1): padding = (-2 * borders, 0) elif gidx == 0: padding = (0, 2 * borders) else: padding = (-2 * borders, 2 * borders) result = {'spiketimes': [], 'amplitudes': [], 'templates': []} local_chunk, local_shape = io.load_chunk(params, gidx, chunk_len, chunk_size, padding, nodes=nodes) if do_spatial_whitening: if use_gpu: local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False) local_chunk = local_chunk.dot(spatial_whitening).asarray() else: local_chunk = numpy.dot(local_chunk, spatial_whitening) if do_temporal_whitening: local_chunk = scipy.ndimage.filters.convolve1d(local_chunk, temporal_whitening, axis=0, mode='constant') #print "Extracting the peaks..." local_peaktimes = numpy.zeros(0, dtype=numpy.int32) if matched_filter: if sign_peaks in ['positive', 'both']: filter_chunk = scipy.ndimage.filters.convolve1d( local_chunk, waveform_pos, axis=0, mode='constant') for i in xrange(N_e): peaktimes = algo.detect_peaks(filter_chunk[:, i], matched_tresholds_pos[i]) local_peaktimes = numpy.concatenate( (local_peaktimes, peaktimes)) if sign_peaks in ['negative', 'both']: filter_chunk = scipy.ndimage.filters.convolve1d( local_chunk, waveform_neg, axis=0, mode='constant') for i in xrange(N_e): peaktimes = algo.detect_peaks(filter_chunk[:, i], matched_tresholds_neg[i]) local_peaktimes = numpy.concatenate( (local_peaktimes, peaktimes)) else: for i in xrange(N_e): if sign_peaks == 'negative': peaktimes = algo.detect_peaks(local_chunk[:, i], thresholds[i], valley=True) elif sign_peaks == 'positive': peaktimes = algo.detect_peaks(local_chunk[:, i], thresholds[i], valley=False) elif sign_peaks == 'both': peaktimes = algo.detect_peaks(numpy.abs(local_chunk[:, i]), thresholds[i], valley=False) local_peaktimes = numpy.concatenate( (local_peaktimes, peaktimes)) local_peaktimes = numpy.unique(local_peaktimes) #print "Removing the useless borders..." local_borders = (template_shift, local_shape - template_shift) idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes < local_borders[1]) local_peaktimes = numpy.compress(idx, local_peaktimes) n_t = len(local_peaktimes) len_chunk = local_chunk.shape[0] all_indices = numpy.arange(n_t) if full_gpu: # all_indices = cmt.CUDAMatrix(all_indices) tmp_gpu = cmt.CUDAMatrix(local_peaktimes.reshape((1, n_t)), copy_on_host=False) if n_t > 0: #print "Computing the b (should full_gpu by putting all chunks on GPU if possible?)..." local_chunk = local_chunk.T.ravel() sub_mat = numpy.zeros((N_e * (2 * template_shift + 1), n_t), dtype=numpy.float32) if len_chunk != last_chunk_size: slice_indices = numpy.zeros(0, dtype=numpy.int32) for idx in xrange(N_e): slice_indices = numpy.concatenate( (slice_indices, len_chunk * idx + temp_window)) last_chunk_size = len_chunk for count, idx in enumerate(local_peaktimes): sub_mat[:, count] = numpy.take(local_chunk, slice_indices + idx) del local_chunk if use_gpu: sub_mat = cmt.CUDAMatrix(sub_mat, copy_on_host=False) b = cmt.sparse_dot(templates, sub_mat) else: b = templates.dot(sub_mat) del sub_mat local_offset = gidx * chunk_size + padding[0] // N_total local_bounds = (temp_2_shift, local_shape - temp_2_shift) all_spikes = local_peaktimes + local_offset penalty = numpy.ones((n_tm, n_t), dtype=numpy.float32) # Because for GPU, slicing by columns is more efficient, we need to transpose b #b = b.transpose() if use_gpu and not full_gpu: b = b.asarray() failure = numpy.zeros(n_t, dtype=numpy.int32) if full_gpu: mask = cmt.CUDAMatrix(penalty, copy_on_host=False) data = cmt.empty(mask.shape) cm_zeros = cmt.CUDAMatrix(numpy.zeros(mask.shape), copy_on_host=False) patch_gpu = b.shape[1] == 1 else: mask = penalty sub_b = b[:n_tm, :] min_time = local_peaktimes.min() max_time = local_peaktimes.max() local_len = max_time - min_time + 1 min_times = numpy.maximum( local_peaktimes - min_time - temp_2_shift, 0) max_times = numpy.minimum( local_peaktimes - min_time + temp_2_shift + 1, max_time - min_time) max_n_t = int(space_explo * (max_time - min_time + 1) // (2 * temp_2_shift + 1)) while (numpy.mean(failure) < nb_chances): if full_gpu: sub_b = b.get_row_slice(0, n_tm) sub_b.mult(mask, data) tmp_mat = data.max(0) argmax_bi = numpy.argsort(tmp_mat.asarray()[0, :])[::-1] del tmp_mat, sub_b else: data = sub_b * mask argmax_bi = numpy.argsort(numpy.max(data, 0))[::-1] while (len(argmax_bi) > 0): subset = [] indices = [] all_times = numpy.zeros(local_len, dtype=numpy.bool) for count, idx in enumerate(argmax_bi): myslice = all_times[min_times[idx]:max_times[idx]] if not myslice.any(): subset += [idx] indices += [count] all_times[min_times[idx]:max_times[idx]] = True if len(subset) > max_n_t: break subset = numpy.array(subset, dtype=numpy.int32) argmax_bi = numpy.delete(argmax_bi, indices) if full_gpu: sub_b = b.get_row_slice(0, n_tm) tmp_mat = sub_b.argmax(0) inds_t, inds_temp = subset, tmp_mat.asarray()[ 0, :][subset].astype(numpy.int32) del tmp_mat else: inds_t, inds_temp = subset, numpy.argmax( numpy.take(sub_b, subset, axis=1), 0) if full_gpu: best_amp = sub_b.asarray()[inds_temp, inds_t] / n_scalar best_amp2 = b.asarray()[inds_temp + n_tm, inds_t] / n_scalar sub_mask = numpy.ones((sub_b.shape), dtype=numpy.float32) sub_mask[inds_temp, inds_t] = 0 sub_mask = cmt.CUDAMatrix(sub_mask, copy_on_host=False) mask.mult(sub_mask) del sub_mask else: mask[inds_temp, inds_t] = 0 best_amp = sub_b[inds_temp, inds_t] / n_scalar best_amp2 = b[inds_temp + n_tm, inds_t] / n_scalar best_amp_n = best_amp / numpy.take(norm_templates, inds_temp) best_amp2_n = best_amp2 / numpy.take( norm_templates, inds_temp + n_tm) all_idx = ((best_amp_n >= amp_limits[inds_temp, 0]) & (best_amp_n <= amp_limits[inds_temp, 1])) to_keep = numpy.where(all_idx == True)[0] to_reject = numpy.where(all_idx == False)[0] ts = numpy.take(local_peaktimes, inds_t[to_keep]) good = (ts >= local_bounds[0]) & (ts < local_bounds[1]) # We reduce to only the good times that will be kept #to_keep = to_keep[good] #ts = ts[good] if len(ts) > 0: if full_gpu: tmp = cmt.CUDAMatrix(numpy.ones((len(ts), 1)), copy_on_host=False) tmp3 = cmt.CUDAMatrix(-ts.reshape((len(ts), 1)), copy_on_host=False) tmp = tmp.dot(tmp_gpu) tmp.add_col_vec(tmp3) condition = cmt.empty(tmp.shape) cmt.abs(tmp, condition).less_than(temp_2_shift + 1) condition = condition.asarray().astype(numpy.bool) tmp = tmp.asarray().astype(numpy.int32) else: tmp = numpy.dot( numpy.ones((len(ts), 1), dtype=numpy.int32), local_peaktimes.reshape((1, n_t))) tmp -= ts.reshape((len(ts), 1)) condition = numpy.abs(tmp) <= temp_2_shift for count, keep in enumerate(to_keep): idx_b = numpy.compress(condition[count, :], all_indices) ytmp = tmp[count, condition[count, :]] + temp_2_shift indices = numpy.zeros((S_over, len(ytmp)), dtype=numpy.float32) indices[ytmp, numpy.arange(len(ytmp))] = 1 if full_gpu: indices = cmt.CUDAMatrix(indices, copy_on_host=False) if patch_gpu: b_lines = b.get_col_slice(0, b.shape[0]) else: b_lines = b.get_col_slice( idx_b[0], idx_b[-1] + 1) tmp1 = cmt.sparse_dot(c_overs[inds_temp[keep]], indices, mult=-best_amp[keep]) tmp2 = cmt.sparse_dot(c_overs[inds_temp[keep] + n_tm], indices, mult=-best_amp2[keep]) b_lines.add(tmp1) b_lines.add(tmp2) del tmp1, tmp2 else: tmp1 = c_overs[inds_temp[keep]].multiply( -best_amp[keep]).dot(indices) tmp2 = c_overs[inds_temp[keep] + n_tm].multiply(-best_amp2[keep] ).dot(indices) b[:, idx_b] += tmp1 + tmp2 if good[count]: t_spike = ts[count] + local_offset result['spiketimes'] += [t_spike] result['amplitudes'] += [(best_amp_n[keep], best_amp2_n[keep])] result['templates'] += [inds_temp[keep]] myslice = numpy.take(inds_t, to_reject) failure[myslice] += 1 sub_idx = (numpy.take(failure, myslice) >= nb_chances) if full_gpu: N = numpy.sum(sub_idx) if N > 0: cu_slice = cmt.CUDAMatrix(numpy.compress( sub_idx, myslice).reshape(1, N), copy_on_host=False) mask.set_selected_columns(cu_slice, cm_zeros) del cu_slice else: mask[:, numpy.compress(sub_idx, myslice)] = 0 if full_gpu: del sub_b spikes_to_write = numpy.array(result['spiketimes'], dtype=numpy.uint32) amplitudes_to_write = numpy.array(result['amplitudes'], dtype=numpy.float32) templates_to_write = numpy.array(result['templates'], dtype=numpy.int32) spiketimes_file.write(spikes_to_write.tostring()) amplitudes_file.write(amplitudes_to_write.tostring()) templates_file.write(templates_to_write.tostring()) if full_gpu: del mask, b, cm_zeros, data if comm.rank == 0: pbar.update(gcount) spiketimes_file.flush() os.fsync(spiketimes_file.fileno()) spiketimes_file.close() amplitudes_file.flush() os.fsync(amplitudes_file.fileno()) amplitudes_file.close() templates_file.flush() os.fsync(templates_file.fileno()) templates_file.close() comm.Barrier() if comm.rank == 0: pbar.finish() if comm.rank == 0: io.collect_data(comm.size, params, erase=True)
def main(params, nb_cpu, nb_gpu, use_gpu): ################################################################# # params = detect_memory(params) _ = init_logging(params.logfile) SHARED_MEMORY = get_shared_memory_flag(params) logger = logging.getLogger('circus.fitting') data_file = params.data_file n_e = params.getint('data', 'N_e') n_total = params.nb_channels n_t = params.getint('detection', 'N_t') template_shift = params.getint('detection', 'template_shift') # file_out = params.get('data', 'file_out') file_out_suff = params.get('data', 'file_out_suff') sign_peaks = params.get('detection', 'peaks') matched_filter = params.getboolean('detection', 'matched-filter') # spike_thresh = params.getfloat('detection', 'spike_thresh') ratio_thresh = params.getfloat('fitting', 'ratio_thresh') two_components = params.getboolean('fitting', 'two_components') # spike_width = params.getfloat('detection', 'spike_width') # dist_peaks = params.getint('detection', 'dist_peaks') do_temporal_whitening = params.getboolean('whitening', 'temporal') do_spatial_whitening = params.getboolean('whitening', 'spatial') templates_normalization = params.getboolean('clustering', 'templates_normalization') # TODO test, switch, test! chunk_size = detect_memory(params, fitting=True) gpu_only = params.getboolean('fitting', 'gpu_only') nodes, edges = get_nodes_and_edges(params) tmp_limits = params.get('fitting', 'amp_limits').replace('(', '').replace(')', '').split(',') tmp_limits = [float(v) for v in tmp_limits] amp_auto = params.getboolean('fitting', 'amp_auto') auto_nb_chances = params.getboolean('fitting', 'auto_nb_chances') if auto_nb_chances: nb_chances = io.load_data(params, 'nb_chances') max_nb_chances = params.getint('fitting', 'max_nb_chances') percent_nb_chances = params.getfloat('fitting', 'percent_nb_chances') total_nb_chances = max(1, numpy.nanpercentile(nb_chances, percent_nb_chances)) total_nb_chances = min(total_nb_chances, max_nb_chances) if comm.rank == 0: print_and_log(['nb_chances set automatically to %g' %total_nb_chances], 'debug', logger) else: total_nb_chances = params.getfloat('fitting', 'nb_chances') max_chunk = params.getfloat('fitting', 'max_chunk') # noise_thr = params.getfloat('clustering', 'noise_thr') collect_all = params.getboolean('fitting', 'collect_all') min_second_component = params.getfloat('fitting', 'min_second_component') debug = params.getboolean('fitting', 'debug') ignore_dead_times = params.getboolean('triggers', 'ignore_times') inv_nodes = numpy.zeros(n_total, dtype=numpy.int32) inv_nodes[nodes] = numpy.arange(len(nodes)) data_file.open() ################################################################# if use_gpu: import cudamat as cmt # # Need to properly handle multi GPU per MPI nodes? if nb_gpu > nb_cpu: gpu_id = int(comm.rank // nb_cpu) else: gpu_id = 0 cmt.cuda_set_device(gpu_id) cmt.init() cmt.cuda_sync_threads() if SHARED_MEMORY: templates, _ = io.load_data_memshared(params, 'templates', normalize=templates_normalization, transpose=True) N_tm, x = templates.shape else: templates = io.load_data(params, 'templates') x, N_tm = templates.shape temp_2_shift = 2 * template_shift temp_3_shift = 3 * template_shift full_gpu = use_gpu and gpu_only n_tm = N_tm // 2 n_scalar = n_e * n_t temp_window = numpy.arange(-template_shift, template_shift + 1) size_window = n_e * (2 * template_shift + 1) if not amp_auto: amp_limits = numpy.zeros((n_tm, 2)) amp_limits[:, 0] = tmp_limits[0] amp_limits[:, 1] = tmp_limits[1] else: amp_limits = io.load_data(params, 'limits') norm_templates = io.load_data(params, 'norm-templates') if not templates_normalization: norm_templates_2 = (norm_templates ** 2.0) * n_scalar if not SHARED_MEMORY: # Normalize templates (if necessary). if templates_normalization: for idx in range(templates.shape[1]): myslice = numpy.arange(templates.indptr[idx], templates.indptr[idx+1]) templates.data[myslice] /= norm_templates[idx] # Transpose templates. templates = templates.T waveform_neg = numpy.empty(0) # default assignment (for PyCharm code inspection) matched_thresholds_neg = None # default assignment (for PyCharm code inspection) waveform_pos = numpy.empty(0) # default assignment (for PyCharm code inspection) matched_thresholds_pos = None # default assignment (for PyCharm code inspection) if matched_filter: if sign_peaks in ['negative', 'both']: waveform_neg = io.load_data(params, 'waveform')[::-1] waveform_neg /= (numpy.abs(numpy.sum(waveform_neg)) * len(waveform_neg)) matched_thresholds_neg = io.load_data(params, 'matched-thresholds') if sign_peaks in ['positive', 'both']: waveform_pos = io.load_data(params, 'waveform-pos')[::-1] waveform_pos /= (numpy.abs(numpy.sum(waveform_pos)) * len(waveform_pos)) matched_thresholds_pos = io.load_data(params, 'matched-thresholds-pos') if ignore_dead_times: all_dead_times = get_dead_times(params) else: all_dead_times = None # default assignment (for PyCharm code inspection) thresholds = io.get_accurate_thresholds(params, ratio_thresh) neighbors = {} if collect_all: for i in range(0, n_tm): tmp = templates[i, :].toarray().reshape(n_e, n_t) if templates_normalization: tmp = tmp * norm_templates[i] neighbors[i] = numpy.where(numpy.sum(tmp, axis=1) != 0.0)[0] if use_gpu: templates = cmt.SparseCUDAMatrix(templates, copy_on_host=False) info_string = '' if comm.rank == 0: if use_gpu: info_string = "using %d GPUs" % comm.size else: info_string = "using %d CPUs" % comm.size comm.Barrier() c_overlap = io.get_overlaps(params, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu) over_shape = c_overlap.get('over_shape')[:] n_over = int(numpy.sqrt(over_shape[0])) s_over = over_shape[1] # # If the number of overlaps is different from templates, we need to recompute them. if n_over != N_tm: if comm.rank == 0: print_and_log(['Templates have been modified, recomputing the overlaps...'], 'default', logger) c_overlap = io.get_overlaps(params, erase=True, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu) over_shape = c_overlap.get('over_shape')[:] n_over = int(numpy.sqrt(over_shape[0])) s_over = over_shape[1] if SHARED_MEMORY: c_overs, _ = io.load_data_memshared(params, 'overlaps') else: c_overs = io.load_data(params, 'overlaps') comm.Barrier() if n_tm == 0: if comm.rank == 0: print_and_log(["No templates present. Redo clustering?"], 'default', logger) sys.exit(0) if comm.rank == 0: print_and_log(["Here comes the SpyKING CIRCUS %s and %d templates..." % (info_string, n_tm)], 'default', logger) purge(file_out_suff, '.data') if do_spatial_whitening: spatial_whitening = io.load_data(params, 'spatial_whitening') else: spatial_whitening = None # default assignment (for PyCharm code inspection) if do_temporal_whitening: temporal_whitening = io.load_data(params, 'temporal_whitening') else: temporal_whitening = None # default assignment (for PyCharm code inspection) if full_gpu: try: # If memory on the GPU is large enough, we load the overlaps onto it for i in range(n_over): c_overs[i] = cmt.SparseCUDAMatrix(c_overs[i], copy_on_host=False) except Exception: if comm.rank == 0: print_and_log(["Not enough memory on GPUs: GPUs are used for projection only"], 'info', logger) for i in range(n_over): if i in c_overs: del c_overs[i] full_gpu = False nb_chunks, last_chunk_len = data_file.analyze(chunk_size) processed_chunks = int(min(nb_chunks, max_chunk)) comm.Barrier() spiketimes_file = open(file_out_suff + '.spiketimes-%d.data' % comm.rank, 'wb') comm.Barrier() amplitudes_file = open(file_out_suff + '.amplitudes-%d.data' % comm.rank, 'wb') comm.Barrier() templates_file = open(file_out_suff + '.templates-%d.data' % comm.rank, 'wb') comm.Barrier() if collect_all: garbage_times_file = open(file_out_suff + '.gspiketimes-%d.data' % comm.rank, 'wb') comm.Barrier() garbage_temp_file = open(file_out_suff + '.gtemplates-%d.data' % comm.rank, 'wb') comm.Barrier() else: garbage_times_file = None # default assignment (for PyCharm code inspection) garbage_temp_file = None # default assignment (for PyCharm code inspection) if debug: # Open debug files. chunk_nbs_debug_file = open(file_out_suff + '.chunk_nbs_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() iteration_nbs_debug_file = open(file_out_suff + '.iteration_nbs_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() peak_nbs_debug_file = open(file_out_suff + '.peak_nbs_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() peak_local_time_steps_debug_file = open( file_out_suff + '.peak_local_time_steps_debug_%d.data' % comm.rank, mode='wb' ) comm.Barrier() peak_time_steps_debug_file = open(file_out_suff + '.peak_time_steps_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() peak_scalar_products_debug_file = open( file_out_suff + '.peak_scalar_products_debug_%d.data' % comm.rank, mode='wb' ) comm.Barrier() peak_solved_flags_debug_file = open(file_out_suff + '.peak_solved_flags_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() template_nbs_debug_file = open(file_out_suff + '.template_nbs_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() success_flags_debug_file = open(file_out_suff + '.success_flags_debug_%d.data' % comm.rank, mode='wb') comm.Barrier() else: chunk_nbs_debug_file = None # default assignment (for PyCharm code inspection) iteration_nbs_debug_file = None # default assignment (for PyCharm code inspection) peak_nbs_debug_file = None # default assignment (for PyCharm code inspection) peak_local_time_steps_debug_file = None # default assignment (for PyCharm code inspection) peak_time_steps_debug_file = None # default assignment (for PyCharm code inspection) peak_scalar_products_debug_file = None # default assignment (for PyCharm code inspection) peak_solved_flags_debug_file = None # default assignment (for PyCharm code inspection) template_nbs_debug_file = None # default assignment (for PyCharm code inspection) success_flags_debug_file = None # default assignment (for PyCharm code inspection) if use_gpu and do_spatial_whitening: spatial_whitening = cmt.CUDAMatrix(spatial_whitening, copy_on_host=False) last_chunk_size = 0 slice_indices = numpy.zeros(0, dtype=numpy.int32) to_explore = range(comm.rank, processed_chunks, comm.size) if comm.rank == 0: to_explore = get_tqdm_progressbar(params, to_explore) for gcount, gidx in enumerate(to_explore): # print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..." # # We need to deal with the borders by taking chunks of size [0, chunck_size + template_shift]. is_first = data_file.is_first_chunk(gidx, nb_chunks) is_last = data_file.is_last_chunk(gidx, nb_chunks) if not (is_first and is_last): if is_last: padding = (-temp_3_shift, 0) elif is_first: padding = (0, temp_3_shift) else: padding = (-temp_3_shift, temp_3_shift) else: padding = (0, 0) result = { 'spiketimes': [], 'amplitudes': [], 'templates': [], } result_debug = { 'chunk_nbs': [], 'iteration_nbs': [], 'peak_nbs': [], 'peak_local_time_steps': [], 'peak_time_steps': [], 'peak_scalar_products': [], 'peak_solved_flags': [], 'template_nbs': [], 'success_flags': [], } local_chunk, t_offset = data_file.get_data(gidx, chunk_size, padding, nodes=nodes) len_chunk = len(local_chunk) if do_spatial_whitening: if use_gpu: local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False) local_chunk = local_chunk.dot(spatial_whitening).asarray() else: local_chunk = numpy.dot(local_chunk, spatial_whitening) if do_temporal_whitening: local_chunk = scipy.ndimage.filters.convolve1d(local_chunk, temporal_whitening, axis=0, mode='constant') # Extracting peaks. all_found_spikes = {} if collect_all: for i in range(n_e): all_found_spikes[i] = [] local_peaktimes = [numpy.empty(0, dtype=numpy.uint32)] if matched_filter: if sign_peaks in ['positive', 'both']: filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_pos, axis=0, mode='constant') for i in range(n_e): peaktimes = scipy.signal.find_peaks(filter_chunk[:, i], height=matched_thresholds_pos[i])[0] local_peaktimes.append(peaktimes) if collect_all: all_found_spikes[i] += peaktimes.tolist() if sign_peaks in ['negative', 'both']: filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_neg, axis=0, mode='constant') for i in range(n_e): peaktimes = scipy.signal.find_peaks(filter_chunk[:, i], height=matched_thresholds_neg[i])[0] local_peaktimes.append(peaktimes) if collect_all: all_found_spikes[i] += peaktimes.tolist() local_peaktimes = numpy.concatenate(local_peaktimes) else: for i in range(n_e): if sign_peaks == 'negative': peaktimes = scipy.signal.find_peaks(-local_chunk[:, i], height=thresholds[i])[0] elif sign_peaks == 'positive': peaktimes = scipy.signal.find_peaks(local_chunk[:, i], height=thresholds[i])[0] elif sign_peaks == 'both': peaktimes = scipy.signal.find_peaks(numpy.abs(local_chunk[:, i]), height=thresholds[i])[0] else: raise ValueError("Unexpected value %s" % sign_peaks) local_peaktimes.append(peaktimes) if collect_all: all_found_spikes[i] += peaktimes.tolist() local_peaktimes = numpy.concatenate(local_peaktimes) local_peaktimes = numpy.unique(local_peaktimes) g_offset = t_offset + padding[0] if ignore_dead_times: dead_indices = numpy.searchsorted(all_dead_times, [t_offset, t_offset + chunk_size]) if dead_indices[0] != dead_indices[1]: is_included = numpy.in1d(local_peaktimes + g_offset, all_dead_times[dead_indices[0]:dead_indices[1]]) local_peaktimes = local_peaktimes[~is_included] local_peaktimes = numpy.sort(local_peaktimes) else: dead_indices = None # default assignment (for PyCharm code inspection) # print "Removing the useless borders..." local_borders = (template_shift, len_chunk - template_shift) idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes < local_borders[1]) local_peaktimes = numpy.compress(idx, local_peaktimes) if collect_all: for i in range(n_e): all_found_spikes[i] = numpy.array(all_found_spikes[i], dtype=numpy.uint32) if ignore_dead_times: if dead_indices[0] != dead_indices[1]: is_included = numpy.in1d( all_found_spikes[i] + g_offset, all_dead_times[dead_indices[0]:dead_indices[1]] ) all_found_spikes[i] = all_found_spikes[i][~is_included] all_found_spikes[i] = numpy.sort(all_found_spikes[i]) idx = (all_found_spikes[i] >= local_borders[0]) & (all_found_spikes[i] < local_borders[1]) all_found_spikes[i] = numpy.compress(idx, all_found_spikes[i]) nb_local_peak_times = len(local_peaktimes) if full_gpu: # all_indices = cmt.CUDAMatrix(all_indices) # tmp_gpu = cmt.CUDAMatrix(local_peaktimes.reshape((1, nb_local_peak_times)), copy_on_host=False) _ = cmt.CUDAMatrix(local_peaktimes.reshape((1, nb_local_peak_times)), copy_on_host=False) if nb_local_peak_times > 0: # print "Computing the b (should full_gpu by putting all chunks on GPU if possible?)..." if collect_all: c_local_chunk = local_chunk.copy() else: c_local_chunk = None # default assignment (for PyCharm code inspection) sub_mat = local_chunk[local_peaktimes[:, None] + temp_window] sub_mat = sub_mat.transpose(2, 1, 0).reshape(size_window, nb_local_peak_times) del local_chunk if use_gpu: sub_mat = cmt.CUDAMatrix(sub_mat, copy_on_host=False) b = cmt.sparse_dot(templates, sub_mat) else: b = templates.dot(sub_mat) del sub_mat local_restriction = (t_offset, t_offset + chunk_size) all_spikes = local_peaktimes + g_offset # Because for GPU, slicing by columns is more efficient, we need to transpose b # b = b.transpose() if use_gpu and not full_gpu: b = b.asarray() failure = numpy.zeros(nb_local_peak_times, dtype=numpy.int32) if full_gpu: mask = numpy.zeros((2 * n_tm, nb_local_peak_times), dtype=numpy.float32) mask[:n_tm, :] = 1 # data = cmt.empty(mask.shape) _ = cmt.empty(mask.shape) patch_gpu = b.shape[1] == 1 else: patch_gpu = None if collect_all: c_all_times = numpy.zeros((len_chunk, n_e), dtype=numpy.bool) c_min_times = numpy.maximum(numpy.arange(len_chunk) - template_shift, 0) c_max_times = numpy.minimum(numpy.arange(len_chunk) + template_shift + 1, len_chunk) for i in range(n_e): c_all_times[all_found_spikes[i], i] = True else: c_all_times = None # default assignment (for PyCharm code inspection) c_min_times = None # default assignment (for PyCharm code inspection) c_max_times = None # default assignment (for PyCharm code inspection) iteration_nb = 0 local_max = 0 numerous_argmax = False nb_argmax = n_tm best_indices = numpy.zeros(0, dtype=numpy.int32) data = b[:n_tm, :] flatten_data = data.ravel() while numpy.mean(failure) < total_nb_chances: # Is there a way to update sub_b * mask at the same time? if full_gpu: b_array = b.asarray() else: b_array = None if numerous_argmax: if len(best_indices) == 0: best_indices = largest_indices(flatten_data, nb_argmax) best_template_index, peak_index = numpy.unravel_index(best_indices[0], data.shape) else: best_template_index, peak_index = numpy.unravel_index(data.argmax(), data.shape) peak_scalar_product = data[best_template_index, peak_index] best_template2_index = best_template_index + n_tm if templates_normalization: if full_gpu: best_amp = b_array[best_template_index, peak_index] / n_scalar best_amp2 = b_array[best_template2_index, peak_index] / n_scalar else: best_amp = b[best_template_index, peak_index] / n_scalar if two_components: best_amp2 = b[best_template2_index, peak_index] / n_scalar else: best_amp2 = 0.0 best_amp_n = best_amp / norm_templates[best_template_index] best_amp2_n = best_amp2 / norm_templates[best_template2_index] else: if full_gpu: best_amp = b_array[best_template_index, peak_index] best_amp = best_amp / norm_templates_2[best_template_index] # TODO is `best_amp` value correct? best_amp2 = b_array[best_template2_index, peak_index] best_amp2 = best_amp2 / norm_templates_2[best_template2_index] # TODO is `best_amp2` value correct? else: best_amp = b[best_template_index, peak_index] best_amp = best_amp / norm_templates_2[best_template_index] # TODO is `best_amp` value correct? if two_components: best_amp2 = b[best_template2_index, peak_index] best_amp2 = best_amp2 / norm_templates_2[best_template2_index] # TODO is `best_amp2` value correct? else: best_amp2 = 0.0 best_amp_n = best_amp best_amp2_n = best_amp2 # Verify amplitude constraint. a_min, a_max = amp_limits[best_template_index, :] if (a_min <= best_amp_n) & (best_amp_n <= a_max): # Keep the matching. peak_time_step = local_peaktimes[peak_index] peak_data = (local_peaktimes - peak_time_step).astype(np.int32) is_neighbor = np.where(np.abs(peak_data) <= temp_2_shift)[0] idx_neighbor = peak_data[is_neighbor] + temp_2_shift nb_neighbors = len(is_neighbor) indices = np.zeros((s_over, nb_neighbors), dtype=np.int32) indices[idx_neighbor, np.arange(nb_neighbors)] = 1 if full_gpu: indices = cmt.CUDAMatrix(indices, copy_on_host=False) if patch_gpu: b_lines = b.get_col_slice(0, b.shape[0]) else: b_lines = b.get_col_slice(is_neighbor[0], is_neighbor[-1]+1) tmp1 = cmt.sparse_dot(c_overs[best_template_index], indices, mult=-best_amp) tmp2 = cmt.sparse_dot(c_overs[best_template2_index], indices, mult=-best_amp2) b_lines.add(tmp1.add(tmp2)) del tmp1, tmp2 else: tmp1 = c_overs[best_template_index].multiply(-best_amp) if numpy.abs(best_amp2) > min_second_component: tmp1 += c_overs[best_template2_index].multiply(-best_amp2) b[:, is_neighbor] += tmp1.dot(indices) numerous_argmax = False # Add matching to the result. t_spike = all_spikes[peak_index] if (t_spike >= local_restriction[0]) and (t_spike < local_restriction[1]): result['spiketimes'] += [t_spike] result['amplitudes'] += [(best_amp_n, best_amp2_n)] result['templates'] += [best_template_index] # Mark current matching as tried. b[best_template_index, peak_index] = -numpy.inf # Save debug data. if debug: result_debug['chunk_nbs'] += [gidx] result_debug['iteration_nbs'] += [iteration_nb] result_debug['peak_nbs'] += [peak_index] result_debug['peak_local_time_steps'] += [local_peaktimes[peak_index]] result_debug['peak_time_steps'] += [all_spikes[peak_index]] result_debug['peak_scalar_products'] += [peak_scalar_product] result_debug['peak_solved_flags'] += [b[best_template_index, peak_index]] result_debug['template_nbs'] += [best_template_index] result_debug['success_flags'] += [True] else: # Reject the matching. numerous_argmax = True # Update failure counter of the peak. failure[peak_index] += 1 # If the maximal number of failures is reached then mark peak as solved (i.e. not fitted). if failure[peak_index] >= total_nb_chances: # Mark all the matching associated to the current peak as tried. b[:, peak_index] = -numpy.inf index = numpy.arange(n_tm) * nb_local_peak_times + peak_index else: # Mark current matching as tried. b[best_template_index, peak_index] = -numpy.inf index = best_template_index * nb_local_peak_times + peak_index if numerous_argmax: best_indices = best_indices[~numpy.in1d(best_indices, index)] # Save debug data. if debug: result_debug['chunk_nbs'] += [gidx] result_debug['iteration_nbs'] += [iteration_nb] result_debug['peak_nbs'] += [peak_index] result_debug['peak_local_time_steps'] += [local_peaktimes[peak_index]] result_debug['peak_time_steps'] += [all_spikes[peak_index]] result_debug['peak_scalar_products'] += [peak_scalar_product] result_debug['peak_solved_flags'] += [b[best_template_index, peak_index]] result_debug['template_nbs'] += [best_template_index] result_debug['success_flags'] += [False] iteration_nb += 1 spikes_to_write = numpy.array(result['spiketimes'], dtype=numpy.uint32) amplitudes_to_write = numpy.array(result['amplitudes'], dtype=numpy.float32) templates_to_write = numpy.array(result['templates'], dtype=numpy.uint32) spiketimes_file.write(spikes_to_write.tostring()) amplitudes_file.write(amplitudes_to_write.tostring()) templates_file.write(templates_to_write.tostring()) if collect_all: for temp, spike in zip(templates_to_write, spikes_to_write - g_offset): c_all_times[c_min_times[spike]:c_max_times[spike], neighbors[temp]] = False gspikes = numpy.where(numpy.sum(c_all_times, 1) > 0)[0] c_all_times = numpy.take(c_all_times, gspikes, axis=0) c_local_chunk = numpy.take(c_local_chunk, gspikes, axis=0) * c_all_times if sign_peaks == 'negative': bestlecs = numpy.argmin(c_local_chunk, 1) if matched_filter: threshs = -matched_thresholds_neg[bestlecs] else: threshs = -thresholds[bestlecs] idx = numpy.where(numpy.min(c_local_chunk, 1) < threshs)[0] elif sign_peaks == 'positive': bestlecs = numpy.argmax(c_local_chunk, 1) if matched_filter: threshs = matched_thresholds_pos[bestlecs] else: threshs = thresholds[bestlecs] idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0] elif sign_peaks == 'both': c_local_chunk = numpy.abs(c_local_chunk) bestlecs = numpy.argmax(c_local_chunk, 1) if matched_filter: threshs = numpy.minimum(matched_thresholds_neg[bestlecs], matched_thresholds_pos[bestlecs]) else: threshs = thresholds[bestlecs] idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0] else: raise ValueError("Unexpected value %s" % sign_peaks) gspikes = numpy.take(gspikes, idx) bestlecs = numpy.take(bestlecs, idx) gspikes_to_write = numpy.array(gspikes + g_offset, dtype=numpy.uint32) gtemplates_to_write = numpy.array(bestlecs, dtype=numpy.uint32) garbage_times_file.write(gspikes_to_write.tostring()) garbage_temp_file.write(gtemplates_to_write.tostring()) if debug: # Write debug data to debug files. for field_label, field_dtype, field_file in [ ('chunk_nbs', numpy.uint32, chunk_nbs_debug_file), ('iteration_nbs', numpy.uint32, iteration_nbs_debug_file), ('peak_nbs', numpy.uint32, peak_nbs_debug_file), ('peak_local_time_steps', numpy.uint32, peak_local_time_steps_debug_file), ('peak_time_steps', numpy.uint32, peak_time_steps_debug_file), ('peak_scalar_products', numpy.float32, peak_scalar_products_debug_file), ('peak_solved_flags', numpy.float32, peak_solved_flags_debug_file), ('template_nbs', numpy.uint32, template_nbs_debug_file), ('success_flags', numpy.bool, success_flags_debug_file), ]: field_to_write = numpy.array(result_debug[field_label], dtype=field_dtype) field_file.write(field_to_write.tostring()) if full_gpu: del b, data sys.stderr.flush() spiketimes_file.flush() os.fsync(spiketimes_file.fileno()) spiketimes_file.close() amplitudes_file.flush() os.fsync(amplitudes_file.fileno()) amplitudes_file.close() templates_file.flush() os.fsync(templates_file.fileno()) templates_file.close() if collect_all: garbage_temp_file.flush() os.fsync(garbage_temp_file.fileno()) garbage_temp_file.close() garbage_times_file.flush() os.fsync(garbage_times_file.fileno()) garbage_times_file.close() if debug: # Close debug files. for field_file in [ chunk_nbs_debug_file, iteration_nbs_debug_file, peak_nbs_debug_file, peak_local_time_steps_debug_file, peak_time_steps_debug_file, peak_scalar_products_debug_file, peak_solved_flags_debug_file, template_nbs_debug_file, success_flags_debug_file, ]: field_file.flush() os.fsync(field_file.fileno()) field_file.close() comm.Barrier() if comm.rank == 0: io.collect_data(comm.size, params, erase=True) data_file.close()
def main(params, nb_cpu, nb_gpu, use_gpu): ################################################################# logger = init_logging(params.logfile) logger = logging.getLogger('circus.fitting') data_file = params.data_file data_file.open() N_e = params.getint('data', 'N_e') N_total = params.nb_channels N_t = params.getint('detection', 'N_t') template_shift = params.getint('detection', 'template_shift') file_out = params.get('data', 'file_out') file_out_suff = params.get('data', 'file_out_suff') sign_peaks = params.get('detection', 'peaks') matched_filter = params.getboolean('detection', 'matched-filter') spike_thresh = params.getfloat('detection', 'spike_thresh') do_temporal_whitening = params.getboolean('whitening', 'temporal') do_spatial_whitening = params.getboolean('whitening', 'spatial') chunk_size = params.getint('fitting', 'chunk_size') gpu_only = params.getboolean('fitting', 'gpu_only') nodes, edges = get_nodes_and_edges(params) tmp_limits = params.get('fitting', 'amp_limits').replace('(', '').replace(')', '').split(',') tmp_limits = map(float, tmp_limits) amp_auto = params.getboolean('fitting', 'amp_auto') space_explo = params.getfloat('fitting', 'space_explo') nb_chances = params.getint('fitting', 'nb_chances') max_chunk = params.getfloat('fitting', 'max_chunk') noise_thr = params.getfloat('clustering', 'noise_thr') collect_all = params.getboolean('fitting', 'collect_all') ignore_dead_times = params.getboolean('triggers', 'ignore_times') inv_nodes = numpy.zeros(N_total, dtype=numpy.int32) inv_nodes[nodes] = numpy.argsort(nodes) ################################################################# if use_gpu: import cudamat as cmt ## Need to properly handle multi GPU per MPI nodes? if nb_gpu > nb_cpu: gpu_id = int(comm.rank//nb_cpu) else: gpu_id = 0 cmt.cuda_set_device(gpu_id) cmt.init() cmt.cuda_sync_threads() if SHARED_MEMORY: templates = io.load_data_memshared(params, 'templates', normalize=True, transpose=True) N_tm, x = templates.shape else: templates = io.load_data(params, 'templates') x, N_tm = templates.shape temp_2_shift = 2*template_shift full_gpu = use_gpu and gpu_only n_tm = N_tm//2 n_scalar = N_e*N_t last_spikes = numpy.zeros((n_tm, 1), dtype=numpy.int32) temp_window = numpy.arange(-template_shift, template_shift+1) if not amp_auto: amp_limits = numpy.zeros((n_tm, 2)) amp_limits[:, 0] = tmp_limits[0] amp_limits[:, 1] = tmp_limits[1] else: amp_limits = io.load_data(params, 'limits') norm_templates = io.load_data(params, 'norm-templates') if not SHARED_MEMORY: for idx in xrange(templates.shape[1]): myslice = numpy.arange(templates.indptr[idx], templates.indptr[idx+1]) templates.data[myslice] /= norm_templates[idx] templates = templates.T if matched_filter: if sign_peaks in ['negative', 'both']: waveform_neg = io.load_data(params, 'waveform') waveform_neg /= (numpy.abs(numpy.sum(waveform_neg))* len(waveform_neg)) matched_tresholds_neg = io.load_data(params, 'matched-thresholds') if sign_peaks in ['positive', 'both']: waveform_pos = io.load_data(params, 'waveform-pos') waveform_pos /= (numpy.abs(numpy.sum(waveform_pos))* len(waveform_pos)) matched_tresholds_pos = io.load_data(params, 'matched-thresholds-pos') if ignore_dead_times: dead_times = numpy.loadtxt(params.get('triggers', 'dead_file')) if len(dead_times.shape) == 1: dead_times = dead_times.reshape(1, 2) dead_in_ms = params.getboolean('triggers', 'dead_in_ms') if dead_in_ms: dead_times *= numpy.int64(data_file.sampling_rate*1e-3) dead_times = dead_times.astype(numpy.int64) all_dead_times = [] for i in xrange(len(dead_times)): all_dead_times += range(dead_times[i, 0], dead_times[i, 1]) thresholds = io.load_data(params, 'thresholds') if collect_all: neighbors = {} for i in xrange(n_tm): tmp = templates[i, :].toarray().reshape(N_e, N_t) * norm_templates[i] neighbors[i] = numpy.where(numpy.sum(tmp, 1) != 0)[0] if use_gpu: templates = cmt.SparseCUDAMatrix(templates, copy_on_host=False) info_string = '' if comm.rank == 0: if use_gpu: info_string = "using %d GPUs" %(comm.size) else: info_string = "using %d CPUs" %(comm.size) comm.Barrier() c_overlap = io.get_overlaps(params, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu) over_shape = c_overlap.get('over_shape')[:] N_over = int(numpy.sqrt(over_shape[0])) S_over = over_shape[1] ## If the number of overlaps is different from templates, we need to recompute them if N_over != N_tm: if comm.rank == 0: print_and_log(['Templates have been modified, recomputing the overlaps...'], 'default', logger) c_overlap = io.get_overlaps(params, erase=True, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu) over_shape = c_overlap.get('over_shape')[:] N_over = int(numpy.sqrt(over_shape[0])) S_over = over_shape[1] if SHARED_MEMORY: c_overs = io.load_data_memshared(params, 'overlaps', nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu) else: c_overlap = io.get_overlaps(params, nb_cpu=nb_cpu, nb_gpu=nb_gpu, use_gpu=use_gpu) over_x = c_overlap.get('over_x')[:] over_y = c_overlap.get('over_y')[:] over_data = c_overlap.get('over_data')[:] over_shape = c_overlap.get('over_shape')[:] c_overlap.close() # To be faster, we rearrange the overlaps into a dictionnary. This has a cost: twice the memory usage for # a short period of time c_overs = {} overlaps = scipy.sparse.csr_matrix((over_data, (over_x, over_y)), shape=(over_shape[0], over_shape[1])) del over_x, over_y, over_data for i in xrange(N_over): c_overs[i] = overlaps[i*N_over:(i+1)*N_over] del overlaps comm.Barrier() if comm.rank == 0: print_and_log(["Here comes the SpyKING CIRCUS %s and %d templates..." %(info_string, n_tm)], 'default', logger) purge(file_out_suff, '.data') if do_spatial_whitening: spatial_whitening = io.load_data(params, 'spatial_whitening') if do_temporal_whitening: temporal_whitening = io.load_data(params, 'temporal_whitening') if full_gpu: try: # If memory on the GPU is large enough, we load the overlaps onto it for i in xrange(N_over): c_overs[i] = cmt.SparseCUDAMatrix(c_overs[i], copy_on_host=False) except Exception: if comm.rank == 0: print_and_log(["Not enough memory on GPUs: GPUs are used for projection only"], 'info', logger) for i in xrange(N_over): if c_overs.has_key(i): del c_overs[i] full_gpu = False nb_chunks, last_chunk_len = data_file.analyze(chunk_size) processed_chunks = int(min(nb_chunks, max_chunk)) comm.Barrier() spiketimes_file = open(file_out_suff + '.spiketimes-%d.data' %comm.rank, 'wb') comm.Barrier() amplitudes_file = open(file_out_suff + '.amplitudes-%d.data' %comm.rank, 'wb') comm.Barrier() templates_file = open(file_out_suff + '.templates-%d.data' %comm.rank, 'wb') comm.Barrier() if collect_all: garbage_times_file = open(file_out_suff + '.gspiketimes-%d.data' %comm.rank, 'wb') comm.Barrier() garbage_temp_file = open(file_out_suff + '.gtemplates-%d.data' %comm.rank, 'wb') comm.Barrier() if use_gpu and do_spatial_whitening: spatial_whitening = cmt.CUDAMatrix(spatial_whitening, copy_on_host=False) last_chunk_size = 0 to_explore = xrange(comm.rank, processed_chunks, comm.size) if comm.rank == 0: to_explore = get_tqdm_progressbar(to_explore) for gcount, gidx in enumerate(to_explore): #print "Node", comm.rank, "is analyzing chunk", gidx, "/", nb_chunks, " ..." ## We need to deal with the borders by taking chunks of size [0, chunck_size+template_shift] is_first = data_file.is_first_chunk(gidx, nb_chunks) is_last = data_file.is_last_chunk(gidx, nb_chunks) if is_last: padding = (-2*template_shift, 0) elif is_first: padding = (0, 2*template_shift) else: padding = (-2*template_shift, 2*template_shift) result = {'spiketimes' : [], 'amplitudes' : [], 'templates' : []} local_chunk, t_offset = data_file.get_data(gidx, chunk_size, padding, nodes=nodes) len_chunk = len(local_chunk) if do_spatial_whitening: if use_gpu: local_chunk = cmt.CUDAMatrix(local_chunk, copy_on_host=False) local_chunk = local_chunk.dot(spatial_whitening).asarray() else: local_chunk = numpy.dot(local_chunk, spatial_whitening) if do_temporal_whitening: local_chunk = scipy.ndimage.filters.convolve1d(local_chunk, temporal_whitening, axis=0, mode='constant') #print "Extracting the peaks..." if collect_all: all_found_spikes = {} for i in xrange(N_e): all_found_spikes[i] = [] local_peaktimes = numpy.zeros(0, dtype=numpy.int32) if matched_filter: if sign_peaks in ['positive', 'both']: filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_pos, axis=0, mode='constant') for i in xrange(N_e): peaktimes = algo.detect_peaks(filter_chunk[:, i], matched_tresholds_pos[i]) local_peaktimes = numpy.concatenate((local_peaktimes, peaktimes)) if collect_all: all_found_spikes[i] += peaktimes.tolist() if sign_peaks in ['negative', 'both']: filter_chunk = scipy.ndimage.filters.convolve1d(local_chunk, waveform_neg, axis=0, mode='constant') for i in xrange(N_e): peaktimes = algo.detect_peaks(filter_chunk[:, i], matched_tresholds_neg[i]) local_peaktimes = numpy.concatenate((local_peaktimes, peaktimes)) if collect_all: all_found_spikes[i] += peaktimes.tolist() else: for i in xrange(N_e): if sign_peaks == 'negative': peaktimes = algo.detect_peaks(local_chunk[:, i], thresholds[i], valley=True) elif sign_peaks == 'positive': peaktimes = algo.detect_peaks(local_chunk[:, i], thresholds[i], valley=False) elif sign_peaks == 'both': peaktimes = algo.detect_peaks(numpy.abs(local_chunk[:, i]), thresholds[i], valley=False) local_peaktimes = numpy.concatenate((local_peaktimes, peaktimes)) if collect_all: all_found_spikes[i] += peaktimes.tolist() local_peaktimes = numpy.unique(local_peaktimes) if ignore_dead_times: local_peaktimes = numpy.array(list(set(local_peaktimes + t_offset).difference(all_dead_times)), dtype=numpy.int32) - t_offset local_peaktimes = numpy.sort(local_peaktimes) #print "Removing the useless borders..." local_borders = (template_shift, len_chunk - template_shift) idx = (local_peaktimes >= local_borders[0]) & (local_peaktimes < local_borders[1]) local_peaktimes = numpy.compress(idx, local_peaktimes) if collect_all: for i in xrange(N_e): all_found_spikes[i] = numpy.array(all_found_spikes[i], dtype=numpy.int32) if ignore_dead_times: all_found_spikes[i] = numpy.array(list(set(all_found_spikes[i] + t_offset).difference(all_dead_times)), dtype=numpy.int32) - t_offset all_found_spikes[i] = numpy.sort(all_found_spikes[i]) idx = (all_found_spikes[i] >= local_borders[0]) & (all_found_spikes[i] < local_borders[1]) all_found_spikes[i] = numpy.compress(idx, all_found_spikes[i]) n_t = len(local_peaktimes) all_indices = numpy.arange(n_t) if full_gpu: # all_indices = cmt.CUDAMatrix(all_indices) tmp_gpu = cmt.CUDAMatrix(local_peaktimes.reshape((1, n_t)), copy_on_host=False) if n_t > 0: #print "Computing the b (should full_gpu by putting all chunks on GPU if possible?)..." if collect_all: c_local_chunk = local_chunk.copy() local_chunk = local_chunk.T.ravel() sub_mat = numpy.zeros((N_e*(2*template_shift+1), n_t), dtype=numpy.float32) if len_chunk != last_chunk_size: slice_indices = numpy.zeros(0, dtype=numpy.int32) for idx in xrange(N_e): slice_indices = numpy.concatenate((slice_indices, len_chunk*idx + temp_window)) last_chunk_size = len_chunk for count, idx in enumerate(local_peaktimes): sub_mat[:, count] = numpy.take(local_chunk, slice_indices + idx) del local_chunk if use_gpu: sub_mat = cmt.CUDAMatrix(sub_mat, copy_on_host=False) b = cmt.sparse_dot(templates, sub_mat) else: b = templates.dot(sub_mat) del sub_mat local_offset = padding[0] + t_offset local_bounds = (temp_2_shift, len_chunk - temp_2_shift) all_spikes = local_peaktimes + local_offset # Because for GPU, slicing by columns is more efficient, we need to transpose b #b = b.transpose() if use_gpu and not full_gpu: b = b.asarray() failure = numpy.zeros(n_t, dtype=numpy.int32) if full_gpu: mask = numpy.zeros((2*n_tm, n_t), dtype=numpy.float32) mask[:n_tm, :] = 1 data = cmt.empty(mask.shape) patch_gpu= b.shape[1] == 1 else: mask = numpy.ones((n_tm, n_t), dtype=numpy.float32) sub_b = b[:n_tm, :] min_time = local_peaktimes.min() max_time = local_peaktimes.max() local_len = max_time - min_time + 1 min_times = numpy.maximum(local_peaktimes - min_time - temp_2_shift, 0) max_times = numpy.minimum(local_peaktimes - min_time + temp_2_shift + 1, max_time - min_time) max_n_t = int(space_explo*(max_time-min_time+1)//(2*temp_2_shift + 1)) if collect_all: c_all_times = numpy.zeros((len_chunk, N_e), dtype=numpy.bool) c_min_times = numpy.maximum(numpy.arange(len_chunk) - template_shift, 0) c_max_times = numpy.minimum(numpy.arange(len_chunk) + template_shift + 1, len_chunk) for i in xrange(N_e): c_all_times[all_found_spikes[i], i] = True while (numpy.mean(failure) < nb_chances): if full_gpu: gpu_mask = cmt.CUDAMatrix(mask, copy_on_host=False) b.mult(gpu_mask, data) tmp_mat = data.max(0) argmax_bi = numpy.argsort(tmp_mat.asarray()[0, :])[::-1] del tmp_mat else: data = sub_b * mask argmax_bi = numpy.argsort(numpy.max(data, 0))[::-1] while (len(argmax_bi) > 0): subset = [] indices = [] all_times = numpy.zeros(local_len, dtype=numpy.bool) for count, idx in enumerate(argmax_bi): myslice = all_times[min_times[idx]:max_times[idx]] if not myslice.any(): subset += [idx] indices += [count] all_times[min_times[idx]:max_times[idx]] = True if len(subset) > max_n_t: break subset = numpy.array(subset, dtype=numpy.int32) argmax_bi = numpy.delete(argmax_bi, indices) if full_gpu: b_array = b.asarray() sub_b = b_array[:n_tm, :] inds_t, inds_temp = subset, numpy.argmax(numpy.take(sub_b, subset, axis=1), 0) if full_gpu: best_amp = sub_b[inds_temp, inds_t]/n_scalar best_amp2 = b_array[inds_temp + n_tm, inds_t]/n_scalar else: best_amp = sub_b[inds_temp, inds_t]/n_scalar best_amp2 = b[inds_temp + n_tm, inds_t]/n_scalar mask[inds_temp, inds_t] = 0 best_amp_n = best_amp/numpy.take(norm_templates, inds_temp) best_amp2_n = best_amp2/numpy.take(norm_templates, inds_temp + n_tm) all_idx = ((best_amp_n >= amp_limits[inds_temp, 0]) & (best_amp_n <= amp_limits[inds_temp, 1])) to_keep = numpy.where(all_idx == True)[0] to_reject = numpy.where(all_idx == False)[0] ts = numpy.take(local_peaktimes, inds_t[to_keep]) good = (ts >= local_bounds[0]) & (ts < local_bounds[1]) # We reduce to only the good times that will be kept #to_keep = to_keep[good] #ts = ts[good] if len(ts) > 0: if full_gpu: tmp = cmt.CUDAMatrix(numpy.ones((len(ts), 1)), copy_on_host=False) tmp3 = cmt.CUDAMatrix(-ts.reshape((len(ts), 1)), copy_on_host=False) tmp = tmp.dot(tmp_gpu) tmp.add_col_vec(tmp3) condition = cmt.empty(tmp.shape) cmt.abs(tmp, condition).less_than(temp_2_shift + 1) condition = condition.asarray().astype(numpy.bool) tmp = tmp.asarray().astype(numpy.int32) else: tmp = numpy.dot(numpy.ones((len(ts), 1), dtype=numpy.int32), local_peaktimes.reshape((1, n_t))) tmp -= ts.reshape((len(ts), 1)) condition = numpy.abs(tmp) <= temp_2_shift for count, keep in enumerate(to_keep): idx_b = numpy.compress(condition[count, :], all_indices) ytmp = tmp[count, condition[count, :]] + temp_2_shift indices = numpy.zeros((S_over, len(ytmp)), dtype=numpy.float32) indices[ytmp, numpy.arange(len(ytmp))] = 1 if full_gpu: indices = cmt.CUDAMatrix(indices, copy_on_host=False) if patch_gpu: b_lines = b.get_col_slice(0, b.shape[0]) else: b_lines = b.get_col_slice(idx_b[0], idx_b[-1]+1) tmp1 = cmt.sparse_dot(c_overs[inds_temp[keep]], indices, mult=-best_amp[keep]) tmp2 = cmt.sparse_dot(c_overs[inds_temp[keep] + n_tm], indices, mult=-best_amp2[keep]) b_lines.add(tmp1.add(tmp2)) del tmp1, tmp2 else: tmp1 = c_overs[inds_temp[keep]].multiply(-best_amp[keep]).dot(indices) tmp2 = c_overs[inds_temp[keep] + n_tm].multiply(-best_amp2[keep]).dot(indices) b[:, idx_b] += tmp1 + tmp2 if good[count]: t_spike = ts[count] + local_offset result['spiketimes'] += [t_spike] result['amplitudes'] += [(best_amp_n[keep], best_amp2_n[keep])] result['templates'] += [inds_temp[keep]] myslice = numpy.take(inds_t, to_reject) failure[myslice] += 1 sub_idx = (numpy.take(failure, myslice) >= nb_chances) mask[:, numpy.compress(sub_idx, myslice)] = 0 spikes_to_write = numpy.array(result['spiketimes'], dtype=numpy.uint32) amplitudes_to_write = numpy.array(result['amplitudes'], dtype=numpy.float32) templates_to_write = numpy.array(result['templates'], dtype=numpy.int32) spiketimes_file.write(spikes_to_write.tostring()) amplitudes_file.write(amplitudes_to_write.tostring()) templates_file.write(templates_to_write.tostring()) if collect_all: for temp, spike in zip(templates_to_write, spikes_to_write - local_offset): c_all_times[c_min_times[spike]:c_max_times[spike], neighbors[temp]] = False gspikes = numpy.where(numpy.sum(c_all_times, 1) > 0)[0] c_all_times = numpy.take(c_all_times, gspikes, axis=0) c_local_chunk = numpy.take(c_local_chunk, gspikes, axis=0) * c_all_times if sign_peaks == 'negative': bestlecs = numpy.argmin(c_local_chunk, 1) if matched_filter: threshs = -matched_tresholds_neg[bestlecs] else: threshs = -thresholds[bestlecs] idx = numpy.where(numpy.min(c_local_chunk, 1) < threshs)[0] elif sign_peaks == 'positive': bestlecs = numpy.argmax(c_local_chunk, 1) if matched_filter: threshs = matched_tresholds_pos[bestlecs] else: threshs = thresholds[bestlecs] idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0] elif sign_peaks == 'both': c_local_chunk = numpy.abs(c_local_chunk) bestlecs = numpy.argmax(c_local_chunk, 1) if matched_filter: threshs = numpy.minimum(matched_tresholds_neg[bestlecs], matched_tresholds_pos[bestlecs]) else: threshs = thresholds[bestlecs] idx = numpy.where(numpy.max(c_local_chunk, 1) > threshs)[0] gspikes = numpy.take(gspikes, idx) bestlecs = numpy.take(bestlecs, idx) gspikes_to_write = numpy.array(gspikes + local_offset, dtype=numpy.uint32) gtemplates_to_write = numpy.array(bestlecs, dtype=numpy.int32) garbage_times_file.write(gspikes_to_write.tostring()) garbage_temp_file.write(gtemplates_to_write.tostring()) if full_gpu: del gpu_mask, b, data spiketimes_file.flush() os.fsync(spiketimes_file.fileno()) spiketimes_file.close() amplitudes_file.flush() os.fsync(amplitudes_file.fileno()) amplitudes_file.close() templates_file.flush() os.fsync(templates_file.fileno()) templates_file.close() if collect_all: garbage_temp_file.flush() os.fsync(garbage_temp_file.fileno()) garbage_temp_file.close() garbage_times_file.flush() os.fsync(garbage_times_file.fileno()) garbage_times_file.close() comm.Barrier() if comm.rank == 0: io.collect_data(comm.size, params, erase=True) data_file.close()