def plot_3dkeypoints(pred, skeleton):
    from mpl_toolkits.mplot3d import Axes3D
    import matplotlib.pyplot as plt
    import matplotlib as mpl
    from matplotlib.axes._axes import _log as matplotlib_axes_logger
    mpl.use("TKAgg", warn=False, force=True)
    matplotlib_axes_logger.setLevel('ERROR')

    def set_axes_radius(ax, origin, radius):
        ax.set_xlim3d([origin[0] - radius, origin[0] + radius])
        ax.set_ylim3d([origin[1] - radius, origin[1] + radius])
        ax.set_zlim3d([origin[2] - radius, origin[2] + radius])

    def set_axes_equal(ax):
        '''Make axes of 3D plot have equal scale so that spheres appear as spheres,
        cubes as cubes, etc..  This is one possible solution to Matplotlib's
        ax.set_aspect('equal') and ax.axis('equal') not working for 3D.

        Call this function before plt.show()

        Input
          ax: a matplotlib axis, e.g., as output from plt.gca().
        '''

        limits = np.array([
            ax.get_xlim3d(),
            ax.get_ylim3d(),
            ax.get_zlim3d(),
        ])

        origin = np.mean(limits, axis=1)
        radius = 0.5 * np.max(np.abs(limits[:, 1] - limits[:, 0]))
        set_axes_radius(ax, origin, radius)

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv.
    cmap = plt.get_cmap('rainbow')
    colors = [cmap(i) for i in np.linspace(0, 1, len(pred) + 2)]
    colors = [np.array((c[2], c[1], c[0])) for c in colors]

    for i, (x, y, z) in enumerate(pred):
        ax.scatter(x, y, z, c=np.array(colors[i]), marker='o')

    for (joint_a, joint_b) in skeleton:
        ax.plot([pred[joint_a][0], pred[joint_b][0]],
                [pred[joint_a][1], pred[joint_b][1]],
                [pred[joint_a][2], pred[joint_b][2]])

    ax.set_xlabel('X Label')
    ax.set_ylabel('Z Label')
    ax.set_zlabel('Y Label')
    # The model by default predicts x and z to be the the width-wise and length-wise directions of the original image
    # With depth added into the y axis. So this rotation shows the original 2D perspective (i.e. the 2D keypoints)
    ax.view_init(azim=-90, elev=-90)
    ax.legend()
    set_axes_equal(ax)
    plt.show()
Exemple #2
0
def plot_planning_boxes(data, plot_var_list, category, save=True):
    """Boxplot planning time, with hue according to the specified category"""

    if data.empty:
        return
    # In order to get the right legend on the plot, first make an empty plot with colored lines
    plt.figure()
    palette = []
    for plot_vars in plot_var_list:
        plt.plot(0,
                 0,
                 c=plot_vars.color,
                 label=plot_vars._asdict()[category],
                 lw=3)
        palette.append(plot_vars.color)
    plt.legend()
    handles, labels = plt.gca().get_legend_handles_labels()
    plt.show()
    plt.close()

    # Now we make the actual plot, since seaborn's catplot doesn't accept an 'ax' argument
    plt.rcParams.update({
        'font.size': cfg.FONTSIZE,
        'figure.figsize': cfg.FIGSIZE
    })
    matplotlib_axes_logger.setLevel('ERROR')
    catplot = sns.catplot(data=data.query('n_errors==0'),
                          y=category,
                          x='transitions',
                          kind='boxen',
                          palette=reversed(palette),
                          orient='h',
                          legend='True',
                          showfliers=False)
    catplot.despine(right=False, top=False)
    plt.ylabel('Macro-action type')
    plt.gcf().set_size_inches(*cfg.FIGSIZE)
    plt.tight_layout()
    plt.xlim(cfg.XLIM)
    ax = plt.gca()
    ax.invert_yaxis()
    ax.set_yticklabels([])
    ax.xaxis.set_major_locator(ticker.MultipleLocator(plot_vars.tick_size))
    ax.set_xlabel('Generated states' + autoscale_xticks(ax, dtype=int))
    plt.tight_layout()
    plt.subplots_adjust(top=.96,
                        bottom=.19,
                        right=.95,
                        left=0.09,
                        hspace=0,
                        wspace=0)
    ax.legend(handles, labels, loc='lower right')
    if save:
        plt.gcf().savefig('results/plots/{}/{}_planning_time_by_{}.png'.format(
            cfg.DIR, cfg.NAME, category),
                          dpi=100)
    plt.show()
Exemple #3
0
def vis_3d_skeleton(kpt_3d, kpt_3d_vis, kps_lines, filename=None):

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv.
    cmap = plt.get_cmap('rainbow')
    colors = [cmap(i) for i in np.linspace(0, 1, len(kps_lines) + 2)]
    colors = [np.array((c[2], c[1], c[0])) for c in colors]
    from matplotlib.axes._axes import _log as matplotlib_axes_logger
    matplotlib_axes_logger.setLevel('ERROR')
    for l in range(len(kps_lines)):
        i1 = kps_lines[l][0]
        i2 = kps_lines[l][1]
        x = np.array([kpt_3d[i1, 0], kpt_3d[i2, 0]])
        y = np.array([kpt_3d[i1, 1], kpt_3d[i2, 1]])
        z = np.array([kpt_3d[i1, 2], kpt_3d[i2, 2]])

        if kpt_3d_vis[i1, 0] > 0 and kpt_3d_vis[i2, 0] > 0:
            ax.plot(x, z, -y, c=colors[l], linewidth=2)
        if kpt_3d_vis[i1, 0] > 0:
            ax.scatter(kpt_3d[i1, 0],
                       kpt_3d[i1, 2],
                       -kpt_3d[i1, 1],
                       c=colors[l],
                       marker='o')
        if kpt_3d_vis[i2, 0] > 0:
            ax.scatter(kpt_3d[i2, 0],
                       kpt_3d[i2, 2],
                       -kpt_3d[i2, 1],
                       c=colors[l],
                       marker='o')

    x_r = np.array([0, cfg.input_shape[1]], dtype=np.float32)
    y_r = np.array([0, cfg.input_shape[0]], dtype=np.float32)
    z_r = np.array([0, 1], dtype=np.float32)

    if filename is None:
        ax.set_title('3D vis')
    else:
        ax.set_title(filename)

    ax.set_xlabel('X Label')
    ax.set_ylabel('Z Label')
    ax.set_zlabel('Y Label')
    #ax.set_xlim([0,cfg.input_shape[1]])
    #ax.set_ylim([0,1])
    #ax.set_zlim([-cfg.input_shape[0],0])
    ax.legend()

    plt.show()
    cv2.waitKey(0)
Exemple #4
0
def plot_purity(mets, omgp=False, save_me=False):
    base_size = 6
    plt.rc("text", usetex=True)
    matplotlib.rcParams.update({"font.size": 14})
    matplotlib.rcParams["figure.figsize"] = [base_size,
                                             0.4 * base_size]  # / sc.golden]
    matplotlib_axes_logger.setLevel("ERROR")
    sns.set_style("ticks", {"grid.linestyle": "--"})

    current_palette = sns.color_palette("deep")
    fig, ax = plt.subplots(figsize=matplotlib.rcParams["figure.figsize"])
    ax.grid(True)

    met_names = ["Purity"]

    mean = np.vstack(mets).mean(axis=0)
    std = np.vstack(mets).std(axis=0)
    xx = np.arange(len(mean))

    plt.fill_between(xx,
                     mean - 2 * std,
                     mean + 2 * std,
                     alpha=0.2,
                     color=current_palette[0])
    plt.plot(xx, mean, lw=2, alpha=1, c=current_palette[0], label=met_names[0])
    ax.set_xlim(0, len(mean) - 1)

    if omgp:
        loc = plticker.MultipleLocator(base=20.0)
    else:
        loc = plticker.MultipleLocator(
            base=2.0)  # this locator puts ticks at regular intervals

    ax.set_ylim(0.4, 1.0)
    ax.xaxis.set_major_locator(loc)
    ax.set_ylabel("Purity $[-]$")

    if omgp:
        ax.set_xlabel("Iterations of the marginalised variational bound")
    else:
        ax.set_xlabel("Number of Boltzmann updates")

    if save_me:
        now = datetime.datetime.now()
        fig.savefig(
            "../illustrations/purity_data_association_results-" +
            now.strftime("%Y-%m-%d-%H:%M") + ".pdf",
            bbox_inches="tight",
        )

    plt.show()
Exemple #5
0
def plot_raw_gp_data(X,
                     Y,
                     with_labels=False,
                     gt_allocation=None,
                     save_me=False):

    plt.style.use("default")

    base_size = 5
    plt.rc("text", usetex=True)
    matplotlib.rcParams.update({"font.size": 16})
    matplotlib.rcParams["figure.figsize"] = [base_size, base_size / sc.golden]
    matplotlib_axes_logger.setLevel("ERROR")
    sns.set_style("ticks", {"grid.linestyle": "--"})

    assert type(X) is np.ndarray
    assert type(Y) is np.ndarray

    fig, ax = plt.subplots(figsize=matplotlib.rcParams["figure.figsize"])

    if with_labels:
        current_palette = sns.color_palette("muted")
        assert gt_allocation is not None
        for i, j in enumerate(gt_allocation):
            ax.scatter(X[i], Y[i], c=current_palette[int(j)], alpha=0.75)

        # Bespoke legend
        pop1 = mpatches.Patch(color=current_palette[0], label="Population 1")
        pop2 = mpatches.Patch(color=current_palette[1], label="Population 2")
        ax.legend(handles=[pop1, pop2], framealpha=1)

    else:
        ax.scatter(X, Y, alpha=0.5, c="k", marker="x")

    ax.set_ylabel("$y$")
    ax.set_xlabel("$x$")
    ax.grid(True)

    if save_me:
        now = datetime.datetime.now()
        fig.savefig("../illustrations/gp_data_association-" +
                    now.strftime("%Y-%m-%d-%H:%M") + ".pdf",
                    bbox_inches="tight")

    plt.show()
Exemple #6
0
def logReg_visualization(X_set, y_set, set_type):
    from matplotlib.axes._axes import _log as matplotlib_axes_logger
    matplotlib_axes_logger.setLevel('ERROR')

    from matplotlib.colors import ListedColormap
    X1, X2 = np.meshgrid(np.arange(start = X_set[:, 0].min() - 1, stop = X_set[:, 0].max() + 1, step = 0.01), np.arange(start = X_set[:, 1].min() - 1, stop = X_set[:, 1].max() + 1, step = 0.01))
    plt.contourf(X1, X2, classifier.predict(np.array([X1.ravel(), X2.ravel()]).T).reshape(X1.shape),alpha = 0.75, cmap = ListedColormap(('red', 'green')))
    plt.xlim(X1.min(), X1.max())
    plt.ylim(X2.min(), X2.max())

    for i, j in enumerate(np.unique(y_set)):
        plt.scatter(
            X_set[y_set == j, 0], X_set[y_set == j, 1],
            c = ListedColormap(('red', 'green'))(i), label = j
        )
    plt.title('Logistic Regression ' + str(set_type))
    plt.xlabel('Predictor')
    plt.ylabel('Target')
    plt.legend()
    plt.show()
def plot_2d_scatter(x, y, value, xlabel='x', ylabel='y', title='', filename='a', show=1, save=0):
    import matplotlib.pyplot as plt
    from matplotlib.axes._axes import _log as matplotlib_axes_logger
    matplotlib_axes_logger.setLevel('ERROR') 
    fig = plt.figure()
    ax = fig.add_subplot(111)
    plt.subplots_adjust(bottom=0.2, right=0.8, left=0.2) 
    for i in range(np.array(x).shape[0]):
        ax.scatter(x[i], y[i], marker='o', s=100*value[i], c=(1,0,0))
    ax.set_title(title, fontsize=20, fontfamily='Times New Roman')
    ax.set_xlabel(xlabel, fontsize=20, fontfamily='Times New Roman') 
    ax.set_ylabel(ylabel, fontsize=20, fontfamily='Times New Roman') 
    ax.tick_params(labelsize=15)
    labels = ax.get_xticklabels() + ax.get_yticklabels() 
    [label.set_fontname('Times New Roman') for label in labels]
    if save == 1:
        plt.savefig(filename+'.jpg', dpi=300) 
    if show == 1:
        plt.show()
    plt.close('all')
Exemple #8
0
def plot_decision_regions(X, y, classifier, resolution=0.02):
    # setup marker generator and color map
    markers = ('s', 'x', 'o', '^', 'v')
    colors = ('red', 'blue', 'lightgreen', 'gray', 'cyan')
    # create a np array
    cmap = ListedColormap(colors[:len(np.unique(y))])

    # I considered using on of the predefined color gradient, but I got too flustered, so
    # Thanks to "Max Kleiner"on stackoverflow.com who recommended this work around, which lowers the error level
    # https://stackoverflow.com/questions/55109716/c-argument-looks-like-a-single-numeric-rgb-or-rgba-sequence
    matplotlib_axes_logger.setLevel('ERROR')

    # plot the decision surface
    x1_min, x1_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    x2_min, x2_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xx1, xx2 = np.meshgrid(np.arange(x1_min, x1_max, resolution),
                           np.arange(x2_min, x2_max, resolution))
    Z = classifier.predict(np.array([xx1.ravel(), xx2.ravel()]).T)
    Z = Z.reshape(xx1.shape)

    # creates the the shaded regions of the data
    plt.contourf(xx1, xx2, Z, alpha=0.4, cmap=cmap)

    # set min and max values of the given data set
    plt.xlim(xx1.min(), xx1.max())
    plt.ylim(xx2.min(), xx2.max())

    # plot class samples
    for idx, cl in enumerate(np.unique(y)):
        plt.scatter(x=X[y == cl, 0],
                    y=X[y == cl, 1],
                    alpha=0.8,
                    c=cmap(idx),
                    marker=markers[idx],
                    label=cl)

    # add in the legend, so the used knows what values are being compared
    plt.legend(loc='upper left')

    plt.show()
Exemple #9
0
class FigureService(Axes, Template, Postprocessing):
    fig, axes = None, None
    handles = []
    # DEBUG = False  # debug模式下,开启三秒展示,展示所有警告信息
    DEBUG = True  # debug模式下,开启三秒展示,展示所有警告信息
    if not DEBUG:
        from matplotlib.axes._axes import _log as matplotlib_axes_logger
        matplotlib_axes_logger.setLevel('ERROR')
    ALPHA = [1, 1, 1, 1, 1, 1]
    COLOR = [plt.get_cmap('tab20c').colors[i] for i in [0, 4, 8, 12, 16, 18]]
    MARKER = ['^', 'o', 's', '*', '+', 'D']
    MARKER_COLOR = [plt.get_cmap('tab20c').colors[i] for i in [1, 5, 8, 12, 16, 18]]

    def __init__(self, **kwargs):
        self.initFigure(num=kwargs.get("num", 1),
                        figsize=kwargs.get("figsize", (10, 8)))

    @classmethod
    def initFigure(cls, **kwargs):
        cls.fig = plt.figure(num=kwargs.get("num"),
                             figsize=kwargs.get("figsize"), )
        cls.axes = plt.subplot()
Exemple #10
0
def plot_pca(featZ,
             meta,
             group_by,
             n_dims=2,
             var_subset=None,
             control=None,
             saveDir=None,
             kde=False,
             PCs_to_keep=10,
             n_feats2print=10,
             sns_colour_palette="Set1",
             hypercolor=False,
             label_size=15,
             figsize=[9, 8],
             sub_adj={
                 'bottom': 0,
                 'left': 0,
                 'top': 1,
                 'right': 1
             },
             legend_loc='upper right',
             n_colours=20,
             **kwargs):
    """ Perform principal components analysis 
        - group_by : column in metadata to group by for plotting (colours) 
        - n_dims : number of principal component dimensions to plot (2 or 3)
        - var_subset : subset list of categorical names in featZ[group_by]
        - saveDir : directory to save PCA results
        - PCs_to_keep : number of PCs to project
        - n_feats2print : number of top features influencing PCs to store 
    """

    import numpy as np
    import pandas as pd
    import seaborn as sns
    from pathlib import Path
    from sklearn.decomposition import PCA
    from matplotlib import pyplot as plt
    from matplotlib import patches
    from matplotlib.axes._axes import _log as mpl_axes_logger
    from mpl_toolkits.mplot3d import Axes3D

    assert (featZ.index == meta.index).all()
    if var_subset is not None:
        assert all(
            [strain in meta[group_by].unique() for strain in var_subset])
    else:
        var_subset = list(meta[group_by].unique())

    # Perform PCA on extracted features
    print("\nPerforming Principal Components Analysis (PCA)...")

    # Fit the PCA model with the normalised data
    pca = PCA()  # OPTIONAL: pca = PCA(n_components=n_dims)
    pca.fit(featZ)

    # Plot summary data from PCA: explained variance (most important features)
    plt.ioff() if saveDir else plt.ion()
    important_feats, fig = pcainfo(pca=pca,
                                   zscores=featZ,
                                   PC=0,
                                   n_feats2print=n_feats2print)
    if saveDir:
        # Save plot of PCA explained variance
        pca_path = Path(saveDir) / 'PCA_explained.eps'
        pca_path.parent.mkdir(exist_ok=True, parents=True)
        plt.tight_layout()
        plt.savefig(pca_path, format='eps', dpi=300)

        # Save PCA important features list
        pca_feat_path = Path(saveDir) / 'PC_top{}_features.csv'.format(
            str(n_feats2print))
        important_feats.to_csv(pca_feat_path, index=False)
    else:
        plt.show()
        plt.pause(2)

    # Project data (zscores) onto PCs
    projected = pca.transform(featZ)  # A matrix is produced
    # NB: Could also have used pca.fit_transform() OR decomposition.TruncatedSVD().fit_transform()

    # Compute explained variance ratio of component axes
    ex_variance = np.var(
        projected, axis=0)  # PCA(n_components=n_dims).fit_transform(featZ)
    ex_variance_ratio = ex_variance / np.sum(ex_variance)

    # Store the results for first few PCs in dataframe
    projected_df = pd.DataFrame(
        data=projected[:, :PCs_to_keep],
        columns=['PC' + str(n + 1) for n in range(PCs_to_keep)],
        index=featZ.index)

    # Create colour palette for plot loop
    if len(var_subset) > n_colours:
        if not control:
            raise IOError(
                'Too many groups for plot color mapping!' +
                'Please provide a control group or subset of groups (n<20) to color plot'
            )
        elif hypercolor:
            # Recycle palette colours to make up to number of groups
            print(
                "\nWARNING: Multiple groups plotted with the same colour (too many groups)"
            )
            colour_labels = sns.color_palette(sns_colour_palette,
                                              len(var_subset))
            palette = dict(zip(var_subset, colour_labels))
        else:
            # Colour the control and make the rest gray
            palette = {
                var: "blue" if var == control else "darkgray"
                for var in meta[group_by].unique()
            }

    elif len(var_subset) <= n_colours:
        # Colour strains of interest
        colour_labels = sns.color_palette(sns_colour_palette, len(var_subset))
        palette = dict(zip(var_subset, colour_labels))

        if set(var_subset) != set(meta[group_by].unique()):
            # Make the rest gray
            gray_strains = [
                var for var in meta[group_by].unique() if var not in var_subset
            ]
            gray_palette = {
                var: 'darkgray'
                for var in gray_strains if not pd.isna(var)
            }
            palette.update(gray_palette)

    plt.close('all')
    plt.style.use(CUSTOM_STYLE)
    plt.rcParams['legend.handletextpad'] = 0.5
    sns.set_style('ticks')
    if n_dims == 2:
        fig, ax = plt.subplots(figsize=figsize)

        grouped = meta.join(projected_df).groupby(group_by)
        #for key, group in grouped:
        for key in list(palette.keys())[::-1]:
            if pd.isna(key):
                continue
            group = grouped.get_group(key)
            group.plot(ax=ax,
                       kind='scatter',
                       x='PC1',
                       y='PC2',
                       label=key,
                       color=palette[key],
                       **kwargs)

        if len(var_subset) <= n_colours and kde:
            sns.kdeplot(
                x='PC1',
                y='PC2',
                data=meta.join(projected_df),
                hue=group_by,
                palette=palette,
                fill=True,  # fill kde plot with plain colour by group
                alpha=0.25,
                thresh=0.05,
                levels=2,
                bw_method="scott",
                bw_adjust=1)

        ax.set_xlabel('Principal Component 1 (%.1f%%)' %
                      (ex_variance_ratio[0] * 100),
                      fontsize=20,
                      labelpad=12)
        ax.set_ylabel('Principal Component 2 (%.1f%%)' %
                      (ex_variance_ratio[1] * 100),
                      fontsize=20,
                      labelpad=12)
        #ax.set_title("PCA by '{}'".format(group_by), fontsize=20)

        # Construct legend from custom handles
        if len(var_subset) <= n_colours:
            plt.tight_layout()  # rect=[0, 0, 1, 1]
            handles = []
            for key in var_subset:
                handles.append(patches.Patch(color=palette[key], label=key))
            # add 'other' for all other strains (in gray)
            if set(var_subset) != set(meta[group_by].unique()) and len(
                    gray_palette.keys()) > 0:
                other_patch = patches.Patch(color='darkgray', label='other')
                handles.append(other_patch)
            ax.legend(handles=handles,
                      frameon=True,
                      loc=legend_loc,
                      fontsize=label_size,
                      handletextpad=0.2)
        elif hypercolor:
            ax.get_legend().remove()
        else:
            control_patch = patches.Patch(color='blue', label=control)
            other_patch = patches.Patch(color='darkgray', label='other')
            ax.legend(handles=[control_patch, other_patch])
        ax.grid(False)

        # adjust subplots for figure legend
        plt.subplots_adjust(top=sub_adj['top'],
                            bottom=sub_adj['bottom'],
                            left=sub_adj['left'],
                            right=sub_adj['right'])
    elif n_dims == 3:
        fig = plt.figure(figsize=[10, 10])
        mpl_axes_logger.setLevel(
            'ERROR')  # Work-around for 3D plot colour warnings
        ax = Axes3D(fig)  # ax = fig.add_subplot(111, projection='3d')

        for g_var in var_subset:
            g_var_projected_df = projected_df[meta[group_by] == g_var]
            ax.scatter(xs=g_var_projected_df['PC1'],
                       ys=g_var_projected_df['PC2'],
                       zs=g_var_projected_df['PC3'],
                       zdir='z',
                       s=30,
                       c=palette[g_var],
                       depthshade=False)
        ax.set_xlabel('Principal Component 1 (%.1f%%)' %
                      (ex_variance_ratio[0] * 100),
                      fontsize=15,
                      labelpad=12)
        ax.set_ylabel('Principal Component 2 (%.1f%%)' %
                      (ex_variance_ratio[1] * 100),
                      fontsize=15,
                      labelpad=12)
        ax.set_zlabel('Principal Component 3 (%.1f%%)' %
                      (ex_variance_ratio[2] * 100),
                      fontsize=15,
                      labelpad=12)
        #ax.set_title("PCA by '{}'".format(group_by), fontsize=20)
        if len(var_subset) <= n_colours:
            ax.legend(var_subset, frameon=True, fontsize=12)
        ax.grid(False)
    else:
        raise ValueError("Value for 'n_dims' must be either 2 or 3")

    # Save PCA plot
    if saveDir:
        pca_path = Path(saveDir) / ('pca_by_{}'.format(group_by) +
                                    ('_colour' if hypercolor else '') +
                                    ('.png' if n_dims == 3 else '.pdf'))
        plt.savefig(pca_path,
                    format='png' if n_dims == 3 else 'pdf',
                    dpi=600 if n_dims == 3 else 300)  # rasterized=True
    else:
        # Rotate the axes and update plot
        if n_dims == 3:
            for angle in range(0, 360):
                ax.view_init(270, angle)
                plt.draw()
                plt.pause(0.0001)
        else:
            plt.show()

    return projected_df
Exemple #11
0
def reg_label(fixed, moving, label, atlaslab, coord, trace, dff, bind,
              meanimglist, reg_type, mode):
    #===============================================================================
    """
    This function builds a coordinate image from the pixel coordinates in ants space. 
    
    Inputs:
    ant_img (ants image): pre-registered, rotated ants image
    coords (np array): X x Y x Z coordinates of rotated cells
    to_del (np array): 1d vector of cell indeces that are outside image
    
    Returns:
    lab_img (ants image): ants image with each pixel labelled by cell number
    """

    import numpy as np
    import matplotlib.pyplot as plt
    import random
    from matplotlib import cm
    from matplotlib.axes._axes import _log as matplotlib_axes_logger
    matplotlib_axes_logger.setLevel('ERROR')

    if coord.shape[0] != trace.shape[0] or coord.shape[0] != dff.shape[
            0] or coord.shape[0] != bind.shape[0]:
        print(
            'Input files not the same shape - check you are loading the correct fish files'
        )
        return ()
    else:

        #Perform registration of fish image to atlas image
        warp_img = ants.registration(fixed, moving, type_of_transform=reg_type)

        #Inspect registration
        fishplot(fixed, warp_img['warpedmovout'], orient='axial', al=0.7)

        #Rotate coordinates
        rot_coords, to_del = rotate_coords(coord, meanimglist)

        #Build coordinate image from pixel coordinates in ants space
        coord_img = lab_Img(moving, rot_coords, to_del)

        #Apply transformation to coordinate image
        warp_coord_img = ants.apply_transforms(fixed,
                                               coord_img,
                                               warp_img['fwdtransforms'],
                                               interpolator='nearestNeighbor')

        #Map warped pixel coordinates to old pixel coordinates
        reg_coord = match_pix2cells(rot_coords, warp_coord_img)

        #Create final coordinate and trace arrays - remove cells outside of image
        loc = np.where(reg_coord[:, 0] == 0)[0]
        fin_coord = np.delete(reg_coord, loc, 0)
        fin_trace = np.delete(trace, loc, 0)
        fin_dff = np.delete(dff, loc, 0)
        fin_bind = np.delete(bind, loc, 0)

        #label coordinates
        new_x = fin_coord[:, 0] // 2
        new_y = fin_coord[:, 1] // 2
        new_z = fin_coord[:, 2] // 2
        lab_xyz = (np.column_stack((new_x, new_y, new_z))).astype(int)
        coarse_reg = list(range(lab_xyz.shape[0]))
        gran_reg = list(range(lab_xyz.shape[0]))

        #Label each cell according to atlas labels
        for i in range(len(gran_reg)):
            curr_val = int(
                label[lab_xyz[i][1], lab_xyz[i][0],
                      lab_xyz[i][2]])  #Transpose to match with label ants img
            gran_reg[i] = np.array(atlaslab[1])[curr_val]
            coarse_reg[i] = np.array(atlaslab[2])[curr_val]

        lab_coord = np.column_stack((fin_coord, gran_reg, coarse_reg))

        #Visualise ouputs
        if mode == 'check':
            print(meanimglist[0])
            print('Rotate suite2p coords onto pre-reg fish brain')
            plane_num = 6
            plt.figure(figsize=(10, 10))
            plt.matshow(moving[:, :, plane_num])
            ploc = np.where(rot_coords[:, 2] == plane_num)
            plt.scatter(rot_coords[:, 0][ploc],
                        rot_coords[:, 1][ploc],
                        s=0.8,
                        c='red',
                        alpha=1)
            plt.show()

            print(meanimglist[0])
            print(
                'Build coordinate image and stack in z over pre-reg fish brain'
            )
            pre_stackplot = np.zeros((coord_img[:, :, 0].shape))
            for i in range(10):
                pre_stackplot += coord_img[:, :, i]

            #Check that labelled image has been built correctly
            fig, axarr = plt.subplots(figsize=(5, 5))
            axarr.matshow(moving[:, :, plane_num])
            axarr.scatter(
                np.where(pre_stackplot > 0)[1],
                np.where(pre_stackplot > 0)[0],
                s=0.1,
                c='red'
            )  #Antsimage and suite2p coordinates are transpose of eachother - Transpose to match
            plt.show()

            print(meanimglist[0])
            print(
                'Plot newly warped cell coordinates over atlas (left) and warped fish image (right)'
            )

            #Check that all neurons are overlaid correctly over brain - postregistration
            xnum = 180
            curr_warped = warp_img['warpedmovout']
            fig, axarr = plt.subplots(1, 2, figsize=(10, 10))
            axarr[0].matshow(fixed[:, :, xnum])
            axarr[0].scatter(reg_coord[:, 0],
                             reg_coord[:, 1],
                             s=3,
                             alpha=0.2,
                             c='red')
            axarr[1].matshow(curr_warped[:, :, xnum])
            axarr[1].scatter(reg_coord[:, 0],
                             reg_coord[:, 1],
                             s=2,
                             alpha=0.1,
                             c='red')
            plt.show()

            print(meanimglist[0])
            print(
                'Plot newly warped cell coordinates over atlas/warped fish brain, by individual planes'
            )
            #Check neurons plane by plane
            curr_warped = warp_img['warpedmovout']
            curr_points = reg_coord.astype(int)
            znumlist = np.arange(100, 200, 20)
            xnumlist = np.arange(200, 300, 20)
            for num in range(len(znumlist)):
                fig, axarr = plt.subplots(1, 4, figsize=(15, 15))
                axarr[0].matshow(fixed[:, :, znumlist[num]])
                axarr[0].scatter(
                    curr_points[:, 0][curr_points[:, 2] == znumlist[num]],
                    curr_points[:, 1][curr_points[:, 2] == znumlist[num]],
                    s=2,
                    c='red')
                axarr[1].matshow(curr_warped[:, :, znumlist[num]])
                axarr[1].scatter(
                    curr_points[:, 0][curr_points[:, 2] == znumlist[num]],
                    curr_points[:, 1][curr_points[:, 2] == znumlist[num]],
                    s=1,
                    c='red')
                axarr[2].matshow(fixed[:, xnumlist[num], :])
                axarr[2].scatter(
                    curr_points[:, 2][curr_points[:, 0] == xnumlist[num]],
                    curr_points[:, 1][curr_points[:, 0] == xnumlist[num]],
                    s=3,
                    c='red')
                axarr[3].matshow(curr_warped[:, xnumlist[num], :])
                axarr[3].scatter(
                    curr_points[:, 2][curr_points[:, 0] == xnumlist[num]],
                    curr_points[:, 1][curr_points[:, 0] == xnumlist[num]],
                    s=2,
                    c='red')
                plt.show()

            print(meanimglist[0])
            print(
                'Plot same cells for pre-reg (left) and post-reg (right) coordinates - are cell positions retained?'
            )
            #Check that cell ids are correctly retained

            old_points = rot_coords  #coord

            n_cells = 10
            xnum = 150

            fig, axarr = plt.subplots(1, 2, figsize=(10, 7))
            axarr[0].scatter(old_points[:, 0],
                             old_points[:, 1],
                             s=2,
                             c='k',
                             alpha=0.1)
            axarr[1].matshow(fixed[:, :, xnum])
            axarr[1].scatter(reg_coord[:, 0],
                             reg_coord[:, 1],
                             s=2,
                             alpha=0.1,
                             c='k')
            for i in range(n_cells):
                colors = cm.Spectral(np.linspace(0, 1, n_cells))
                choose = random.randint(1, len(old_points) + 1)
                axarr[0].scatter(old_points[:, 0][choose],
                                 old_points[:, 1][choose],
                                 s=40,
                                 c=colors[i],
                                 alpha=1)
                axarr[1].scatter(reg_coord[:, 0][choose],
                                 reg_coord[:, 1][choose],
                                 s=20,
                                 c=colors[i],
                                 alpha=1)
            plt.show()

            print(meanimglist[0])
            print(
                'Plot cells from the pre-reg matrix and traces (green, left) and post-reg matrix and traces (orange, right) - are traces mapped onto correct cells?'
            )
            old_points = coord
            n_cells = 5
            for i in range(n_cells):
                choose = random.randint(1, len(old_points) + 1)
                if sum(choose == loc) == 0:
                    to_min = sum(choose >= loc)
                    old_val = choose
                    new_val = old_val - to_min
                    old_trace = trace[old_val]
                    new_trace = fin_trace[new_val]
                    print(old_val)
                    fig, axarr = plt.subplots(figsize=(8, 1))
                    plt.plot(old_trace, c='green')
                    plt.show()
                    print(new_val)
                    fig, axarr = plt.subplots(figsize=(8, 1))
                    plt.plot(new_trace, c='orangered')
                    plt.show()

                    fig, axarr = plt.subplots(1, 2, figsize=(10, 7))
                    axarr[0].scatter(old_points[:, 0],
                                     old_points[:, 1],
                                     s=2,
                                     c='k',
                                     alpha=0.1)
                    axarr[1].matshow(fixed[:, :, xnum])
                    axarr[1].scatter(fin_coord[:, 0],
                                     fin_coord[:, 1],
                                     s=2,
                                     alpha=0.1,
                                     c='k')
                    axarr[0].scatter(old_points[:, 0][old_val],
                                     old_points[:, 1][old_val],
                                     s=40,
                                     c='green',
                                     alpha=1)
                    axarr[1].scatter(fin_coord[:, 0][new_val],
                                     fin_coord[:, 1][new_val],
                                     s=20,
                                     c='orangered',
                                     alpha=1)
                    plt.show()

            #Check that all neurons are overlaid correctly over brain - postregistration
            xnum = 150
            curr_warped = warp_img['warpedmovout']
            fig, axarr = plt.subplots(1, 2, figsize=(10, 10))
            axarr[0].matshow(label[:, :, np.int(xnum / 2)])
            axarr[0].scatter(lab_xyz[:, 0],
                             lab_xyz[:, 1],
                             s=2,
                             alpha=0.2,
                             c='red')
            axarr[1].matshow(label[:, :, np.int(xnum / 2)])
            axarr[1].scatter(
                lab_xyz[:, 0][np.array(coarse_reg) == 'Telencephalon'],
                lab_xyz[:, 1][np.array(coarse_reg) == 'Telencephalon'],
                s=2,
                alpha=1,
                c='orange')
            axarr[1].scatter(
                lab_xyz[:, 0][np.array(coarse_reg) == 'Diencephalon'],
                lab_xyz[:, 1][np.array(coarse_reg) == 'Diencephalon'],
                s=2,
                alpha=1,
                c='green')
            axarr[1].scatter(lab_xyz[:, 0][np.array(coarse_reg) == 'Midbrain'],
                             lab_xyz[:, 1][np.array(coarse_reg) == 'Midbrain'],
                             s=2,
                             alpha=1,
                             c='cyan')
            axarr[1].scatter(lab_xyz[:,
                                     0][np.array(coarse_reg) == 'Hindbrain'],
                             lab_xyz[:,
                                     1][np.array(coarse_reg) == 'Hindbrain'],
                             s=2,
                             alpha=1,
                             c='violet')
            axarr[1].scatter(lab_xyz[:, 0][np.array(coarse_reg) == 'nan'],
                             lab_xyz[:, 1][np.array(coarse_reg) == 'nan'],
                             s=2,
                             alpha=1,
                             c='red')
            axarr[1].scatter(lab_xyz[:,
                                     0][np.array(coarse_reg) == 'Peripheral'],
                             lab_xyz[:,
                                     1][np.array(coarse_reg) == 'Peripheral'],
                             s=2,
                             alpha=1,
                             c='black')
            axarr[1].scatter(lab_xyz[:,
                                     0][np.array(coarse_reg) == 'Unspecified'],
                             lab_xyz[:,
                                     1][np.array(coarse_reg) == 'Unspecified'],
                             s=2,
                             alpha=1,
                             c='black')
            plt.show()

            #Check that all neurons are overlaid correctly over brain - postregistration
            xnum = 200
            curr_warped = warp_img['warpedmovout']
            fig, axarr = plt.subplots(1, 2, figsize=(10, 10))

            axarr[0].matshow(label[:, np.int(xnum / 2), :])
            axarr[0].scatter(lab_xyz[:, 2],
                             lab_xyz[:, 1],
                             s=2,
                             alpha=0.2,
                             c='red')
            axarr[1].matshow(label[:, np.int(xnum / 2), :])
            axarr[1].scatter(
                lab_xyz[:, 2][np.array(coarse_reg) == 'Telencephalon'],
                lab_xyz[:, 1][np.array(coarse_reg) == 'Telencephalon'],
                s=2,
                alpha=1,
                c='orange')
            axarr[1].scatter(
                lab_xyz[:, 2][np.array(coarse_reg) == 'Diencephalon'],
                lab_xyz[:, 1][np.array(coarse_reg) == 'Diencephalon'],
                s=2,
                alpha=1,
                c='green')
            axarr[1].scatter(lab_xyz[:, 2][np.array(coarse_reg) == 'Midbrain'],
                             lab_xyz[:, 1][np.array(coarse_reg) == 'Midbrain'],
                             s=2,
                             alpha=1,
                             c='cyan')
            axarr[1].scatter(lab_xyz[:,
                                     2][np.array(coarse_reg) == 'Hindbrain'],
                             lab_xyz[:,
                                     1][np.array(coarse_reg) == 'Hindbrain'],
                             s=2,
                             alpha=1,
                             c='violet')
            plt.show()

        return (fin_coord, lab_coord, fin_trace, fin_bind, fin_dff)
Exemple #12
0
def visualize_feature_importance(
        models,
        columns: Union[None, List[str]] = None,
        plot_type='bar',
        ax: Union[None, plt.Axes] = None,
        top_n: Union[None, int] = None,
        feature_extractor: Union[None, Callable[[BaseEstimator],
                                                np.ndarray]] = None,
        **plot_kwgs) -> Tuple[plt.Figure, plt.Axes, pd.DataFrame]:
    """
    plot feature importance from a learned Model

    Currently following model are supported.
        * scikit learn's
            * linear model
            * random forest
        * xgboost (sklearn interface)
        * lightgbm (sklearn interface)

    The `extract_importance` method is used to retrieve features from a model. See also.

    Args:
        models:
            list of trained models.
        columns:
            List of names of feature
        plot_type:
            importance plot style. if set as "bar", call seaborn.barplot and "boxend" calls seaborn.boxen plot.
            `"bar"` or `"boxen"`.
        top_n:
            When int is specified, plot the top n items
        ax:
            matplotlib plt.Axes obj. Create a new fig, ax is none.
        feature_extractor:
            It is an argument for plotting a feature for an unsupported model.
            If set, the feature-grabbing method is overridden.
            Must be a function that takes model as an argument and returns a np array.
        **plot_kwgs:
            plot extra kwrgs. pass to seaborn.boxenplot or barplot function.

    Returns:
        ax is None, return fig, ax, feature importance df
        else: return ax, feature importance df
    """

    # set matplotlib loglevel 'ERROR' (avoid 'c' argument looks like a single numeric RGB or RGBA sequence)
    from matplotlib.axes._axes import _log as matplotlib_axes_logger
    matplotlib_axes_logger.setLevel('ERROR')

    if feature_extractor is None: feature_extractor = extract_importance
    importance_df = pd.DataFrame()

    for i, model in enumerate(models):
        _df = pd.DataFrame()
        if isinstance(model, PrePostProcessModel):
            clf = model.fitted_model_
        else:
            clf = model

        importance = feature_extractor(clf)
        _df['feature_importance'] = np.array(importance).reshape(-1)
        _df['column'] = columns if columns is not None else range(len(_df))
        _df['fold'] = i + 1
        importance_df = pd.concat([importance_df, _df],
                                  axis=0,
                                  ignore_index=True)

    order = importance_df.groupby('column').sum()[[
        'feature_importance'
    ]].sort_values('feature_importance', ascending=False).index

    if isinstance(top_n, int):
        order = order[:top_n]

    if ax is None:
        h = max(len(order) * .2, 5)
        fig = plt.figure(figsize=(7, h))
        ax = fig.add_subplot(111)
    else:
        fig = None
    params = {
        'data': importance_df,
        'x': 'feature_importance',
        'y': 'column',
        'order': order,
        'ax': ax,
        'orient': 'h',
        'palette': 'viridis'
    }
    params.update(plot_kwgs)

    if plot_type == 'boxen':
        sns.boxenplot(**params)
    elif plot_type == 'bar':
        sns.barplot(**params)
    else:
        raise ValueError(
            'plot_type must be in boxen or bar. Actually, {}'.format(
                plot_type))
    ax.tick_params(axis='x', rotation=90)

    if fig:
        fig.tight_layout()
    return fig, ax, importance_df
Exemple #13
0
else: 
    ax.set_title('All features 2-Component PCA', fontsize=20)
plt.tight_layout(rect=[0.04, 0, 0.84, 0.96])
ax.legend(bacterial_strains, frameon=False, loc=(1, 0.1), fontsize=15)
ax.grid()

# Save scatterplot of first 2 PCs
plotpath = os.path.join(plotroot, 'PCA_2PCs.eps')
savefig(plotpath, tight_layout=True, tellme=True, saveFormat='eps')

plt.show(); plt.pause(2)

#%% 3D Plot - first 3 PCs - ALL FOODS

# Work-around for 3D plot colour warnings
mpl_axes_logger.setLevel('ERROR')

# Plot first 3 principal components
plt.close('all')
plt.rc('xtick',labelsize=12)
plt.rc('ytick',labelsize=12)
fig = plt.figure(figsize=[10,10])
ax = Axes3D(fig) # ax = fig.add_subplot(111, projection='3d')

# Create colour palette for plot loop
palette = itertools.cycle(sns.color_palette("gist_rainbow", len(bacterial_strains)))

for food in bacterial_strains:
    food_projected_df = projected_df[projected_df['food_type'].str.upper()==food]
    ax.scatter(xs=food_projected_df['PC1'], ys=food_projected_df['PC2'], zs=food_projected_df['PC3'],\
               zdir='z', s=50, c=next(palette), depthshade=depthshade)
Exemple #14
0
    def print_net(self):
        import matplotlib.pyplot as plt
        from matplotlib.axes._axes import _log as matplotlib_axes_logger
        matplotlib_axes_logger.setLevel('ERROR')
        import networkx as nx
        if self.graph is None:
            self.graph = self.network.generate_networkx()
            self.positions = nx.spring_layout(self.graph)
            for train in self.trains:
                self.graph.add_node(train.line.name + "(" + str(train.number) +
                                    ")",
                                    label=train.line.name)
                self.positions[train.line.name + "(" + str(train.number) +
                               ")"] = (0, 0)
        colors_of_trains = self.get_colors_of_trains()
        for i in range(20):
            # draw the trains linearly interpolated with respect to their from and to stations
            for train in self.trains:
                if train.minutes >= train.start_minute and not train.waiting:
                    from_pos = self.positions[train.current_station.name]
                    to_pos = self.positions[train.target_station.name]
                    interpolated_pos = interpolate(from_pos, to_pos,
                                                   i * 1.0 / 20.0)
                    self.positions[train.line.name + "(" + str(train.number) +
                                   ")"] = interpolated_pos
            axes = plt.gca()
            plt.cla()
            with_labels = False
            visible_nodes = None
            if self.xlim and self.ylim:  # to preserve the zoom and translations of the user
                axes.set_xlim(self.xlim)
                axes.set_ylim(self.ylim)
                # draw labels if it is zoomed in a lot
                delta_x = self.xlim[1] - self.xlim[0]
                delta_y = self.ylim[1] - self.ylim[0]
                if delta_x * delta_y < 0.25:
                    with_labels = True
                visible_nodes = []
                for nodename, position in self.positions.items():
                    if is_visible(position, self.xlim, self.ylim):
                        visible_nodes.append(nodename)
            if visible_nodes:
                nx.draw_networkx(self.graph,
                                 self.positions,
                                 nodelist=visible_nodes,
                                 with_labels=with_labels)
            else:
                nx.draw_networkx(self.graph,
                                 self.positions,
                                 with_labels=with_labels)
            for color, stations in colors_of_trains.items():
                nx.draw_networkx_nodes(self.graph,
                                       self.positions,
                                       nodelist=stations,
                                       node_color=color)

            plt.title('Graph Representation of Rail Map', size=15)
            plt.draw()
            plt.pause(0.001)
            self.ylim = axes.get_ylim()
            self.xlim = axes.get_xlim()
    bbox[0] = c_x - bbox[2] / 2.
    bbox[1] = c_y - bbox[3] / 2.
    return bbox


def pixel2cam(pixel_coord, f, c):
    x = (pixel_coord[:, 0] - c[0]) / f[0] * pixel_coord[:, 2]
    y = (pixel_coord[:, 1] - c[1]) / f[1] * pixel_coord[:, 2]
    z = pixel_coord[:, 2]
    cam_coord = np.concatenate((x[:, None], y[:, None], z[:, None]), 1)
    return cam_coord


from matplotlib.axes._axes import _log as matplotlib_axes_logger  # provisional...

matplotlib_axes_logger.setLevel('ERROR')  # provisional...


def vis_keypoints(img, kps, kps_lines, kp_thresh=0.4, alpha=1):

    # Convert from plt 0-1 RGBA colors to 0-255 BGR colors for opencv.
    cmap = plt.get_cmap('rainbow')
    colors = [cmap(i) for i in np.linspace(0, 1, len(kps_lines) + 2)]
    colors = [(c[2] * 255, c[1] * 255, c[0] * 255) for c in colors]

    # Perform the drawing on a copy of the image, to allow for blending.
    kp_mask = np.copy(img)

    # Draw the keypoints.
    for l in range(len(kps_lines)):
        i1 = kps_lines[l][0]
Exemple #16
0
from time import time
import wntr
import numpy as np
import math
import pandas as pd
import csv
import matplotlib.pyplot as plt
from matplotlib.axes._axes import _log as matplotlib_axes_logger

from clustering import KMeans
from clustering import Node

matplotlib_axes_logger.setLevel('ERROR')  # 只显示error
np.seterr(divide='ignore', invalid='ignore')


# 通过水力模拟获取节点压力数据,并返回包含各节点坐标与压力信息的numpy数组数据
def sim_data(inp_file='data/Net3/Net3.inp'):
    print('start')
    start = time()
    wn = wntr.network.WaterNetworkModel(inp_file)
    sim = wntr.sim.WNTRSimulator(wn, mode='PDD')

    result = sim.run_sim()
    end = time()
    print('end')
    run_time = end - start
    print('Run time for no burst:' + "%.2f" % run_time + 's')

    pressure_all_time = result.node['pressure']
    pressure_mean = pressure_all_time.mean()  # 各时段压力值平均
Exemple #17
0
from matplotlib import cm, ticker
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from matplotlib.colors import ListedColormap, LogNorm
from matplotlib.axes._axes import _log as matplotlib_axes_logger
from mpl_toolkits.mplot3d import Axes3D  # <--- This is important for 3d plotting
import seaborn as sns
import numpy as np
import random
import decimal
from functions import Rosenbrock, dRosenbrock, Rastrigin, dRastrigin, Paraboloid, dParaboloid, Easom, Eggholder, dParaboloid
import time
from PSO_Gradient_Descent import Gradient_Descent
from tqdm import tqdm

matplotlib_axes_logger.setLevel(
    'ERROR')  ## To supress Warning related to color

#### Define functions from functions.py #####
functs = {
    # 'Gradient' : Gradient_Descent,
    '0': (Paraboloid, dParaboloid),
    '1': (Rastrigin, dRastrigin),
    '2': (Rosenbrock, dRosenbrock),
    '3': (Easom, 0),
    '4': (Eggholder, 0)
}

######### Settings params
#### Ask User for input ####
print("Press [Enter] for default option".format())
no_particles = int(input("Number of particles - (default 20): \n>>")
Exemple #18
0
Licensed under GNU Lesser General Public License v3.0
"""

import os
from pathlib import Path

import cv2
import numpy as np
import pandas as pd
from matplotlib.axes._axes import _log as matplotlib_axes_logger
from tqdm import tqdm

from deeplabcut.utils import auxiliaryfunctions
from deeplabcut.utils import auxiliaryfunctions_3d

matplotlib_axes_logger.setLevel("ERROR")


def triangulate(
    config,
    video_path,
    videotype="avi",
    filterpredictions=True,
    filtertype="median",
    gputouse=None,
    destfolder=None,
    save_as_csv=False,
):
    """
    This function triangulates the detected DLC-keypoints from the two camera views
    using the camera matrices (derived from calibration) to calculate 3D predictions.
#!/usr/bin/env python
# coding: utf-8

# In[22]:

import matplotlib.pyplot as plt
from matplotlib import style
import numpy as np
from scipy.stats import norm
np.random.seed(0)
import pickle
from matplotlib.axes._axes import _log as matplotlib_axes_logger
matplotlib_axes_logger.setLevel('ERROR')
import glob
import errno
from scipy.stats import multivariate_normal

# In[23]:


def load(name):
    file = open(name, 'rb')
    data = pickle.load(file)
    file.close()
    return data


# In[24]:


def save(data, name):
Exemple #20
0
def plotPCA(projected_df, grouping_variable, var_subset=None, savepath=None,\
            title=None, n_component_axes=2, rotate=False):
    """ A function to plot PCA projections and colour by a given categorical 
        variable (eg. grouping_variable = 'food type'). 
        Optionally, a subset of the data can be plotted for the grouping variable 
        (eg. var_subset=[list of foods]). 
    """

    # TODO: Plot features that have greatest influence on PCA (eg. PC1)

    if var_subset == None or len(var_subset) == 0:
        var_subset = list(projected_df[grouping_variable].unique())

    plt.close('all')

    # OPTION 1: Plot PCA - 2 principal components
    if n_component_axes == 2:
        plt.rc('xtick', labelsize=15)
        plt.rc('ytick', labelsize=15)
        sns.set_style("whitegrid")
        fig, ax = plt.subplots(figsize=[10, 10])

        # Create colour palette for plot loop
        palette = itertools.cycle(
            sns.color_palette("gist_rainbow", len(var_subset)))

        for g_var in var_subset:
            g_var_projected_df = projected_df[projected_df[grouping_variable]
                                              == g_var]
            sns.scatterplot(g_var_projected_df['PC1'],
                            g_var_projected_df['PC2'],
                            color=next(palette),
                            s=50)
        ax.set_xlabel('Principal Component 1', fontsize=20, labelpad=12)
        ax.set_ylabel('Principal Component 2', fontsize=20, labelpad=12)
        if title:
            ax.set_title(
                title,
                fontsize=20)  # title = 'Top256 features 2-Component PCA'
        if len(var_subset) <= 15:
            plt.tight_layout(rect=[0.04, 0, 0.84, 0.96])
            ax.legend(var_subset, frameon=False, loc=(1, 0.85), fontsize=15)
        ax.grid()

        # Save PCA scatterplot of first 2 PCs
        if savepath:
            savefig(savepath,
                    tight_layout=False,
                    tellme=True,
                    saveFormat='eps')  # rasterized=True

        plt.show()
        plt.pause(2)

    # OPTION 2: Plot PCA - 3 principal components
    elif n_component_axes == 3:
        # Work-around for 3D plot colour warnings
        mpl_axes_logger.setLevel('ERROR')

        plt.rc('xtick', labelsize=12)
        plt.rc('ytick', labelsize=12)
        fig = plt.figure(figsize=[10, 10])
        ax = Axes3D(fig)  # ax = fig.add_subplot(111, projection='3d')

        # Create colour palette for plot loop
        palette = itertools.cycle(
            sns.color_palette("gist_rainbow", len(var_subset)))

        for g_var in var_subset:
            g_var_projected_df = projected_df[projected_df[grouping_variable]
                                              == g_var]
            ax.scatter(xs=g_var_projected_df['PC1'], ys=g_var_projected_df['PC2'], zs=g_var_projected_df['PC3'],\
                       zdir='z', s=30, c=next(palette), depthshade=False)
        ax.set_xlabel('Principal Component 1', fontsize=15, labelpad=12)
        ax.set_ylabel('Principal Component 2', fontsize=15, labelpad=12)
        ax.set_zlabel('Principal Component 3', fontsize=15, labelpad=12)
        if title:
            ax.set_title(title, fontsize=20)
        if len(var_subset) <= 15:
            ax.legend(var_subset, frameon=False, fontsize=12)
            #ax.set_rasterized(True)
        ax.grid()

        # Save PCA scatterplot of first 3 PCs
        if savepath:
            savefig(savepath,
                    tight_layout=False,
                    tellme=True,
                    saveFormat='png')  # rasterized=True

        # Rotate the axes and update plot
        if rotate:
            for angle in range(0, 360):
                ax.view_init(270, angle)
                plt.draw()
                plt.pause(0.001)
        else:
            plt.show()
            plt.pause(2)
    else:
        print("Please select from n_component_axes = 2 or 3.")