示例#1
0
def plot_spike_clusters_and_gt_overlay(clusters, nll, data_arr, gt_labels, geom, channels, colors, topn=2,
                                       sort_by_count=True, min_cls_size=0,
                                       figdir="./", fname_postfix="", size_single=(9, 9), vertical=False,
                                       plot_params={"time_scale": 1.1, "scale": 8., "alpha_overlay": 0.1}):
    """Plot spikes colored by assigned clusters vs. ground truth clusters, with all spikes overlayed on each other
    """
    gt_path = plot_raw_spikes_overlay(data_arr, gt_labels, geom, channels, colors, size_single=size_single,
                                      vertical=vertical,
                                      time_scale=plot_params['time_scale'],
                                      scale=plot_params['scale'],
                                      alpha_overlay=plot_params['alpha_overlay'],
                                      figtitle="Ground Truth", titlesize=25,
                                      figdir=figdir, fname_postfix=fname_postfix + '_overlay_gt', show=False)
    fig_paths = [gt_path]

    topn_clusters, topn_nll = get_topn_clusters(clusters, nll, topn)

    for i in range(topn):
        snll = topn_nll[i]
        cs = topn_clusters[i]
        K = len(set(cs))
        pr = np.exp(-snll)
        title = 'NCP: {} Clusters (Prob: {:.3f})'.format(K, pr)
        topn_path = plot_raw_spikes_overlay(data_arr, cs, geom, channels, colors, size_single=size_single,
                                            vertical=vertical,
                                            min_cls_size=min_cls_size,
                                            sort_by_count=sort_by_count,
                                            time_scale=plot_params['time_scale'],
                                            scale=plot_params['scale'],
                                            alpha_overlay=plot_params['alpha_overlay'],
                                            figtitle=title, titlesize=25,
                                            figdir=figdir, fname_postfix=fname_postfix + '_overlay_top{}'.format(i+1), show=False)
        fig_paths.append(topn_path)

    if vertical:
        combine_imgs_vertical(fig_paths, gt_path.replace("_gt", "_combined"))
    else:
        combine_imgs(fig_paths, gt_path.replace("_gt", "_combined"))
    for f in fig_paths:
        if os.path.exists(f):
            os.remove(f)
示例#2
0
def plot_spike_clusters_and_gt_in_rows(clusters, nll, data_arr, gt_labels, topn=2, figdir="./", fname_postfix="",
                                       plot_params={"spacing": 1.25, "width": 0.9, "vscale": 1.5, "subplot_adj": 0.9}, downsample=None):
    """Plot spikes colored by assigned clusters vs. ground truth clusters, each spike as a row
    """
    topn_clusters, topn_nll = get_topn_clusters(clusters, nll, topn)

    reorder = np.argsort(gt_labels)
    gt_labels = gt_labels[reorder]
    data_arr = data_arr[reorder]
    topn_clusters = topn_clusters[:, reorder]

    gt_path = plot_raw_spikes_in_rows(data_arr, gt_labels,
                                      spacing=plot_params["spacing"], width=plot_params["width"],
                                      vscale=plot_params["vscale"],
                                      subplot_adj=plot_params['subplot_adj'],
                                      figtitle='Ground truth', figdir=figdir, fname_postfix=fname_postfix + '_rows_gt', show=False)
    fig_paths = [gt_path]

    for i in range(len(topn_nll)):
        snll = topn_nll[i]
        cs = topn_clusters[i]

        K = len(set(cs))
        pr = np.exp(-snll)
        title = 'NCP: {} Clusters (Prob: {:.3f})'.format(K, pr)
        fpath = plot_raw_spikes_in_rows(data_arr, cs,
                                        spacing=plot_params["spacing"], width=plot_params["width"],
                                        vscale=plot_params["vscale"],
                                        subplot_adj=plot_params['subplot_adj'],
                                        figtitle=title, figdir=figdir, fname_postfix=fname_postfix + '_rows_pred{}'.format(i), show=False)
        fig_paths.append(fpath)
    combine_imgs(fig_paths, fpath.replace(
        "pred", "pred"), downsample=downsample)
    for f in fig_paths:
        if os.path.exists(f):
            os.remove(f)
    for fname in fnames_list:

        npz = np.load(
            os.path.join(input_dir, "data_input", "{}.npz".format(fname)))
        data_arr, gt_labels = npz['data_arr'], npz['gt_labels']

        data_ncp = torch.from_numpy(np.array([data_arr.transpose(0, 2, 1)]))
        print("Running inference on {}:".format(fname))
        t = time.time()
        clusters, nll, highest_prob = cluster_spikes_ncp(model,
                                                         data_ncp,
                                                         NCP_Sampler,
                                                         S=n_parallel_sample,
                                                         beam=use_beam)
        inference_time = time.time() - t
        print("  time {:4f}".format(inference_time))

        topn_clusters, topn_nll = get_topn_clusters(clusters, nll, topn)

        # save data
        npz_fname = os.path.join(data_dir, "{}_ncp.npz".format(fname))
        np.savez_compressed(npz_fname,
                            clusters=clusters,
                            nll=nll,
                            topn_clusters=topn_clusters,
                            topn_nll=topn_nll,
                            data_arr=data_arr,
                            gt_labels=gt_labels,
                            inference_time=inference_time)
示例#4
0
def plot_raw_and_encoded_spikes_tsne(clusters, nll, data_arr, data_encoded, colors, topn=2,
                                     extra_clusters=None, extra_name=None, gt_labels=None,
                                     min_cls_size=10, sort_by_count=True,
                                     figdir="./", fname_postfix="", size_single=(6, 6),
                                     tsne_params={'seed': 0, 'perplexity': 30},
                                     plot_params={'pt_scale': 1}, show=True):
    """Plot spike waveforms and the NCP-encoded vectors using t-SNE
    """

    from MulticoreTSNE import MulticoreTSNE as TSNE

    N = data_arr.shape[0]
    pt_size = 26.*(100/N)**0.5 * plot_params['pt_scale']
    topn_clusters, topn_nll = get_topn_clusters(clusters, nll, topn)

    n_plots = topn
    plot_clusters = topn_clusters[:topn]
    plot_clusters_names = ['NCP clusters #{}'.format(i) for i in range(topn)]

    if extra_clusters is not None and extra_name is not None:
        n_plots += 1
        plot_clusters = np.concatenate(
            [topn_clusters[:topn], [extra_clusters]])
        plot_clusters_names.append(extra_name)

    if gt_labels is not None:
        n_plots += 1
        plot_clusters = np.concatenate([[gt_labels], topn_clusters[:topn]])
        plot_clusters_names = ["Ground Truth"] + plot_clusters_names

    fig, axes = plt.subplots(n_plots, 2, figsize=(
        size_single[0] * 2, size_single[1] * n_plots))
    if n_plots == 1:
        axes = [axes]
    tsne = TSNE(n_jobs=4, n_components=2,
                perplexity=tsne_params['perplexity'], random_state=tsne_params['seed'])
    tsne_encoded = tsne.fit_transform(data_encoded)
    data_reshape = data_arr.reshape((data_arr.shape[0], -1))
    tsne_raw = tsne.fit_transform(data_reshape)

    for i in range(n_plots):
        cs = plot_clusters[i]
        n_clusters = len(set(cs))

        cluster_ids, counts = np.unique(cs, return_counts=True)
        if sort_by_count:
            sorted_idx = np.argsort(-counts)
            cluster_ids, counts = cluster_ids[sorted_idx], counts[sorted_idx]
            cluster_rename = {cluster_ids[i]                              : i for i in range(len(cluster_ids))}
            cluster_ids = np.vectorize(cluster_rename.get)(cluster_ids)
            clusters_use = np.vectorize(cluster_rename.get)(cs)
        else:
            clusters_use = cs

        large_clusters = cluster_ids[counts > min_cls_size]
        small_clusters = cluster_ids[counts <= min_cls_size]
        color_mapping = {cluster_ids[i]: (colors[i] if cluster_ids[i] in large_clusters else 'grey')
                         for i in range(len(cluster_ids))}

        fontsize = 12
        for k in cluster_ids:
            mask = (clusters_use == k)
            axes[i][0].scatter(tsne_raw[mask, 0], tsne_raw[mask, 1],
                               color=color_mapping[k], s=pt_size)
            axes[i][1].scatter(
                tsne_encoded[mask, 0], tsne_encoded[mask, 1], color=color_mapping[k], s=pt_size)

        axes[i][0].set_title(
            't-SNE of raw spikes ({})'.format(plot_clusters_names[i]), fontsize=fontsize)
        axes[i][1].set_title(
            't-SNE of NCP-encoded spikes ({})'.format(plot_clusters_names[i]), fontsize=fontsize)

    plt.tight_layout()
    if show:
        plt.show()
        return None
    else:
        save_path = os.path.join(figdir, "{}_tsne.png".format(fname_postfix))
        plt.savefig(save_path)
        save_path = os.path.join(figdir, "{}_tsne.pdf".format(fname_postfix))
        plt.savefig(save_path)
        plt.close()
        return save_path
示例#5
0
def plot_spike_clusters_and_templates_overlay(clusters, nll, data_arr, geom, channels, colors,
                                              topn=1, min_cls_size=10,
                                              templates=None, template_name="Templates",
                                              gt_labels=None,
                                              extra_clusters=None, extra_name=None,
                                              sort_by_count=True, figdir="./", fname_postfix="", size_single=(9, 9), vertical=False,
                                              plot_params={"time_scale": 1.1, "scale": 8., "alpha_overlay": 0.1}):
    """Plot spikes colored by assigned clusters and their templates, with all spikes overlayed on each other
    """
    fig_paths = []

    if gt_labels is not None:
        K = len(set(gt_labels))
        title = 'Groud Truth: {} Clusters'.format(K)
        extra_path = plot_raw_spikes_overlay(data_arr, gt_labels, geom, channels, colors, size_single=size_single,
                                             vertical=vertical,
                                             min_cls_size=min_cls_size,
                                             sort_by_count=sort_by_count,
                                             time_scale=plot_params['time_scale'],
                                             scale=plot_params['scale'],
                                             alpha_overlay=plot_params['alpha_overlay'],
                                             figtitle=title, titlesize=25,
                                             figdir=figdir, fname_postfix=fname_postfix + '_overlay_gt', show=False)
        fig_paths.append(extra_path)

    if templates is not None:
        gt_path = plot_templates_separate(templates, geom, channels, colors, size_single=size_single,
                                          vertical=vertical,
                                          time_scale=plot_params['time_scale'],
                                          scale=plot_params['scale'],
                                          alpha_overlay=1,
                                          figtitle=template_name, titlesize=25,
                                          figdir=figdir, fname_postfix=fname_postfix, show=False)
        fig_paths.append(gt_path)

    topn_clusters, topn_nll = get_topn_clusters(clusters, nll, topn)

    for i in range(topn):
        snll = topn_nll[i]
        cs = topn_clusters[i]
        K = len(set(cs))
        pr = np.exp(-snll)
        title = 'NCP: {} Clusters (Prob: {:.3f})'.format(K, pr)
        topn_path = plot_raw_spikes_overlay(data_arr, cs, geom, channels, colors, size_single=size_single,
                                            vertical=vertical,
                                            min_cls_size=min_cls_size,
                                            sort_by_count=sort_by_count,
                                            time_scale=plot_params['time_scale'],
                                            scale=plot_params['scale'],
                                            alpha_overlay=plot_params['alpha_overlay'],
                                            figtitle=title, titlesize=25,
                                            figdir=figdir, fname_postfix=fname_postfix + '_overlay_pred{}'.format(i+1), show=False)
        fig_paths.append(topn_path)

    if extra_clusters is not None and extra_name is not None:
        K = len(set(extra_clusters))
        title = '{}: {} Clusters'.format(extra_name, K)
        extra_path = plot_raw_spikes_overlay(data_arr, extra_clusters, geom, channels, colors, size_single=size_single,
                                             vertical=vertical,
                                             min_cls_size=min_cls_size,
                                             sort_by_count=sort_by_count,
                                             time_scale=plot_params['time_scale'],
                                             scale=plot_params['scale'],
                                             alpha_overlay=plot_params['alpha_overlay'],
                                             figtitle=title, titlesize=25,
                                             figdir=figdir, fname_postfix=fname_postfix + '_overlay_{}'.format(extra_name), show=False)
        fig_paths.append(extra_path)

    if vertical:
        combine_imgs_vertical(fig_paths, topn_path.replace("pred", "top"))
    else:
        combine_imgs(fig_paths, topn_path.replace("pred", "top"))

    for f in fig_paths:
        if os.path.exists(f):
            os.remove(f)