def plot_spike_clusters_and_gt_overlay(clusters, nll, data_arr, gt_labels, geom, channels, colors, topn=2, sort_by_count=True, min_cls_size=0, figdir="./", fname_postfix="", size_single=(9, 9), vertical=False, plot_params={"time_scale": 1.1, "scale": 8., "alpha_overlay": 0.1}): """Plot spikes colored by assigned clusters vs. ground truth clusters, with all spikes overlayed on each other """ gt_path = plot_raw_spikes_overlay(data_arr, gt_labels, geom, channels, colors, size_single=size_single, vertical=vertical, time_scale=plot_params['time_scale'], scale=plot_params['scale'], alpha_overlay=plot_params['alpha_overlay'], figtitle="Ground Truth", titlesize=25, figdir=figdir, fname_postfix=fname_postfix + '_overlay_gt', show=False) fig_paths = [gt_path] topn_clusters, topn_nll = get_topn_clusters(clusters, nll, topn) for i in range(topn): snll = topn_nll[i] cs = topn_clusters[i] K = len(set(cs)) pr = np.exp(-snll) title = 'NCP: {} Clusters (Prob: {:.3f})'.format(K, pr) topn_path = plot_raw_spikes_overlay(data_arr, cs, geom, channels, colors, size_single=size_single, vertical=vertical, min_cls_size=min_cls_size, sort_by_count=sort_by_count, time_scale=plot_params['time_scale'], scale=plot_params['scale'], alpha_overlay=plot_params['alpha_overlay'], figtitle=title, titlesize=25, figdir=figdir, fname_postfix=fname_postfix + '_overlay_top{}'.format(i+1), show=False) fig_paths.append(topn_path) if vertical: combine_imgs_vertical(fig_paths, gt_path.replace("_gt", "_combined")) else: combine_imgs(fig_paths, gt_path.replace("_gt", "_combined")) for f in fig_paths: if os.path.exists(f): os.remove(f)
def plot_spike_clusters_and_gt_in_rows(clusters, nll, data_arr, gt_labels, topn=2, figdir="./", fname_postfix="", plot_params={"spacing": 1.25, "width": 0.9, "vscale": 1.5, "subplot_adj": 0.9}, downsample=None): """Plot spikes colored by assigned clusters vs. ground truth clusters, each spike as a row """ topn_clusters, topn_nll = get_topn_clusters(clusters, nll, topn) reorder = np.argsort(gt_labels) gt_labels = gt_labels[reorder] data_arr = data_arr[reorder] topn_clusters = topn_clusters[:, reorder] gt_path = plot_raw_spikes_in_rows(data_arr, gt_labels, spacing=plot_params["spacing"], width=plot_params["width"], vscale=plot_params["vscale"], subplot_adj=plot_params['subplot_adj'], figtitle='Ground truth', figdir=figdir, fname_postfix=fname_postfix + '_rows_gt', show=False) fig_paths = [gt_path] for i in range(len(topn_nll)): snll = topn_nll[i] cs = topn_clusters[i] K = len(set(cs)) pr = np.exp(-snll) title = 'NCP: {} Clusters (Prob: {:.3f})'.format(K, pr) fpath = plot_raw_spikes_in_rows(data_arr, cs, spacing=plot_params["spacing"], width=plot_params["width"], vscale=plot_params["vscale"], subplot_adj=plot_params['subplot_adj'], figtitle=title, figdir=figdir, fname_postfix=fname_postfix + '_rows_pred{}'.format(i), show=False) fig_paths.append(fpath) combine_imgs(fig_paths, fpath.replace( "pred", "pred"), downsample=downsample) for f in fig_paths: if os.path.exists(f): os.remove(f)
for fname in fnames_list: npz = np.load( os.path.join(input_dir, "data_input", "{}.npz".format(fname))) data_arr, gt_labels = npz['data_arr'], npz['gt_labels'] data_ncp = torch.from_numpy(np.array([data_arr.transpose(0, 2, 1)])) print("Running inference on {}:".format(fname)) t = time.time() clusters, nll, highest_prob = cluster_spikes_ncp(model, data_ncp, NCP_Sampler, S=n_parallel_sample, beam=use_beam) inference_time = time.time() - t print(" time {:4f}".format(inference_time)) topn_clusters, topn_nll = get_topn_clusters(clusters, nll, topn) # save data npz_fname = os.path.join(data_dir, "{}_ncp.npz".format(fname)) np.savez_compressed(npz_fname, clusters=clusters, nll=nll, topn_clusters=topn_clusters, topn_nll=topn_nll, data_arr=data_arr, gt_labels=gt_labels, inference_time=inference_time)
def plot_raw_and_encoded_spikes_tsne(clusters, nll, data_arr, data_encoded, colors, topn=2, extra_clusters=None, extra_name=None, gt_labels=None, min_cls_size=10, sort_by_count=True, figdir="./", fname_postfix="", size_single=(6, 6), tsne_params={'seed': 0, 'perplexity': 30}, plot_params={'pt_scale': 1}, show=True): """Plot spike waveforms and the NCP-encoded vectors using t-SNE """ from MulticoreTSNE import MulticoreTSNE as TSNE N = data_arr.shape[0] pt_size = 26.*(100/N)**0.5 * plot_params['pt_scale'] topn_clusters, topn_nll = get_topn_clusters(clusters, nll, topn) n_plots = topn plot_clusters = topn_clusters[:topn] plot_clusters_names = ['NCP clusters #{}'.format(i) for i in range(topn)] if extra_clusters is not None and extra_name is not None: n_plots += 1 plot_clusters = np.concatenate( [topn_clusters[:topn], [extra_clusters]]) plot_clusters_names.append(extra_name) if gt_labels is not None: n_plots += 1 plot_clusters = np.concatenate([[gt_labels], topn_clusters[:topn]]) plot_clusters_names = ["Ground Truth"] + plot_clusters_names fig, axes = plt.subplots(n_plots, 2, figsize=( size_single[0] * 2, size_single[1] * n_plots)) if n_plots == 1: axes = [axes] tsne = TSNE(n_jobs=4, n_components=2, perplexity=tsne_params['perplexity'], random_state=tsne_params['seed']) tsne_encoded = tsne.fit_transform(data_encoded) data_reshape = data_arr.reshape((data_arr.shape[0], -1)) tsne_raw = tsne.fit_transform(data_reshape) for i in range(n_plots): cs = plot_clusters[i] n_clusters = len(set(cs)) cluster_ids, counts = np.unique(cs, return_counts=True) if sort_by_count: sorted_idx = np.argsort(-counts) cluster_ids, counts = cluster_ids[sorted_idx], counts[sorted_idx] cluster_rename = {cluster_ids[i] : i for i in range(len(cluster_ids))} cluster_ids = np.vectorize(cluster_rename.get)(cluster_ids) clusters_use = np.vectorize(cluster_rename.get)(cs) else: clusters_use = cs large_clusters = cluster_ids[counts > min_cls_size] small_clusters = cluster_ids[counts <= min_cls_size] color_mapping = {cluster_ids[i]: (colors[i] if cluster_ids[i] in large_clusters else 'grey') for i in range(len(cluster_ids))} fontsize = 12 for k in cluster_ids: mask = (clusters_use == k) axes[i][0].scatter(tsne_raw[mask, 0], tsne_raw[mask, 1], color=color_mapping[k], s=pt_size) axes[i][1].scatter( tsne_encoded[mask, 0], tsne_encoded[mask, 1], color=color_mapping[k], s=pt_size) axes[i][0].set_title( 't-SNE of raw spikes ({})'.format(plot_clusters_names[i]), fontsize=fontsize) axes[i][1].set_title( 't-SNE of NCP-encoded spikes ({})'.format(plot_clusters_names[i]), fontsize=fontsize) plt.tight_layout() if show: plt.show() return None else: save_path = os.path.join(figdir, "{}_tsne.png".format(fname_postfix)) plt.savefig(save_path) save_path = os.path.join(figdir, "{}_tsne.pdf".format(fname_postfix)) plt.savefig(save_path) plt.close() return save_path
def plot_spike_clusters_and_templates_overlay(clusters, nll, data_arr, geom, channels, colors, topn=1, min_cls_size=10, templates=None, template_name="Templates", gt_labels=None, extra_clusters=None, extra_name=None, sort_by_count=True, figdir="./", fname_postfix="", size_single=(9, 9), vertical=False, plot_params={"time_scale": 1.1, "scale": 8., "alpha_overlay": 0.1}): """Plot spikes colored by assigned clusters and their templates, with all spikes overlayed on each other """ fig_paths = [] if gt_labels is not None: K = len(set(gt_labels)) title = 'Groud Truth: {} Clusters'.format(K) extra_path = plot_raw_spikes_overlay(data_arr, gt_labels, geom, channels, colors, size_single=size_single, vertical=vertical, min_cls_size=min_cls_size, sort_by_count=sort_by_count, time_scale=plot_params['time_scale'], scale=plot_params['scale'], alpha_overlay=plot_params['alpha_overlay'], figtitle=title, titlesize=25, figdir=figdir, fname_postfix=fname_postfix + '_overlay_gt', show=False) fig_paths.append(extra_path) if templates is not None: gt_path = plot_templates_separate(templates, geom, channels, colors, size_single=size_single, vertical=vertical, time_scale=plot_params['time_scale'], scale=plot_params['scale'], alpha_overlay=1, figtitle=template_name, titlesize=25, figdir=figdir, fname_postfix=fname_postfix, show=False) fig_paths.append(gt_path) topn_clusters, topn_nll = get_topn_clusters(clusters, nll, topn) for i in range(topn): snll = topn_nll[i] cs = topn_clusters[i] K = len(set(cs)) pr = np.exp(-snll) title = 'NCP: {} Clusters (Prob: {:.3f})'.format(K, pr) topn_path = plot_raw_spikes_overlay(data_arr, cs, geom, channels, colors, size_single=size_single, vertical=vertical, min_cls_size=min_cls_size, sort_by_count=sort_by_count, time_scale=plot_params['time_scale'], scale=plot_params['scale'], alpha_overlay=plot_params['alpha_overlay'], figtitle=title, titlesize=25, figdir=figdir, fname_postfix=fname_postfix + '_overlay_pred{}'.format(i+1), show=False) fig_paths.append(topn_path) if extra_clusters is not None and extra_name is not None: K = len(set(extra_clusters)) title = '{}: {} Clusters'.format(extra_name, K) extra_path = plot_raw_spikes_overlay(data_arr, extra_clusters, geom, channels, colors, size_single=size_single, vertical=vertical, min_cls_size=min_cls_size, sort_by_count=sort_by_count, time_scale=plot_params['time_scale'], scale=plot_params['scale'], alpha_overlay=plot_params['alpha_overlay'], figtitle=title, titlesize=25, figdir=figdir, fname_postfix=fname_postfix + '_overlay_{}'.format(extra_name), show=False) fig_paths.append(extra_path) if vertical: combine_imgs_vertical(fig_paths, topn_path.replace("pred", "top")) else: combine_imgs(fig_paths, topn_path.replace("pred", "top")) for f in fig_paths: if os.path.exists(f): os.remove(f)