예제 #1
0
def plot_y_centroids(yi,
                     lami,
                     wi,
                     vlim=(-4, 4),
                     fs_xlabel=12,
                     fs_ylabel=12,
                     fs_xticks=12,
                     fs_yticks=12,
                     rotation=90,
                     ha="center",
                     markernames=[],
                     ylab="subpopulations",
                     gridlines_color='black',
                     gridlines_lw=1,
                     cm=blue2red.cm(9),
                     population=None,
                     Zi=None):
    J = yi.shape[1]
    K = wi.shape[0]

    selected_features = np.argwhere(wi > 0)[:, 0]
    K_sel = selected_features.shape[0]
    wi_sel = wi[selected_features]
    selected_features = selected_features[np.argsort(wi_sel)]
    wi_sel_sorted = wi_sel[np.argsort(wi_sel)]

    y_centers = np.zeros((J, K_sel))
    yticks = [0] * K_sel

    for j in range(J):
        for k in range(K_sel)[::-1]:
            k_ = selected_features[k] + 1
            y_centers[j, k] = yi[lami == k_, j].mean()
            w_ik_perc = (wi_sel_sorted[k] * 100).round(1)
            if population is None or Zi is None:
                yticks[k] = '{} ({}%)'.format(k_, w_ik_perc)
            else:
                label = population.label(Zi[:, k_ - 1])
                yticks[k] = '{} ({}%)'.format(label, w_ik_perc)

    im = plt.imshow(y_centers.T,
                    aspect='auto',
                    cmap=cm,
                    vmin=vlim[0],
                    vmax=vlim[1])
    if markernames == []:
        markernames = np.arange(J) + 1

    plt.xticks(range(J),
               markernames,
               rotation=rotation,
               fontsize=fs_xticks,
               ha=ha)
    plt.yticks(range(K_sel), yticks, fontsize=fs_yticks)
    plt.xlabel('markers', fontsize=fs_xlabel)
    plt.ylabel(ylab, fontsize=fs_ylabel)
    gridlines(y_centers.T, color=gridlines_color, lw=gridlines_lw)
    colorbar_horizontal(im)
예제 #2
0
def plot_y(yi,
           wi_mean,
           lami_est,
           fs_lab=10,
           fs_cbar=10,
           lw=3,
           cm=blue2red.cm(6),
           vlim=(-3, 3),
           fs_xlab=10,
           fs_ylab=10,
           ylab="cells",
           markernames=[],
           interpolation=None,
           rotation=90,
           ha="center"):
    J = yi.shape[1]
    vmin, vmax = vlim

    if type(wi_mean) == int:
        K = wi_mean
        wi_mean = np.array([(lami_est == k + 1).mean() for k in range(K)])

    lami_new, counts = relabel_lam(lami_est, wi_mean)
    counts_cumsum = np.cumsum(counts)
    yi_sorted = yi[np.argsort(lami_new), :]

    im = plt.imshow(yi_sorted,
                    aspect='auto',
                    vmin=vmin,
                    vmax=vmax,
                    cmap=cm,
                    interpolation=interpolation)
    for c in counts_cumsum[:-1]:
        plt.axhline(c, color='yellow', linewidth=lw)
    plt.xticks(rotation=rotation, ha=ha)
    if len(markernames) == 0:
        plt.xticks(np.arange(J), np.arange(J) + 1, fontsize=fs_xlab)
    else:
        plt.xticks(np.arange(J), markernames, fontsize=fs_xlab)
    plt.yticks(fontsize=fs_ylab)
    plt.xlabel("markers", fontsize=fs_lab)
    plt.ylabel(ylab, fontsize=fs_lab)

    ax = plt.gca()
    ax_divider = make_axes_locatable(ax)
    cax = ax_divider.append_axes("top", size="7%", pad="2%")
    cax.xaxis.set_ticks_position("top")
    cbar = colorbar(im, cax=cax, orientation="horizontal")
    cbar.ax.tick_params(labelsize=fs_cbar)
예제 #3
0
    y = copy.deepcopy(data['y'])

    # Get a subsampmle of data
    if 0 < subsample < 1:
        for i in range(len(y)):
            Ni = y[i].shape[0]
            idx = np.random.choice(Ni, int(Ni * subsample), replace=False)
            y[i] = y[i][idx, :]

    # Print size of data
    print('N: {}'.format([yi.shape[0] for yi in y]))

    # Color map
    cm_greys = plt.cm.get_cmap('Greys', 5)
    VMIN, VMAX = VLIM = (-4, 4)
    cm = blue2red.cm(9)

    # Plot yi histograms
    # plt.hist(y[0][:, 1], bins=100, density=True); plt.xlim(-15, 15); plt.show()
    # plt.hist(y[1][:, 3], bins=100, density=True); plt.xlim(-15, 15); plt.show()
    # plt.hist(y[2][:, -1], bins=100, density=True); plt.xlim(-15, 15); plt.show()

    # Heatmaps
    for i in range(I):
        plt.imshow(y[i], aspect='auto', vmin=VMIN, vmax=VMAX, cmap=cm)
        plt.colorbar()
        plt.savefig('{}/y{}.pdf'.format(img_dir, i + 1))
        plt.close()

    K = 30
    L = [5, 3]
예제 #4
0
            yi_path = f'{path}/img/txt/y{i}_mean.csv'
            yi = np.loadtxt(yi_path, delimiter=',')
            lami_path = f'{path}/img/txt/lam{i}_best.txt'
            lami = np.loadtxt(lami_path, dtype=int)
            wi_path = f'{path}/img/txt/W{i}_best.txt'
            wi = np.loadtxt(wi_path)
            zi_path = f'{path}/img/txt/Z{i}_best.txt'
            zi = np.loadtxt(zi_path)

            # Plot
            plt.figure(figsize=(6, 6))
            plot_yz.plot_y_centroids(yi,
                                     lami,
                                     wi,
                                     vlim=(-3, 3),
                                     cm=blue2red.cm(6),
                                     population=population,
                                     Zi=zi,
                                     fs_xlabel=16,
                                     fs_ylabel=16,
                                     fs_xticks=16,
                                     fs_yticks=16)
            outpath = f'{path}/img/y{i}_centroid.pdf'
            plt.savefig(outpath, bbox_inches="tight")
            plt.close()

            # Plot Z estimate.
            plt.figure(figsize=(6, 6))
            plot_yz.plot_Z(Z_mean=zi,
                           wi_mean=wi,
                           lami_est=lami,
예제 #5
0
        population = Population()

        # Rcounts
        R_path = f'{path}/img/txt/Rcounts.txt'
        R = np.loadtxt(R_path)
        plt.figure(figsize=(5,5))
        zinfo.plot_num_selected_features(R, refs=[6, 5], ymax=1.05)
        plt.savefig(f'{path}/img/Rcounts.pdf', bbox_inches='tight')
        plt.close()

        for i in (1, 2):
            # Read data
            yi_path = f'{path}/img/txt/y{i}_mean.csv'
            yi = np.loadtxt(yi_path, delimiter=',')
            lami_path = f'{path}/img/txt/lam{i}_best.txt'
            lami = np.loadtxt(lami_path, dtype=int)
            wi_path = f'{path}/img/txt/W{i}_best.txt'
            wi = np.loadtxt(wi_path)
            zi_path = f'{path}/img/txt/Z{i}_best.txt'
            zi = np.loadtxt(zi_path)

            # Plot
            plt.figure(figsize=(6,6))
            plot_yz.plot_y_centroids(yi, lami, wi, vlim=(-3, 3), cm=blue2red.cm(6),
                                     population=population, Zi=zi,
                                     fs_xlabel=16, fs_ylabel=16,
                                     fs_xticks=16, fs_yticks=16)
            outpath = f'{path}/img/y{i}_centroid.pdf'
            plt.savefig(outpath, bbox_inches="tight")
            plt.close()
예제 #6
0
def plot_yz(yi,
            Z_mean,
            wi_mean,
            lami_est,
            w_thresh=.01,
            cm_greys=plt.cm.get_cmap('Greys', 5),
            markernames=[],
            rotation=90,
            ha="center",
            cm_y=blue2red.cm(6),
            vlim_y=(-3, 3),
            fs_w=10,
            w_digits=1):
    J = yi.shape[1]

    vmin_y, vmax_y = vlim_y
    # cm_y.set_bad(color='black')
    # cm_y.set_under(color='blue')
    # cm_y.set_over(color='red')

    # gs = gridspec.GridSpec(1, 2, width_ratios=[2, 5])
    gs = gridspec.GridSpec(2, 1, height_ratios=[5, 2])

    # Plot y
    lami_new, counts = relabel_lam(lami_est, wi_mean)
    counts_cumsum = np.cumsum(counts)
    yi_sorted = yi[np.argsort(lami_new), :]

    plt.subplot(gs[0])
    im = plt.imshow(yi_sorted,
                    aspect='auto',
                    vmin=vmin_y,
                    vmax=vmax_y,
                    cmap=cm_y)
    for c in counts_cumsum[:-1]:
        plt.axhline(c, color='yellow')
    plt.xticks(rotation=rotation, ha=ha)
    if len(markernames) == 0:
        plt.xticks(np.arange(J), np.arange(J) + 1)
    else:
        plt.xticks(np.arange(J), markernames)

    ax = plt.gca()
    ax_divider = make_axes_locatable(ax)
    cax = ax_divider.append_axes("top", size="7%", pad="2%")
    cax.xaxis.set_ticks_position("top")
    colorbar(im, cax=cax, orientation="horizontal")

    # Plot Z
    k_ord = wi_mean.argsort()
    z_cols = []

    for k in k_ord.tolist():
        if wi_mean[k] > w_thresh:
            z_cols.append(k)

    z_cols = np.array(z_cols)
    Z_hat = Z_mean[:, z_cols].T

    plt.subplot(gs[1])
    im = plt.imshow(Z_hat, aspect='auto', vmin=0, vmax=1, cmap=cm_greys)
    ax = plt.gca()
    plt.xticks([])
    plt.yticks(np.arange(len(z_cols)), z_cols + 1, fontsize=fs_w)
    add_gridlines_Z(Z_hat)
    plt.colorbar(orientation='horizontal', pad=.05)

    # add wi_mean on right side
    K = z_cols.shape[0]
    ax2 = ax.twinx()
    ax2.set_yticks(range(K))
    w_perc = wi_mean[z_cols]
    w_perc = [str((wp * 100).round(w_digits)) + '%' for wp in w_perc]
    plt.yticks((K - 1) / K * np.arange(K) + .5, w_perc[::-1], fontsize=fs_w)
    plt.yticks()
    ax2.tick_params(length=0)

    fig = plt.gcf()
    fig.subplots_adjust(hspace=0.2)