def write_gene_pairs_names_in_each_cluster(run_name, gene_pair_indexs,
                                           gene_pair_names,
                                           cluster_labels_of_gene_pairs):
    output_dir = os.path.join(FIGURE_DIR, run_name)
    mkdirs([output_dir])
    cluster_labels_of_gene_pairs = np.array(cluster_labels_of_gene_pairs)
    gene_pairs = np.array(gene_pair_names)
    # cluster_mapping = {cid: cluster  for cid, cluster in enumerate(np.unique(cluster_labels_of_gene_pairs))}
    Ncluster = len(np.unique(cluster_labels_of_gene_pairs))
    for cluster_id in range(Ncluster):
        gene_pairs_in_this_cluster = gene_pairs[cluster_labels_of_gene_pairs ==
                                                cluster_id]
        print(gene_pairs_in_this_cluster)
        gene_pairs_indexs_in_this_cluster = gene_pair_indexs[
            cluster_labels_of_gene_pairs == cluster_id]
        print(gene_pairs_indexs_in_this_cluster)
        out_arr = np.array([[
            gene_pairs_indexs_in_this_cluster[idx],
            gene_pair_name.split("_")[0],
            gene_pair_name.split("_")[1]
        ] for idx, gene_pair_name in enumerate(gene_pairs_in_this_cluster)])
        outdir = os.path.join(output_dir, "cluster_%d" % (cluster_id + 1))
        mkdirs([outdir])
        output_fp = os.path.join(
            outdir, "gene_pair_names_of_cluster_%d.csv" % (cluster_id + 1))
        np.savetxt(output_fp, out_arr[:, :], delimiter='\n', fmt="%s,%s,%s")
def get_unique_gene_pairs_names_in_each_cluster(run_name, gene_pairs,
                                                cluster_labels_of_gene_pairs):
    output_dir = os.path.join(GENE_PAIR_NAME_DATA, run_name)
    #output_dir = os.path.join(FIGURE_DIR, run_name)
    mkdirs([output_dir])

    unique_gene_names = []
    cluster_labels_of_gene_pairs = np.array(cluster_labels_of_gene_pairs)
    gene_pairs = np.array(gene_pairs)
    n_cluster = len(np.unique(cluster_labels_of_gene_pairs))
    for cluster_id in range(n_cluster):
        gene_pairs_in_this_cluster = gene_pairs[cluster_labels_of_gene_pairs ==
                                                cluster_id]
        all_gene_names_in_this_cluster = []
        for gene_pair_name in gene_pairs_in_this_cluster:
            for gene_name in gene_pair_name.split("_"):
                all_gene_names_in_this_cluster.append(gene_name)
        unique_gene_names_in_this_cluster = np.unique(
            all_gene_names_in_this_cluster)
        unique_gene_names.append(unique_gene_names_in_this_cluster)
        output_fp = os.path.join(output_dir, "%d.csv" % (cluster_id + 1))

        np.savetxt(output_fp,
                   np.array(unique_gene_names_in_this_cluster)[:],
                   delimiter='\n',
                   fmt="%s")
def fig8_plot_selected_cluster_trajectories(selected_clusters,
                                            cluster_labels,
                                            traj_lst,
                                            run_name,
                                            N_SAMPLE=4,
                                            fig_format='png',
                                            cmap="RdBu",
                                            color_palette=None):
    selected_clusters = [item - 1 for item in selected_clusters]
    N_CLUSTER = len(selected_clusters)
    plt.rc('axes', linewidth=2., labelsize=10.,
           labelpad=6.)  # length for x tick and padding
    plt.rc('xtick.major', size=6, pad=3)  # length for x tick and padding
    plt.rc('ytick.major', size=6, pad=3)  # length for y tick and padding
    plt.rc('lines', mew=5, lw=4)  # line 'marker edge width' and thickness
    plt.rc('ytick', labelsize=8)  # ytick label size
    plt.rc('xtick', labelsize=8)  # xtick label
    plt.rc('figure', dpi=300)  # Sets rendering resolution to 300
    plt.rc('lines',
           markersize=3.1)  # Sets marker size for scatter to reasonable size

    c_arr = np.array([(time_point + 1.) / 10. for time_point in range(10)])
    for cid, cluster in enumerate(selected_clusters):
        fig = plt.figure(1000 + cluster)
        trajectories = np.array(traj_lst)[np.array(cluster_labels) == cluster]
        # n_sample = N_SAMPLEx
        # sampled_indexs = np.random.choice(np.arange(trajectories.shape[0]), min(trajectories.shape[0], n_sample),
        #                                   replace=False)
        # sampled_trajectories = trajectories[sampled_indexs]

        for traj in trajectories:
            colorline(plt.gca(), traj[:, 0], traj[:, 1], c_arr, cmap=cmap)
        plt.xlim(-17.14, 65)
        plt.ylim(-22.23, 40)
        plt.xlabel("PCA Component 1")
        plt.ylabel("PCA Component 2")
        #         plt.title(titles_of_each_cluster[cid])
        fig_dir = os.path.join(FIGURE_DIR, run_name)
        mkdirs([fig_dir])
        fig_fp = os.path.join(
            fig_dir, "selected_cluster_%d.%s" % ((cluster + 1), fig_format))
        plt.savefig(fig_fp,
                    transparent=True,
                    bbox_inches='tight',
                    pad_inches=0.1,
                    format='png')
def plot_hierarchical_cluster(df,
                              linkage,
                              color_palette,
                              distance_threshold,
                              label_arr,
                              link_cols,
                              run_name,
                              fig_format="png",
                              label_position=-400):

    n_cluster = len(np.unique(label_arr))
    row_colors = df.cluster_label.map(color_palette)
    fig_dir = os.path.join(FIGURE_DIR, run_name)
    mkdirs([fig_dir])
    #     cm = sns.clustermap(df, method="ward", col_cluster=True, col_colors=row_colors,  yticklabels=True, figsize=(35, 35))
    #     fig_fp1 = os.path.join(fig_dir, "hierarchical_cluster_%d.%s" % (n_cluster, fig_format))
    #     cm.savefig(fig_fp1, dpi=200)

    fig2 = plt.figure(2, figsize=(30, 30))
    R = dendrogram(linkage,
                   no_labels=True,
                   leaf_rotation=90,
                   orientation="top",
                   leaf_font_size=8,
                   distance_sort='ascending',
                   color_threshold=distance_threshold,
                   above_threshold_color="black",
                   link_color_func=lambda x: link_cols[x])
    prev_sum = 0
    for lbl_id in range(n_cluster):
        label_tmp = label_arr[label_arr == lbl_id].shape[0] * 10
        coord = prev_sum + label_tmp * 0.3
        prev_sum += label_tmp
        plt.text(coord,
                 label_position,
                 "cluster %d" % (lbl_id + 1),
                 rotation=45,
                 backgroundcolor=color_palette[lbl_id])
    plt.ylim([label_position - 20, linkage[-1, -2]])
    fig_fp2 = os.path.join(
        fig_dir,
        "cluster_dendrogram_%d_with_leafs.%s" % (n_cluster, fig_format))
    fig2.savefig(fig_fp2, dpi=200)
def plot_heatmap_serie_of_each_cluster(data_dict,
                                       N_CLUSTER,
                                       cluster_lst,
                                       passed_pair_indexs,
                                       run_name,
                                       fig_format="png",
                                       TARGET_CLUSTER_IDs=None):
    '''
    Plots given trajectories with a color that is specific for every trajectory's own cluster index.
    Outlier trajectories which are specified with -1 in `cluster_lst` are plotted dashed with black color
    '''
    print(N_CLUSTER)
    data_dict = data_dict[0]
    N_COL = 10
    Max_NROW = 5
    TICKS = range(0, 21, 5)
    N_SUBFIG_PER_FIG = Max_NROW * N_COL
    cluster_lst = np.array(cluster_lst)
    fig_dir = os.path.join(FIGURE_DIR, run_name)
    mkdirs([fig_dir])
    if TARGET_CLUSTER_IDs:
        cluster_ids = TARGET_CLUSTER_IDs
    else:
        cluster_ids = range(N_CLUSTER)
    for cluster in cluster_ids:
        gene_pair_indexs = passed_pair_indexs[cluster_lst == cluster]
        n_gene_pairs_in_cluster = len(gene_pair_indexs)
        NFIG = int(math.ceil(float(n_gene_pairs_in_cluster) / Max_NROW))
        sub_fig_dir = os.path.join(fig_dir, "cluster_%d" % (cluster + 1))
        mkdirs([sub_fig_dir])
        for i in range(NFIG):
            if NFIG > 1:
                fig_fp = os.path.join(
                    sub_fig_dir,
                    "cluster_%d_%d.%s" % (cluster + 1, i, fig_format))
            else:
                fig_fp = os.path.join(
                    sub_fig_dir, "cluster_%d.%s" % (cluster + 1, fig_format))
            base_index = i * N_SUBFIG_PER_FIG
            N_remaining_files = n_gene_pairs_in_cluster * N_COL - base_index
            N_ROW = int(math.ceil(
                float(N_remaining_files) /
                N_COL)) if N_remaining_files <= N_SUBFIG_PER_FIG else Max_NROW

            fig, axs = plt.subplots(N_ROW,
                                    N_COL,
                                    figsize=(N_COL * EACH_SUB_FIG_SIZE,
                                             N_ROW * EACH_SUB_FIG_SIZE))
            SUB_FIG_RANGE = N_SUBFIG_PER_FIG if N_remaining_files > N_SUBFIG_PER_FIG else N_remaining_files
            plt.set_cmap('viridis_r')
            for j in range(SUB_FIG_RANGE):
                row = j // N_COL
                col = j % N_COL
                if N_ROW == 1:
                    ax = axs[col]
                else:
                    ax = axs[row][col]
                gene_pair_id = gene_pair_indexs[i * Max_NROW + row] + 1
                prob2d_array = data_dict[gene_pair_id][
                    'prob2d']  # shape 21* 21
                prob2d = prob2d_array[col]
                q_potential = -np.log(np.abs(prob2d))
                cax = ax.pcolormesh(q_potential, vmin=3, vmax=14)
                ax.set_yticks(TICKS)
                ax.set_xticks(TICKS)
                if row == 0:
                    ax.set_title("Stage %d" % Stages[col])
                if col == 0:
                    ax.set_ylabel(data_dict[gene_pair_id]['pair_name'])
                if j == 0:
                    fig.colorbar(cax, ax=ax)
            plt.savefig(fig_fp, dpi=200)
            print("cluster %d" % (cluster + 1))
def plot_cluster(traj_lst,
                 cluster_lst,
                 run_name,
                 fig_format="png",
                 color_palette=None,
                 log_transformed=True,
                 cmap="gist_rainbow"):
    '''
    Plots given trajectories with a color that is specific for every trajectory's own cluster index.
    Outlier trajectories which are specified with -1 in `cluster_lst` are plotted dashed with black color
    '''
    if log_transformed:
        X_MAX_LIM = 65
        Y_MAX_LIM = 40
        X_MIN_LIM = -17.14
        Y_MIN_LIM = -22.23
    else:
        X_MAX_LIM = 0.6
        Y_MAX_LIM = 0.15
        X_MIN_LIM = -0.2
        Y_MIN_LIM = -0.1
    N_CLUSTER = len(np.unique(cluster_lst))
    cluster_mapping = {
        cluster: cid
        for cid, cluster in enumerate(np.unique(cluster_lst))
    }
    N_COL = 4
    N_ROW = int(math.ceil(float(N_CLUSTER) / N_COL))
    c_arr = np.array([(time_point + 1.) / 10. for time_point in range(10)])
    traj_lst = np.array(traj_lst)
    cluster_lst = np.array(cluster_lst)
    fig, axs = plt.subplots(N_ROW,
                            N_COL,
                            figsize=(N_COL * EACH_SUB_FIG_SIZE,
                                     N_ROW * EACH_SUB_FIG_SIZE))
    for index, (traj, cluster) in enumerate(zip(traj_lst, cluster_lst)):
        row = cluster_mapping[cluster] // N_COL
        col = cluster_mapping[cluster] % N_COL
        if N_ROW > 1:
            ax = axs[row][col]
        else:
            ax = axs[col]
        colorline(ax, traj[:, 0], traj[:, 1], c_arr, cmap=cmap)
        ax.set_xlim(X_MIN_LIM, X_MAX_LIM)
        ax.set_ylim(Y_MIN_LIM, Y_MAX_LIM)
        if row == N_ROW - 1:
            ax.set_xlabel("PCA Component 1")
        if col == 0:
            ax.set_ylabel("PCA Component 2")
        if color_palette:
            ax.set_title(
                "cluster %d" % (cluster_mapping[cluster] + 1),
                backgroundcolor=color_palette[cluster_mapping[cluster]])
        else:
            ax.set_title("cluster %d" % (cluster_mapping[cluster] + 1))

    fig_dir = os.path.join(FIGURE_DIR, run_name)
    mkdirs([fig_dir])
    fig_fp = os.path.join(
        fig_dir, "trajactory_clusters_%d.%s" % (N_CLUSTER, fig_format))
    plt.savefig(fig_fp,
                transparent=True,
                bbox_inches='tight',
                pad_inches=0.1,
                format='png')
    #plt.show()
    return [cluster_mapping[cluster] for cluster in cluster_lst]
from matplotlib import collections as mcoll
from Util import mkdirs
#%matplotlib inline

# Utility Functions
color_lst = plt.rcParams['axes.prop_cycle'].by_key()['color']
color_lst.extend([
    'firebrick', 'olive', 'indigo', 'khaki', 'teal', 'saddlebrown', 'skyblue',
    'coral', 'darkorange', 'lime', 'darkorchid', 'dimgray'
])
DATA_DIR = os.path.join(os.path.abspath(os.curdir), 'DATA')
PICKLE_DATA = os.path.join(DATA_DIR, "pickle_data")
NPY_DATA = os.path.join(DATA_DIR, "npy_data")
GENE_PAIR_NAME_DATA = os.path.join(DATA_DIR, "gene_pair_names")
FIGURE_DIR = os.path.join(os.path.abspath(os.curdir), 'Figures')
mkdirs([PICKLE_DATA, NPY_DATA, FIGURE_DIR, GENE_PAIR_NAME_DATA])
Stages = [8, 10, 11, 12, 13, 14, 16, 18, 20, 22]
EACH_SUB_FIG_SIZE = 5
FIGURE_FORMAT = "png"


def make_segments(x, y):
    """
    Create list of line segments from x and y coordinates, in the correct format
    for LineCollection: an array of the form numlines x (points per line) x 2 (x
    and y) array
    """

    points = np.array([x, y]).T.reshape(-1, 1, 2)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)
    return segments