예제 #1
0
def plot_demo_factor_dist(results,
                          c,
                          figsize=12,
                          dpi=300,
                          ext='png',
                          plot_dir=None):
    DA = results.DA
    sex = DA.raw_data['Sex']
    sex_percent = "{0:0.1f}%".format(np.mean(sex) * 100)
    scores = DA.get_scores(c)
    axes = scores.hist(bins=40, grid=False, figsize=(figsize * 1.3, figsize))
    axes = axes.flatten()
    f = plt.gcf()
    for ax in axes:
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
    axes[-1].set_xlabel('N: %s, Female Percent: %s' %
                        (len(scores), sex_percent),
                        labelpad=20)
    if plot_dir:
        filename = 'factor_distributions_DA%s.%s' % (c, ext)
        save_figure(f, path.join(plot_dir, filename), {
            'bbox_inches': 'tight',
            'dpi': dpi
        })
        plt.close()
def plot_polar_factors(results,
                       c,
                       color_by_group=True,
                       rotate='oblimin',
                       dpi=300,
                       ext='png',
                       plot_dir=None):
    """ Plots factor analytic results as polar plots
    
    Args:
        results: a dimensional structure results object
        c: the number of components to use
        color_by_group: whether to color the polar plot by factor groups. Groups
            are defined by the factor each measurement loads most highly on
        dpi: the final dpi for the image
        ext: the extension for the saved figure
        plot_dir: the directory to save the figure. If none, do not save
    """
    EFA = results.EFA
    loadings = EFA.get_loading(c, rotate=rotate)
    groups = get_factor_groups(loadings)
    # plot polar plot factor visualization for metric loadings
    filename = 'factor_polar_EFA%s.%s' % (c, ext)
    if color_by_group == True:
        colors = None
    else:
        colors = ['b'] * len(loadings.columns)
    fig = visualize_factors(loadings, n_rows=2, groups=groups, colors=colors)
    if plot_dir is not None:
        save_figure(fig, path.join(plot_dir, filename), {
            'bbox_inches': 'tight',
            'dpi': dpi
        })
        plt.close()
def plot_GAM(gams, X, Y, size=4, dpi=300, ext='png', filename=None):
    cols = X.shape[1]
    rows = Y.shape[1]
    colors = sns.color_palette(n_colors=rows)
    plt.rcParams['figure.figsize'] = (cols*size, rows*size)
    fig, mat_axs = plt.subplots(rows, cols)
    titles = X.columns
    for j, (name, out) in enumerate(gams.items()):
        axs = mat_axs[j]
        gam = out['model']
        R2 = get_avg_score(out['scores_cv'])
        p_vals = gam.statistics_['p_values']
        for i, ax in enumerate(axs):
            plot_term(gam, i, ax, colors[j], size=size)
            ax.set_xlabel('')
            ax.text(.5, .95, 'p< %s' % format_num(p_vals[i]), va='center', 
                    fontsize=size*3, transform=ax.transAxes)
            if j%2==0:
                ax.set_title(titles[i], fontsize=size*4)
            if i==0:
                ax.set_ylabel(name + ' (%s)' % format_num(R2), 
                              fontsize=size*4)
            else:
                ax.set_ylabel('')
                
    plt.subplots_adjust(hspace=.4)
    if filename is not None:
        save_figure(fig, '%s.%s' % (filename,ext),
                    {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
def plot_EFA_retest(combined, size=4.6, dpi=300, 
                    ext='png', plot_dir=None):
    corr = combined.corr()
    max_val = abs(corr).max().max()
    
    fig = plt.figure(figsize=(size,size)); 
    ax = fig.add_axes([.1, .1, .8, .8])
    cbar_ax = fig.add_axes([.92, .15, .04, .7])
    sns.heatmap(corr, square=True, ax=ax, cbar_ax=cbar_ax,
                vmin=-1, vmax=1,
                cmap=sns.diverging_palette(220,15,n=100,as_cmap=True),
                cbar_kws={'orientation': 'vertical',
                          'ticks': [-1, 0, 1]}); 
    ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
    ax.tick_params(labelsize=size/len(corr)*40)
    
    # format cbar axis
    cbar_ax.set_yticklabels([format_num(-max_val), 0, format_num(max_val)])
    cbar_ax.tick_params(labelsize=size, length=0, pad=size/2)
    cbar_ax.set_ylabel('Factor Loading', rotation=-90, 
                   fontsize=size, labelpad=size/2)
    
    # set divider lines
    n = corr.shape[1]
    ax.axvline(n//2, 0, n, color='k', linewidth=size/3)
    ax.axhline(n//2, 0, n, color='k', linewidth=size/3)
    
    if plot_dir is not None:
            save_figure(fig, path.join(plot_dir, 'EFA_test_retest_heatmap.%s' % ext),
                        {'bbox_inches': 'tight', 'dpi': dpi})
            plt.close()
def plot_factor_fingerprint(results,
                            classifier='ridge',
                            rotate='oblimin',
                            change=False,
                            size=4.6,
                            dpi=300,
                            ext='png',
                            plot_dir=None):
    colors = ref_colors[results.ID.split('_')[0]]
    reorder_vec = results.DA.get_factor_reorder(
        results.DA.results['num_factors'])
    targets = results.DA.get_loading().columns
    targets = [targets[i] for i in reorder_vec]
    if change:
        targets = [t + ' Change' for t in targets]

    predictions = results.load_prediction_object(EFA=True,
                                                 change=change,
                                                 classifier=classifier,
                                                 rotate=rotate)
    if predictions is None:
        print('No prediction object found!')
        return
    else:
        predictions = predictions['data']
    factors = predictions[targets[0]]['predvars']
    importances = np.vstack([predictions[k]['importances'] for k in targets])

    ncols = 3
    nrows = math.ceil(len(factors) / ncols)
    figsize = (size, size * nrows / ncols)
    f, axes = plt.subplots(nrows,
                           ncols,
                           figsize=figsize,
                           subplot_kw={'projection': 'polar'})
    plt.subplots_adjust(wspace=.5, hspace=.5)
    axes = f.get_axes()
    for i, factor in enumerate(factors):
        label_importance = [targets, importances[:, i]]
        visualize_importance(label_importance,
                             axes[i],
                             yticklabels=False,
                             xticklabels=True,
                             title=factor,
                             label_size=size * 1.2,
                             label_scale=.2,
                             color=colors[0],
                             ymax=math.ceil(np.max(importances) * 10) / 10 *
                             1.1)

    if plot_dir is not None:
        changestr = '_change' if change else ''
        filename = 'EFA%s_%s_factor_fingerprint.%s' % (changestr, classifier,
                                                       ext)

        save_figure(f, path.join(plot_dir, filename), {
            'bbox_inches': 'tight',
            'dpi': dpi
        })
        plt.close()
def plot_polar_factors(results, c, color_by_group=True, rotate='oblimin',
                       dpi=300, ext='png', plot_dir=None):
    """ Plots factor analytic results as polar plots
    
    Args:
        results: a dimensional structure results object
        c: the number of components to use
        color_by_group: whether to color the polar plot by factor groups. Groups
            are defined by the factor each measurement loads most highly on
        dpi: the final dpi for the image
        ext: the extension for the saved figure
        plot_dir: the directory to save the figure. If none, do not save
    """
    EFA = results.EFA
    loadings = EFA.get_loading(c, rotate=rotate)
    groups = get_factor_groups(loadings)    
    # plot polar plot factor visualization for metric loadings
    filename =  'factor_polar_EFA%s.%s' % (c, ext)
    if color_by_group==True:
        colors=None
    else:
        colors=['b']*len(loadings.columns)
    fig = visualize_factors(loadings, n_rows=2, groups=groups, colors=colors)
    if plot_dir is not None:
        save_figure(fig, path.join(plot_dir, filename),
                    {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
 def scale_plot(input_data, data_colors=None, cluster_colors=None,
                cluster_sizes=None, dissimilarity='euclidean', filey=None):
     """ Plot MDS of data and clusters """
     if data_colors is None:
         data_colors = 'r'
     if cluster_colors is None:
         cluster_colors='b'
     if cluster_sizes is None:
         cluster_sizes = 2200
         
     # scale
     mds = MDS(dissimilarity=dissimilarity)
     mds_out = mds.fit_transform(input_data)
     
     with sns.axes_style('white'):
         f=plt.figure(figsize=(14,14))
         plt.scatter(mds_out[n_clusters:,0], mds_out[n_clusters:,1], 
                     s=75, color=data_colors)
         plt.scatter(mds_out[:n_clusters,0], mds_out[:n_clusters,1], 
                     marker='*', s=cluster_sizes, color=cluster_colors,
                     edgecolor='black', linewidth=2)
         # plot cluster number
         offset = .011
         font_dict = {'fontsize': 17, 'color':'white'}
         for i,(x,y) in enumerate(mds_out[:n_clusters]):
             if i<9:
                 plt.text(x-offset,y-offset,i+1, font_dict)
             else:
                 plt.text(x-offset*2,y-offset,i+1, font_dict)
     if filey is not None:
         plt.title(path.basename(filey)[:-4], fontsize=20)
         save_figure(f, filey)
         plt.close()
def plot_nesting(results, thresh=.5, rotate='oblimin', title=True,
                 dpi=300, figsize=12, ext='png', plot_dir=None):
    """ Plots nesting of factor solutions
    
    Args:
        results: a dimensional structure results object
        thresh: the threshold to pass to EFA.get_nesting_matrix
        dpi: the final dpi for the image
        figsize: scalar - the width and height of the (square) image
        ext: the extension for the saved figure
        plot_dir: the directory to save the figure. If none, do not save
    """
    EFA = results.EFA
    explained_scores, sum_explained = EFA.get_nesting_matrix(thresh, 
                                                             rotate=rotate)

    # plot lower nesting
    fig, ax = plt.subplots(1, 1, figsize=(figsize, figsize))
    cbar_ax = fig.add_axes([.905, .3, .05, .3])
    sns.heatmap(sum_explained, annot=explained_scores,
                fmt='.2f', mask=(explained_scores==-1), square=True,
                ax = ax, vmin=.2, cbar_ax=cbar_ax,
                xticklabels = range(1,sum_explained.shape[1]+1),
                yticklabels = range(1,sum_explained.shape[0]+1))
    ax.set_xlabel('Higher Factors (Explainer)', fontsize=25)
    ax.set_ylabel('Lower Factors (Explainee)', fontsize=25)
    ax.set_title('Nesting of Lower Level Factors based on R2', fontsize=30)
    if plot_dir is not None:
        filename = 'lower_nesting_heatmap.%s' % ext
        save_figure(fig, path.join(plot_dir, filename), 
                    {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
def plot_predictors_comparison(R2_df, size=2, dpi=300, filename=None):
    CV_df = R2_df.filter(regex='CV', axis=0)
    CV_corr = CV_df.corr(method='spearman')

    max_R2 = round(CV_df.max(numeric_only=True).max(), 1)
    size = 2
    grid = sns.pairplot(CV_df, hue='Target_Cat', height=size)
    for i, row in enumerate(grid.axes):
        for j, ax in enumerate(row):
            ax.set_xlim([0, max_R2])
            ax.set_ylim([0, max_R2])
            xlim = ax.get_xlim()
            ylim = ax.get_ylim()
            ax.plot(xlim, ylim, ls=":", c=".5", zorder=-1)
            ax.set_xlim(xlim)
            ax.set_ylim(ylim)
            if j < i:
                ax.text(.5,
                        1,
                        r'$\rho$ = %s' % format_num(CV_corr.iloc[i, j]),
                        ha='center',
                        fontsize=size * 7,
                        fontweight='bold',
                        transform=ax.transAxes)
            if j > i:
                ax.set_visible(False)
    if filename is not None:
        save_figure(grid.fig, filename, {'bbox_inches': 'tight', 'dpi': dpi})
    else:
        return grid
def plot_RSA(corr, cluster=False, size=8, dpi=300, filename=None):
    """ plots similarity of ontological fingerprints between outcomes """
    figsize = (size, size)
    if cluster == False:
        f = plt.figure(figsize=figsize)
        ax = sns.heatmap(corr,
                         square=True,
                         cmap=sns.diverging_palette(220,
                                                    15,
                                                    n=100,
                                                    as_cmap=True))
    else:
        f = sns.clustermap(corr,
                           cmap=sns.diverging_palette(220,
                                                      15,
                                                      n=100,
                                                      as_cmap=True),
                           figsize=figsize)
        ax = f.ax_heatmap
        corr = f.data2d
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
    if filename is not None:
        save_figure(f, filename, {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
    return corr
def plot_clustering_similarity(results, plot_dir=None, verbose=False, ext='png'):  
    HCA = results.HCA
    # get all clustering solutions
    clusterings = HCA.results.items()
    # plot cluster agreement across embedding spaces
    names = [k for k,v in clusterings]
    cluster_similarity = np.zeros((len(clusterings), len(clusterings)))
    cluster_similarity = pd.DataFrame(cluster_similarity, 
                                     index=names,
                                     columns=names)
    
    distance_similarity = np.zeros((len(clusterings), len(clusterings)))
    distance_similarity = pd.DataFrame(distance_similarity, 
                                     index=names,
                                     columns=names)
    for clustering1, clustering2 in combinations(clusterings, 2):
        name1 = clustering1[0].split('-')[-1]
        name2 = clustering2[0].split('-')[-1]
        # record similarity of distance_df
        dist_corr = np.corrcoef(squareform(clustering1[1]['distance_df']),
                                squareform(clustering2[1]['distance_df']))[1,0]
        distance_similarity.loc[name1, name2] = dist_corr
        distance_similarity.loc[name2, name1] = dist_corr
        # record similarity of clustering of dendrogram
        clusters1 = clustering1[1]['labels']
        clusters2 = clustering2[1]['labels']
        rand_score = adjusted_rand_score(clusters1, clusters2)
        MI_score = adjusted_mutual_info_score(clusters1, clusters2)
        cluster_similarity.loc[name1, name2] = rand_score
        cluster_similarity.loc[name2, name1] = MI_score
    
    with sns.plotting_context(context='notebook', font_scale=1.4):
        clust_fig = plt.figure(figsize = (12,12))
        sns.heatmap(cluster_similarity, square=True)
        plt.title('Cluster Similarity: TRIL: Adjusted MI, TRIU: Adjusted Rand',
                  y=1.02)
        
        dist_fig = plt.figure(figsize = (12,12))
        sns.heatmap(distance_similarity, square=True)
        plt.title('Distance Similarity, metric: %s' % HCA.dist_metric,
                  y=1.02)
        
    if plot_dir is not None:
        save_figure(clust_fig, path.join(plot_dir, 
                                   'cluster_similarity_across_measures.%s' % ext),
                    {'bbox_inches': 'tight'})
        save_figure(dist_fig, path.join(plot_dir, 
                                   'distance_similarity_across_measures.%s' % ext),
                    {'bbox_inches': 'tight'})
        plt.close(clust_fig)
        plt.close(dist_fig)
    
    if verbose:
        # assess relationship between two measurements
        rand_scores = cluster_similarity.values[np.triu_indices_from(cluster_similarity, k=1)]
        MI_scores = cluster_similarity.T.values[np.triu_indices_from(cluster_similarity, k=1)]
        score_consistency = np.corrcoef(rand_scores, MI_scores)[0,1]
        print('Correlation between measures of cluster consistency: %.2f' \
              % score_consistency)
def plot_cross_within_prediction(prediction_loc,
                                 size=4.6,
                                 dpi=300,
                                 ext='png',
                                 plot_dir=None):
    predictions = pickle.load(open(prediction_loc, 'rb'))

    titles = [
        'Within Tasks', 'Within Surveys', 'Survey-By-Tasks', 'Task-By-Surveys'
    ]
    colors = [
        sns.color_palette('Blues_d', 3)[0],
        sns.color_palette('Reds_d', 3)[0], [.4, .4, .4], [.4, .4, .4]
    ]

    with sns.axes_style('whitegrid'):
        f, axes = plt.subplots(4, 1, figsize=(size, size * 1.5))

    for i, vals in enumerate([
            predictions['within']['task'], predictions['within']['survey'],
            predictions['across']['task_to_survey'],
            predictions['across']['survey_to_task']
    ]):
        sns.violinplot(list(vals.values()),
                       orient='h',
                       color=colors[i],
                       ax=axes[i],
                       width=.5,
                       linewidth=size * .3)

    min_x = min([ax.get_xlim()[0] for ax in axes])
    for i, ax in enumerate(axes):
        [i.set_linewidth(size * .3) for i in ax.spines.values()]
        ax.grid(linewidth=size * .15, which='both')
        ax.set_xlim([min_x, 1])
        ax.text(min_x + (1 - min_x) * .02,
                -.34,
                titles[i],
                color=colors[i],
                ha='left',
                fontsize=size * 3.5)
        xticks = np.arange(math.floor(min_x * 10) / 10, 1, .2)
        ax.set_xticks(xticks)
        if i != (len(axes) - 1):
            ax.set_xticklabels([])
        else:
            ax.tick_params(labelsize=size * 2.5, pad=size, length=0)
    axes[-1].set_xlabel(r'$R^2$', fontsize=size * 5)
    plt.subplots_adjust(hspace=0)
    if plot_dir is not None:
        # make histogram plot
        save_figure(f, path.join(plot_dir, 'cross_prediction.%s' % ext), {
            'dpi': dpi,
            'transparent': True
        })
        plt.close()
    else:
        return f
예제 #13
0
def plot_cluster_factors(results,
                         c,
                         rotate='oblimin',
                         ext='png',
                         plot_dir=None):
    """
    Args:
        EFA: EFA_Analysis object
        c: number of components for EFA
        task_sublists: a dictionary whose values are sets of tasks, and 
                        whose keywords are labels for those lists
    """
    # set up variables
    HCA = results.HCA
    EFA = results.EFA

    names, cluster_loadings = zip(
        *HCA.get_cluster_loading(EFA, rotate=rotate).items())
    cluster_DVs = HCA.get_cluster_DVs(inp='EFA%s_%s' % (EFA.get_c(), rotate))
    cluster_loadings = list(
        zip([cluster_DVs[n] for n in names], cluster_loadings))
    max_loading = max([max(abs(i[1])) for i in cluster_loadings])
    # plot
    colors = sns.hls_palette(len(cluster_loadings))
    ncols = min(5, len(cluster_loadings))
    nrows = ceil(len(cluster_loadings) / ncols)
    f, axes = plt.subplots(nrows,
                           ncols,
                           figsize=(ncols * 10, nrows * (8 + nrows)),
                           subplot_kw={'projection': 'polar'})
    axes = f.get_axes()
    for i, (measures, loading) in enumerate(cluster_loadings):
        plot_loadings(axes[i],
                      loading,
                      kind='line',
                      offset=.5,
                      plot_kws={
                          'alpha': .8,
                          'c': colors[i]
                      })
        axes[i].set_title('Cluster %s' % i, y=1.14, fontsize=25)
        # set tick labels
        xtick_locs = np.arange(0.0, 2 * np.pi, 2 * np.pi / len(loading))
        axes[i].set_xticks(xtick_locs)
        axes[i].set_xticks(xtick_locs + np.pi / len(loading), minor=True)
        if i % (ncols * 2) == 0 or i % (ncols * 2) == (ncols - 1):
            axes[i].set_xticklabels(loading.index, y=.08, minor=True)
            # set ylim
            axes[i].set_ylim(top=max_loading)
    for j in range(i + 1, len(axes)):
        axes[j].set_visible(False)
    plt.subplots_adjust(hspace=.5, wspace=.5)

    filename = 'polar_factors_EFA%s_%s.%s' % (c, rotate, ext)
    if plot_dir is not None:
        save_figure(f, path.join(plot_dir, filename), {'bbox_inches': 'tight'})
        plt.close()
def plot_prediction_scatter(results, target_order=None, EFA=True, change=False,
                            classifier='ridge', rotate='oblimin', 
                            normalize=False, metric='R2', size=4.6,  
                            dpi=300, ext='png', plot_dir=None):
    predictions = results.load_prediction_object(EFA=EFA, 
                                                 change=change,
                                                 classifier=classifier,
                                                 rotate=rotate)
    if predictions is None:
        print('No prediction object found!')
        return
    else:
        predictions = predictions['data']
    if EFA:
        predictors = results.EFA.get_scores()
    else:
        predictors = results.data
    if change:
        target_factors, _ = results.DA.get_change(results.dataset.replace('Complete', 'Retest'))
        predictors = predictors.loc[target_factors.index]
    else:
        target_factors = results.DA.get_scores()
    
    sns.set_style('whitegrid')
    n_cols = 2
    n_rows = math.ceil(len(target_factors.columns)/n_cols)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(size, size/n_cols*n_rows))
    axes = fig.get_axes()
    for i,v in enumerate(target_factors.columns):
        MAE = format_num(predictions[v]['scores_cv'][0]['MAE'])
        R2 = format_num(predictions[v]['scores_cv'][0]['R2'])
        axes[i].set_title('%s: R2: %s, MAE: %s' % (v, R2, MAE), 
            fontweight='bold', fontsize=size*1.5)
        clf=predictions[v]['clf']
        axes[i].scatter(target_factors[v], clf.predict(predictors), s=size*3)  
        axes[i].tick_params(length=0, labelsize=0)
        if i%2==0:
            axes[i].set_ylabel('Predicted Factor Score', fontsize=size*1.5)
    axes[i].set_xlabel('Target Factor Score', fontsize=size*1.5)
    axes[i-1].set_xlabel('Target Factor Score', fontsize=size*1.5)
    
    empty_plots = n_cols*n_rows - len(target_factors.columns)
    for ax in axes[-empty_plots:]:
        ax.set_visible(False)
    plt.subplots_adjust(hspace=.4, wspace=.3)
    
    if plot_dir is not None:
        changestr = '_change' if change else ''
        if EFA:
            filename = 'EFA%s_%s_prediction_scatter.%s' % (changestr, classifier, ext)
        else:
            filename = 'IDM%s_%s_prediction_scatter.%s' % (changestr, classifier, ext)
        save_figure(fig, path.join(plot_dir, filename), 
                    {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
def plot_prediction_scatter(predictions,
                            predictors,
                            targets,
                            target_order=None,
                            metric='R2',
                            size=4.6,
                            dpi=300,
                            filename=None):
    # subset predictors
    predictors = predictors.loc[targets.index]
    if target_order is None:
        target_order = predictions.keys()

    sns.set_style('white')
    n_cols = 4
    n_rows = math.ceil(len(target_order) / n_cols)
    fig, axes = plt.subplots(n_rows,
                             n_cols,
                             figsize=(size, size / n_cols * n_rows))
    axes = fig.get_axes()
    for i, v in enumerate(target_order):
        MAE = format_num(predictions[v]['scores_cv'][0]['MAE'])
        R2 = format_num(predictions[v]['scores_cv'][0]['R2'])
        axes[i].set_title('%s\nR2: %s, MAE: %s' %
                          ('\n'.join(v.split()), R2, MAE),
                          fontweight='bold',
                          fontsize=size * 1)
        clf = predictions[v]['clf']
        axes[i].scatter(targets[v],
                        clf.predict(predictors),
                        s=size * 2.5,
                        edgecolor='white',
                        linewidth=size / 30)
        axes[i].tick_params(length=0, labelsize=0)
        # add diagonal
        xlim = axes[i].get_xlim()
        ylim = axes[i].get_ylim()
        axes[i].plot(xlim, ylim, ls="-", c=".5", zorder=-1)
        axes[i].set_xlim(xlim)
        axes[i].set_ylim(ylim)
        for spine in ['top', 'right']:
            axes[i].spines[spine].set_visible(False)
        if i % n_cols == 0:
            axes[i].set_ylabel('Predicted Score', fontsize=size * 1.2)
    for ax in axes[-(len(target_order) + 1):]:
        ax.set_xlabel('Target Score', fontsize=size * 1.2)

    empty_plots = n_cols * n_rows - len(targets.columns)
    if empty_plots > 0:
        for ax in axes[-empty_plots:]:
            ax.set_visible(False)
    plt.subplots_adjust(hspace=.6, wspace=.3)
    if filename is not None:
        save_figure(fig, filename, {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
def plot_BIC(all_results, size=4.6, dpi=300, ext='png', plot_dir=None):
    """ Plots BIC and SABIC curves
    
    Args:
        all_results: a dimensional structure all_results object
        dpi: the final dpi for the image
        ext: the extension for the saved figure
        plot_dir: the directory to save the figure. If none, do not save
    """
    all_colors = [sns.color_palette('Blues_d',3)[0:3],
              sns.color_palette('Reds_d',3)[0:3],
              sns.color_palette('Greens_d',3)[0:3],
              sns.color_palette('Oranges_d',3)[0:3]]
    height= size*.75/len(all_results)
    with sns.axes_style('white'):
        fig, axes = plt.subplots(1, len(all_results), figsize=(size, height))
    for i, results in enumerate([all_results[key] for key in ['task','survey']]):
        ax1 = axes[i]
        name = results.ID.split('_')[0].title()
        EFA = results.EFA
        # Plot BIC and SABIC curves
        colors = all_colors[i]
        with sns.axes_style('white'):
            x = list(EFA.results['cscores_metric-BIC'].keys())
            # score keys
            keys = [k for k in EFA.results.keys() if 'cscores' in k]
            for key in keys:
                metric = key.split('-')[-1]
                BIC_scores = [EFA.results[key][i] for i in x]
                BIC_c = EFA.results['c_metric-%s' % metric]
                ax1.plot(x, BIC_scores,  'o-', c=colors[0], lw=size/6, label=metric,
                         markersize=height*2)
                ax1.plot(BIC_c, BIC_scores[BIC_c-1], '.', color='white',
                         markeredgecolor=colors[0], markeredgewidth=height/2, 
                         markersize=height*4)
            if i==0:
                if len(keys)>1:
                    ax1.set_ylabel('Score', fontsize=height*3)
                    leg = ax1.legend(loc='center right',
                                     fontsize=height*3, markerscale=0)
                    beautify_legend(leg, colors=colors)
                else:
                    ax1.set_ylabel(metric, fontsize=height*4)
            ax1.set_xlabel('# Factors', fontsize=height*4)
            ax1.set_xticks(x)
            ax1.set_xticklabels(x)
            ax1.tick_params(labelsize=height*2, pad=size/4, length=0)
            ax1.set_title(name, fontsize=height*4, y=1.01)
            ax1.grid(linewidth=size/8)
            [i.set_linewidth(size*.1) for i in ax1.spines.values()]
    if plot_dir is not None:
        save_figure(fig, path.join(plot_dir, 'BIC_curves.%s' % ext),
                    {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
def importance_bar_plots(predictions,
                         target_order=None,
                         show_sign=True,
                         colorbar=True,
                         size=5,
                         dpi=300,
                         filename=None):
    #palette = sns.cubehelix_palette(100)
    # plot
    if target_order is None:
        target_order = predictions.keys()
    n_predictors = len(predictions[list(target_order)[0]]['importances'][0])
    #set up color styling
    palette = sns.color_palette('Blues_d', n_predictors)
    # get max r2
    max_r2 = 0
    vals = [predictions[i] for i in target_order]
    max_r2 = max(max_r2, max([i['scores_cv'][0]['R2'] for i in vals]))
    importances = [(i['predvars'], i['importances'][0]) for i in vals]
    prediction_df = pd.DataFrame([i[1] for i in importances],
                                 columns=importances[0][0],
                                 index=target_order)
    prediction_df.sort_values(axis=1,
                              by=prediction_df.index[0],
                              inplace=True,
                              ascending=False)

    # plot
    sns.set_style('white')
    ax = prediction_df.plot(kind='bar',
                            edgecolor=None,
                            linewidth=0,
                            figsize=(size, size * .67),
                            color=palette)
    fig = ax.get_figure()
    ax.tick_params(labelsize=size)
    #ax.tick_params(axis='x', rotation=0)
    ax.set_ylabel(r'Standardized $\beta$', fontsize=size * 1.5)
    # set up legend and other aesthetic
    ax.grid(axis='y', linewidth=size / 10)
    leg = ax.legend(frameon=False,
                    fontsize=size * 1.5,
                    bbox_to_anchor=(1.25, .8),
                    handlelength=0,
                    handletextpad=0,
                    framealpha=1)
    beautify_legend(leg, colors=palette)
    for name, spine in ax.spines.items():
        spine.set_visible(False)
    if filename is not None:
        save_figure(fig, filename, {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
    else:
        return fig
def plot_task_factors(results, c, rotate='oblimin',
                      task_sublists=None, normalize_loadings=False,
                      figsize=10,  dpi=300, ext='png', plot_dir=None):
    """ Plots task factors as polar plots
    
    Args:
        results: a dimensional structure results object
        c: the number of components to use
        task_sublists: a dictionary whose values are sets of tasks, and 
                        whose keywords are labels for those lists
        dpi: the final dpi for the image
        figsize: scalar - a width multiplier for the plot
        ext: the extension for the saved figure
        plot_dir: the directory to save the figure. If none, do not save
    """
    EFA = results.EFA
    # plot task factor loading
    loadings = EFA.get_loading(c, rotate=rotate)
    max_loading = abs(loadings).max().max()
    tasks = np.unique([i.split('.')[0] for i in loadings.index])
    
    if task_sublists is None:
        task_sublists = {'surveys': [t for t in tasks if 'survey' in t],
                        'tasks': [t for t in tasks if 'survey' not in t]}

    for sublist_name, task_sublist in task_sublists.items():
        for i, task in enumerate(task_sublist):
            # plot loading distributions. Each measure is scaled so absolute
            # comparisons are impossible. Only the distributions can be compared
            f, ax = plt.subplots(1,1, 
                                 figsize=(figsize, figsize), subplot_kw={'projection': 'polar'})
            task_loadings = loadings.filter(regex='^%s' % task, axis=0)
            task_loadings.index = format_variable_names(task_loadings.index)
            if normalize_loadings:
                task_loadings = task_loadings = (task_loadings.T/abs(task_loadings).max(1)).T
            # format variable names
            task_loadings.index = format_variable_names(task_loadings.index)
            # plot
            visualize_task_factors(task_loadings, ax, ymax=max_loading,
                                   xticklabels=True, label_size=figsize*2)
            ax.set_title(' '.join(task.split('_')), 
                              y=1.14, fontsize=25)
            
            if plot_dir is not None:
                if normalize_loadings:
                    function_directory = 'factor_DVnormdist_EFA%s_subset-%s' % (c, sublist_name)
                else:
                    function_directory = 'factor_DVdist_EFA%s_subset-%s' % (c, sublist_name)
                makedirs(path.join(plot_dir, function_directory), exist_ok=True)
                filename = '%s.%s' % (task, ext)
                save_figure(f, path.join(plot_dir, function_directory, filename),
                            {'bbox_inches': 'tight', 'dpi': dpi})
                plt.close()
def plot_EFA_change(combined, ax=None, color_on=False, method=PCA,
                    size=4.6, dpi=300, ext='png', plot_dir=None):
    n = combined.shape[1]//2
    orig = combined.iloc[:,:n]
    retest = combined.iloc[:,n:]
    retest.columns = orig.columns
    retest.index = [i+'_retest' for i in retest.index]
    both = pd.concat([orig, retest])
    projector = method(2)    
    projection = projector.fit_transform(both)   
    orig_projection = projection[:both.shape[0]//2,:]
    retest_projection = projection[both.shape[0]//2:,:]
    
    color=[.2,.2,.2, .9]
    # get color range
    mins = np.min(orig_projection)
    ranges = np.max(orig_projection)-mins
    if ax is None:
        with sns.axes_style('white'):
            fig, ax = plt.subplots(figsize=(size,size))
    markersize = size
    markeredge = size/5
    linewidth = size/3
    for i in range(len(orig_projection)):
        label = [None, None]
        if i==0:
            label=['T1 Scores', 'T2 Scores']
        if color_on == True:
            color = list((orig_projection[i,:]-mins)/ranges)
            color = [color[0]] + [0] + [color[1]]
        elif color_on != False:
            color = color_on
        ax.plot(*zip(orig_projection[i,:], retest_projection[i,:]), marker='o',
                 markersize=markersize, color=color,
                 markeredgewidth=markeredge, markerfacecolor='w',
                 linewidth=linewidth, label=label[0])
        ax.plot(retest_projection[i,0], retest_projection[i,1], marker='o', 
                 markersize=markersize, color=color, 
                 linewidth=linewidth, label=label[1])
    ax.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
    ax.set_xlabel('PC 1', fontsize=size*2.5)
    ax.set_ylabel('PC 2', fontsize=size*2.5)
    ax.set_xlim(np.min(projection)-abs(np.min(projection))*.1, 
                np.max(projection)+abs(np.max(projection))*.1)
    ax.set_ylim(ax.get_xlim())
    ax.legend(fontsize=size*1.9)
    ax.get_legend().get_frame().set_linewidth(linewidth/2)
        
    if plot_dir is not None:
            save_figure(fig, path.join(plot_dir, 'EFA_test_retest_sticks.%s' % ext),
                        {'bbox_inches': 'tight', 'dpi': dpi})
            plt.close()
def plot_prediction_comparison(results,
                               size=4.6,
                               change=False,
                               dpi=300,
                               ext='png',
                               plot_dir=None):
    colors = ref_colors[results.ID.split('_')[0]]
    R2s = {}
    for EFA in [False, True]:
        predictions = results.get_prediction_files(EFA=EFA,
                                                   change=change,
                                                   shuffle=False)
        predictions = sorted(predictions, key=path.getmtime)
        classifiers = np.unique([i.split('_')[-2] for i in predictions])
        # get last prediction file of each type
        for classifier in classifiers:
            filey = [i for i in predictions if classifier in i][-1]
            prediction_object = pickle.load(open(filey, 'rb'))['data']
            R2 = [i['scores_cv'][0]['R2'] for i in prediction_object.values()]
            R2 = np.nan_to_num(R2)
            feature = 'EFA' if EFA else 'IDM'
            R2s[feature + '_' + classifier] = R2
    if len(R2s) == 0:
        print('No prediction objects found')
        return
    R2s = pd.DataFrame(R2s).melt(var_name='Classifier', value_name='R2')
    R2s['Feature'], R2s['Classifier'] = R2s.Classifier.str.split('_', 1).str
    f = plt.figure(figsize=(size, size * .62))
    sns.barplot(x='Classifier',
                y='R2',
                data=R2s,
                hue='Feature',
                palette=colors[:2],
                errwidth=size / 5)
    ax = plt.gca()
    ax.tick_params(axis='y', labelsize=size * 1.8)
    ax.tick_params(axis='x', labelsize=size * 1.8)
    leg = ax.legend(fontsize=size * 2, loc='upper right')
    beautify_legend(leg, colors[:2])
    plt.xlabel('Classifier', fontsize=size * 2.2, labelpad=size / 2)
    plt.ylabel('R2', fontsize=size * 2.2, labelpad=size / 2)
    plt.title('Comparison of Prediction Methods', fontsize=size * 2.5, y=1.05)

    if plot_dir is not None:
        filename = 'prediction_comparison.%s' % ext
        save_figure(f, path.join(plot_dir, filename), {
            'bbox_inches': 'tight',
            'dpi': dpi
        })
        plt.close()
def plot_cross_silhouette(all_results,
                          rotate,
                          size=4.6,
                          dpi=300,
                          ext='png',
                          plot_dir=None):
    with sns.axes_style('white'):
        fig, axes = plt.subplots(len(all_results),
                                 2,
                                 figsize=(size,
                                          size * .375 * len(all_results)))
    axes = fig.get_axes()
    letters = [chr(i).upper() for i in range(ord('a'), ord('z') + 1)]

    for i, (name, results) in enumerate(all_results.items()):
        ax = axes[i * 2]
        ax2 = axes[i * 2 + 1]
        inp = 'EFA%s_%s' % (results.EFA.get_c(), rotate)
        plot_silhouette(results, inp=inp, axes=(ax, ax2), size=size)
        ax.set_ylabel('%s cluster separated DVs' % name.title(),
                      fontsize=size * 1.2)
        ax2.set_ylabel('%s average silhouette score' % name.title(),
                       fontsize=size * 1.2)
        if i == 0:
            ax.set_xlabel('')
            ax2.set_xlabel('')
        else:
            ax.set_xlabel('Silhouette score', fontsize=size * 1.2)
            ax2.set_xlabel('Number of clusters', fontsize=size * 1.2)
        if i != 0:
            ax.set_title('')
            ax2.set_title('')
        [i.set_linewidth(size * .1) for i in ax.spines.values()]
        [i.set_linewidth(size * .1) for i in ax2.spines.values()]
    plt.subplots_adjust(hspace=.2)
    max_x = max([ax.get_xlim()[1] for ax in axes[::2]])
    min_x = min([ax.get_xlim()[0] for ax in axes[::2]])
    for i in range(len(all_results)):
        ax = axes[i * 2]
        ax2 = axes[i * 2 + 1]
        ax.set_xlim([min_x, max_x])
        place_letter(ax, letters.pop(0), fontsize=size * 9 / 4.6)
        place_letter(ax2, letters.pop(0), fontsize=size * 9 / 4.6)

    if plot_dir is not None:
        save_figure(
            fig, path.join(plot_dir, rotate, 'silhouette_analysis.%s' % ext),
            {'dpi': dpi})
        plt.close()
def plot_factor_fingerprint(results, classifier='ridge', rotate='oblimin', 
                            change=False, size=4.6,  
                            dpi=300, ext='png', plot_dir=None):
    colors = ref_colors[results.ID.split('_')[0]]
    reorder_vec = results.DA.get_factor_reorder(results.DA.results['num_factors'])
    targets = results.DA.get_loading().columns
    targets = [targets[i] for i in reorder_vec]
    if change:
        targets = [t+' Change' for t in targets]
        
    predictions = results.load_prediction_object(EFA=True, 
                                                 change=change,
                                                 classifier=classifier,
                                                 rotate=rotate)
    if predictions is None:
        print('No prediction object found!')
        return
    else:
        predictions = predictions['data']
    factors = predictions[targets[0]]['predvars']
    importances = np.vstack([predictions[k]['importances'] for k in targets])

    ncols = 3
    nrows = math.ceil(len(factors)/ncols)
    figsize = (size, size*nrows/ncols)
    f, axes = plt.subplots(nrows, ncols, figsize=figsize, 
                           subplot_kw={'projection':'polar'})
    plt.subplots_adjust(wspace=.5, hspace=.5)
    axes = f.get_axes()
    for i, factor in enumerate(factors):
        label_importance = [targets, importances[:,i]]
        visualize_importance(label_importance, axes[i], yticklabels=False,
                             xticklabels=True,
                             title=factor,
                             label_size=size*1.2,
                             label_scale=.2,
                             color=colors[0],
                             ymax=math.ceil(np.max(importances)*10)/10*1.1)
    
    if plot_dir is not None:
        changestr = '_change' if change else ''
        filename = 'EFA%s_%s_factor_fingerprint.%s' % (changestr, classifier, ext)

        save_figure(f, path.join(plot_dir, filename), 
                    {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
def plot_BIC_SABIC(results, size=2.3, dpi=300, ext='png', plot_dir=None):
    """ Plots BIC and SABIC curves
    
    Args:
        results: a dimensional structure results object
        dpi: the final dpi for the image
        ext: the extension for the saved figure
        plot_dir: the directory to save the figure. If none, do not save
    """
    EFA = results.EFA
    # Plot BIC and SABIC curves
    colors = ['c', 'm']
    with sns.axes_style('white'):
        fig, ax1 = plt.subplots(1,1, figsize=(size, size*.75))
        x = sorted(list(EFA.results['cscores_metric-BIC'].keys()))
        # BIC
        BIC_scores = [EFA.results['cscores_metric-BIC'][i] for i in x]
        BIC_c = EFA.results['c_metric-BIC']
        ax1.plot(x, BIC_scores,  'o-', c=colors[0], lw=3, label='BIC',
                 markersize=size*2)
        ax1.set_xlabel('# Factors', fontsize=size*3)
        ax1.set_ylabel('BIC', fontsize=size*3)
        ax1.plot(BIC_c, BIC_scores[BIC_c-1], '.', color='white',
                 markeredgecolor=colors[0], markeredgewidth=size/2, 
                 markersize=size*4)
        ax1.tick_params(labelsize=size*2)
        if 'cscores_metric-SABIC' in EFA.results.keys():
            # SABIC
            ax2 = ax1.twinx()
            SABIC_scores = list(EFA.results['cscores_metric-SABIC'].values())
            SABIC_c = EFA.results['c_metric-SABIC']
            ax2.plot(x, SABIC_scores, c=colors[1], lw=3, label='SABIC',
                     markersize=size*2)
            ax2.set_ylabel('SABIC', fontsize=size*4)
            ax2.plot(SABIC_c, SABIC_scores[SABIC_c],'k.',
                 markeredgecolor=colors[0], markeredgewidth=size/2, 
                 markersize=size*4)
            # set up legend
            ax1.plot(np.nan, c='m', lw=3, label='SABIC')
            leg = ax1.legend(loc='right center')
            beautify_legend(leg, colors=colors)
        if plot_dir is not None:
            save_figure(fig, path.join(plot_dir, 'BIC_SABIC_curves.%s' % ext),
                        {'bbox_inches': 'tight', 'dpi': dpi})
            plt.close()
예제 #24
0
def plot_nesting(results,
                 thresh=.5,
                 rotate='oblimin',
                 title=True,
                 dpi=300,
                 figsize=12,
                 ext='png',
                 plot_dir=None):
    """ Plots nesting of factor solutions
    
    Args:
        results: a dimensional structure results object
        thresh: the threshold to pass to EFA.get_nesting_matrix
        dpi: the final dpi for the image
        figsize: scalar - the width and height of the (square) image
        ext: the extension for the saved figure
        plot_dir: the directory to save the figure. If none, do not save
    """
    EFA = results.EFA
    explained_scores, sum_explained = EFA.get_nesting_matrix(thresh,
                                                             rotate=rotate)

    # plot lower nesting
    fig, ax = plt.subplots(1, 1, figsize=(figsize, figsize))
    cbar_ax = fig.add_axes([.905, .3, .05, .3])
    sns.heatmap(sum_explained,
                annot=explained_scores,
                fmt='.2f',
                mask=(explained_scores == -1),
                square=True,
                ax=ax,
                vmin=.2,
                cbar_ax=cbar_ax,
                xticklabels=range(1, sum_explained.shape[1] + 1),
                yticklabels=range(1, sum_explained.shape[0] + 1))
    ax.set_xlabel('Higher Factors (Explainer)', fontsize=25)
    ax.set_ylabel('Lower Factors (Explainee)', fontsize=25)
    ax.set_title('Nesting of Lower Level Factors based on R2', fontsize=30)
    if plot_dir is not None:
        filename = 'lower_nesting_heatmap.%s' % ext
        save_figure(fig, path.join(plot_dir, filename), {
            'bbox_inches': 'tight',
            'dpi': dpi
        })
        plt.close()
def plot_EFA_retest(combined, size=4.6, dpi=300, ext='png', plot_dir=None):
    corr = combined.corr()
    max_val = abs(corr).max().max()

    fig = plt.figure(figsize=(size, size))
    ax = fig.add_axes([.1, .1, .8, .8])
    cbar_ax = fig.add_axes([.92, .15, .04, .7])
    sns.heatmap(corr,
                square=True,
                ax=ax,
                cbar_ax=cbar_ax,
                vmin=-1,
                vmax=1,
                cmap=sns.diverging_palette(220, 15, n=100, as_cmap=True),
                cbar_kws={
                    'orientation': 'vertical',
                    'ticks': [-1, 0, 1]
                })
    ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
    print('LABELS THAT WORK??????')
    print(ax.get_yticklabels())
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
    print(ax.get_yticklabels())
    ax.tick_params(labelsize=size / len(corr) * 40)

    # format cbar axis
    cbar_ax.set_yticklabels([format_num(-max_val), 0, format_num(max_val)])
    cbar_ax.tick_params(labelsize=size, length=0, pad=size / 2)
    cbar_ax.set_ylabel('Factor Loading',
                       rotation=-90,
                       fontsize=size,
                       labelpad=size / 2)

    # set divider lines
    n = corr.shape[1]
    ax.axvline(n // 2, 0, n, color='k', linewidth=size / 3)
    ax.axhline(n // 2, 0, n, color='k', linewidth=size / 3)

    if plot_dir is not None:
        save_figure(fig, path.join(plot_dir,
                                   'EFA_test_retest_heatmap.%s' % ext), {
                                       'bbox_inches': 'tight',
                                       'dpi': dpi
                                   })
        plt.close()
def plot_cluster_factors(results, c, rotate='oblimin',  ext='png', plot_dir=None):
    """
    Args:
        EFA: EFA_Analysis object
        c: number of components for EFA
        task_sublists: a dictionary whose values are sets of tasks, and 
                        whose keywords are labels for those lists
    """
    # set up variables
    HCA = results.HCA
    EFA = results.EFA
    
    names, cluster_loadings = zip(*HCA.get_cluster_loading(EFA, rotate=rotate).items())
    cluster_DVs = HCA.get_cluster_DVs(inp='EFA%s_%s' % (EFA.get_c(), rotate))
    cluster_loadings = list(zip([cluster_DVs[n] for n in names], cluster_loadings))
    max_loading = max([max(abs(i[1])) for i in cluster_loadings])
    # plot
    colors = sns.hls_palette(len(cluster_loadings))
    ncols = min(5, len(cluster_loadings))
    nrows = ceil(len(cluster_loadings)/ncols)
    f, axes = plt.subplots(nrows, ncols, 
                               figsize=(ncols*10,nrows*(8+nrows)),
                               subplot_kw={'projection': 'polar'})
    axes = f.get_axes()
    for i, (measures, loading) in enumerate(cluster_loadings):
        plot_loadings(axes[i], loading, kind='line', offset=.5,
              plot_kws={'alpha': .8, 'c': colors[i]})
        axes[i].set_title('Cluster %s' % i, y=1.14, fontsize=25)
        # set tick labels
        xtick_locs = np.arange(0.0, 2*np.pi, 2*np.pi/len(loading))
        axes[i].set_xticks(xtick_locs)
        axes[i].set_xticks(xtick_locs+np.pi/len(loading), minor=True)
        if i%(ncols*2)==0 or i%(ncols*2)==(ncols-1):
            axes[i].set_xticklabels(loading.index,  y=.08, minor=True)
            # set ylim
            axes[i].set_ylim(top=max_loading)
    for j in range(i+1, len(axes)):
        axes[j].set_visible(False)
    plt.subplots_adjust(hspace=.5, wspace=.5)
    
    filename = 'polar_factors_EFA%s_%s.%s' % (c, rotate, ext)
    if plot_dir is not None:
        save_figure(f, path.join(plot_dir, filename),
                    {'bbox_inches': 'tight'})
        plt.close()
def plot_glasso_edge_strength(all_results, graph_loc,  size=4.6, 
                             dpi=300, ext='png', plot_dir=None):
    task_length = all_results['task'].data.shape[1]
    g = pickle.load(open(graph_loc, 'rb'))
    # subset graph
    task_within = squareform(g.graph_to_dataframe().iloc[:task_length, :task_length])
    survey_within = squareform(g.graph_to_dataframe().iloc[task_length:, task_length:])
    across = g.graph_to_dataframe().iloc[:task_length, task_length:].values.flatten()
    

    titles = ['Within Tasks', 'Within Surveys', 'Between Tasks And Surveys']
    colors = [sns.color_palette('Blues_d',3)[0],
              sns.color_palette('Reds_d',3)[0],
              [0,0,0]]
    
    with sns.axes_style('whitegrid'):
        f, axes = plt.subplots(3,1, figsize=(size,size*1.5))

    for i, corr in enumerate([task_within, survey_within, across]):
        sns.stripplot(corr, jitter=.2, alpha=.5, orient='h', ax=axes[i],
                      color=colors[i], s=size/2)
        
    max_x = max([ax.get_xlim()[1] for ax in axes])*1.1
    for i, ax in enumerate(axes):
        [i.set_linewidth(size*.3) for i in ax.spines.values()]
        ax.grid(linewidth=size*.15)
        ax.set_xlim([0, max_x])
        ax.text(max_x*.02, -.35, titles[i], color=colors[i], ha='left',
                fontsize=size*3.5)
        ax.set_xticks(np.arange(0, round(max_x*10)/10,.1))
        if i!=(len(axes)-1):
            ax.set_xticklabels([])
        else:
            ax.tick_params(labelsize=size*2.5, pad=size, length=0)
    axes[-1].set_xlabel('Edge Weight', fontsize=size*5)
    plt.subplots_adjust(hspace=0)
    if plot_dir is not None:
        # make histogram plot
        save_figure(f, path.join(plot_dir, 'glasso_edge_strength.%s' % ext),
                                {'dpi': dpi,
                                 'transparent': True})   
        plt.close()
    else:
        return f
def plot_cross_within_prediction(prediction_loc, size=4.6, 
                                 dpi=300, ext='png', plot_dir=None):
    predictions = pickle.load(open(prediction_loc, 'rb'))

    titles = ['Within Tasks', 'Within Surveys', 'Survey-By-Tasks', 'Task-By-Surveys']
    colors = [sns.color_palette('Blues_d',3)[0],
              sns.color_palette('Reds_d',3)[0],
              [.4,.4,.4],
              [.4,.4,.4]]
    
    with sns.axes_style('whitegrid'):
        f, axes = plt.subplots(4,1, figsize=(size,size*1.5))

    for i, vals in enumerate([predictions['within']['task'],
                              predictions['within']['survey'],
                              predictions['across']['task_to_survey'],
                              predictions['across']['survey_to_task']]):
        sns.violinplot(list(vals.values()), orient='h', color=colors[i],
                    ax=axes[i], width=.5, linewidth=size*.3)
        
    min_x = min([ax.get_xlim()[0] for ax in axes])
    for i, ax in enumerate(axes):
        [i.set_linewidth(size*.3) for i in ax.spines.values()]
        ax.grid(linewidth=size*.15, which='both')
        ax.set_xlim([min_x, 1])
        ax.text(min_x+(1-min_x)*.02, -.34, titles[i], color=colors[i], ha='left',
                fontsize=size*3.5)
        xticks = np.arange(math.floor(min_x*10)/10,1,.2)
        ax.set_xticks(xticks)
        if i!=(len(axes)-1):
            ax.set_xticklabels([])
        else:
            ax.tick_params(labelsize=size*2.5, pad=size, length=0)
    axes[-1].set_xlabel(r'$R^2$', fontsize=size*5)
    plt.subplots_adjust(hspace=0)
    if plot_dir is not None:
        # make histogram plot
        save_figure(f, path.join(plot_dir, 'cross_prediction.%s' % ext),
                                {'dpi': dpi,
                                 'transparent': True})   
        plt.close()
    else:
        return f
예제 #29
0
    def scale_plot(input_data,
                   data_colors=None,
                   cluster_colors=None,
                   cluster_sizes=None,
                   dissimilarity='euclidean',
                   filey=None):
        """ Plot MDS of data and clusters """
        if data_colors is None:
            data_colors = 'r'
        if cluster_colors is None:
            cluster_colors = 'b'
        if cluster_sizes is None:
            cluster_sizes = 2200

        # scale
        mds = MDS(dissimilarity=dissimilarity)
        mds_out = mds.fit_transform(input_data)

        with sns.axes_style('white'):
            f = plt.figure(figsize=(14, 14))
            plt.scatter(mds_out[n_clusters:, 0],
                        mds_out[n_clusters:, 1],
                        s=75,
                        color=data_colors)
            plt.scatter(mds_out[:n_clusters, 0],
                        mds_out[:n_clusters, 1],
                        marker='*',
                        s=cluster_sizes,
                        color=cluster_colors,
                        edgecolor='black',
                        linewidth=2)
            # plot cluster number
            offset = .011
            font_dict = {'fontsize': 17, 'color': 'white'}
            for i, (x, y) in enumerate(mds_out[:n_clusters]):
                if i < 9:
                    plt.text(x - offset, y - offset, i + 1, font_dict)
                else:
                    plt.text(x - offset * 2, y - offset, i + 1, font_dict)
        if filey is not None:
            plt.title(path.basename(filey)[:-4], fontsize=20)
            save_figure(f, filey)
            plt.close()
def plot_factor_correlation(results, c, rotate='oblimin', title=True,
                            DA=False, size=4.6, dpi=300, ext='png', plot_dir=None):
    if DA:
        EFA = results.DA
    else:
        EFA = results.EFA
    loading = EFA.get_loading(c, rotate=rotate)
    # get factor correlation matrix
    reorder_vec = EFA.get_factor_reorder(c)
    phi = get_attr(EFA.results['factor_tree_Rout_%s' % rotate][c],'Phi')
    phi = pd.DataFrame(phi, columns=loading.columns, index=loading.columns)
    phi = phi.iloc[reorder_vec, reorder_vec]
    mask = np.zeros_like(phi)
    mask[np.tril_indices_from(mask, -1)] = True
    with sns.plotting_context('notebook', font_scale=2) and sns.axes_style('white'):
        f = plt.figure(figsize=(size*5/4, size))
        ax1 = f.add_axes([0,0,.9,.9])
        cbar_ax = f.add_axes([.91, .05, .03, .8])
        sns.heatmap(phi, ax=ax1, square=True, vmax=1, vmin=-1,
                    cbar_ax=cbar_ax, 
                    cmap=sns.diverging_palette(220,15,n=100,as_cmap=True))
        sns.heatmap(phi, ax=ax1, square=True, vmax=1, vmin=-1,
                    cbar_ax=cbar_ax, annot=True, annot_kws={"size": size/c*15},
                    cmap=sns.diverging_palette(220,15,n=100,as_cmap=True),
                    mask=mask)
        yticklabels = ax1.get_yticklabels()
        ax1.set_yticklabels(yticklabels, rotation=0, ha="right")
        ax1.set_xticklabels(ax1.get_xticklabels(), rotation=90)
        if title == True:
            ax1.set_title('%s Factor Correlations' % results.ID.split('_')[0].title(),
                      weight='bold', y=1.05, fontsize=size*3)
        ax1.tick_params(labelsize=size*3)
        # format cbar
        cbar_ax.tick_params(axis='y', length=0)
        cbar_ax.tick_params(labelsize=size*2)
        cbar_ax.set_ylabel('Pearson Correlation', rotation=-90, labelpad=size*4, fontsize=size*3)
    
    if plot_dir:
        filename = 'factor_correlations_EFA%s.%s' % (c, ext)
        save_figure(f, path.join(plot_dir, filename), 
                    {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
def plot_cross_silhouette(all_results, rotate, size=4.6,  dpi=300, 
                    ext='png', plot_dir=None):
    with sns.axes_style('white'):
        fig, axes =  plt.subplots(len(all_results), 2, 
                                  figsize=(size, size*.375*len(all_results)))
    axes = fig.get_axes()
    letters = [chr(i).upper() for i in range(ord('a'),ord('z')+1)]
    
    for i, (name, results) in enumerate(all_results.items()):
        ax = axes[i*2]
        ax2 = axes[i*2+1]
        inp = 'EFA%s_%s' % (results.EFA.get_c(), rotate)
        plot_silhouette(results, inp=inp, axes=(ax,ax2), size=size)
        ax.set_ylabel('%s cluster separated DVs' % name.title(), fontsize=size*1.2)
        ax2.set_ylabel('%s average silhouette score' % name.title(), fontsize=size*1.2)
        if i == 0:
            ax.set_xlabel('')
            ax2.set_xlabel('')
        else:
            ax.set_xlabel('Silhouette score', fontsize=size*1.2)
            ax2.set_xlabel('Number of clusters', fontsize=size*1.2)
        if i != 0:
            ax.set_title('')
            ax2.set_title('')
        [i.set_linewidth(size*.1) for i in ax.spines.values()]
        [i.set_linewidth(size*.1) for i in ax2.spines.values()]
    plt.subplots_adjust(hspace=.2)
    max_x = max([ax.get_xlim()[1] for ax in axes[::2]])
    min_x = min([ax.get_xlim()[0] for ax in axes[::2]])
    for i in range(len(all_results)):
        ax = axes[i*2]
        ax2 = axes[i*2+1]
        ax.set_xlim([min_x, max_x])
        place_letter(ax, letters.pop(0), fontsize=size*9/4.6)
        place_letter(ax2, letters.pop(0), fontsize=size*9/4.6)
        
    if plot_dir is not None:
        save_figure(fig, path.join(plot_dir, rotate,
                                         'silhouette_analysis.%s' % ext),
                    {'dpi': dpi})
        plt.close()
def plot_prediction_comparison(results, size=4.6, change=False,
                               dpi=300, ext='png', plot_dir=None):
    colors = ref_colors[results.ID.split('_')[0]]
    R2s = {}
    for EFA in [False, True]:
        predictions = results.get_prediction_files(EFA=EFA, change=change, 
                                                   shuffle=False)
        predictions = sorted(predictions, key = path.getmtime)
        classifiers = np.unique([i.split('_')[-2] for i in predictions])
        # get last prediction file of each type
        for classifier in classifiers:
            filey = [i for i in predictions if classifier in i][-1]
            prediction_object = pickle.load(open(filey, 'rb'))['data']
            R2 = [i['scores_cv'][0]['R2'] for i in prediction_object.values()]
            R2 = np.nan_to_num(R2)
            feature = 'EFA' if EFA else 'IDM'
            R2s[feature+'_'+classifier] = R2
    if len(R2s) == 0:
        print('No prediction objects found')
        return
    R2s = pd.DataFrame(R2s).melt(var_name='Classifier', value_name='R2')
    R2s['Feature'], R2s['Classifier'] = R2s.Classifier.str.split('_', 1).str
    f = plt.figure(figsize=(size, size*.62))
    sns.barplot(x='Classifier', y='R2', data=R2s, hue='Feature',
                palette=colors[:2], errwidth=size/5)
    ax = plt.gca()
    ax.tick_params(axis='y', labelsize=size*1.8)
    ax.tick_params(axis='x', labelsize=size*1.8)
    leg = ax.legend(fontsize=size*2, loc='upper right')
    beautify_legend(leg, colors[:2])
    plt.xlabel('Classifier', fontsize=size*2.2, labelpad=size/2)
    plt.ylabel('R2', fontsize=size*2.2, labelpad=size/2)
    plt.title('Comparison of Prediction Methods', fontsize=size*2.5, y=1.05)
    
    if plot_dir is not None:
        filename = 'prediction_comparison.%s' % ext
        save_figure(f, path.join(plot_dir, filename), 
                    {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
def plot_prediction_relevance(results,
                              EFA=True,
                              classifier='ridge',
                              rotate='oblimin',
                              change=False,
                              size=4.6,
                              dpi=300,
                              ext='png',
                              plot_dir=None):
    """ Plots the relevant relevance of each factor for predicting all outcomes """
    predictions = results.load_prediction_object(EFA=EFA,
                                                 change=change,
                                                 classifier=classifier,
                                                 rotate=rotate)['data']

    targets = list(predictions.keys())
    predictors = predictions[targets[0]]['predvars']
    importances = abs(
        np.vstack([predictions[k]['importances'] for k in targets]))
    # scale to 0-1
    scaler = MinMaxScaler()
    scaled_importances = scaler.fit_transform(importances.T).T
    # make proportion
    scaled_importances = scaled_importances / np.expand_dims(
        scaled_importances.sum(1), 1)
    # convert to dataframe
    scaled_df = pd.DataFrame(scaled_importances,
                             index=targets,
                             columns=predictors)
    melted = scaled_df.melt(var_name='Factor', value_name='Importance')
    plt.figure(figsize=(8, 12))
    f = sns.boxplot(y='Factor', x='Importance', data=melted, width=.5)
    if plot_dir is not None:
        filename = 'prediction_relevance'
        save_figure(f, path.join(plot_dir, filename), {
            'bbox_inches': 'tight',
            'dpi': dpi
        })
        plt.close()
def plot_outcome_ontological_similarity(results, EFA=True, classifier='ridge', 
                                        rotate='oblimin', change=False, size=4.6, 
                                        dpi=300, ext='png',  plot_dir=None):
    """ plots similarity of ontological fingerprints between outcomes """
    predictions = results.load_prediction_object(EFA=EFA, 
                                                 change=change,
                                                 classifier=classifier,
                                                 rotate=rotate)['data']

    targets = list(predictions.keys())
    predictors = predictions[targets[0]]['predvars']
    importances = np.vstack([predictions[k]['importances'] for k in targets])
    # convert to dataframe
    df = pd.DataFrame(importances, index=targets, columns=predictors)
    plt.figure(figsize=(8,12))
    f=sns.clustermap(df.T.corr(),
                     cmap=sns.diverging_palette(220,15,n=100,as_cmap=True))
    ax = f.ax_heatmap
    ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
    if plot_dir is not None:
        filename = 'prediction_relevance'
        save_figure(f, path.join(plot_dir, filename), 
                    {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
예제 #35
0
def plot_GAM(gams, X, Y, size=4, dpi=300, ext='png', filename=None):
    cols = X.shape[1]
    rows = Y.shape[1]
    colors = sns.color_palette(n_colors=rows)
    plt.rcParams['figure.figsize'] = (cols * size, rows * size)
    fig, mat_axs = plt.subplots(rows, cols)
    titles = X.columns
    for j, (name, out) in enumerate(gams.items()):
        axs = mat_axs[j]
        gam = out['model']
        R2 = get_avg_score(out['scores_cv'])
        p_vals = gam.statistics_['p_values']
        for i, ax in enumerate(axs):
            plot_term(gam, i, ax, colors[j], size=size)
            ax.set_xlabel('')
            ax.text(.5,
                    .95,
                    'p< %s' % format_num(p_vals[i]),
                    va='center',
                    fontsize=size * 3,
                    transform=ax.transAxes)
            if j % 2 == 0:
                ax.set_title(titles[i], fontsize=size * 4)
            if i == 0:
                ax.set_ylabel(name + ' (%s)' % format_num(R2),
                              fontsize=size * 4)
            else:
                ax.set_ylabel('')

    plt.subplots_adjust(hspace=.4)
    if filename is not None:
        save_figure(fig, '%s.%s' % (filename, ext), {
            'bbox_inches': 'tight',
            'dpi': dpi
        })
        plt.close()
예제 #36
0
def plot_silhouette(results,
                    inp='data',
                    labels=None,
                    axes=None,
                    size=4.6,
                    dpi=300,
                    ext='png',
                    plot_dir=None):
    HCA = results.HCA
    clustering = HCA.results[inp]
    name = inp
    sample_scores, avg_score = silhouette_analysis(clustering, labels)
    # raw clustering for comparison
    raw_clustering = HCA.results['data']
    _, raw_avg_score = silhouette_analysis(raw_clustering, labels)

    if labels is None:
        labels = clustering['labels']
    n_clusters = len(np.unique(labels))
    colors = sns.hls_palette(n_clusters)
    if axes is None:
        fig, (ax, ax2) = plt.subplots(1, 2, figsize=(size, size * .375))
    else:
        ax, ax2 = axes
    y_lower = 5
    ax.grid(False)
    ax2.grid(linewidth=size / 10)
    cluster_names = HCA.get_cluster_names(inp=inp)
    for i in range(n_clusters):
        # Aggregate the silhouette scores for samples belonging to
        # cluster i, and sort them
        ith_cluster_silhouette_values = sample_scores[labels == i + 1]
        # skip "clusters" with one value
        if len(ith_cluster_silhouette_values) == 1:
            continue
        ith_cluster_silhouette_values.sort()
        size_cluster_i = ith_cluster_silhouette_values.shape[0]
        # update y range and plot
        y_upper = y_lower + size_cluster_i
        ax.fill_betweenx(np.arange(y_lower, y_upper),
                         0,
                         ith_cluster_silhouette_values,
                         alpha=0.7,
                         color=colors[i],
                         linewidth=size / 10)
        # Label the silhouette plots with their cluster numbers at the middle
        ax.text(-0.02,
                y_lower + 0.25 * size_cluster_i,
                cluster_names[i],
                fontsize=size / 1.7,
                ha='right')
        # Compute the new y_lower for next plot
        y_lower = y_upper + 5  # 10 for the 0 samples
    ax.axvline(x=avg_score, color="red", linestyle="--", linewidth=size * .1)
    ax.set_xlabel('Silhouette score', fontsize=size, labelpad=5)
    ax.set_ylabel('Cluster Separated DVs', fontsize=size)
    ax.tick_params(pad=size / 4,
                   length=size / 4,
                   labelsize=size * .8,
                   width=size / 10,
                   left=False,
                   labelleft=False,
                   bottom=True)
    ax.set_title('Dynamic tree cut', fontsize=size * 1.2, y=1.02)
    ax.set_xlim(-1, 1)
    # plot silhouettes for constant thresholds
    _, scores, _ = get_constant_height_labels(clustering)
    ax2.plot(*zip(*scores),
             'o',
             color='b',
             markeredgecolor='white',
             markeredgewidth=size * .1,
             markersize=size * .5,
             label='Fixed Height Cut')
    # plot the dynamic tree cut point
    ax2.plot(n_clusters,
             avg_score,
             'o',
             color='r',
             markeredgecolor='white',
             markeredgewidth=size * .1,
             markersize=size * .75,
             label='EFA Dynamic Cut')
    ax2.plot(n_clusters,
             raw_avg_score,
             'o',
             color='k',
             markeredgecolor='white',
             markeredgewidth=size * .1,
             markersize=size * .75,
             label='Raw Dynamic Cut')
    ax2.xaxis.set_major_locator(MaxNLocator(integer=True))
    ax2.set_xlabel('Number of clusters', fontsize=size)
    ax2.set_ylabel('Average Silhouette Score', fontsize=size)
    ax2.set_title('Single cut height', fontsize=size * 1.2, y=1.02)
    ax2.tick_params(labelsize=size * .8,
                    pad=size / 4,
                    length=size / 4,
                    width=size / 10,
                    bottom=True)
    ax2.legend(loc='center right', fontsize=size * .8)
    plt.subplots_adjust(wspace=.3)
    if plot_dir is not None:
        save_figure(
            fig, path.join(plot_dir,
                           'silhouette_analysis_%s.%s' % (name, ext)), {
                               'bbox_inches': 'tight',
                               'dpi': dpi
                           })
        plt.close()
def plot_dendrogram(loading, clustering, title=None, 
                    break_lines=True, drop_list=None, double_drop_list=None,
                    absolute_loading=False,  size=4.6,  dpi=300, 
                    filename=None):
    """ Plots HCA results as dendrogram with loadings underneath
    
    Args:
        loading: pandas df, a results EFA loading matrix
        clustering: pandas df, a results HCA clustering
        title (optional): str, title to plot
        break_lines: whether to separate EFA heatmap based on clusters, default=True
        drop_list (optional): list of cluster indices to drop the cluster label
        drop_list (optional): list of cluster indices to drop the cluster label twice
        absolute_loading: whether to plot the absolute loading value, default False
        plot_dir: if set, where to save the plot
        
    """


    c = loading.shape[1]
    # extract cluster vars
    link = clustering['linkage']
    DVs = clustering['clustered_df'].columns
    ordered_loading = loading.loc[DVs]
    if absolute_loading:
        ordered_loading = abs(ordered_loading)
    # get cluster sizes
    labels=clustering['labels']
    cluster_sizes = [np.sum(labels==(i+1)) for i in range(max(labels))]
    link_function, colors = get_dendrogram_color_fun(link, clustering['reorder_vec'],
                                                     labels)
    
    # set figure properties
    figsize = (size, size*.6)
    # set up axes' size 
    heatmap_height = ordered_loading.shape[1]*.035
    heat_size = [.1, heatmap_height]
    dendro_size=[np.sum(heat_size), .3]
    # set up plot axes
    dendro_size = [.15,dendro_size[0], .78, dendro_size[1]]
    heatmap_size = [.15,heat_size[0],.78,heat_size[1]]
    cbar_size = [.935,heat_size[0],.015,heat_size[1]]
    ordered_loading = ordered_loading.T

    with sns.axes_style('white'):
        fig = plt.figure(figsize=figsize)
        ax1 = fig.add_axes(dendro_size) 
        # **********************************
        # plot dendrogram
        # **********************************
        with plt.rc_context({'lines.linewidth': size*.125}):
            dendrogram(link, ax=ax1, link_color_func=link_function,
                       orientation='top')
        # change axis properties
        ax1.tick_params(axis='x', which='major', labelsize=14,
                        labelbottom=False)
        ax1.get_yaxis().set_visible(False)
        ax1.spines['top'].set_visible(False)
        ax1.spines['right'].set_visible(False)
        ax1.spines['bottom'].set_visible(False)
        ax1.spines['left'].set_visible(False)
        # **********************************
        # plot loadings as heatmap below
         # **********************************
        ax2 = fig.add_axes(heatmap_size)
        cbar_ax = fig.add_axes(cbar_size)
        max_val = np.max(abs(loading.values))
        # bring to closest .25
        max_val = ceil(max_val*4)/4
        sns.heatmap(ordered_loading, ax=ax2, 
                    cbar=True, cbar_ax=cbar_ax,
                    yticklabels=True,
                    xticklabels=True,
                    vmax =  max_val, vmin = -max_val,
                    cbar_kws={'orientation': 'vertical',
                              'ticks': [-max_val, 0, max_val]},
                    cmap=sns.diverging_palette(220,15,n=100,as_cmap=True))
        ax2.set_yticklabels(ax2.get_yticklabels(), rotation=0)
        ax2.tick_params(axis='y', labelsize=size*heat_size[1]*30/c, pad=size/4, length=0)            
        # format cbar axis
        cbar_ax.set_yticklabels([format_num(-max_val), 0, format_num(max_val)])
        cbar_ax.tick_params(labelsize=size*heat_size[1]*25/c, length=0, pad=size/2)
        cbar_ax.set_ylabel('Factor Loading', rotation=-90, 
                       fontsize=size*heat_size[1]*30/c, labelpad=size*2)
        # add lines to heatmap to distinguish clusters
        if break_lines == True:
            xlim = ax2.get_xlim(); 
            ylim = ax2.get_ylim()
            step = xlim[1]/len(labels)
            cluster_breaks = [i*step for i in np.cumsum(cluster_sizes)]
            ax2.vlines(cluster_breaks[:-1], ylim[0], ylim[1], linestyles='dashed',
                       linewidth=size*.1, colors=[.5,.5,.5], zorder=10)
        # **********************************
        # plot cluster names
        # **********************************
        beginnings = np.hstack([[0],np.cumsum(cluster_sizes)[:-1]])
        centers = beginnings+np.array(cluster_sizes)//2+.5
        offset = .07
        if 'cluster_names' in clustering.keys():
            ax2.tick_params(axis='x', reset=True, top=False, bottom=False, width=size/8, length=0)
            names = [transform_name(i) for i in clustering['cluster_names']]
            ax2.set_xticks(centers)
            ax2.set_xticklabels(names, rotation=0, ha='center', 
                                fontsize=heatmap_size[2]*size*1)
            ticks = ax2.xaxis.get_ticklines()[::2]
            for i, label in enumerate(ax2.get_xticklabels()):
                if label.get_text() != '':
                    ax2.hlines(c+offset,beginnings[i]+.5,beginnings[i]+cluster_sizes[i]-.5, 
                               clip_on=False, color=colors[i], linewidth=size/5)
                    label.set_color(colors[i])
                    ticks[i].set_color(colors[i])
                    y_drop = .005
                    line_drop = .3
                    if drop_list and i in drop_list:
                        y_drop = .05
                        line_drop = 1.6
                    if double_drop_list and i in double_drop_list:
                        y_drop = .1
                        line_drop = 2.9
                    label.set_y(-(y_drop/heatmap_height+heatmap_height/c*offset))
                    ax2.vlines(beginnings[i]+cluster_sizes[i]/2, 
                               c+offset, c+offset+line_drop,
                               clip_on=False, color=colors[i], 
                               linewidth=size/7.5)

        # add title
        if title:
            ax1.set_title(title, fontsize=size*2, y=1.05)
            
    if filename is not None:
        save_figure(fig, filename,
                    {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
    else:
        return fig
def plot_subbranch(target_color, cluster_i, tree, loading, cluster_sizes, title=None,
                   size=2.3, dpi=300, plot_loc=None):
    sns.set_style('white')
    colormap = sns.diverging_palette(220,15,n=100,as_cmap=True)
    # get variables in subbranch based on coloring
    curr_color = tree['color_list'][0]
    start = 0
    for i, color in enumerate(tree['color_list']):
        if color != curr_color:
            end = i
            if curr_color == to_hex(target_color):
                break
            if color != "#808080":
                start = i
            curr_color = color
    
    if (end-start)+1 != cluster_sizes[cluster_i]:
        return
    
    # get subset of loading
    cumsizes = np.cumsum(cluster_sizes)
    if cluster_i==0:
        loading_start = 0
    else:
        loading_start = cumsizes[cluster_i-1]
    subset_loading = loading.T.iloc[:,loading_start:cumsizes[cluster_i]]
    
    # plotting
    N = subset_loading.shape[1]
    length = N*.05
    dendro_size = [0,.746,length,.12]
    heatmap_size = [0,.5,length,.25]
    fig = plt.figure(figsize=(size,size*2))
    dendro_ax = fig.add_axes(dendro_size) 
    heatmap_ax = fig.add_axes(heatmap_size)
    cbar_size = [length+.22, .5, .05, .25]
    factor_avg_size = [length+.01,.5,.2,.25]
    factor_avg_ax = fig.add_axes(factor_avg_size)
    cbar_ax = fig.add_axes(cbar_size)
    #subset_loading.columns = [col.replace(': ',':\n', 1) for col in subset_loading.columns]
    plot_tree(tree, range(start, end), dendro_ax, linewidth=size/2)
    dendro_ax.set_xticklabels('')
    
    max_val = np.max(loading.values)
    # if max_val is high, just make it 1
    if max_val > .9:
        max_val = 1
    sns.heatmap(subset_loading, ax=heatmap_ax, 
                cbar=True,
                cbar_ax=cbar_ax,
                cbar_kws={'ticks': [-max_val, 0, max_val]},
                yticklabels=True,
                vmin=-max_val,
                vmax=max_val,
                cmap=colormap,)
    yn, xn = subset_loading.shape
    tick_label_size = size*30/max(yn, 8)
    heatmap_ax.tick_params(labelsize=tick_label_size, length=size*.5, 
                           width=size/5, pad=size)
    heatmap_ax.set_yticklabels(heatmap_ax.get_yticklabels(), rotation=0)
    heatmap_ax.set_xticks([i+.5 for i in range(0,subset_loading.shape[1])])
    heatmap_ax.set_xticklabels([str(i) for i in range(1,subset_loading.shape[1]+1)], 
                                size=size*2, rotation=0, ha='center')

    avg_factors = abs(subset_loading).mean(1)
    # format cbar axis
    cbar_ax.set_yticklabels([format_num(-max_val), 0, format_num(max_val)])
    cbar_ax.tick_params(axis='y', length=0)
    cbar_ax.tick_params(labelsize=size*3)
    cbar_ax.set_ylabel('Factor Loading', rotation=-90, fontsize=size*3,
                       labelpad=size*2)
    # add axis labels as text above
    text_ax = fig.add_axes([-.22,.44-.02*N,.4,.02*N]) 
    for spine in ['top','right','bottom','left']:
        text_ax.spines[spine].set_visible(False)
    for i, label in enumerate(subset_loading.columns):
        text_ax.text(0, 1-i/N, str(i+1)+'.', fontsize=size*2.8, ha='right')
        text_ax.text(.1, 1-i/N, label, fontsize=size*3)
    text_ax.tick_params(which='both', labelbottom=False, labelleft=False,
                        bottom=False, left=False)
    # average factor bar                
    avg_factors[::-1].plot(kind='barh', ax = factor_avg_ax, width=.7,
                     color= tree['color_list'][start])
    factor_avg_ax.set_xlim(0, max_val)
    #factor_avg_ax.set_xticks([max(avg_factors)])
    #factor_avg_ax.set_xticklabels([format_num(max(avg_factors))])
    factor_avg_ax.set_xticklabels('')
    factor_avg_ax.set_yticklabels('')
    factor_avg_ax.tick_params(length=0)
    factor_avg_ax.spines['top'].set_visible(False)
    factor_avg_ax.spines['bottom'].set_visible(False)
    factor_avg_ax.spines['left'].set_visible(False)
    factor_avg_ax.spines['right'].set_visible(False)
        
    # title and axes styling of dendrogram
    if title:
        dendro_ax.set_title(title, fontsize=size*3, y=1.05, fontweight='bold')
    dendro_ax.get_yaxis().set_visible(False)
    dendro_ax.spines['top'].set_visible(False)
    dendro_ax.spines['right'].set_visible(False)
    dendro_ax.spines['bottom'].set_visible(False)
    dendro_ax.spines['left'].set_visible(False)
    if plot_loc is not None:
        save_figure(fig, plot_loc, {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
    else:
        return fig
def plot_outcome_ontological_similarity(results,
                                        EFA=True,
                                        classifier='ridge',
                                        rotate='oblimin',
                                        change=False,
                                        size=4.6,
                                        dpi=300,
                                        ext='png',
                                        plot_dir=None):
    """ plots similarity of ontological fingerprints between outcomes """
    predictions = results.load_prediction_object(EFA=EFA,
                                                 change=change,
                                                 classifier=classifier,
                                                 rotate=rotate)['data']

    targets = list(predictions.keys())
    predictors = predictions[targets[0]]['predvars']
    importances = np.vstack([predictions[k]['importances'] for k in targets])
    # convert to dataframe
    df = pd.DataFrame(importances, index=targets, columns=predictors)
    clustered = hierarchical_cluster(df,
                                     pdist_kws={'metric': 'abscorrelation'})
    corr = 1 - clustered['clustered_df']
    mask = np.zeros_like(corr)
    mask[np.tril_indices_from(mask, -1)] = True
    n = len(corr)
    # plot
    f = plt.figure(figsize=(size * 5 / 4, size))
    ax1 = f.add_axes([0, 0, .9, .9])
    cbar_ax = f.add_axes([.91, .05, .03, .8])
    sns.heatmap(corr,
                ax=ax1,
                square=True,
                vmax=1,
                vmin=0,
                cbar_ax=cbar_ax,
                linewidth=2,
                cmap=sns.light_palette((15, 75, 50),
                                       input='husl',
                                       n_colors=100,
                                       as_cmap=True))
    sns.heatmap(corr,
                ax=ax1,
                square=True,
                vmax=1,
                vmin=0,
                cbar_ax=cbar_ax,
                annot=True,
                annot_kws={"size": size / n * 15},
                cmap=sns.light_palette((15, 75, 50),
                                       input='husl',
                                       n_colors=100,
                                       as_cmap=True),
                mask=mask,
                linewidth=2)
    yticklabels = ax1.get_yticklabels()
    ax1.set_yticklabels(yticklabels, rotation=0, ha="right")
    ax1.set_xticklabels(ax1.get_xticklabels(), rotation=90)

    ax1.tick_params(labelsize=size * 2)
    # format cbar
    cbar_ax.tick_params(axis='y', length=0)
    cbar_ax.tick_params(labelsize=size * 2)
    cbar_ax.set_ylabel('Pearson Correlation',
                       rotation=-90,
                       labelpad=size * 4,
                       fontsize=size * 3)
    if plot_dir is not None:
        filename = 'ontological_similarity.%s' % ext
        save_figure(f, path.join(plot_dir, filename), {
            'bbox_inches': 'tight',
            'dpi': dpi
        })
        plt.close()
def plot_corr_heatmap(all_results,
                      EFA=False,
                      size=4.6,
                      dpi=300,
                      ext='png',
                      plot_dir=None):
    def get_EFA_HCA(results, EFA):
        if EFA == False:
            return results.HCA.results['data']
        else:
            c = results.EFA.results['num_factors']
            return results.HCA.results['EFA%s_oblimin' % c]

    survey_order = get_EFA_HCA(all_results['survey'], EFA)['reorder_vec']
    task_order = get_EFA_HCA(all_results['task'], EFA)['reorder_vec']

    if EFA == False:
        all_data = pd.concat([
            all_results['task'].data.iloc[:, task_order],
            all_results['survey'].data.iloc[:, survey_order]
        ],
                             axis=1)
    else:
        all_data = pd.concat([
            all_results['task'].EFA.get_loading().T.iloc[:, task_order],
            all_results['survey'].EFA.get_loading().T.iloc[:, survey_order]
        ],
                             axis=1)

    f = plt.figure(figsize=(size, size))
    ax = f.add_axes([.05, .05, .8, .8])
    cbar_ax = f.add_axes([.86, .1, .04, .7])
    corr = abs(all_data.corr())
    sns.heatmap(corr,
                square=True,
                ax=ax,
                cbar_ax=cbar_ax,
                xticklabels=False,
                yticklabels=False,
                vmax=1,
                vmin=0,
                cbar_kws={'ticks': [0, 1]},
                cmap=ListedColormap(sns.color_palette('Reds', 100)))
    # add separating lines
    if ax.get_ylim()[0] > ax.get_ylim()[1]:
        ax.hlines(len(task_order),
                  0,
                  all_data.shape[1],
                  lw=size / 4,
                  color='k',
                  linestyle='--')
    else:
        ax.hlines(len(survey_order),
                  0,
                  all_data.shape[1],
                  lw=size / 4,
                  color='k',
                  linestyle='--')
    ax.vlines(len(task_order),
              0,
              all_data.shape[1],
              lw=size / 4,
              color='k',
              linestyle='--')
    # format cbar
    cbar_ax.tick_params(axis='y', length=0)
    cbar_ax.set_yticklabels([0, 1])
    cbar_ax.tick_params(labelsize=size * 2, pad=size / 2)
    cbar_ax.set_ylabel('Pearson Correlation',
                       rotation=-90,
                       labelpad=size * 2,
                       fontsize=size * 2)
    # add bars to indicate category
    left_ax = f.add_axes([.01, .05, .04, .8])
    bottom_ax = f.add_axes([.05, 0.01, .8, .04])
    left_ax.axis('off')
    bottom_ax.axis('off')
    perc_task = len(task_order) / all_data.shape[1]
    # add labels
    left_ax.text(0, (1 - perc_task / 2),
                 'Task DVs',
                 rotation=90,
                 va='center',
                 fontsize=size * 3)
    left_ax.text(0, ((1 - perc_task) / 2),
                 'Survey DVs',
                 rotation=90,
                 va='center',
                 fontsize=size * 3)
    bottom_ax.text(perc_task / 2,
                   0,
                   'Task DVs',
                   ha='center',
                   fontsize=size * 3)
    bottom_ax.text((1 - (1 - perc_task) / 2),
                   0,
                   'Survey DVs',
                   ha='center',
                   fontsize=size * 3)
    if plot_dir is not None:
        # make histogram plot
        save_figure(f, path.join(plot_dir, 'data_correlations.%s' % ext), {
            'dpi': dpi,
            'transparent': True
        })
        plt.close()
    else:
        return f
def plot_communality(results, c, rotate='oblimin', retest_threshold=.2,
                     size=4.6, dpi=300, ext='png', plot_dir=None):
    EFA = results.EFA
    communality = get_communality(EFA, rotate, c)
    # load retest data
    retest_data = get_retest_data(dataset=results.dataset.replace('Complete','Retest'))
    if retest_data is None:
        print('No retest data found for datafile: %s' % results.dataset)
        return
    
    # reorder data in line with communality
    retest_data = retest_data.loc[communality.index]
    # reformat variable names
    communality.index = format_variable_names(communality.index)
    retest_data.index = format_variable_names(retest_data.index)
    if len(retest_data) > 0:
        adjusted_communality,correlation, noise_ceiling = \
                get_adjusted_communality(communality, 
                                         retest_data,
                                         retest_threshold)
        
    # plot communality bars woo!
    if len(retest_data)>0:
        f, axes = plt.subplots(1, 3, figsize=(3*(size/10), size))
    
        plot_bar_factor(communality, axes[0], width=size/10, height=size,
                        label_rows=True,  title='Communality')
        plot_bar_factor(noise_ceiling, axes[1], width=size/10, height=size,
                        label_rows=False,  title='Test-Retest')
        plot_bar_factor(adjusted_communality, axes[2], width=size/10, height=size,
                        label_rows=False,  title='Adjusted Communality')
    else:
        f = plot_bar_factor(communality, label_rows=True, 
                            width=size/3, height=size*2, title='Communality')
    if plot_dir:
        filename = 'communality_bars-EFA%s.%s' % (c, ext)
        save_figure(f, path.join(plot_dir, filename), 
                    {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
    
    # plot communality histogram
    if len(retest_data) > 0:
        with sns.axes_style('white'):
            colors = sns.color_palette(n_colors=2, desat=.75)
            f, ax = plt.subplots(1,1,figsize=(size,size))
            sns.kdeplot(communality, linewidth=size/4, 
                        shade=True, label='Communality', color=colors[0])
            sns.kdeplot(adjusted_communality, linewidth=size/4, 
                        shade=True, label='Adjusted Communality', color=colors[1])
            ylim = ax.get_ylim()
            ax.vlines(np.mean(communality), ylim[0], ylim[1],
                      color=colors[0], linewidth=size/4, linestyle='--')
            ax.vlines(np.mean(adjusted_communality), ylim[0], ylim[1],
                      color=colors[1], linewidth=size/4, linestyle='--')
            leg=ax.legend(fontsize=size*2, loc='upper right')
            beautify_legend(leg, colors)
            plt.xlabel('Communality', fontsize=size*2)
            plt.ylabel('Normalized Density', fontsize=size*2)
            ax.set_yticks([])
            ax.tick_params(labelsize=size)
            ax.set_ylim(0, ax.get_ylim()[1])
            ax.set_xlim(0, ax.get_xlim()[1])
            ax.spines['right'].set_visible(False)
            #ax.spines['left'].set_visible(False)
            ax.spines['top'].set_visible(False)
            # add correlation
            correlation = format_num(np.mean(correlation))
            ax.text(1.1, 1.25, 'Correlation Between Communality \nand Test-Retest: %s' % correlation,
                    size=size*2)

        if plot_dir:
            filename = 'communality_dist-EFA%s.%s' % (c, ext)
            save_figure(f, path.join(plot_dir, filename), 
                        {'bbox_inches': 'tight', 'dpi': dpi})
            plt.close()
def plot_silhouette(results, inp='data', labels=None, axes=None,
                    size=4.6,  dpi=300,  ext='png', plot_dir=None):
    HCA = results.HCA
    clustering = HCA.results[inp]
    name = inp
    sample_scores, avg_score = silhouette_analysis(clustering, labels)
    # raw clustering for comparison
    raw_clustering = HCA.results['data']
    _, raw_avg_score = silhouette_analysis(raw_clustering, labels)
    
    if labels is None:
        labels = clustering['labels']
    n_clusters = len(np.unique(labels))
    colors = sns.hls_palette(n_clusters)
    if axes is None:
        fig, (ax, ax2) =  plt.subplots(1, 2, figsize=(size, size*.375))
    else:
        ax, ax2 = axes
    y_lower = 5
    ax.grid(False)
    ax2.grid(linewidth=size/10)
    cluster_names = HCA.get_cluster_names(inp=inp)
    for i in range(n_clusters):
        # Aggregate the silhouette scores for samples belonging to
        # cluster i, and sort them
        ith_cluster_silhouette_values = sample_scores[labels == i+1]
        # skip "clusters" with one value
        if len(ith_cluster_silhouette_values) == 1:
            continue
        ith_cluster_silhouette_values.sort()
        size_cluster_i = ith_cluster_silhouette_values.shape[0]
        # update y range and plot
        y_upper = y_lower + size_cluster_i
        ax.fill_betweenx(np.arange(y_lower, y_upper),
                          0, ith_cluster_silhouette_values,
                          alpha=0.7, color=colors[i],
                          linewidth=size/10)
        # Label the silhouette plots with their cluster numbers at the middle
        ax.text(-0.02, y_lower + 0.25 * size_cluster_i, cluster_names[i], fontsize=size/1.7, ha='right')
        # Compute the new y_lower for next plot
        y_lower = y_upper + 5  # 10 for the 0 samples
    ax.axvline(x=avg_score, color="red", linestyle="--", linewidth=size*.1)
    ax.set_xlabel('Silhouette score', fontsize=size, labelpad=5)
    ax.set_ylabel('Cluster Separated DVs', fontsize=size)
    ax.tick_params(pad=size/4, length=size/4, labelsize=size*.8, width=size/10,
                   left=False, labelleft=False, bottom=True)
    ax.set_title('Dynamic tree cut', fontsize=size*1.2, y=1.02)
    ax.set_xlim(-1, 1)
    # plot silhouettes for constant thresholds
    _, scores, _ = get_constant_height_labels(clustering)
    ax2.plot(*zip(*scores), 'o', color='b', 
             markeredgecolor='white', markeredgewidth=size*.1, markersize=size*.5, 
             label='Fixed Height Cut')
    # plot the dynamic tree cut point
    ax2.plot(n_clusters, avg_score, 'o', color ='r', 
             markeredgecolor='white', markeredgewidth=size*.1, markersize=size*.75, 
             label='EFA Dynamic Cut')
    ax2.plot(n_clusters, raw_avg_score, 'o', color ='k', 
             markeredgecolor='white', markeredgewidth=size*.1, markersize=size*.75, 
             label='Raw Dynamic Cut')
    ax2.xaxis.set_major_locator(MaxNLocator(integer=True))
    ax2.set_xlabel('Number of clusters', fontsize=size)
    ax2.set_ylabel('Average Silhouette Score', fontsize=size)
    ax2.set_title('Single cut height', fontsize=size*1.2, y=1.02)
    ax2.tick_params(labelsize=size*.8, pad=size/4, length=size/4, width=size/10, bottom=True)
    ax2.legend(loc='center right', fontsize=size*.8)
    plt.subplots_adjust(wspace=.3)
    if plot_dir is not None:
        save_figure(fig, path.join(plot_dir, 
                                         'silhouette_analysis_%s.%s' % (name, ext)),
                    {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
def plot_corr_hist(all_results,
                   reps=100,
                   size=4.6,
                   dpi=300,
                   ext='png',
                   plot_dir=None):
    colors = sns.color_palette('Blues_d', 3)[0:2] + sns.color_palette(
        'Reds_d', 2)[:1]
    survey_corr = abs(all_results['survey'].data.corr())
    task_corr = abs(all_results['task'].data.corr())
    all_data = pd.concat(
        [all_results['task'].data, all_results['survey'].data], axis=1)
    datasets = [('survey', all_results['survey'].data),
                ('task', all_results['task'].data), ('all', all_data)]
    # get cross corr
    cross_corr = abs(all_data.corr()).loc[survey_corr.columns,
                                          task_corr.columns]

    plot_elements = [(extract_tril(survey_corr.values, -1), 'Within Surveys'),
                     (extract_tril(task_corr.values, -1), 'Within Tasks'),
                     (cross_corr.values.flatten(), 'Surveys x Tasks')]

    # get shuffled 95% correlation
    shuffled_95 = []
    for label, df in datasets:
        shuffled_corr = np.array([])
        for _ in range(reps):
            # create shuffled
            shuffled = df.copy()
            for i in shuffled:
                shuffle_vec = shuffled[i].sample(len(shuffled)).tolist()
                shuffled.loc[:, i] = shuffle_vec
            if label == 'all':
                shuffled_corr = abs(shuffled.corr()).loc[survey_corr.columns,
                                                         task_corr.columns]
            else:
                shuffled_corr = abs(shuffled.corr())
            np.append(shuffled_corr, extract_tril(shuffled_corr.values, -1))
        shuffled_95.append(np.percentile(shuffled_corr, 95))

    # get cross_validated r2
    average_r2 = {}
    for (slabel, source), (tlabel, target) in product(datasets[:-1], repeat=2):
        scores = []
        for var, values in target.iteritems():
            if var in source.columns:
                predictors = source.drop(var, axis=1)
            else:
                predictors = source
            lr = RidgeCV()
            cv_score = np.mean(cross_val_score(lr, predictors, values, cv=10))
            scores.append(cv_score)
        average_r2[(slabel, tlabel)] = np.mean(scores)

    # bring everything together
    plot_elements = [
        (extract_tril(survey_corr.values,
                      -1), 'Within Surveys', average_r2[('survey', 'survey')]),
        (extract_tril(task_corr.values,
                      -1), 'Within Tasks', average_r2[('task', 'task')]),
        (cross_corr.values.flatten(), 'Surveys x Tasks', average_r2[('survey',
                                                                     'task')])
    ]

    with sns.axes_style('white'):
        f, axes = plt.subplots(1, 3, figsize=(10, 4))
        plt.subplots_adjust(wspace=.3)
        for i, (corr, label, r2) in enumerate(plot_elements):
            #h = axes[i].hist(corr, normed=True, color=colors[i],
            #         bins=12, label=label, rwidth=1, alpha=.4)
            sns.kdeplot(corr,
                        ax=axes[i],
                        color=colors[i],
                        shade=True,
                        label=label,
                        linewidth=3)
            axes[i].text(.4, axes[i].get_ylim()[1] * .5,
                         'CV-R2: {0:.2f}'.format(r2))
        for i, ax in enumerate(axes):
            ax.vlines(shuffled_95[i],
                      *ax.get_ylim(),
                      color=[.2, .2, .2],
                      linewidth=2,
                      linestyle='dashed',
                      zorder=10)
            ax.set_xlim(0, 1)
            ax.set_ylim(0, ax.get_ylim()[1])
            ax.set_xticks([0, .5, 1])
            ax.set_xticklabels([0, .5, 1], fontsize=16)
            ax.set_yticks([])
            ax.spines['right'].set_visible(False)
            #ax.spines['left'].set_visible(False)
            ax.spines['top'].set_visible(False)
            leg = ax.legend(fontsize=14, loc='upper center')
            beautify_legend(leg, [colors[i]])
        axes[1].set_xlabel('Pearson Correlation', fontsize=20, labelpad=10)
        axes[0].set_ylabel('Normalized Density', fontsize=20, labelpad=10)

    # save
    if plot_dir is not None:
        # make histogram plot
        save_figure(f,
                    path.join(plot_dir, 'within-across_correlations.%s' % ext),
                    {
                        'bbox_inches': 'tight',
                        'dpi': dpi
                    })
def plot_corr_hist(all_results, reps=100, size=4.6, 
                   dpi=300, ext='png', plot_dir=None):
    colors = sns.color_palette('Blues_d',3)[0:2] + sns.color_palette('Reds_d',2)[:1]
    survey_corr = abs(all_results['survey'].data.corr())
    task_corr = abs(all_results['task'].data.corr())
    all_data = pd.concat([all_results['task'].data, all_results['survey'].data], axis=1)
    datasets = [('survey', all_results['survey'].data), 
                ('task', all_results['task'].data), 
                ('all', all_data)]
    # get cross corr
    cross_corr = abs(all_data.corr()).loc[survey_corr.columns,
                                                    task_corr.columns]
    
    plot_elements = [(extract_tril(survey_corr.values,-1), 'Within Surveys'),
                     (extract_tril(task_corr.values,-1), 'Within Tasks'),
                     (cross_corr.values.flatten(), 'Surveys x Tasks')]
    
    # get shuffled 95% correlation
    shuffled_95 = []
    for label, df in datasets:
        shuffled_corr = np.array([])
        for _ in range(reps):
            # create shuffled
            shuffled = df.copy()
            for i in shuffled:
                shuffle_vec = shuffled[i].sample(len(shuffled)).tolist()
                shuffled.loc[:,i] = shuffle_vec
            if label == 'all':
                shuffled_corr = abs(shuffled.corr()).loc[survey_corr.columns,
                                                    task_corr.columns]
            else:
                shuffled_corr = abs(shuffled.corr())
            np.append(shuffled_corr, extract_tril(shuffled_corr.values,-1))
        shuffled_95.append(np.percentile(shuffled_corr,95))
    
    # get cross_validated r2
    average_r2 = {}
    for (slabel, source), (tlabel, target) in product(datasets[:-1], repeat=2):
        scores = []
        for var, values in target.iteritems():
            if var in source.columns:
                predictors = source.drop(var, axis=1)
            else:
                predictors = source
            lr = RidgeCV()  
            cv_score = np.mean(cross_val_score(lr, predictors, values, cv=10))
            scores.append(cv_score)
        average_r2[(slabel, tlabel)] = np.mean(scores)

                
    # bring everything together
    plot_elements = [(extract_tril(survey_corr.values,-1), 'Within Surveys', 
                      average_r2[('survey','survey')]),
                     (extract_tril(task_corr.values,-1), 'Within Tasks',
                      average_r2[('task','task')]),
                     (cross_corr.values.flatten(), 'Surveys x Tasks',
                      average_r2[('survey', 'task')])]
    
    with sns.axes_style('white'):
        f, axes = plt.subplots(1,3, figsize=(10,4))
        plt.subplots_adjust(wspace=.3)
        for i, (corr, label, r2) in enumerate(plot_elements):
            #h = axes[i].hist(corr, normed=True, color=colors[i], 
            #         bins=12, label=label, rwidth=1, alpha=.4)
            sns.kdeplot(corr, ax=axes[i], color=colors[i], shade=True,
                        label=label, linewidth=3)
            axes[i].text(.4, axes[i].get_ylim()[1]*.5, 'CV-R2: {0:.2f}'.format(r2))
        for i, ax in enumerate(axes):
            ax.vlines(shuffled_95[i], *ax.get_ylim(), color=[.2,.2,.2], 
                      linewidth=2, linestyle='dashed', zorder=10)
            ax.set_xlim(0,1)
            ax.set_ylim(0, ax.get_ylim()[1])
            ax.set_xticks([0,.5,1])
            ax.set_xticklabels([0,.5,1], fontsize=16)
            ax.set_yticks([])
            ax.spines['right'].set_visible(False)
            #ax.spines['left'].set_visible(False)
            ax.spines['top'].set_visible(False)
            leg=ax.legend(fontsize=14, loc='upper center')
            beautify_legend(leg, [colors[i]])
        axes[1].set_xlabel('Pearson Correlation', fontsize=20, labelpad=10)
        axes[0].set_ylabel('Normalized Density', fontsize=20, labelpad=10)
    
    # save
    if plot_dir is not None:
        # make histogram plot
        save_figure(f, path.join(plot_dir, 'within-across_correlations.%s' % ext),
                                {'bbox_inches': 'tight', 'dpi': dpi})
def plot_prediction_scatter(results,
                            target_order=None,
                            EFA=True,
                            change=False,
                            classifier='ridge',
                            rotate='oblimin',
                            normalize=False,
                            metric='R2',
                            size=4.6,
                            dpi=300,
                            ext='png',
                            plot_dir=None):
    predictions = results.load_prediction_object(EFA=EFA,
                                                 change=change,
                                                 classifier=classifier,
                                                 rotate=rotate)
    if predictions is None:
        print('No prediction object found!')
        return
    else:
        predictions = predictions['data']
    if EFA:
        predictors = results.EFA.get_scores()
    else:
        predictors = results.data
    if change:
        target_factors, _ = results.DA.get_change(
            results.dataset.replace('Complete', 'Retest'))
        predictors = predictors.loc[target_factors.index]
    else:
        target_factors = results.DA.get_scores()

    sns.set_style('whitegrid')
    n_cols = 2
    n_rows = math.ceil(len(target_factors.columns) / n_cols)
    fig, axes = plt.subplots(n_rows,
                             n_cols,
                             figsize=(size, size / n_cols * n_rows))
    axes = fig.get_axes()
    for i, v in enumerate(target_factors.columns):
        MAE = format_num(predictions[v]['scores_cv'][0]['MAE'])
        R2 = format_num(predictions[v]['scores_cv'][0]['R2'])
        axes[i].set_title('%s: R2: %s, MAE: %s' % (v, R2, MAE),
                          fontweight='bold',
                          fontsize=size * 1.5)
        clf = predictions[v]['clf']
        axes[i].scatter(target_factors[v], clf.predict(predictors), s=size * 3)
        axes[i].tick_params(length=0, labelsize=0)
        if i % 2 == 0:
            axes[i].set_ylabel('Predicted Factor Score', fontsize=size * 1.5)
    axes[i].set_xlabel('Target Factor Score', fontsize=size * 1.5)
    axes[i - 1].set_xlabel('Target Factor Score', fontsize=size * 1.5)

    empty_plots = n_cols * n_rows - len(target_factors.columns)
    for ax in axes[-empty_plots:]:
        ax.set_visible(False)
    plt.subplots_adjust(hspace=.4, wspace=.3)

    if plot_dir is not None:
        changestr = '_change' if change else ''
        if EFA:
            filename = 'EFA%s_%s_prediction_scatter.%s' % (changestr,
                                                           classifier, ext)
        else:
            filename = 'IDM%s_%s_prediction_scatter.%s' % (changestr,
                                                           classifier, ext)
        save_figure(fig, path.join(plot_dir, filename), {
            'bbox_inches': 'tight',
            'dpi': dpi
        })
        plt.close()
예제 #46
0
def plot_subbranch(target_color,
                   cluster_i,
                   tree,
                   loading,
                   cluster_sizes,
                   title=None,
                   size=2.3,
                   dpi=300,
                   plot_loc=None):
    sns.set_style('white')
    colormap = sns.diverging_palette(220, 15, n=100, as_cmap=True)
    # get variables in subbranch based on coloring
    curr_color = tree['color_list'][0]
    start = 0
    for i, color in enumerate(tree['color_list']):
        if color != curr_color:
            end = i
            if curr_color == to_hex(target_color):
                break
            if color != "#808080":
                start = i
            curr_color = color

    if (end - start) + 1 != cluster_sizes[cluster_i]:
        return

    # get subset of loading
    cumsizes = np.cumsum(cluster_sizes)
    if cluster_i == 0:
        loading_start = 0
    else:
        loading_start = cumsizes[cluster_i - 1]
    subset_loading = loading.T.iloc[:, loading_start:cumsizes[cluster_i]]

    # plotting
    N = subset_loading.shape[1]
    length = N * .05
    dendro_size = [0, .746, length, .12]
    heatmap_size = [0, .5, length, .25]
    fig = plt.figure(figsize=(size * 2, size * 4))
    dendro_ax = fig.add_axes(dendro_size)
    heatmap_ax = fig.add_axes(heatmap_size)
    cbar_size = [length + .22, .5, .05, .25]
    factor_avg_size = [length + .01, .5, .2, .25]
    factor_avg_ax = fig.add_axes(factor_avg_size)
    cbar_ax = fig.add_axes(cbar_size)
    #subset_loading.columns = [col.replace(': ',':\n', 1) for col in subset_loading.columns]
    plot_tree(tree, range(start, end), dendro_ax, linewidth=size / 2)
    dendro_ax.set_xticklabels('')

    max_val = np.max(loading.values)
    # if max_val is high, just make it 1
    if max_val > .9:
        max_val = 1
    sns.heatmap(
        subset_loading,
        ax=heatmap_ax,
        cbar=True,
        cbar_ax=cbar_ax,
        cbar_kws={'ticks': [-max_val, 0, max_val]},
        yticklabels=True,
        vmin=-max_val,
        vmax=max_val,
        cmap=colormap,
    )
    yn, xn = subset_loading.shape
    tick_label_size = size * 30 / max(yn, 8)
    heatmap_ax.tick_params(labelsize=tick_label_size,
                           length=size * .5,
                           width=size / 5,
                           pad=size)
    heatmap_ax.set_yticklabels(heatmap_ax.get_yticklabels(), rotation=0)
    heatmap_ax.set_xticks([i + .5 for i in range(0, subset_loading.shape[1])])
    heatmap_ax.set_xticklabels(
        [str(i) for i in range(1, subset_loading.shape[1] + 1)],
        size=size * 2,
        rotation=0,
        ha='center')

    avg_factors = abs(subset_loading).mean(1)
    # format cbar axis
    cbar_ax.set_yticklabels([format_num(-max_val), 0, format_num(max_val)])
    cbar_ax.tick_params(axis='y', length=0)
    cbar_ax.tick_params(labelsize=size * 3)
    cbar_ax.set_ylabel('Factor Loading',
                       rotation=-90,
                       fontsize=size * 3,
                       labelpad=size * 2)
    # add axis labels as text above
    text_ax = fig.add_axes([-.22, .44 - .02 * N, .4, .02 * N])
    for spine in ['top', 'right', 'bottom', 'left']:
        text_ax.spines[spine].set_visible(False)
    for i, label in enumerate(subset_loading.columns):
        text_ax.text(0,
                     1 - i / N,
                     str(i + 1) + '.',
                     fontsize=size * 2.8,
                     ha='right')
        text_ax.text(.1, 1 - i / N, label, fontsize=size * 3)
    text_ax.tick_params(which='both',
                        labelbottom=False,
                        labelleft=False,
                        bottom=False,
                        left=False)
    # average factor bar
    avg_factors[::-1].plot(kind='barh',
                           ax=factor_avg_ax,
                           width=.7,
                           color=tree['color_list'][start])
    factor_avg_ax.set_xlim(0, max_val)
    #factor_avg_ax.set_xticks([max(avg_factors)])
    #factor_avg_ax.set_xticklabels([format_num(max(avg_factors))])
    factor_avg_ax.set_xticklabels('')
    factor_avg_ax.set_yticklabels('')
    factor_avg_ax.tick_params(length=0)
    factor_avg_ax.spines['top'].set_visible(False)
    factor_avg_ax.spines['bottom'].set_visible(False)
    factor_avg_ax.spines['left'].set_visible(False)
    factor_avg_ax.spines['right'].set_visible(False)

    # title and axes styling of dendrogram
    if title:
        dendro_ax.set_title(title,
                            fontsize=size * 3,
                            y=1.05,
                            fontweight='bold')
    dendro_ax.get_yaxis().set_visible(False)
    dendro_ax.spines['top'].set_visible(False)
    dendro_ax.spines['right'].set_visible(False)
    dendro_ax.spines['bottom'].set_visible(False)
    dendro_ax.spines['left'].set_visible(False)
    if plot_loc is not None:
        try:
            print('about to crash? - dpi: ' + str(dpi))
            save_figure(fig, plot_loc, {'bbox_inches': 'tight', 'dpi': dpi})
            plt.close()
        except ValueError:
            print('something when wrong with that plot')
            plt.close()
    else:
        return fig
def plot_corr_heatmap(all_results, EFA=False, size=4.6, 
                   dpi=300, ext='png', plot_dir=None):
    def get_EFA_HCA(results, EFA):
        if EFA == False:
            return results.HCA.results['data']
        else:
            c = results.EFA.results['num_factors']
            return results.HCA.results['EFA%s_oblimin' % c]
    

    survey_order = get_EFA_HCA(all_results['survey'], EFA)['reorder_vec']
    task_order = get_EFA_HCA(all_results['task'], EFA)['reorder_vec']
    
    if EFA == False:
        all_data = pd.concat([all_results['task'].data.iloc[:, task_order], 
                              all_results['survey'].data.iloc[:, survey_order]], 
                            axis=1)
    else:
        all_data = pd.concat([all_results['task'].EFA.get_loading().T.iloc[:, task_order], 
                              all_results['survey'].EFA.get_loading().T.iloc[:, survey_order]], 
                            axis=1)

    f = plt.figure(figsize=(size,size))
    ax = f.add_axes([.05,.05,.8,.8])
    cbar_ax = f.add_axes([.86,.1,.04,.7])
    corr = abs(all_data.corr())
    sns.heatmap(corr, square=True, ax=ax, cbar_ax=cbar_ax,
                xticklabels=False, yticklabels=False,
                vmax=1, vmin=0,
                cbar_kws={'ticks': [0, 1]},
                cmap=ListedColormap(sns.color_palette('Reds', 100)))
    # add separating lines
    if ax.get_ylim()[0] > ax.get_ylim()[1]:
        ax.hlines(len(task_order), 0, all_data.shape[1], lw=size/4, 
                   color='k', linestyle='--')
    else:
        ax.hlines(len(survey_order), 0, all_data.shape[1], lw=size/4, 
                   color='k', linestyle='--')
    ax.vlines(len(task_order), 0, all_data.shape[1], lw=size/4, 
               color='k', linestyle='--')
    # format cbar
    cbar_ax.tick_params(axis='y', length=0)
    cbar_ax.set_yticklabels([0, 1])
    cbar_ax.tick_params(labelsize=size*2, pad=size/2)
    cbar_ax.set_ylabel('Pearson Correlation', rotation=-90, labelpad=size*2, fontsize=size*2)
    # add bars to indicate category
    left_ax = f.add_axes([.01,.05,.04,.8])
    bottom_ax = f.add_axes([.05,0.01,.8,.04])
    left_ax.axis('off'); bottom_ax.axis('off')
    perc_task = len(task_order)/all_data.shape[1]
    # add labels
    left_ax.text(0, (1-perc_task/2), 'Task DVs', rotation=90, va='center', fontsize=size*3)
    left_ax.text(0, ((1-perc_task)/2), 'Survey DVs', rotation=90, va='center', fontsize=size*3)
    bottom_ax.text(perc_task/2, 0, 'Task DVs', ha='center', fontsize=size*3)
    bottom_ax.text((1-(1-perc_task)/2), 0, 'Survey DVs', ha='center', fontsize=size*3)
    if plot_dir is not None:
        # make histogram plot
        save_figure(f, path.join(plot_dir, 'data_correlations.%s' % ext),
                                {'dpi': dpi,
                                 'transparent': True})   
        plt.close()
    else:
        return f
예제 #48
0
def plot_dendrogram(loading,
                    clustering,
                    title=None,
                    break_lines=True,
                    drop_list=None,
                    double_drop_list=None,
                    absolute_loading=False,
                    size=4.6,
                    dpi=300,
                    filename=None):
    """ Plots HCA results as dendrogram with loadings underneath
    
    Args:
        loading: pandas df, a results EFA loading matrix
        clustering: pandas df, a results HCA clustering
        title (optional): str, title to plot
        break_lines: whether to separate EFA heatmap based on clusters, default=True
        drop_list (optional): list of cluster indices to drop the cluster label
        drop_list (optional): list of cluster indices to drop the cluster label twice
        absolute_loading: whether to plot the absolute loading value, default False
        plot_dir: if set, where to save the plot
        
    """

    c = loading.shape[1]
    # extract cluster vars
    link = clustering['linkage']
    DVs = clustering['clustered_df'].columns
    ordered_loading = loading.loc[DVs]
    if absolute_loading:
        ordered_loading = abs(ordered_loading)
    # get cluster sizes
    labels = clustering['labels']
    cluster_sizes = [np.sum(labels == (i + 1)) for i in range(max(labels))]
    link_function, colors = get_dendrogram_color_fun(link,
                                                     clustering['reorder_vec'],
                                                     labels)

    # set figure properties
    figsize = (size, size * .6)
    # set up axes' size
    heatmap_height = ordered_loading.shape[1] * .035
    heat_size = [.1, heatmap_height]
    dendro_size = [np.sum(heat_size), .3]
    # set up plot axes
    dendro_size = [.15, dendro_size[0], .78, dendro_size[1]]
    heatmap_size = [.15, heat_size[0], .78, heat_size[1]]
    cbar_size = [.935, heat_size[0], .015, heat_size[1]]
    ordered_loading = ordered_loading.T

    with sns.axes_style('white'):
        fig = plt.figure(figsize=figsize)
        ax1 = fig.add_axes(dendro_size)
        # **********************************
        # plot dendrogram
        # **********************************
        with plt.rc_context({'lines.linewidth': size * .125}):
            dendrogram(link,
                       ax=ax1,
                       link_color_func=link_function,
                       orientation='top')
        # change axis properties
        ax1.tick_params(axis='x',
                        which='major',
                        labelsize=14,
                        labelbottom=False)
        ax1.get_yaxis().set_visible(False)
        ax1.spines['top'].set_visible(False)
        ax1.spines['right'].set_visible(False)
        ax1.spines['bottom'].set_visible(False)
        ax1.spines['left'].set_visible(False)
        # **********************************
        # plot loadings as heatmap below
        # **********************************
        ax2 = fig.add_axes(heatmap_size)
        cbar_ax = fig.add_axes(cbar_size)
        max_val = np.max(abs(loading.values))
        # bring to closest .25
        max_val = ceil(max_val * 4) / 4
        sns.heatmap(ordered_loading,
                    ax=ax2,
                    cbar=True,
                    cbar_ax=cbar_ax,
                    yticklabels=True,
                    xticklabels=True,
                    vmax=max_val,
                    vmin=-max_val,
                    cbar_kws={
                        'orientation': 'vertical',
                        'ticks': [-max_val, 0, max_val]
                    },
                    cmap=sns.diverging_palette(220, 15, n=100, as_cmap=True))
        ax2.set_yticklabels(ax2.get_yticklabels(), rotation=0)
        ax2.tick_params(axis='y',
                        labelsize=size * heat_size[1] * 30 / c,
                        pad=size / 4,
                        length=0)
        # format cbar axis
        cbar_ax.set_yticklabels([format_num(-max_val), 0, format_num(max_val)])
        cbar_ax.tick_params(labelsize=size * heat_size[1] * 25 / c,
                            length=0,
                            pad=size / 2)
        cbar_ax.set_ylabel('Factor Loading',
                           rotation=-90,
                           fontsize=size * heat_size[1] * 30 / c,
                           labelpad=size * 2)

        # add lines to heatmap to distinguish clusters
        if break_lines == True:
            xlim = ax2.get_xlim()
            ylim = ax2.get_ylim()
            step = xlim[1] / len(labels)
            cluster_breaks = [i * step for i in np.cumsum(cluster_sizes)]
            ax2.vlines(cluster_breaks[:-1],
                       ylim[0],
                       ylim[1],
                       linestyles='dashed',
                       linewidth=size * .1,
                       colors=[.5, .5, .5],
                       zorder=10)
        # **********************************
        # plot cluster names
        # **********************************
        beginnings = np.hstack([[0], np.cumsum(cluster_sizes)[:-1]])
        centers = beginnings + np.array(cluster_sizes) // 2 + .5
        offset = .07
        if 'cluster_names' in clustering.keys():
            ax2.tick_params(axis='x',
                            reset=True,
                            top=False,
                            bottom=False,
                            width=size / 8,
                            length=0)
            names = [transform_name(i) for i in clustering['cluster_names']]
            ax2.set_xticks(centers)
            ax2.set_xticklabels(names,
                                rotation=0,
                                ha='center',
                                fontsize=heatmap_size[2] * size * 1)
            ticks = ax2.xaxis.get_ticklines()[::2]
            for i, label in enumerate(ax2.get_xticklabels()):
                if label.get_text() != '':
                    ax2.hlines(c + offset,
                               beginnings[i] + .5,
                               beginnings[i] + cluster_sizes[i] - .5,
                               clip_on=False,
                               color=colors[i],
                               linewidth=size / 5)
                    label.set_color(colors[i])
                    ticks[i].set_color(colors[i])
                    y_drop = .005
                    line_drop = .3
                    if drop_list and i in drop_list:
                        y_drop = .05
                        line_drop = 1.6
                    if double_drop_list and i in double_drop_list:
                        y_drop = .1
                        line_drop = 2.9
                    label.set_y(-(y_drop / heatmap_height +
                                  heatmap_height / c * offset))
                    ax2.vlines(beginnings[i] + cluster_sizes[i] / 2,
                               c + offset,
                               c + offset + line_drop,
                               clip_on=False,
                               color=colors[i],
                               linewidth=size / 7.5)

        # add title
        if title:
            ax1.set_title(title, fontsize=size * 2, y=1.05)

    if filename is not None:
        save_figure(fig, filename, {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
    else:
        return fig
def plot_BIC(all_results, size=4.6, dpi=300, ext='png', plot_dir=None):
    """ Plots BIC and SABIC curves
    
    Args:
        all_results: a dimensional structure all_results object
        dpi: the final dpi for the image
        ext: the extension for the saved figure
        plot_dir: the directory to save the figure. If none, do not save
    """
    all_colors = [
        sns.color_palette('Blues_d', 3)[0:3],
        sns.color_palette('Reds_d', 3)[0:3],
        sns.color_palette('Greens_d', 3)[0:3],
        sns.color_palette('Oranges_d', 3)[0:3]
    ]
    height = size * .75 / len(all_results)
    with sns.axes_style('white'):
        fig, axes = plt.subplots(1, len(all_results), figsize=(size, height))
    for i, results in enumerate(
        [all_results[key] for key in ['task', 'survey']]):
        ax1 = axes[i]
        name = results.ID.split('_')[0].title()
        EFA = results.EFA
        # Plot BIC and SABIC curves
        colors = all_colors[i]
        with sns.axes_style('white'):
            x = list(EFA.results['cscores_metric-BIC'].keys())
            # score keys
            keys = [k for k in EFA.results.keys() if 'cscores' in k]
            for key in keys:
                metric = key.split('-')[-1]
                BIC_scores = [EFA.results[key][i] for i in x]
                BIC_c = EFA.results['c_metric-%s' % metric]
                ax1.plot(x,
                         BIC_scores,
                         'o-',
                         c=colors[0],
                         lw=size / 6,
                         label=metric,
                         markersize=height * 2)
                ax1.plot(BIC_c,
                         BIC_scores[BIC_c - 1],
                         '.',
                         color='white',
                         markeredgecolor=colors[0],
                         markeredgewidth=height / 2,
                         markersize=height * 4)
            if i == 0:
                if len(keys) > 1:
                    ax1.set_ylabel('Score', fontsize=height * 3)
                    leg = ax1.legend(loc='center right',
                                     fontsize=height * 3,
                                     markerscale=0)
                    beautify_legend(leg, colors=colors)
                else:
                    ax1.set_ylabel(metric, fontsize=height * 4)
            ax1.set_xlabel('# Factors', fontsize=height * 4)
            ax1.set_xticks(x)
            ax1.set_xticklabels(x)
            ax1.tick_params(labelsize=height * 2, pad=size / 4, length=0)
            ax1.set_title(name, fontsize=height * 4, y=1.01)
            ax1.grid(linewidth=size / 8)
            [i.set_linewidth(size * .1) for i in ax1.spines.values()]
    if plot_dir is not None:
        save_figure(fig, path.join(plot_dir, 'BIC_curves.%s' % ext), {
            'bbox_inches': 'tight',
            'dpi': dpi
        })
        plt.close()
def plot_prediction(predictions, comparison_predictions, 
                    colors=None, EFA=None, comparison_label=None,
                    target_order=None,  metric='R2', size=4.6,  
                    dpi=300, filename=None):
    if colors is None:
        colors = [sns.color_palette('Purples_d', 4)[i] for i in [1,3]]
    if comparison_label is None:
        comparison_label = '95% shuffled prediction'
    basefont = max(size, 5)
    sns.set_style('white')
    if target_order is None:
        target_order = predictions.keys()
    # get prediction success
    r2s = [[k,predictions[k]['scores_cv'][0][metric]] for k in target_order]
    insample_r2s = [[k, predictions[k]['scores_insample'][0][metric]] for k in target_order]
    # get shuffled values
    shuffled_r2s = []
    insample_shuffled_r2s = []
    for i, k in enumerate(target_order):
        # normalize r2s to significance
        R2s = [i[metric] for i in comparison_predictions[k]['scores_cv']]
        R2_95 = np.percentile(R2s, 95)
        shuffled_r2s.append((k,R2_95))
        # and insample
        R2s = [i[metric] for i in comparison_predictions[k]['scores_insample']]
        R2_95 = np.percentile(R2s, 95)
        insample_shuffled_r2s.append((k,R2_95))
        
    # convert nans to 0
    r2s = [(i, k) if k==k else (i,0) for i, k in r2s]
    insample_r2s = [(i, k) if k==k else (i,0) for i, k in insample_r2s]
    shuffled_r2s = [(i, k) if k==k else (i,0) for i, k in shuffled_r2s]
    
    # plot
    shuffled_grey = [.3,.3,.3]
    # plot variables
    figsize = (size, size*.75)
    fig = plt.figure(figsize=figsize)
    # plot bars
    ind = np.arange(len(r2s))
    width=.25
    ax1 = fig.add_axes([0,0,1,.5]) 
    ax1.bar(ind, [i[1] for i in r2s], width, 
            label='Cross-validated prediction', color=colors[0])
    ax1.bar(ind+width, [i[1] for i in insample_r2s], width, 
            label='Insample prediction', color=colors[1])
    # plot shuffled values above
    ax1.bar(ind, [i[1] for i in shuffled_r2s], width, 
             color='none', edgecolor=shuffled_grey, 
            linewidth=size/10, linestyle='--', label=comparison_label)
    ax1.bar(ind+width, [i[1] for i in insample_shuffled_r2s], width, 
            color='none', edgecolor=shuffled_grey, 
            linewidth=size/10, linestyle='--')
    
    ax1.set_xticks(np.arange(0,len(r2s))+width/2)
    ax1.set_xticklabels([i[0] for i in r2s], rotation=15, fontsize=basefont*1.4)
    ax1.tick_params(axis='y', labelsize=size*1.2)
    ax1.tick_params(length=size/4, width=size/10, pad=size/2, left=True, bottom=True)
    xlow, xhigh = ax1.get_xlim()
    if metric == 'R2':
        ax1.set_ylabel(r'$R^2$', fontsize=basefont*1.5, labelpad=size*1.5)
    else:
        ax1.set_ylabel(metric, fontsize=basefont*1.5, labelpad=size*1.5)
    # add a legend
    leg = ax1.legend(fontsize=basefont*1.4, loc='upper left', framealpha=1,
                     frameon=True, handlelength=0, handletextpad=0)
    leg.get_frame().set_linewidth(size/10)
    beautify_legend(leg, colors[:2]+[shuffled_grey])
    # change y extents
    ylim = ax1.get_ylim()
    r2_max = max(max(r2s, key=lambda x: x[1])[1],
                 max(insample_r2s, key=lambda x: x[1])[1])
    ymax = r2_max*1.5
    ax1.set_ylim(ylim[0], ymax)
    # change yticks
    if ymax<.15:
        ax1.set_ylim(ylim[0], .15)
        ax1.yaxis.set_major_locator(ticker.MultipleLocator(.025))
    else:
        ax1.yaxis.set_major_locator(ticker.MultipleLocator(.05))
        ax1.set_yticks(np.append([0, .025, .05, .075, .1, .125], np.arange(.15, .45, .05)))
    # draw grid
    ax1.set_axisbelow(True)
    plt.grid(axis='y', linestyle='dotted', linewidth=size/6)
    plt.setp(list(ax1.spines.values()), linewidth=size/10)
    # Plot Polar Plots for importances
    if EFA is not None:
        reorder_vec = EFA.get_factor_reorder(EFA.results['num_factors'])
        reorder_fun = lambda x: [x[i] for i in reorder_vec]
        # get importances
        vals = [predictions[i] for i in target_order]
        importances = [(reorder_fun(i['predvars']), 
                        reorder_fun(i['importances'][0])) for i in vals]
        # plot
        axes=[]
        N = len(importances)
        best_predictors = sorted(enumerate(r2s), key = lambda x: x[1][1])
        #if plot_heights is None:
        ylim = ax1.get_ylim(); yrange = np.sum(np.abs(ylim))
        zero_place = abs(ylim[0])/yrange
        plot_heights = [int(r2s[i][1]>0)
                        *(max(r2s[i][1],
                              insample_r2s[i][1],
                              shuffled_r2s[i][1],
                              insample_shuffled_r2s[i][1])/yrange)
                        for i, k in enumerate(target_order)]
        plot_heights = [(h+zero_place+.02)*.5 for h in plot_heights]
        # mask heights
        plot_heights = [plot_heights[i] if r2s[i][1]>max(shuffled_r2s[i][1],0) else np.nan
                        for i in range(len(plot_heights))]
        plot_x = (ax1.get_xticks()-xlow)/(xhigh-xlow)-(1/N/2)
        for i, importance in enumerate(importances):
            if pd.isnull(plot_heights[i]):
                continue
            axes.append(fig.add_axes([plot_x[i], plot_heights[i], 1/N,1/N], projection='polar'))
            color = colors[0]
            visualize_importance(importance, axes[-1],
                                 yticklabels=False, xticklabels=False,
                                 label_size=figsize[1]*1,
                                 color=color,
                                 axes_linewidth=size/10)
        # plot top 2 predictions, labeled  
        if best_predictors[-1][0] < best_predictors[-2][0]:
            locs = [.32, .68]
        else:
            locs = [.68, .32]
        label_importance = importances[best_predictors[-1][0]]
        # write abbreviation key
        pad = 0
        text = [(l, shortened_factors.get(l, None)) for l in label_importance[0]] # for abbeviations text
        if len([True for t in text if t[1] is not None]) > 0:
            pad = .05
            text_ax = fig.add_axes([.8,.56,.1,.34]) 
            text_ax.tick_params(labelleft=False, left=False, 
                                labelbottom=False, bottom=False)
            for spine in ['top','right','bottom','left']:
                text_ax.spines[spine].set_visible(False)
            for i, (val, abr) in enumerate(text):
                text_ax.text(0, i/len(text), abr+':', fontsize=size*1.2)
                text_ax.text(.5, i/len(text), val, fontsize=size*1.2)
                
        ratio = figsize[1]/figsize[0]
        axes.append(fig.add_axes([locs[0]-.2*ratio-pad,.56,.3*ratio,.3], projection='polar'))
        visualize_importance(label_importance, axes[-1], yticklabels=False,
                             xticklabels=True,
                             label_size=max(figsize[1]*1.5, 5),
                             label_scale=.22,
                             title=best_predictors[-1][1][0],
                             color=colors[0],
                             axes_linewidth=size/10)
        # 2nd top
        label_importance = importances[best_predictors[-2][0]]
        ratio = figsize[1]/figsize[0]
        axes.append(fig.add_axes([locs[1]-.2*ratio-pad,.56,.3*ratio,.3], projection='polar'))
        visualize_importance(label_importance, axes[-1], yticklabels=False,
                             xticklabels=True,
                             label_size=max(figsize[1]*1.5, 5),
                             label_scale=.22,
                             title=best_predictors[-2][1][0],
                             color=colors[0],
                             axes_linewidth=size/10)
    if filename is not None:
        save_figure(fig, filename, 
            {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
def plot_prediction(predictions,
                    comparison_predictions,
                    colors=None,
                    EFA=None,
                    comparison_label=None,
                    target_order=None,
                    metric='R2',
                    size=4.6,
                    dpi=300,
                    filename=None):
    if colors is None:
        colors = [sns.color_palette('Purples_d', 4)[i] for i in [1, 3]]
    if comparison_label is None:
        comparison_label = '95% shuffled prediction'
    basefont = max(size, 5)
    sns.set_style('white')
    if target_order is None:
        target_order = predictions.keys()
    # get prediction success
    r2s = [[k, predictions[k]['scores_cv'][0][metric]] for k in target_order]
    insample_r2s = [[k, predictions[k]['scores_insample'][0][metric]]
                    for k in target_order]
    # get shuffled values
    shuffled_r2s = []
    insample_shuffled_r2s = []
    for i, k in enumerate(target_order):
        # normalize r2s to significance
        R2s = [i[metric] for i in comparison_predictions[k]['scores_cv']]
        R2_95 = np.percentile(R2s, 95)
        shuffled_r2s.append((k, R2_95))
        # and insample
        R2s = [i[metric] for i in comparison_predictions[k]['scores_insample']]
        R2_95 = np.percentile(R2s, 95)
        insample_shuffled_r2s.append((k, R2_95))

    # convert nans to 0
    r2s = [(i, k) if k == k else (i, 0) for i, k in r2s]
    insample_r2s = [(i, k) if k == k else (i, 0) for i, k in insample_r2s]
    shuffled_r2s = [(i, k) if k == k else (i, 0) for i, k in shuffled_r2s]

    # plot
    shuffled_grey = [.3, .3, .3]
    # plot variables
    figsize = (size, size * .75)
    fig = plt.figure(figsize=figsize)
    # plot bars
    ind = np.arange(len(r2s))
    width = .25
    ax1 = fig.add_axes([0, 0, 1, .5])
    ax1.bar(ind, [i[1] for i in r2s],
            width,
            label='Cross-validated prediction',
            color=colors[0])
    ax1.bar(ind + width, [i[1] for i in insample_r2s],
            width,
            label='Insample prediction',
            color=colors[1])
    # plot shuffled values above
    ax1.bar(ind, [i[1] for i in shuffled_r2s],
            width,
            color='none',
            edgecolor=shuffled_grey,
            linewidth=size / 10,
            linestyle='--',
            label=comparison_label)
    ax1.bar(ind + width, [i[1] for i in insample_shuffled_r2s],
            width,
            color='none',
            edgecolor=shuffled_grey,
            linewidth=size / 10,
            linestyle='--')

    ax1.set_xticks(np.arange(0, len(r2s)) + width / 2)
    ax1.set_xticklabels([i[0] for i in r2s],
                        rotation=15,
                        fontsize=basefont * 1.4)
    ax1.tick_params(axis='y', labelsize=size * 1.2)
    ax1.tick_params(length=size / 4,
                    width=size / 10,
                    pad=size / 2,
                    left=True,
                    bottom=True)
    xlow, xhigh = ax1.get_xlim()
    if metric == 'R2':
        ax1.set_ylabel(r'$R^2$', fontsize=basefont * 1.5, labelpad=size * 1.5)
    else:
        ax1.set_ylabel(metric, fontsize=basefont * 1.5, labelpad=size * 1.5)
    # add a legend
    leg = ax1.legend(fontsize=basefont * 1.4,
                     loc='upper left',
                     framealpha=1,
                     frameon=True,
                     handlelength=0,
                     handletextpad=0)
    leg.get_frame().set_linewidth(size / 10)
    beautify_legend(leg, colors[:2] + [shuffled_grey])
    # change y extents
    ylim = ax1.get_ylim()
    r2_max = max(
        max(r2s, key=lambda x: x[1])[1],
        max(insample_r2s, key=lambda x: x[1])[1])
    ymax = r2_max * 1.5
    ax1.set_ylim(ylim[0], ymax)
    # change yticks
    if ymax < .15:
        ax1.set_ylim(ylim[0], .15)
        ax1.yaxis.set_major_locator(ticker.MultipleLocator(.025))
    else:
        ax1.yaxis.set_major_locator(ticker.MultipleLocator(.05))
        ax1.set_yticks(
            np.append([0, .025, .05, .075, .1, .125], np.arange(.15, .45,
                                                                .05)))
    # draw grid
    ax1.set_axisbelow(True)
    plt.grid(axis='y', linestyle='dotted', linewidth=size / 6)
    plt.setp(list(ax1.spines.values()), linewidth=size / 10)
    # Plot Polar Plots for importances
    if EFA is not None:
        reorder_vec = EFA.get_factor_reorder(EFA.results['num_factors'])
        reorder_fun = lambda x: [x[i] for i in reorder_vec]
        # get importances
        vals = [predictions[i] for i in target_order]
        importances = [(reorder_fun(i['predvars']),
                        reorder_fun(i['importances'][0])) for i in vals]
        # plot
        axes = []
        N = len(importances)
        best_predictors = sorted(enumerate(r2s), key=lambda x: x[1][1])
        #if plot_heights is None:
        ylim = ax1.get_ylim()
        yrange = np.sum(np.abs(ylim))
        zero_place = abs(ylim[0]) / yrange
        plot_heights = [
            int(r2s[i][1] > 0) *
            (max(r2s[i][1], insample_r2s[i][1], shuffled_r2s[i][1],
                 insample_shuffled_r2s[i][1]) / yrange)
            for i, k in enumerate(target_order)
        ]
        plot_heights = [(h + zero_place + .02) * .5 for h in plot_heights]
        # mask heights
        plot_heights = [
            plot_heights[i]
            if r2s[i][1] > max(shuffled_r2s[i][1], 0) else np.nan
            for i in range(len(plot_heights))
        ]
        plot_x = (ax1.get_xticks() - xlow) / (xhigh - xlow) - (1 / N / 2)
        for i, importance in enumerate(importances):
            if pd.isnull(plot_heights[i]):
                continue
            axes.append(
                fig.add_axes([plot_x[i], plot_heights[i], 1 / N, 1 / N],
                             projection='polar'))
            color = colors[0]
            visualize_importance(importance,
                                 axes[-1],
                                 yticklabels=False,
                                 xticklabels=False,
                                 label_size=figsize[1] * 1,
                                 color=color,
                                 axes_linewidth=size / 10)
        # plot top 2 predictions, labeled
        if best_predictors[-1][0] < best_predictors[-2][0]:
            locs = [.32, .68]
        else:
            locs = [.68, .32]
        label_importance = importances[best_predictors[-1][0]]
        # write abbreviation key
        pad = 0
        text = [(l, shortened_factors.get(l, None))
                for l in label_importance[0]]  # for abbeviations text
        if len([True for t in text if t[1] is not None]) > 0:
            pad = .05
            text_ax = fig.add_axes([.8, .56, .1, .34])
            text_ax.tick_params(labelleft=False,
                                left=False,
                                labelbottom=False,
                                bottom=False)
            for spine in ['top', 'right', 'bottom', 'left']:
                text_ax.spines[spine].set_visible(False)
            for i, (val, abr) in enumerate(text):
                text_ax.text(0, i / len(text), abr + ':', fontsize=size * 1.2)
                text_ax.text(.5, i / len(text), val, fontsize=size * 1.2)

        ratio = figsize[1] / figsize[0]
        axes.append(
            fig.add_axes([locs[0] - .2 * ratio - pad, .56, .3 * ratio, .3],
                         projection='polar'))
        visualize_importance(label_importance,
                             axes[-1],
                             yticklabels=False,
                             xticklabels=True,
                             label_size=max(figsize[1] * 1.5, 5),
                             label_scale=.22,
                             title=best_predictors[-1][1][0],
                             color=colors[0],
                             axes_linewidth=size / 10)
        # 2nd top
        label_importance = importances[best_predictors[-2][0]]
        ratio = figsize[1] / figsize[0]
        axes.append(
            fig.add_axes([locs[1] - .2 * ratio - pad, .56, .3 * ratio, .3],
                         projection='polar'))
        visualize_importance(label_importance,
                             axes[-1],
                             yticklabels=False,
                             xticklabels=True,
                             label_size=max(figsize[1] * 1.5, 5),
                             label_scale=.22,
                             title=best_predictors[-2][1][0],
                             color=colors[0],
                             axes_linewidth=size / 10)
    if filename is not None:
        save_figure(fig, filename, {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
def plot_cross_EFA_retest(all_results, rotate='oblimin', size=4.6, dpi=300, 
                          EFA_retest_fun=None, plot_factor_corr=True,
                          annot_heatmap=False, add_patch=False,
                          ext='png', plot_dir=None):
    if EFA_retest_fun is None:
        EFA_retest_fun = calc_EFA_retest
    colors = {'survey': sns.color_palette('Reds_d',3)[0], 
              'task': sns.color_palette('Blues_d',3)[0]}
    letters = [chr(i).upper() for i in range(ord('a'),ord('z')+1)]
    keys = list(all_results.keys())
    num_cols = 2
    num_rows = math.ceil(len(keys)*2/num_cols)
    with sns.axes_style('white'):
        fig, axes = plt.subplots(num_rows, num_cols, 
                                 figsize=(size, size/2*num_rows*1.1))
    plt.subplots_adjust(hspace=.35)
    axes = fig.get_axes()
    cbar_ax = fig.add_axes([.2, .03, .2, .02])
    # get fontsize for factor labels
    for i, (name,results) in enumerate(all_results.items()):
        combined, *the_rest = EFA_retest_fun(results, rotate=rotate)
        color = list(colors.get(name, [.2,.2,.2])) + [.8]
        ax2 = axes[i*2]; ax = axes[i*2+num_rows//2]
        plot_EFA_change(combined=combined,  color_on=color, ax=ax, size=size/2)
        ax.set_xlabel('PC 1', fontsize=size*1.8)
        ax.set_ylabel('PC 2', fontsize=size*1.8)
        # plot corr between test and retest
        num_labels = combined.shape[1]//2
        corr = combined.corr().iloc[:num_labels, num_labels:]
        # plot factor correlations if flagged
        if plot_factor_corr:
            factor_corr = combined.corr().iloc[:num_labels, :num_labels]
            upper_mask = np.triu(factor_corr,1)==0
            lower_mask = np.tril(corr)==0
            tmp_corr = np.tril(corr) + np.triu(factor_corr, 1)
            corr.iloc[:,:] = tmp_corr
        else:
            lower_mask = np.ones(corr.shape)
            factor_corr = None
        annot_fontsize = size/num_labels*7
        annot=False
        if annot_heatmap:
            annot=True
        if i == len(all_results)-1:
            sns.heatmap(corr, square=True, ax=ax2, cbar_ax=cbar_ax, 
                        vmin=-1, vmax=1,
                        cbar_kws={'orientation': 'horizontal',
                                  'ticks': [-1, 0, 1]},
                        cmap=sns.diverging_palette(220,15,n=100,as_cmap=True),
                        annot=annot,
                        mask=lower_mask,
                        annot_kws={'fontsize': annot_fontsize}); 
            cbar_ax.set_xlabel('Pearson Correlation', fontsize=size*1.5)
            cbar_ax.tick_params(labelsize=size, pad=size/2, length=0)
        else:
            sns.heatmap(corr, square=True, ax=ax2, vmin=-1, vmax=1,
                        cbar=False, annot=annot, mask=lower_mask,
                        cmap=sns.diverging_palette(220,15,n=100,as_cmap=True),
                        annot_kws={'fontsize': annot_fontsize})
        if factor_corr is not None:
            pos1 = ax2.get_position() # get the original position 
            pos2 = [pos1.x0 + 0.01, pos1.y0 + 0.01,  pos1.width, pos1.height] 
            factor_corr_ax = fig.add_axes(pos2)
            factor_corr_ax.patch.set_alpha(0)
            sns.heatmap(factor_corr, square=True, ax=factor_corr_ax, cbar_ax=cbar_ax, 
                    vmin=-1, vmax=1,
                    cbar_kws={'orientation': 'horizontal',
                              'ticks': [-1, 0, 1]},
                    cmap=sns.diverging_palette(220,15,n=100,as_cmap=True),
                    annot=annot,
                    mask=upper_mask,
                    xticklabels=False, yticklabels=False,
                    annot_kws={'fontsize': annot_fontsize}); 
                            
        ax2.set_xticklabels('')
        ax2.set_yticks(np.arange(.5, num_labels+.5))
        ax2.set_yticklabels(combined.columns[:num_labels], rotation=0, va='center')
        ax2.tick_params(axis='y', labelsize=min(size/num_labels/num_rows*24, size*1.6), 
                        pad=size/2, length=0)
        ax2.tick_params(axis='x', length=0, pad=size/2)
        ax2.set_xlabel('Retest (T2)', fontsize=size*1.8)
        factor_corr_ax.set_title('T1 Factor Correlations', fontsize=size*1.8, x=.6)
        ax2.set_ylabel('Test (T1)', fontsize=size*1.8)
        # add text for measurement category
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        ax.text(x=xlim[1]+(xlim[1]-xlim[0])*0.05, 
                y=ylim[0]+(ylim[1]-ylim[0])/2, 
                s=name.title(),
                rotation=-90,
                size=size/num_rows*5,
                fontweight='bold')
        place_letter(ax2, letters.pop(0), fontsize=size*9/4.6)
        place_letter(ax, letters.pop(0), fontsize=size*9/4.6)
        [i.set_linewidth(size*.1) for i in ax.spines.values()]
        [i.set_linewidth(size*.1) for i in ax2.spines.values()]
        if add_patch:
            # add row patch
            ax2.add_patch(plt.Rectangle([-.6,-.15], 
                        width=3, height=1.31, zorder=-100,
                        facecolor='#F8F8F8', edgecolor='white', 
                        transform=ax2.transAxes,
                        linewidth=1, clip_on=False))
        
    if plot_dir is not None:
        filename = 'EFA_test_retest'
        if annot_heatmap:
            filename += '_annot'
        save_figure(fig, path.join(plot_dir, rotate, '%s.%s' % (filename, ext)),
                    {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
예제 #53
0
def plot_clustering_similarity(results,
                               plot_dir=None,
                               verbose=False,
                               ext='png'):
    HCA = results.HCA
    # get all clustering solutions
    clusterings = HCA.results.items()
    # plot cluster agreement across embedding spaces
    names = [k for k, v in clusterings]
    cluster_similarity = np.zeros((len(clusterings), len(clusterings)))
    cluster_similarity = pd.DataFrame(cluster_similarity,
                                      index=names,
                                      columns=names)

    distance_similarity = np.zeros((len(clusterings), len(clusterings)))
    distance_similarity = pd.DataFrame(distance_similarity,
                                       index=names,
                                       columns=names)
    for clustering1, clustering2 in combinations(clusterings, 2):
        name1 = clustering1[0].split('-')[-1]
        name2 = clustering2[0].split('-')[-1]
        # record similarity of distance_df
        dist_corr = np.corrcoef(squareform(clustering1[1]['distance_df']),
                                squareform(clustering2[1]['distance_df']))[1,
                                                                           0]
        distance_similarity.loc[name1, name2] = dist_corr
        distance_similarity.loc[name2, name1] = dist_corr
        # record similarity of clustering of dendrogram
        clusters1 = clustering1[1]['labels']
        clusters2 = clustering2[1]['labels']
        rand_score = adjusted_rand_score(clusters1, clusters2)
        MI_score = adjusted_mutual_info_score(clusters1, clusters2)
        cluster_similarity.loc[name1, name2] = rand_score
        cluster_similarity.loc[name2, name1] = MI_score

    with sns.plotting_context(context='notebook', font_scale=1.4):
        clust_fig = plt.figure(figsize=(12, 12))
        sns.heatmap(cluster_similarity, square=True)
        plt.title('Cluster Similarity: TRIL: Adjusted MI, TRIU: Adjusted Rand',
                  y=1.02)

        dist_fig = plt.figure(figsize=(12, 12))
        sns.heatmap(distance_similarity, square=True)
        plt.title('Distance Similarity, metric: %s' % HCA.dist_metric, y=1.02)

    if plot_dir is not None:
        save_figure(
            clust_fig,
            path.join(plot_dir, 'cluster_similarity_across_measures.%s' % ext),
            {'bbox_inches': 'tight'})
        save_figure(
            dist_fig,
            path.join(plot_dir,
                      'distance_similarity_across_measures.%s' % ext),
            {'bbox_inches': 'tight'})
        plt.close(clust_fig)
        plt.close(dist_fig)

    if verbose:
        # assess relationship between two measurements
        rand_scores = cluster_similarity.values[np.triu_indices_from(
            cluster_similarity, k=1)]
        MI_scores = cluster_similarity.T.values[np.triu_indices_from(
            cluster_similarity, k=1)]
        score_consistency = np.corrcoef(rand_scores, MI_scores)[0, 1]
        print('Correlation between measures of cluster consistency: %.2f' \
              % score_consistency)
def plot_glasso_edge_strength(all_results,
                              graph_loc,
                              size=4.6,
                              dpi=300,
                              ext='png',
                              plot_dir=None):
    task_length = all_results['task'].data.shape[1]
    g = pickle.load(open(graph_loc, 'rb'))
    # subset graph
    task_within = squareform(
        g.graph_to_dataframe().iloc[:task_length, :task_length])
    survey_within = squareform(g.graph_to_dataframe().iloc[task_length:,
                                                           task_length:])
    across = g.graph_to_dataframe().iloc[:task_length,
                                         task_length:].values.flatten()

    titles = ['Within Tasks', 'Within Surveys', 'Between Tasks And Surveys']
    colors = [
        sns.color_palette('Blues_d', 3)[0],
        sns.color_palette('Reds_d', 3)[0], [0, 0, 0]
    ]

    with sns.axes_style('whitegrid'):
        f, axes = plt.subplots(3, 1, figsize=(size, size * 1.5))

    for i, corr in enumerate([task_within, survey_within, across]):
        sns.stripplot(corr,
                      jitter=.2,
                      alpha=.5,
                      orient='h',
                      ax=axes[i],
                      color=colors[i],
                      s=size / 2)

    max_x = max([ax.get_xlim()[1] for ax in axes]) * 1.1
    for i, ax in enumerate(axes):
        [i.set_linewidth(size * .3) for i in ax.spines.values()]
        ax.grid(linewidth=size * .15)
        ax.set_xlim([0, max_x])
        ax.text(max_x * .02,
                -.35,
                titles[i],
                color=colors[i],
                ha='left',
                fontsize=size * 3.5)
        ax.set_xticks(np.arange(0, round(max_x * 10) / 10, .1))
        if i != (len(axes) - 1):
            ax.set_xticklabels([])
        else:
            ax.tick_params(labelsize=size * 2.5, pad=size, length=0)
    axes[-1].set_xlabel('Edge Weight', fontsize=size * 5)
    plt.subplots_adjust(hspace=0)
    if plot_dir is not None:
        # make histogram plot
        save_figure(f, path.join(plot_dir, 'glasso_edge_strength.%s' % ext), {
            'dpi': dpi,
            'transparent': True
        })
        plt.close()
    else:
        return f
def plot_cross_EFA_retest(all_results,
                          rotate='oblimin',
                          size=4.6,
                          dpi=300,
                          EFA_retest_fun=None,
                          plot_factor_corr=True,
                          annot_heatmap=False,
                          add_patch=False,
                          ext='png',
                          plot_dir=None):
    if EFA_retest_fun is None:
        EFA_retest_fun = calc_EFA_retest
    colors = {
        'survey': sns.color_palette('Reds_d', 3)[0],
        'task': sns.color_palette('Blues_d', 3)[0]
    }
    letters = [chr(i).upper() for i in range(ord('a'), ord('z') + 1)]
    keys = list(all_results.keys())
    num_cols = 2
    num_rows = math.ceil(len(keys) * 2 / num_cols)
    with sns.axes_style('white'):
        fig, axes = plt.subplots(num_rows,
                                 num_cols,
                                 figsize=(size, size / 2 * num_rows * 1.1))
    plt.subplots_adjust(hspace=.35)
    axes = fig.get_axes()
    cbar_ax = fig.add_axes([.2, .03, .2, .02])
    # get fontsize for factor labels
    print('CROSS PLOT')
    print('*' * 79)
    for i, (name, results) in enumerate(all_results.items()):
        print(name)
        combined, *the_rest = EFA_retest_fun(results, rotate=rotate)
        color = list(colors.get(name, [.2, .2, .2])) + [.8]
        ax2 = axes[i * 2]
        ax = axes[i * 2 + num_rows // 2]
        plot_EFA_change(combined=combined,
                        color_on=color,
                        ax=ax,
                        size=size / 2)
        ax.set_xlabel('PC 1', fontsize=size * 1.8)
        ax.set_ylabel('PC 2', fontsize=size * 1.8)
        # plot corr between test and retest
        num_labels = combined.shape[1] // 2
        corr = combined.corr().iloc[:num_labels, num_labels:]
        # plot factor correlations if flagged
        if plot_factor_corr:
            factor_corr = combined.corr().iloc[:num_labels, :num_labels]
            upper_mask = np.triu(factor_corr, 1) == 0
            lower_mask = np.tril(corr) == 0
            tmp_corr = np.tril(corr) + np.triu(factor_corr, 1)
            corr.iloc[:, :] = tmp_corr
        else:
            lower_mask = np.ones(corr.shape)
            factor_corr = None
        annot_fontsize = size / num_labels * 7
        annot = False
        if annot_heatmap:
            annot = True
        if i == len(all_results) - 1:
            sns.heatmap(corr,
                        square=True,
                        ax=ax2,
                        cbar_ax=cbar_ax,
                        vmin=-1,
                        vmax=1,
                        cbar_kws={
                            'orientation': 'horizontal',
                            'ticks': [-1, 0, 1]
                        },
                        cmap=sns.diverging_palette(220,
                                                   15,
                                                   n=100,
                                                   as_cmap=True),
                        annot=annot,
                        mask=lower_mask,
                        annot_kws={'fontsize': annot_fontsize})
            cbar_ax.set_xlabel('Pearson Correlation', fontsize=size * 1.5)
            cbar_ax.tick_params(labelsize=size, pad=size / 2, length=0)
        else:
            sns.heatmap(corr,
                        square=True,
                        ax=ax2,
                        vmin=-1,
                        vmax=1,
                        cbar=False,
                        annot=annot,
                        mask=lower_mask,
                        cmap=sns.diverging_palette(220,
                                                   15,
                                                   n=100,
                                                   as_cmap=True),
                        annot_kws={'fontsize': annot_fontsize})
        if factor_corr is not None:
            pos1 = ax2.get_position()  # get the original position
            pos2 = [pos1.x0 + 0.01, pos1.y0 + 0.01, pos1.width, pos1.height]
            factor_corr_ax = fig.add_axes(pos2)
            factor_corr_ax.patch.set_alpha(0)
            sns.heatmap(factor_corr,
                        square=True,
                        ax=factor_corr_ax,
                        cbar_ax=cbar_ax,
                        vmin=-1,
                        vmax=1,
                        cbar_kws={
                            'orientation': 'horizontal',
                            'ticks': [-1, 0, 1]
                        },
                        cmap=sns.diverging_palette(220,
                                                   15,
                                                   n=100,
                                                   as_cmap=True),
                        annot=annot,
                        mask=upper_mask,
                        xticklabels=False,
                        yticklabels=False,
                        annot_kws={'fontsize': annot_fontsize})

        ax2.set_xticklabels('')
        ax2.set_yticks(np.arange(.5, num_labels + .5))
        # ax2.set_yticklabels(combined.columns[:num_labels], rotation=0, va='center') #OLD LABELING
        ax2.set_yticklabels(ax2.get_yticklabels(), rotation=0, va='center')
        print('heatmap labels:')
        print(combined.columns[:num_labels])
        ax2.tick_params(axis='y',
                        labelsize=min(size / num_labels / num_rows * 24,
                                      size * 1.6),
                        pad=size / 2,
                        length=0)
        ax2.tick_params(axis='x', length=0, pad=size / 2)
        ax2.set_xlabel('Retest (T2)', fontsize=size * 1.8)
        factor_corr_ax.set_title('T1 Factor Correlations',
                                 fontsize=size * 1.8,
                                 x=.6)
        ax2.set_ylabel('Test (T1)', fontsize=size * 1.8)
        # add text for measurement category
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        ax.text(x=xlim[1] + (xlim[1] - xlim[0]) * 0.05,
                y=ylim[0] + (ylim[1] - ylim[0]) / 2,
                s=name.title(),
                rotation=-90,
                size=size / num_rows * 5,
                fontweight='bold')
        print(name.title())
        place_letter(ax2, letters.pop(0), fontsize=size * 9 / 4.6)
        place_letter(ax, letters.pop(0), fontsize=size * 9 / 4.6)
        [i.set_linewidth(size * .1) for i in ax.spines.values()]
        [i.set_linewidth(size * .1) for i in ax2.spines.values()]
        if add_patch:
            # add row patch
            ax2.add_patch(
                plt.Rectangle([-.6, -.15],
                              width=3,
                              height=1.31,
                              zorder=-100,
                              facecolor='#F8F8F8',
                              edgecolor='white',
                              transform=ax2.transAxes,
                              linewidth=1,
                              clip_on=False))

    if plot_dir is not None:
        filename = 'EFA_test_retest'
        if annot_heatmap:
            filename += '_annot'
        save_figure(fig, path.join(plot_dir, rotate,
                                   '%s.%s' % (filename, ext)), {
                                       'bbox_inches': 'tight',
                                       'dpi': dpi
                                   })
        plt.close()
def plot_heatmap_factors(results, c, size=4.6, thresh=75, rotate='oblimin',
                     DA=False, dpi=300, ext='png', plot_dir=None):
    """ Plots factor analytic results as bars
    
    Args:
        results: a dimensional structure results object
        c: the number of components to use
        dpi: the final dpi for the image
        size: scalar - the width of the plot. The height is determined
            by the number of factors
        thresh: proportion of factor loadings to remove
        ext: the extension for the saved figure
        plot_dir: the directory to save the figure. If none, do not save
    """
    if DA:
        EFA = results.DA
    else:
        EFA = results.EFA
    loadings = EFA.get_loading(c, rotate=rotate)
    loadings = EFA.reorder_factors(loadings, rotate=rotate)           
    grouping = get_factor_groups(loadings)
    flattened_factor_order = []
    for sublist in [i[1] for i in grouping]:
        flattened_factor_order += sublist
    loadings = loadings.loc[flattened_factor_order]
    # get threshold for loadings
    if thresh>0:
        thresh_val = np.percentile(abs(loadings).values, thresh)
        print('Thresholding all loadings less than %s' % np.round(thresh_val, 3))
        loadings = loadings.mask(abs(loadings) <= thresh_val, 0)
        # remove variables that don't cross the threshold for any factor
        kept_vars = list(loadings.index[loadings.mean(1)!=0])
        print('%s Variables out of %s are kept after threshold' % (len(kept_vars), loadings.shape[0]))
        loadings = loadings.loc[kept_vars]
        # remove masked variabled from grouping
        threshed_groups = []
        for factor, group in grouping:
            group = [x for x in group if x in kept_vars]
            threshed_groups.append([factor,group])
        grouping = threshed_groups
    # change variable names to make them more readable
    loadings.index = format_variable_names(loadings.index)
    # set up plot variables
    DV_fontsize = size*2/(loadings.shape[0]//2)*30
    figsize = (size,size*2)
    
    f = plt.figure(figsize=figsize)
    ax = f.add_axes([0, 0, .08*loadings.shape[1], 1]) 
    cbar_ax = f.add_axes([.08*loadings.shape[1]+.02,0,.04,1]) 

    max_val = abs(loadings).max().max()
    sns.heatmap(loadings, ax=ax, cbar_ax=cbar_ax,
                vmax =  max_val, vmin = -max_val,
                cbar_kws={'ticks': [-max_val, -max_val/2, 0, max_val/2, max_val]},
                linecolor='white', linewidth=.01,
                cmap=sns.diverging_palette(220,15,n=100,as_cmap=True))
    ax.set_yticks(np.arange(.5,loadings.shape[0]+.5,1))
    ax.set_yticklabels(loadings.index, fontsize=DV_fontsize, rotation=0)
    ax.set_xticklabels(loadings.columns, 
                       fontsize=min(size*3, DV_fontsize*1.5),
                       ha='center',
                       rotation=90)
    ax.tick_params(length=size*.5, width=size/10)
    # format cbar
    cbar_ax.set_yticklabels([format_num(-max_val, 2), 
                             format_num(-max_val/2, 2),
                             0, 
                             format_num(-max_val/2, 2),
                             format_num(max_val, 2)])
    cbar_ax.tick_params(axis='y', length=0)
    cbar_ax.tick_params(labelsize=DV_fontsize*1.5)
    cbar_ax.set_ylabel('Factor Loading', rotation=-90, fontsize=DV_fontsize*2)
    
    # draw lines separating groups
    if grouping is not None:
        factor_breaks = np.cumsum([len(i[1]) for i in grouping])[:-1]
        for y_val in factor_breaks:
            ax.hlines(y_val, 0, loadings.shape[1], lw=size/5, 
                      color='grey', linestyle='dashed')
                
    if plot_dir:
        filename = 'factor_heatmap_EFA%s.%s' % (c, ext)
        save_figure(f, path.join(plot_dir, filename), 
                    {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
def plot_cross_communality(all_results, rotate='oblimin', retest_threshold=.2,
                           size=4.6, dpi=300, ext='png', plot_dir=None):
    
    retest_data = None
    num_cols = 2
    num_rows = math.ceil(len(all_results.keys())/2)
    with sns.axes_style('white'):
        f, axes = plt.subplots(num_rows, num_cols, figsize=(size, size/2*num_rows))
    max_y = 0
    for i, (name, results) in enumerate(all_results.items()):
        if retest_data is None:
            # load retest data
            retest_data = get_retest_data(dataset=results.dataset.replace('Complete','Retest'))
            if retest_data is None:
                print('No retest data found for datafile: %s' % results.dataset)
        c = results.EFA.get_c()
        EFA = results.EFA
        loading = EFA.get_loading(c, rotate=rotate)
        # get communality from psych out
        fa = EFA.results['factor_tree_Rout_%s' % rotate][c]
        communality = get_attr(fa, 'communalities')
        communality = pd.Series(communality, index=loading.index)
        # alternative calculation
        #communality = (loading**2).sum(1).sort_values()
        communality.index = [i.replace('.logTr','') for i in communality.index]
        
        # reorder data in line with communality
        retest_subset= retest_data.loc[communality.index]
        # reformat variable names
        communality.index = format_variable_names(communality.index)
        retest_subset.index = format_variable_names(retest_subset.index)
        if len(retest_subset) > 0:
            # noise ceiling
            noise_ceiling = retest_subset.pearson
            # remove very low reliabilities
            if retest_threshold:
                noise_ceiling[noise_ceiling<retest_threshold]= np.nan
            # adjust
            adjusted_communality = communality/noise_ceiling
            
        # plot communality histogram
        if len(retest_subset) > 0:
            ax = axes[i]
            ax.set_title(name.title(), fontweight='bold', fontsize=size*2)
            colors = sns.color_palette(n_colors=2, desat=.75)
            sns.kdeplot(communality, linewidth=size/4, ax=ax, vertical=True,
                        shade=True, label='Communality', color=colors[0])
            sns.kdeplot(adjusted_communality, linewidth=size/4, ax=ax, vertical=True,
                        shade=True, label='Adjusted Communality', color=colors[1])
            xlim = ax.get_xlim()
            ax.hlines(np.mean(communality), xlim[0], xlim[1],
                      color=colors[0], linewidth=size/4, linestyle='--')
            ax.hlines(np.mean(adjusted_communality), xlim[0], xlim[1],
                      color=colors[1], linewidth=size/4, linestyle='--')
            ax.set_xticks([])
            ax.tick_params(labelsize=size*1.2)
            ax.set_ylim(0, ax.get_ylim()[1])
            ax.set_xlim(0, ax.get_xlim()[1])
            ax.spines['right'].set_visible(False)
            ax.spines['top'].set_visible(False)
            if (i+1) == len(all_results):
                ax.set_xlabel('Normalized Density', fontsize=size*2)
                leg=ax.legend(fontsize=size*1.5, loc='upper right',
                              bbox_to_anchor=(1.2, 1.0), 
                              handlelength=0, handletextpad=0)
                beautify_legend(leg, colors)
            elif i>=len(all_results)-2:
                ax.set_xlabel('Normalized Density', fontsize=size*2)
                ax.legend().set_visible(False)
            else:
                ax.legend().set_visible(False)
            if i%2==0:
                ax.set_ylabel('Communality', fontsize=size*2)
                ax.tick_params(labelleft=True, left=True, 
                               length=size/4, width=size/8)
            else:
                ax.tick_params(labelleft=False, left=True, 
                               length=0, width=size/8)
            # update max_x
            if ax.get_ylim()[1] > max_y:
                max_y = ax.get_ylim()[1]
            ax.grid(False)
            [i.set_linewidth(size*.1) for i in ax.spines.values()]
        for ax in axes:
            ax.set_ylim((0, max_y))
        plt.subplots_adjust(wspace=0)
                    
        if plot_dir:
            filename = 'communality_adjustment.%s' % ext
            save_figure(f, path.join(plot_dir, rotate, filename), 
                        {'bbox_inches': 'tight', 'dpi': dpi})
            plt.close()
def plot_cross_communality(all_results,
                           rotate='oblimin',
                           retest_threshold=.2,
                           size=4.6,
                           dpi=300,
                           ext='png',
                           plot_dir=None):

    retest_data = None
    num_cols = 2
    num_rows = math.ceil(len(all_results.keys()) / 2)
    with sns.axes_style('white'):
        f, axes = plt.subplots(num_rows,
                               num_cols,
                               figsize=(size, size / 2 * num_rows))
    max_y = 0
    for i, (name, results) in enumerate(all_results.items()):
        if retest_data is None:
            # load retest data
            retest_data = get_retest_data(
                dataset=results.dataset.replace('Complete', 'Retest'))
            if retest_data is None:
                print('No retest data found for datafile: %s' %
                      results.dataset)
        c = results.EFA.get_c()
        EFA = results.EFA
        loading = EFA.get_loading(c, rotate=rotate)
        # get communality from psych out
        fa = EFA.results['factor_tree_Rout_%s' % rotate][c]
        communality = get_attr(fa, 'communalities')
        communality = pd.Series(communality, index=loading.index)
        # alternative calculation
        #communality = (loading**2).sum(1).sort_values()
        communality.index = [
            i.replace('.logTr', '') for i in communality.index
        ]

        # reorder data in line with communality
        retest_subset = retest_data.loc[communality.index]
        # reformat variable names
        communality.index = format_variable_names(communality.index)
        retest_subset.index = format_variable_names(retest_subset.index)
        if len(retest_subset) > 0:
            # noise ceiling
            noise_ceiling = retest_subset.pearson
            # remove very low reliabilities
            if retest_threshold:
                noise_ceiling[noise_ceiling < retest_threshold] = np.nan
            # adjust
            adjusted_communality = communality / noise_ceiling

        # plot communality histogram
        if len(retest_subset) > 0:
            ax = axes[i]
            ax.set_title(name.title(), fontweight='bold', fontsize=size * 2)
            colors = sns.color_palette(n_colors=2, desat=.75)
            sns.kdeplot(communality,
                        linewidth=size / 4,
                        ax=ax,
                        vertical=True,
                        shade=True,
                        label='Communality',
                        color=colors[0])
            sns.kdeplot(adjusted_communality,
                        linewidth=size / 4,
                        ax=ax,
                        vertical=True,
                        shade=True,
                        label='Adjusted Communality',
                        color=colors[1])
            xlim = ax.get_xlim()
            ax.hlines(np.mean(communality),
                      xlim[0],
                      xlim[1],
                      color=colors[0],
                      linewidth=size / 4,
                      linestyle='--')
            ax.hlines(np.mean(adjusted_communality),
                      xlim[0],
                      xlim[1],
                      color=colors[1],
                      linewidth=size / 4,
                      linestyle='--')
            ax.set_xticks([])
            ax.tick_params(labelsize=size * 1.2)
            ax.set_ylim(0, ax.get_ylim()[1])
            ax.set_xlim(0, ax.get_xlim()[1])
            ax.spines['right'].set_visible(False)
            ax.spines['top'].set_visible(False)
            if (i + 1) == len(all_results):
                ax.set_xlabel('Normalized Density', fontsize=size * 2)
                leg = ax.legend(fontsize=size * 1.5,
                                loc='upper right',
                                bbox_to_anchor=(1.2, 1.0),
                                handlelength=0,
                                handletextpad=0)
                beautify_legend(leg, colors)
            elif i >= len(all_results) - 2:
                ax.set_xlabel('Normalized Density', fontsize=size * 2)
                ax.legend().set_visible(False)
            else:
                ax.legend().set_visible(False)
            if i % 2 == 0:
                ax.set_ylabel('Communality', fontsize=size * 2)
                ax.tick_params(labelleft=True,
                               left=True,
                               length=size / 4,
                               width=size / 8)
            else:
                ax.tick_params(labelleft=False,
                               left=True,
                               length=0,
                               width=size / 8)
            # update max_x
            if ax.get_ylim()[1] > max_y:
                max_y = ax.get_ylim()[1]
            ax.grid(False)
            [i.set_linewidth(size * .1) for i in ax.spines.values()]
        for ax in axes:
            ax.set_ylim((0, max_y))
        plt.subplots_adjust(wspace=0)

        if plot_dir:
            filename = 'communality_adjustment.%s' % ext
            save_figure(f, path.join(plot_dir, rotate, filename), {
                'bbox_inches': 'tight',
                'dpi': dpi
            })
            plt.close()
def plot_EFA_change(combined,
                    ax=None,
                    color_on=False,
                    method=PCA,
                    size=4.6,
                    dpi=300,
                    ext='png',
                    plot_dir=None):
    n = combined.shape[1] // 2
    orig = combined.iloc[:, :n]
    retest = combined.iloc[:, n:]
    retest.columns = orig.columns
    retest.index = [i + '_retest' for i in retest.index]
    both = pd.concat([orig, retest])
    projector = method(2)
    projection = projector.fit_transform(both)
    orig_projection = projection[:both.shape[0] // 2, :]
    retest_projection = projection[both.shape[0] // 2:, :]

    color = [.2, .2, .2, .9]
    # get color range
    mins = np.min(orig_projection)
    ranges = np.max(orig_projection) - mins
    if ax is None:
        with sns.axes_style('white'):
            fig, ax = plt.subplots(figsize=(size, size))
    markersize = size
    markeredge = size / 5
    linewidth = size / 3
    for i in range(len(orig_projection)):
        label = [None, None]
        if i == 0:
            label = ['T1 Scores', 'T2 Scores']
        if color_on == True:
            color = list((orig_projection[i, :] - mins) / ranges)
            color = [color[0]] + [0] + [color[1]]
        elif color_on != False:
            color = color_on
        ax.plot(*zip(orig_projection[i, :], retest_projection[i, :]),
                marker='o',
                markersize=markersize,
                color=color,
                markeredgewidth=markeredge,
                markerfacecolor='w',
                linewidth=linewidth,
                label=label[0])
        ax.plot(retest_projection[i, 0],
                retest_projection[i, 1],
                marker='o',
                markersize=markersize,
                color=color,
                linewidth=linewidth,
                label=label[1])
    ax.tick_params(left=False,
                   bottom=False,
                   labelleft=False,
                   labelbottom=False)
    ax.set_xlabel('PC 1', fontsize=size * 2.5)
    ax.set_ylabel('PC 2', fontsize=size * 2.5)
    ax.set_xlim(
        np.min(projection) - abs(np.min(projection)) * .1,
        np.max(projection) + abs(np.max(projection)) * .1)
    ax.set_ylim(ax.get_xlim())
    ax.legend(fontsize=size * 1.9)
    ax.get_legend().get_frame().set_linewidth(linewidth / 2)

    if plot_dir is not None:
        save_figure(fig, path.join(plot_dir,
                                   'EFA_test_retest_sticks.%s' % ext), {
                                       'bbox_inches': 'tight',
                                       'dpi': dpi
                                   })
        plt.close()
def plot_prediction_relevance(results, EFA=True, classifier='ridge', 
                              rotate='oblimin', change=False, size=4.6, 
                              dpi=300, ext='png', plot_dir=None):
    """ Plots the relevant relevance of each factor for predicting all outcomes """
    predictions = results.load_prediction_object(EFA=EFA, 
                                                 change=change,
                                                 classifier=classifier,
                                                 rotate=rotate)['data']

    targets = list(predictions.keys())
    predictors = predictions[targets[0]]['predvars']
    importances = abs(np.vstack([predictions[k]['importances'] for k in targets]))
    # scale to 0-1 
    scaler = MinMaxScaler()
    scaled_importances = scaler.fit_transform(importances.T).T
    # make proportion
    scaled_importances = scaled_importances/np.expand_dims(scaled_importances.sum(1),1)
    # convert to dataframe
    scaled_df = pd.DataFrame(scaled_importances, index=targets, columns=predictors)
    melted = scaled_df.melt(var_name='Factor', value_name='Importance')
    plt.figure(figsize=(8,12))
    f=sns.boxplot(y='Factor', x='Importance',  data=melted,
                  width=.5)
    if plot_dir is not None:
        filename = 'prediction_relevance'
        save_figure(f, path.join(plot_dir, filename), 
                    {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()