def main(argv=None): if argv is None: argv = sys.argv[1:] header = get_colored_header() parser = argparse.ArgumentParser( description=header, formatter_class=argparse.RawTextHelpFormatter) parser.add_argument('datafile', help='data file') parser.add_argument('-e', '--extension', help='extension to consider for visualization', default='') if len(argv) == 0: parser.print_help() sys.exit() args = parser.parse_args(argv) filename = os.path.abspath(args.datafile) extension = args.extension params = CircusParser(filename) if os.path.exists(params.logfile): os.remove(params.logfile) logger = init_logging(params.logfile) logger = logging.getLogger(__name__) data_file = params.get_data_file() data_dtype = data_file.data_dtype gain = data_file.gain t_start = data_file.t_start file_format = data_file.description if file_format not in supported_by_matlab: print_and_log([ "File format %s is not supported by MATLAB. Waveforms disabled" % file_format ], 'info', logger) if numpy.iterable(gain): print_and_log( ['Multiple gains are not supported, using a default value of 1'], 'info', logger) gain = 1 file_out_suff = params.get('data', 'file_out_suff') if hasattr(data_file, 'data_offset'): data_offset = data_file.data_offset else: data_offset = 0 probe = params.probe if extension != '': extension = '-' + extension def generate_matlab_mapping(probe): p = {} positions = [] nodes = [] for key in probe['channel_groups'].keys(): p.update(probe['channel_groups'][key]['geometry']) nodes += probe['channel_groups'][key]['channels'] positions += [ p[channel] for channel in probe['channel_groups'][key]['channels'] ] idx = numpy.argsort(nodes) positions = numpy.array(positions)[idx] t = tempfile.NamedTemporaryFile().name + '.hdf5' cfile = h5py.File(t, 'w') to_write = { 'positions': positions / 10., 'permutation': numpy.sort(nodes), 'nb_total': numpy.array([probe['total_nb_channels']]) } write_datasets(cfile, to_write.keys(), to_write) cfile.close() return t mapping = generate_matlab_mapping(probe) if not params.getboolean('data', 'overwrite'): filename = params.get('data', 'data_file_no_overwrite') else: filename = params.get('data', 'data_file') apply_patch_for_similarities(params, extension) gui_file = pkg_resources.resource_filename( 'circus', os.path.join('matlab_GUI', 'SortingGUI.m')) # Change to the directory of the matlab file os.chdir(os.path.abspath(os.path.dirname(gui_file))) # Use quotation marks for string arguments if file_format not in supported_by_matlab: gui_params = [ params.rate, os.path.abspath(file_out_suff), '%s.mat' % extension, mapping, 2, t_start ] is_string = [False, True, True, True, False] else: gui_params = [ params.rate, os.path.abspath(file_out_suff), '%s.mat' % extension, mapping, 2, t_start, data_dtype, data_offset, gain, filename ] is_string = [ False, True, True, True, False, False, True, False, False, True ] arguments = ', '.join([ "'%s'" % arg if s else "%s" % arg for arg, s in zip(gui_params, is_string) ]) matlab_command = 'SortingGUI(%s)' % arguments print_and_log(["Launching the MATLAB GUI..."], 'info', logger) print_and_log([matlab_command], 'debug', logger) if params.getboolean('fitting', 'collect_all'): print_and_log([ 'You can not view the unfitted spikes with the MATLAB GUI', 'Please consider using phy if you really would like to see them' ], 'info', logger) try: sys.exit( subprocess.call( ['matlab', '-nodesktop', '-nosplash', '-r', matlab_command])) except Exception: print_and_log( ["Something wrong with MATLAB. Try circus-gui-python instead?"], 'error', logger) sys.exit(1)
def main(params, nb_cpu, nb_gpu, use_gpu, extension): logger = init_logging(params.logfile) logger = logging.getLogger('circus.converting') data_file = params.data_file file_out_suff = params.get('data', 'file_out_suff') probe = params.probe output_path = params.get('data', 'file_out_suff') + extension + '.GUI' N_e = params.getint('data', 'N_e') N_t = params.getint('detection', 'N_t') erase_all = params.getboolean('converting', 'erase_all') export_pcs = params.get('converting', 'export_pcs') export_all = params.getboolean('converting', 'export_all') sparse_export = params.getboolean('converting', 'sparse_export') if export_all and not params.getboolean('fitting', 'collect_all'): if comm.rank == 0: print_and_log(['Export unfitted spikes only if [fitting] collect_all is True'], 'error', logger) sys.exit(0) def generate_mapping(probe): p = {} positions = [] nodes = [] for key in probe['channel_groups'].keys(): p.update(probe['channel_groups'][key]['geometry']) nodes += probe['channel_groups'][key]['channels'] positions += [p[channel] for channel in probe['channel_groups'][key]['channels']] idx = numpy.argsort(nodes) positions = numpy.array(positions)[idx] return positions def get_max_loc_channel(params): nodes, edges = get_nodes_and_edges(params) max_loc_channel = 0 for key in edges.keys(): if len(edges[key]) > max_loc_channel: max_loc_channel = len(edges[key]) return max_loc_channel def write_results(path, params, extension): result = io.get_results(params, extension) spikes = numpy.zeros(0, dtype=numpy.uint64) clusters = numpy.zeros(0, dtype=numpy.uint32) amplitudes = numpy.zeros(0, dtype=numpy.double) N_tm = len(result['spiketimes']) for key in result['spiketimes'].keys(): temp_id = int(key.split('_')[-1]) data = result['spiketimes'].pop(key).astype(numpy.uint64) spikes = numpy.concatenate((spikes, data)) data = result['amplitudes'].pop(key).astype(numpy.double) amplitudes = numpy.concatenate((amplitudes, data[:, 0])) clusters = numpy.concatenate((clusters, temp_id*numpy.ones(len(data), dtype=numpy.uint32))) if export_all: print_and_log(["Last %d templates are unfitted spikes on all electrodes" %N_e], 'info', logger) garbage = io.load_data(params, 'garbage', extension) for key in garbage['gspikes'].keys(): elec_id = int(key.split('_')[-1]) data = garbage['gspikes'].pop(key).astype(numpy.uint64) spikes = numpy.concatenate((spikes, data)) amplitudes = numpy.concatenate((amplitudes, numpy.zeros(len(data)))) clusters = numpy.concatenate((clusters, (elec_id + N_tm)*numpy.ones(len(data), dtype=numpy.uint32))) idx = numpy.argsort(spikes) numpy.save(os.path.join(output_path, 'spike_templates'), clusters[idx]) numpy.save(os.path.join(output_path, 'spike_times'), spikes[idx]) numpy.save(os.path.join(output_path, 'amplitudes'), amplitudes[idx]) return def write_templates(path, params, extension): max_loc_channel = get_max_loc_channel(params) templates = io.load_data(params, 'templates', extension) N_tm = templates.shape[1]//2 if sparse_export: n_channels_max = 0 for t in xrange(N_tm): data = numpy.sum(numpy.sum(templates[:, t].toarray().reshape(N_e, N_t), 1) != 0) if data > n_channels_max: n_channels_max = data else: n_channels_max = N_e if export_all: to_write_sparse = numpy.zeros((N_tm + N_e, N_t, n_channels_max), dtype=numpy.float32) mapping_sparse = -1 * numpy.ones((N_tm + N_e, n_channels_max), dtype=numpy.int32) else: to_write_sparse = numpy.zeros((N_tm, N_t, n_channels_max), dtype=numpy.float32) mapping_sparse = -1 * numpy.ones((N_tm, n_channels_max), dtype=numpy.int32) for t in xrange(N_tm): tmp = templates[:, t].toarray().reshape(N_e, N_t).T x, y = tmp.nonzero() nb_loc = len(numpy.unique(y)) if sparse_export: all_positions = numpy.zeros(y.max()+1, dtype=numpy.int32) all_positions[numpy.unique(y)] = numpy.arange(nb_loc, dtype=numpy.int32) pos = all_positions[y] to_write_sparse[t, x, pos] = tmp[x, y] mapping_sparse[t, numpy.arange(nb_loc)] = numpy.unique(y) else: pos = y to_write_sparse[t, x, pos] = tmp[x, y] if export_all: for t in xrange(N_tm, N_tm + N_e): mapping_sparse[t, 0] = t - N_tm numpy.save(os.path.join(output_path, 'templates'), to_write_sparse) if sparse_export: numpy.save(os.path.join(output_path, 'template_ind'), mapping_sparse) return N_tm def write_pcs(path, params, extension, N_tm, mode=0): spikes = numpy.load(os.path.join(output_path, 'spike_times.npy')) labels = numpy.load(os.path.join(output_path, 'spike_templates.npy')) max_loc_channel = get_max_loc_channel(params) nb_features = params.getint('whitening', 'output_dim') sign_peaks = params.get('detection', 'peaks') nodes, edges = get_nodes_and_edges(params) N_total = params.getint('data', 'N_total') if export_all: nb_templates = N_tm + N_e else: nb_templates = N_tm pc_features_ind = numpy.zeros((nb_templates, max_loc_channel), dtype=numpy.int32) clusters = io.load_data(params, 'clusters', extension) best_elec = clusters['electrodes'] if export_all: best_elec = numpy.concatenate((best_elec, numpy.arange(N_e))) inv_nodes = numpy.zeros(N_total, dtype=numpy.int32) inv_nodes[nodes] = numpy.argsort(nodes) for count, elec in enumerate(best_elec): nb_loc = len(edges[nodes[elec]]) pc_features_ind[count, numpy.arange(nb_loc)] = inv_nodes[edges[nodes[elec]]] if sign_peaks in ['negative', 'both']: basis_proj, basis_rec = io.load_data(params, 'basis') elif sign_peaks in ['positive']: basis_proj, basis_rec = io.load_data(params, 'basis-pos') to_process = numpy.arange(comm.rank, nb_templates, comm.size) all_offsets = numpy.zeros(nb_templates, dtype=numpy.int32) for target in xrange(nb_templates): if mode == 0: all_offsets[target] = len(numpy.where(labels == target)[0]) elif mode == 1: all_offsets[target] = min(500, len(numpy.where(labels == target)[0])) all_paddings = numpy.concatenate(([0] , numpy.cumsum(all_offsets))) total_pcs = numpy.sum(all_offsets) pc_file = os.path.join(output_path, 'pc_features.npy') pc_file_ids = os.path.join(output_path, 'pc_feature_spike_ids.npy') from numpy.lib.format import open_memmap if comm.rank == 0: pc_features = open_memmap(pc_file, shape=(total_pcs, nb_features, max_loc_channel), dtype=numpy.float32, mode='w+') if mode == 1: pc_ids = open_memmap(pc_file_ids, shape=(total_pcs, ), dtype=numpy.int32, mode='w+') comm.Barrier() pc_features = open_memmap(pc_file, mode='r+') if mode == 1: pc_ids = open_memmap(pc_file_ids, mode='r+') to_explore = xrange(comm.rank, nb_templates, comm.size) if comm.rank == 0: to_explore = get_tqdm_progressbar(to_explore) all_idx = numpy.zeros(0, dtype=numpy.int32) for gcount, target in enumerate(to_explore): count = all_paddings[target] if mode == 1: idx = numpy.random.permutation(numpy.where(labels == target)[0])[:500] pc_ids[count:count+len(idx)] = idx elif mode == 0: idx = numpy.where(labels == target)[0] elec = best_elec[target] indices = inv_nodes[edges[nodes[elec]]] labels_i = target*numpy.ones(len(idx)) times_i = numpy.take(spikes, idx).astype(numpy.int64) sub_data = io.get_stas(params, times_i, labels_i, elec, neighs=indices, nodes=nodes, auto_align=False) pcs = numpy.dot(sub_data, basis_proj) pcs = numpy.swapaxes(pcs, 1,2) if mode == 0: pc_features[idx, :, :len(indices)] = pcs elif mode == 1: pc_features[count:count+len(idx), :, :len(indices)] = pcs comm.Barrier() if comm.rank == 0: numpy.save(os.path.join(output_path, 'pc_feature_ind'), pc_features_ind.astype(numpy.uint32)) #n_templates, n_loc_chan do_export = True if comm.rank == 0: if os.path.exists(output_path): if not erase_all: do_export = query_yes_no(Fore.WHITE + "Export already made! Do you want to erase everything?", default=None) if do_export: if os.path.exists(os.path.abspath('.phy')): shutil.rmtree(os.path.abspath('.phy')) shutil.rmtree(output_path) if do_export == True: comm.bcast(numpy.array([1], dtype=numpy.int32), root=0) elif do_export == False: comm.bcast(numpy.array([0], dtype=numpy.int32), root=0) else: do_export = bool(comm.bcast(numpy.array([0], dtype=numpy.int32), root=0)) comm.Barrier() if do_export: if comm.rank == 0: os.makedirs(output_path) print_and_log(["Exporting data for the phy GUI with %d CPUs..." %nb_cpu], 'info', logger) if params.getboolean('whitening', 'spatial'): whitening_mat = io.load_data(params, 'spatial_whitening').astype(numpy.double) numpy.save(os.path.join(output_path, 'whitening_mat'), whitening_mat) numpy.save(os.path.join(output_path, 'whitening_mat_inv'), numpy.linalg.inv(whitening_mat)) else: numpy.save(os.path.join(output_path, 'whitening_mat'), numpy.eye(N_e)) numpy.save(os.path.join(output_path, 'channel_positions'), generate_mapping(probe).astype(numpy.double)) nodes, edges = get_nodes_and_edges(params) numpy.save(os.path.join(output_path, 'channel_map'), nodes.astype(numpy.int32)) write_results(output_path, params, extension) N_tm = write_templates(output_path, params, extension) apply_patch_for_similarities(params, extension) template_file = h5py.File(file_out_suff + '.templates%s.hdf5' %extension, 'r', libver='earliest') similarities = template_file.get('maxoverlap')[:] template_file.close() norm = N_e*N_t if export_all: to_write = numpy.zeros((N_tm + N_e, N_tm + N_e), dtype=numpy.single) to_write[:N_tm, :N_tm] = (similarities[:N_tm, :N_tm]/norm).astype(numpy.single) else: to_write = (similarities[:N_tm, :N_tm]/norm).astype(numpy.single) numpy.save(os.path.join(output_path, 'similar_templates'), to_write) comm.bcast(numpy.array([N_tm], dtype=numpy.int32), root=0) else: N_tm = int(comm.bcast(numpy.array([0], dtype=numpy.int32), root=0)) comm.Barrier() make_pcs = 2 if comm.rank == 0: if export_pcs == 'prompt': key = '' while key not in ['a', 's', 'n']: print(Fore.WHITE + "Do you want SpyKING CIRCUS to export PCs? (a)ll / (s)ome / (n)o") key = raw_input('') else: key = export_pcs if key == 'a': make_pcs = 0 comm.bcast(numpy.array([0], dtype=numpy.int32), root=0) elif key == 's': make_pcs = 1 comm.bcast(numpy.array([1], dtype=numpy.int32), root=0) elif key == 'n': comm.bcast(numpy.array([2], dtype=numpy.int32), root=0) if os.path.exists(os.path.join(output_path, 'pc_features.npy')): os.remove(os.path.join(output_path, 'pc_features.npy')) if os.path.exists(os.path.join(output_path, 'pc_feature_ind.npy')): os.remove(os.path.join(output_path, 'pc_feature_ind.npy')) else: make_pcs = comm.bcast(numpy.array([0], dtype=numpy.int32), root=0) make_pcs = make_pcs[0] comm.Barrier() if make_pcs < 2: write_pcs(output_path, params, extension, N_tm, make_pcs)
def main(params, nb_cpu, nb_gpu, use_gpu, extension): _ = init_logging(params.logfile) logger = logging.getLogger('circus.converting') data_file = params.data_file file_out_suff = params.get('data', 'file_out_suff') probe = params.probe output_path = params.get('data', 'file_out_suff') + extension + '.GUI' N_e = params.getint('data', 'N_e') prelabelling = params.getboolean('converting', 'prelabelling') N_t = params.getint('detection', 'N_t') erase_all = params.getboolean('converting', 'erase_all') export_pcs = params.get('converting', 'export_pcs') export_all = params.getboolean('converting', 'export_all') sparse_export = params.getboolean('converting', 'sparse_export') rpv_threshold = params.getfloat('converting', 'rpv_threshold') if export_all and not params.getboolean('fitting', 'collect_all'): if comm.rank == 0: print_and_log([ 'Export unfitted spikes only if [fitting] collect_all is True' ], 'error', logger) sys.exit(0) def generate_mapping(probe): p = {} positions = [] nodes = [] shanks = [] for key in probe['channel_groups'].keys(): p.update(probe['channel_groups'][key]['geometry']) nodes += probe['channel_groups'][key]['channels'] positions += [ p[channel] for channel in probe['channel_groups'][key]['channels'] ] shanks += [key] * len(probe['channel_groups'][key]['channels']) positions = numpy.array(positions) shanks = numpy.array(shanks) return positions, shanks def get_max_loc_channel(params, extension): if test_if_support(params, extension): supports = io.load_data(params, 'supports', extension) max_loc_channel = numpy.sum(supports, 1).max() else: nodes, edges = get_nodes_and_edges(params) max_loc_channel = 0 for key in edges.keys(): if len(edges[key]) > max_loc_channel: max_loc_channel = len(edges[key]) return max_loc_channel def write_results(path, params, extension): result = io.get_results(params, extension) spikes = [numpy.zeros(0, dtype=numpy.uint64)] clusters = [numpy.zeros(0, dtype=numpy.uint32)] amplitudes = [numpy.zeros(0, dtype=numpy.double)] N_tm = len(result['spiketimes']) has_purity = test_if_purity(params, extension) rpvs = [] if prelabelling: labels = [] norms = io.load_data(params, 'norm-templates', extension) norms = norms[:len(norms) // 2] if has_purity: purity = io.load_data(params, 'purity', extension) for key in result['spiketimes'].keys(): temp_id = int(key.split('_')[-1]) myspikes = result['spiketimes'].pop(key).astype(numpy.uint64) spikes.append(myspikes) myamplitudes = result['amplitudes'].pop(key).astype(numpy.double) amplitudes.append(myamplitudes[:, 0]) clusters.append(temp_id * numpy.ones(len(myamplitudes), dtype=numpy.uint32)) rpv = get_rpv(myspikes, params.data_file.sampling_rate) rpvs += [[temp_id, rpv]] if prelabelling: if has_purity: if rpv <= rpv_threshold: if purity[temp_id] > 0.75: labels += [[temp_id, 'good']] else: if purity[temp_id] > 0.75: labels += [[temp_id, 'mua']] else: labels += [[temp_id, 'noise']] else: median_amp = numpy.median(myamplitudes[:, 0]) std_amp = numpy.std(myamplitudes[:, 0]) if rpv <= rpv_threshold and numpy.abs(median_amp - 1) < 0.25: labels += [[temp_id, 'good']] else: if median_amp < 0.5: labels += [[temp_id, 'mua']] elif norms[temp_id] < 0.1: labels += [[temp_id, 'noise']] if export_all: print_and_log([ "Last %d templates are unfitted spikes on all electrodes" % N_e ], 'info', logger) garbage = io.load_data(params, 'garbage', extension) for key in garbage['gspikes'].keys(): elec_id = int(key.split('_')[-1]) data = garbage['gspikes'].pop(key).astype(numpy.uint64) spikes.append(data) amplitudes.append(numpy.ones(len(data))) clusters.append((elec_id + N_tm) * numpy.ones(len(data), dtype=numpy.uint32)) if prelabelling: f = open(os.path.join(output_path, 'cluster_group.tsv'), 'w') f.write('cluster_id\tgroup\n') for l in labels: f.write('%s\t%s\n' % (l[0], l[1])) f.close() # f = open(os.path.join(output_path, 'cluster_rpv.tsv'), 'w') # f.write('cluster_id\trpv\n') # for l in rpvs: # f.write('%s\t%s\n' % (l[0], l[1])) # f.close() spikes = numpy.concatenate(spikes).astype(numpy.uint64) amplitudes = numpy.concatenate(amplitudes).astype(numpy.double) clusters = numpy.concatenate(clusters).astype(numpy.uint32) idx = numpy.argsort(spikes) numpy.save(os.path.join(output_path, 'spike_templates'), clusters[idx]) numpy.save(os.path.join(output_path, 'spike_times'), spikes[idx]) numpy.save(os.path.join(output_path, 'amplitudes'), amplitudes[idx]) return def write_templates(path, params, extension): max_loc_channel = get_max_loc_channel(params, extension) templates = io.load_data(params, 'templates', extension) N_tm = templates.shape[1] // 2 nodes, edges = get_nodes_and_edges(params) if sparse_export: n_channels_max = 0 for t in range(N_tm): data = numpy.sum( numpy.sum(templates[:, t].toarray().reshape(N_e, N_t), 1) != 0) if data > n_channels_max: n_channels_max = data else: n_channels_max = N_e if export_all: to_write_sparse = numpy.zeros((N_tm + N_e, N_t, n_channels_max), dtype=numpy.float32) mapping_sparse = -1 * numpy.ones( (N_tm + N_e, n_channels_max), dtype=numpy.int32) else: to_write_sparse = numpy.zeros((N_tm, N_t, n_channels_max), dtype=numpy.float32) mapping_sparse = -1 * numpy.ones( (N_tm, n_channels_max), dtype=numpy.int32) has_purity = test_if_purity(params, extension) if has_purity: purity = io.load_data(params, 'purity', extension) f = open(os.path.join(output_path, 'cluster_purity.tsv'), 'w') f.write('cluster_id\tpurity\n') for i in range(N_tm): f.write('%d\t%g\n' % (i, purity[i])) f.close() for t in range(N_tm): tmp = templates[:, t].toarray().reshape(N_e, N_t).T x, y = tmp.nonzero() nb_loc = len(numpy.unique(y)) if sparse_export: all_positions = numpy.zeros(y.max() + 1, dtype=numpy.int32) all_positions[numpy.unique(y)] = numpy.arange( nb_loc, dtype=numpy.int32) pos = all_positions[y] to_write_sparse[t, x, pos] = tmp[x, y] mapping_sparse[t, numpy.arange(nb_loc)] = numpy.unique(y) else: pos = y to_write_sparse[t, x, pos] = tmp[x, y] if export_all: garbage = io.load_data(params, 'garbage', extension) for t in range(N_tm, N_tm + N_e): elec = t - N_tm spikes = garbage['gspikes'].pop('elec_%d' % elec).astype( numpy.uint64) spikes = numpy.random.permutation(spikes)[:100] mapping_sparse[t, 0] = t - N_tm waveform = io.get_stas(params, times_i=spikes, labels_i=np.ones(len(spikes)), src=elec, neighs=[elec], nodes=nodes, mean_mode=True) nb_loc = 1 if sparse_export: to_write_sparse[t, :, 0] = waveform else: to_write_sparse[t, :, elec] = waveform numpy.save(os.path.join(output_path, 'templates'), to_write_sparse) if sparse_export: numpy.save(os.path.join(output_path, 'template_ind'), mapping_sparse) return N_tm def write_pcs(path, params, extension, N_tm, mode=0): spikes = numpy.load(os.path.join(output_path, 'spike_times.npy')) labels = numpy.load(os.path.join(output_path, 'spike_templates.npy')) max_loc_channel = get_max_loc_channel(params, extension) nb_features = params.getint('whitening', 'output_dim') sign_peaks = params.get('detection', 'peaks') nodes, edges = get_nodes_and_edges(params) N_total = params.getint('data', 'N_total') has_support = test_if_support(params, extension) if has_support: supports = io.load_data(params, 'supports', extension) else: inv_nodes = numpy.zeros(N_total, dtype=numpy.int32) inv_nodes[nodes] = numpy.arange(len(nodes)) if export_all: nb_templates = N_tm + N_e else: nb_templates = N_tm pc_features_ind = numpy.zeros((nb_templates, max_loc_channel), dtype=numpy.int32) best_elec = io.load_data(params, 'electrodes', extension) if export_all: best_elec = numpy.concatenate((best_elec, numpy.arange(N_e))) if has_support: for count, support in enumerate(supports): nb_loc = numpy.sum(support) pc_features_ind[count, numpy.arange(nb_loc)] = numpy.where( support == True)[0] else: for count, elec in enumerate(best_elec): nb_loc = len(edges[nodes[elec]]) pc_features_ind[count, numpy.arange(nb_loc)] = inv_nodes[edges[ nodes[elec]]] if sign_peaks in ['negative', 'both']: basis_proj, basis_rec = io.load_data(params, 'basis') elif sign_peaks in ['positive']: basis_proj, basis_rec = io.load_data(params, 'basis-pos') to_process = numpy.arange(comm.rank, nb_templates, comm.size) all_offsets = numpy.zeros(nb_templates, dtype=numpy.int32) for target in range(nb_templates): if mode == 0: all_offsets[target] = len(numpy.where(labels == target)[0]) elif mode == 1: all_offsets[target] = min( 500, len(numpy.where(labels == target)[0])) all_paddings = numpy.concatenate(([0], numpy.cumsum(all_offsets))) total_pcs = numpy.sum(all_offsets) pc_file = os.path.join(output_path, 'pc_features.npy') pc_file_ids = os.path.join(output_path, 'pc_feature_spike_ids.npy') from numpy.lib.format import open_memmap if comm.rank == 0: pc_features = open_memmap(pc_file, shape=(total_pcs, nb_features, max_loc_channel), dtype=numpy.float32, mode='w+') if mode == 1: pc_ids = open_memmap(pc_file_ids, shape=(total_pcs, ), dtype=numpy.int32, mode='w+') comm.Barrier() pc_features = open_memmap(pc_file, mode='r+') if mode == 1: pc_ids = open_memmap(pc_file_ids, mode='r+') to_explore = range(comm.rank, nb_templates, comm.size) if comm.rank == 0: to_explore = get_tqdm_progressbar(params, to_explore) all_idx = numpy.zeros(0, dtype=numpy.int32) for gcount, target in enumerate(to_explore): count = all_paddings[target] if mode == 1: idx = numpy.random.permutation( numpy.where(labels == target)[0])[:500] pc_ids[count:count + len(idx)] = idx elif mode == 0: idx = numpy.where(labels == target)[0] elec = best_elec[target] if has_support: indices = numpy.where(supports[target])[0] else: indices = inv_nodes[edges[nodes[elec]]] labels_i = target * numpy.ones(len(idx)) times_i = numpy.take(spikes, idx).astype(numpy.int64) sub_data = io.get_stas(params, times_i, labels_i, elec, neighs=indices, nodes=nodes, auto_align=False) pcs = numpy.dot(sub_data, basis_proj) pcs = numpy.swapaxes(pcs, 1, 2) if mode == 0: pc_features[idx, :, :len(indices)] = pcs elif mode == 1: pc_features[count:count + len(idx), :, :len(indices)] = pcs comm.Barrier() if comm.rank == 0: numpy.save(os.path.join(output_path, 'pc_feature_ind'), pc_features_ind.astype( numpy.uint32)) # n_templates, n_loc_chan do_export = True if comm.rank == 0: if os.path.exists(output_path): if not erase_all: do_export = query_yes_no( Fore.WHITE + "Export already made! Do you want to erase everything?", default=None) if do_export: if os.path.exists(os.path.abspath('.phy')): shutil.rmtree(os.path.abspath('.phy')) shutil.rmtree(output_path) if do_export: comm.bcast(numpy.array([1], dtype=numpy.int32), root=0) else: comm.bcast(numpy.array([0], dtype=numpy.int32), root=0) else: do_export = bool( comm.bcast(numpy.array([0], dtype=numpy.int32), root=0)) comm.Barrier() if do_export: apply_patch_for_similarities(params, extension) if comm.rank == 0: os.makedirs(output_path) print_and_log( ["Exporting data for the phy GUI with %d CPUs..." % nb_cpu], 'info', logger) if params.getboolean('whitening', 'spatial'): whitening_mat = io.load_data( params, 'spatial_whitening').astype(numpy.double) numpy.save(os.path.join(output_path, 'whitening_mat'), whitening_mat) numpy.save(os.path.join(output_path, 'whitening_mat_inv'), numpy.linalg.inv(whitening_mat)) else: numpy.save(os.path.join(output_path, 'whitening_mat'), numpy.eye(N_e)) positions, shanks = generate_mapping(probe) numpy.save(os.path.join(output_path, 'channel_positions'), positions.astype(numpy.double)) numpy.save(os.path.join(output_path, 'channel_shanks'), shanks.astype(numpy.double)) nodes, edges = get_nodes_and_edges(params) numpy.save(os.path.join(output_path, 'channel_map'), nodes.astype(numpy.int32)) write_results(output_path, params, extension) N_tm = write_templates(output_path, params, extension) template_file = h5py.File(file_out_suff + '.templates%s.hdf5' % extension, 'r', libver='earliest') similarities = template_file.get('maxoverlap')[:] template_file.close() norm = N_e * N_t if export_all: to_write = numpy.zeros((N_tm + N_e, N_tm + N_e), dtype=numpy.single) to_write[:N_tm, :N_tm] = (similarities[:N_tm, :N_tm] / norm).astype(numpy.single) else: to_write = (similarities[:N_tm, :N_tm] / norm).astype( numpy.single) numpy.save(os.path.join(output_path, 'similar_templates'), to_write) comm.bcast(numpy.array([N_tm], dtype=numpy.int32), root=0) else: N_tm = int(comm.bcast(numpy.array([0], dtype=numpy.int32), root=0)) comm.Barrier() make_pcs = 2 if comm.rank == 0: if export_pcs == 'prompt': key = '' while key not in ['a', 's', 'n']: print( Fore.WHITE + "Do you want SpyKING CIRCUS to export PCs? (a)ll / (s)ome / (n)o" ) key = raw_input('') else: key = export_pcs if key == 'a': make_pcs = 0 comm.bcast(numpy.array([0], dtype=numpy.int32), root=0) elif key == 's': make_pcs = 1 comm.bcast(numpy.array([1], dtype=numpy.int32), root=0) elif key == 'n': comm.bcast(numpy.array([2], dtype=numpy.int32), root=0) if os.path.exists(os.path.join(output_path, 'pc_features.npy')): os.remove(os.path.join(output_path, 'pc_features.npy')) if os.path.exists( os.path.join(output_path, 'pc_feature_ind.npy')): os.remove(os.path.join(output_path, 'pc_feature_ind.npy')) else: make_pcs = comm.bcast(numpy.array([0], dtype=numpy.int32), root=0) make_pcs = make_pcs[0] comm.Barrier() if make_pcs < 2: write_pcs(output_path, params, extension, N_tm, make_pcs) supported_by_phy = ['raw_binary', 'mcs_raw_binary', 'mda'] file_format = data_file.description gui_params = {} if file_format in supported_by_phy: if not params.getboolean('data', 'overwrite'): gui_params['dat_path'] = r"%s" % params.get( 'data', 'data_file_no_overwrite') else: if params.get('data', 'stream_mode') == 'multi-files': data_file = params.get_data_file(source=True, has_been_created=False) gui_params['dat_path'] = "[" for f in data_file.get_file_names(): gui_params['dat_path'] += 'r"%s", ' % f gui_params['dat_path'] += "]" else: gui_params['dat_path'] = 'r"%s"' % params.get( 'data', 'data_file') else: gui_params['dat_path'] = 'giverandomname.dat' gui_params['n_channels_dat'] = params.nb_channels gui_params['n_features_per_channel'] = 5 gui_params['dtype'] = data_file.data_dtype if 'data_offset' in data_file.params.keys(): gui_params['offset'] = data_file.data_offset gui_params['sample_rate'] = params.rate gui_params['dir_path'] = output_path gui_params['hp_filtered'] = True f = open(os.path.join(output_path, 'params.py'), 'w') for key, value in gui_params.items(): if key in ['dir_path', 'dtype']: f.write('%s = r"%s"\n' % (key, value)) else: f.write("%s = %s\n" % (key, value)) f.close()