def display_closest_DVs(consensus, n_closest=10):
    nth = {
        1: "first",
        2: "second",
        3: "third",
        4: "fourth",
        5: "fifth",
        6: "sixth",
        7: "seventh",
        8: "eigth",
        9: "ninth",
        10: "tenth",
    }
    df = consensus.get_consensus_cluster()['distance_df']
    df.index = format_variable_names(df.index)
    df.columns = format_variable_names(df.columns)

    sorted_df = pd.DataFrame(data=np.zeros((len(df),n_closest)), index=df.index)
    sorted_df.columns = [nth[i+1] for i in sorted_df.columns]
    for name, row in sorted_df.iterrows():
        closest = 1-df.loc[name].drop(name).sort_values()[:n_closest]
        closest = ['%s: %s%%' % (i,int(b*100)) for i,b in closest.iteritems()]
        sorted_df.loc[name] = closest
        
    def magnify():
        return [dict(selector="tr:hover",
                    props=[("border-top", "2pt solid black"),
                           ("border-bottom", "2pt solid black")]),
                dict(selector="th:hover",
                     props=[("font-size", "10pt")]),
                dict(selector="td",
                     props=[('padding', "0em 0em")]),
               # dict(selector="th:hover",
               #      props=[("font-size", "12pt")]),
                dict(selector="tr:hover td:hover",
                     props=[('max-width', '200px'),
                            ('font-weight', 'bold'),
                            ('color', 'black'),
                           ('font-size', '9pt')])
    ]

    cm =sns.diverging_palette(220,15,n=161)
    def color_cell(val):
        val = val[val.rindex(': ')+2:val.rindex('%')]
        color = to_hex(cm[int(val)+30])
        return 'background-color: %s' % color


    styler = sorted_df.style
    styler \
        .applymap(color_cell) \
        .set_properties(**{'max-width': '100px','font-size': '10pt', 'border-color': 'white'})\
        .set_precision(2)\
        .set_table_styles(magnify())
    return styler
def display_closest_DVs(consensus, n_closest=10):
    nth = {
        1: "first",
        2: "second",
        3: "third",
        4: "fourth",
        5: "fifth",
        6: "sixth",
        7: "seventh",
        8: "eigth",
        9: "ninth",
        10: "tenth",
    }
    df = consensus.get_consensus_cluster()['distance_df']
    df.index = format_variable_names(df.index)
    df.columns = format_variable_names(df.columns)

    sorted_df = pd.DataFrame(data=np.zeros((len(df), n_closest)),
                             index=df.index)
    sorted_df.columns = [nth[i + 1] for i in sorted_df.columns]
    for name, row in sorted_df.iterrows():
        closest = 1 - df.loc[name].drop(name).sort_values()[:n_closest]
        closest = [
            '%s: %s%%' % (i, int(b * 100)) for i, b in closest.iteritems()
        ]
        sorted_df.loc[name] = closest

    def magnify():
        return [
            dict(selector="tr:hover",
                 props=[("border-top", "2pt solid black"),
                        ("border-bottom", "2pt solid black")]),
            dict(selector="th:hover", props=[("font-size", "10pt")]),
            dict(selector="td", props=[('padding', "0em 0em")]),
            # dict(selector="th:hover",
            #      props=[("font-size", "12pt")]),
            dict(selector="tr:hover td:hover",
                 props=[('max-width', '200px'), ('font-weight', 'bold'),
                        ('color', 'black'), ('font-size', '9pt')])
        ]

    cm = sns.diverging_palette(220, 15, n=161)

    def color_cell(val):
        val = val[val.rindex(': ') + 2:val.rindex('%')]
        color = to_hex(cm[int(val) + 30])
        return 'background-color: %s' % color

    styler = sorted_df.style
    styler \
        .applymap(color_cell) \
        .set_properties(**{'max-width': '100px','font-size': '10pt', 'border-color': 'white'})\
        .set_precision(2)\
        .set_table_styles(magnify())
    return styler
def plot_task_factors(results, c, rotate='oblimin',
                      task_sublists=None, normalize_loadings=False,
                      figsize=10,  dpi=300, ext='png', plot_dir=None):
    """ Plots task factors as polar plots
    
    Args:
        results: a dimensional structure results object
        c: the number of components to use
        task_sublists: a dictionary whose values are sets of tasks, and 
                        whose keywords are labels for those lists
        dpi: the final dpi for the image
        figsize: scalar - a width multiplier for the plot
        ext: the extension for the saved figure
        plot_dir: the directory to save the figure. If none, do not save
    """
    EFA = results.EFA
    # plot task factor loading
    loadings = EFA.get_loading(c, rotate=rotate)
    max_loading = abs(loadings).max().max()
    tasks = np.unique([i.split('.')[0] for i in loadings.index])
    
    if task_sublists is None:
        task_sublists = {'surveys': [t for t in tasks if 'survey' in t],
                        'tasks': [t for t in tasks if 'survey' not in t]}

    for sublist_name, task_sublist in task_sublists.items():
        for i, task in enumerate(task_sublist):
            # plot loading distributions. Each measure is scaled so absolute
            # comparisons are impossible. Only the distributions can be compared
            f, ax = plt.subplots(1,1, 
                                 figsize=(figsize, figsize), subplot_kw={'projection': 'polar'})
            task_loadings = loadings.filter(regex='^%s' % task, axis=0)
            task_loadings.index = format_variable_names(task_loadings.index)
            if normalize_loadings:
                task_loadings = task_loadings = (task_loadings.T/abs(task_loadings).max(1)).T
            # format variable names
            task_loadings.index = format_variable_names(task_loadings.index)
            # plot
            visualize_task_factors(task_loadings, ax, ymax=max_loading,
                                   xticklabels=True, label_size=figsize*2)
            ax.set_title(' '.join(task.split('_')), 
                              y=1.14, fontsize=25)
            
            if plot_dir is not None:
                if normalize_loadings:
                    function_directory = 'factor_DVnormdist_EFA%s_subset-%s' % (c, sublist_name)
                else:
                    function_directory = 'factor_DVdist_EFA%s_subset-%s' % (c, sublist_name)
                makedirs(path.join(plot_dir, function_directory), exist_ok=True)
                filename = '%s.%s' % (task, ext)
                save_figure(f, path.join(plot_dir, function_directory, filename),
                            {'bbox_inches': 'tight', 'dpi': dpi})
                plt.close()
示例#4
0
def plot_vars(tasks, contrasts, axes=None, xlabel='Value', standardize=False):
    colors = sns.hls_palette(4)
    desat_colors = [sns.desaturate(c, .5) for c in colors]
    for i, task in enumerate(tasks):
        subset = contrasts.filter(regex='^' + task)
        if subset.shape[1] != 0:
            if standardize:
                subset = subset / subset.std()
            subset.columns = [c.split('.')[1] for c in subset.columns]
            subset.columns = format_variable_names(subset.columns)
            # add mean value to columns
            means = subset.mean()
            subset.columns = [
                subset.columns[i] + ': %s' % format_num(means.iloc[i])
                for i in range(len(means))
            ]
            subset = subset.melt(var_name='Variable', value_name='Value')

            sns.stripplot(x='Value',
                          y='Variable',
                          hue='Variable',
                          ax=axes[i],
                          data=subset,
                          palette=desat_colors,
                          jitter=True,
                          alpha=.75)
            # plot central tendency
            N = len(means)
            axes[i].scatter(means,
                            range(N),
                            s=200,
                            c=colors[:N],
                            edgecolors='white',
                            linewidths=2,
                            zorder=3)

            # add legend
            leg = axes[i].get_legend()
            leg.set_title('')
            beautify_legend(leg, colors=colors, fontsize=14)
            # change axes
            max_val = subset.Value.abs().max()
            axes[i].set_xlim(-max_val, max_val)
            axes[i].set_xlabel(xlabel, fontsize=16)
            axes[i].set_ylabel('')
            axes[i].set_yticklabels('')
        axes[i].set_title(format_variable_names([task])[0].title(),
                          fontsize=20)
    plt.subplots_adjust(hspace=.3)
def plot_factor_df(EFA, rotate='oblimin'):
    c = EFA.get_c()
    loadings = EFA.get_loading(c, rotate=rotate)
    loadings = EFA.reorder_factors(loadings, rotate=rotate)           
    grouping = get_factor_groups(loadings)
    flattened_factor_order = []
    for sublist in [i[1] for i in grouping]:
        flattened_factor_order += sublist
    loadings = loadings.loc[flattened_factor_order]
    loadings.index = format_variable_names(loadings.index)
    loadings.columns = loadings.columns.map(lambda x: str(x).ljust(15))

    # visualization functions
    def magnify():
        return [dict(selector="tr:hover",
                    props=[("border-top", "2pt solid black"),
                           ("border-bottom", "2pt solid black")]),
                dict(selector="th:hover",
                     props=[("font-size", "10pt")]),
                dict(selector="td",
                     props=[('padding', "0em 0em")]),
               # dict(selector="th:hover",
               #      props=[("font-size", "12pt")]),
                dict(selector="tr:hover td:hover",
                     props=[('max-width', '200px'),
                            ('font-size', '16pt')])
    ]
    
    cm =sns.diverging_palette(220,15,n=200)
    def color_cell(val):
        color = to_hex(cm[int(val*100)+100])
        return 'background-color: %s' % color

    
    styler = loadings.style
    styler \
        .applymap(color_cell) \
        .set_properties(**{'max-width': '100px', 'font-size': '0pt', 'border-color': 'white'})\
        .set_precision(2)\
        .set_table_styles(magnify())
    return styler
def plot_factor_df(EFA, rotate='oblimin'):
    c = EFA.get_c()
    loadings = EFA.get_loading(c, rotate=rotate)
    loadings = EFA.reorder_factors(loadings, rotate=rotate)
    grouping = get_factor_groups(loadings)
    flattened_factor_order = []
    for sublist in [i[1] for i in grouping]:
        flattened_factor_order += sublist
    loadings = loadings.loc[flattened_factor_order]
    loadings.index = format_variable_names(loadings.index)
    loadings.columns = loadings.columns.map(lambda x: str(x).ljust(15))

    # visualization functions
    def magnify():
        return [
            dict(selector="tr:hover",
                 props=[("border-top", "2pt solid black"),
                        ("border-bottom", "2pt solid black")]),
            dict(selector="th:hover", props=[("font-size", "10pt")]),
            dict(selector="td", props=[('padding', "0em 0em")]),
            # dict(selector="th:hover",
            #      props=[("font-size", "12pt")]),
            dict(selector="tr:hover td:hover",
                 props=[('max-width', '200px'), ('font-size', '16pt')])
        ]

    cm = sns.diverging_palette(220, 15, n=200)

    def color_cell(val):
        color = to_hex(cm[int(val * 100) + 100])
        return 'background-color: %s' % color

    styler = loadings.style
    styler \
        .applymap(color_cell) \
        .set_properties(**{'max-width': '100px', 'font-size': '0pt', 'border-color': 'white'})\
        .set_precision(2)\
        .set_table_styles(magnify())
    return styler
def plot_communality(results, c, rotate='oblimin', retest_threshold=.2,
                     size=4.6, dpi=300, ext='png', plot_dir=None):
    EFA = results.EFA
    communality = get_communality(EFA, rotate, c)
    # load retest data
    retest_data = get_retest_data(dataset=results.dataset.replace('Complete','Retest'))
    if retest_data is None:
        print('No retest data found for datafile: %s' % results.dataset)
        return
    
    # reorder data in line with communality
    retest_data = retest_data.loc[communality.index]
    # reformat variable names
    communality.index = format_variable_names(communality.index)
    retest_data.index = format_variable_names(retest_data.index)
    if len(retest_data) > 0:
        adjusted_communality,correlation, noise_ceiling = \
                get_adjusted_communality(communality, 
                                         retest_data,
                                         retest_threshold)
        
    # plot communality bars woo!
    if len(retest_data)>0:
        f, axes = plt.subplots(1, 3, figsize=(3*(size/10), size))
    
        plot_bar_factor(communality, axes[0], width=size/10, height=size,
                        label_rows=True,  title='Communality')
        plot_bar_factor(noise_ceiling, axes[1], width=size/10, height=size,
                        label_rows=False,  title='Test-Retest')
        plot_bar_factor(adjusted_communality, axes[2], width=size/10, height=size,
                        label_rows=False,  title='Adjusted Communality')
    else:
        f = plot_bar_factor(communality, label_rows=True, 
                            width=size/3, height=size*2, title='Communality')
    if plot_dir:
        filename = 'communality_bars-EFA%s.%s' % (c, ext)
        save_figure(f, path.join(plot_dir, filename), 
                    {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
    
    # plot communality histogram
    if len(retest_data) > 0:
        with sns.axes_style('white'):
            colors = sns.color_palette(n_colors=2, desat=.75)
            f, ax = plt.subplots(1,1,figsize=(size,size))
            sns.kdeplot(communality, linewidth=size/4, 
                        shade=True, label='Communality', color=colors[0])
            sns.kdeplot(adjusted_communality, linewidth=size/4, 
                        shade=True, label='Adjusted Communality', color=colors[1])
            ylim = ax.get_ylim()
            ax.vlines(np.mean(communality), ylim[0], ylim[1],
                      color=colors[0], linewidth=size/4, linestyle='--')
            ax.vlines(np.mean(adjusted_communality), ylim[0], ylim[1],
                      color=colors[1], linewidth=size/4, linestyle='--')
            leg=ax.legend(fontsize=size*2, loc='upper right')
            beautify_legend(leg, colors)
            plt.xlabel('Communality', fontsize=size*2)
            plt.ylabel('Normalized Density', fontsize=size*2)
            ax.set_yticks([])
            ax.tick_params(labelsize=size)
            ax.set_ylim(0, ax.get_ylim()[1])
            ax.set_xlim(0, ax.get_xlim()[1])
            ax.spines['right'].set_visible(False)
            #ax.spines['left'].set_visible(False)
            ax.spines['top'].set_visible(False)
            # add correlation
            correlation = format_num(np.mean(correlation))
            ax.text(1.1, 1.25, 'Correlation Between Communality \nand Test-Retest: %s' % correlation,
                    size=size*2)

        if plot_dir:
            filename = 'communality_dist-EFA%s.%s' % (c, ext)
            save_figure(f, path.join(plot_dir, filename), 
                        {'bbox_inches': 'tight', 'dpi': dpi})
            plt.close()
def plot_heatmap_factors(results, c, size=4.6, thresh=75, rotate='oblimin',
                     DA=False, dpi=300, ext='png', plot_dir=None):
    """ Plots factor analytic results as bars
    
    Args:
        results: a dimensional structure results object
        c: the number of components to use
        dpi: the final dpi for the image
        size: scalar - the width of the plot. The height is determined
            by the number of factors
        thresh: proportion of factor loadings to remove
        ext: the extension for the saved figure
        plot_dir: the directory to save the figure. If none, do not save
    """
    if DA:
        EFA = results.DA
    else:
        EFA = results.EFA
    loadings = EFA.get_loading(c, rotate=rotate)
    loadings = EFA.reorder_factors(loadings, rotate=rotate)           
    grouping = get_factor_groups(loadings)
    flattened_factor_order = []
    for sublist in [i[1] for i in grouping]:
        flattened_factor_order += sublist
    loadings = loadings.loc[flattened_factor_order]
    # get threshold for loadings
    if thresh>0:
        thresh_val = np.percentile(abs(loadings).values, thresh)
        print('Thresholding all loadings less than %s' % np.round(thresh_val, 3))
        loadings = loadings.mask(abs(loadings) <= thresh_val, 0)
        # remove variables that don't cross the threshold for any factor
        kept_vars = list(loadings.index[loadings.mean(1)!=0])
        print('%s Variables out of %s are kept after threshold' % (len(kept_vars), loadings.shape[0]))
        loadings = loadings.loc[kept_vars]
        # remove masked variabled from grouping
        threshed_groups = []
        for factor, group in grouping:
            group = [x for x in group if x in kept_vars]
            threshed_groups.append([factor,group])
        grouping = threshed_groups
    # change variable names to make them more readable
    loadings.index = format_variable_names(loadings.index)
    # set up plot variables
    DV_fontsize = size*2/(loadings.shape[0]//2)*30
    figsize = (size,size*2)
    
    f = plt.figure(figsize=figsize)
    ax = f.add_axes([0, 0, .08*loadings.shape[1], 1]) 
    cbar_ax = f.add_axes([.08*loadings.shape[1]+.02,0,.04,1]) 

    max_val = abs(loadings).max().max()
    sns.heatmap(loadings, ax=ax, cbar_ax=cbar_ax,
                vmax =  max_val, vmin = -max_val,
                cbar_kws={'ticks': [-max_val, -max_val/2, 0, max_val/2, max_val]},
                linecolor='white', linewidth=.01,
                cmap=sns.diverging_palette(220,15,n=100,as_cmap=True))
    ax.set_yticks(np.arange(.5,loadings.shape[0]+.5,1))
    ax.set_yticklabels(loadings.index, fontsize=DV_fontsize, rotation=0)
    ax.set_xticklabels(loadings.columns, 
                       fontsize=min(size*3, DV_fontsize*1.5),
                       ha='center',
                       rotation=90)
    ax.tick_params(length=size*.5, width=size/10)
    # format cbar
    cbar_ax.set_yticklabels([format_num(-max_val, 2), 
                             format_num(-max_val/2, 2),
                             0, 
                             format_num(-max_val/2, 2),
                             format_num(max_val, 2)])
    cbar_ax.tick_params(axis='y', length=0)
    cbar_ax.tick_params(labelsize=DV_fontsize*1.5)
    cbar_ax.set_ylabel('Factor Loading', rotation=-90, fontsize=DV_fontsize*2)
    
    # draw lines separating groups
    if grouping is not None:
        factor_breaks = np.cumsum([len(i[1]) for i in grouping])[:-1]
        for y_val in factor_breaks:
            ax.hlines(y_val, 0, loadings.shape[1], lw=size/5, 
                      color='grey', linestyle='dashed')
                
    if plot_dir:
        filename = 'factor_heatmap_EFA%s.%s' % (c, ext)
        save_figure(f, path.join(plot_dir, filename), 
                    {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
def plot_bar_factors(results, c, size=4.6, thresh=75, rotate='oblimin',
                     bar_kws=None, dpi=300, ext='png', plot_dir=None):
    """ Plots factor analytic results as bars
    
    Args:
        results: a dimensional structure results object
        c: the number of components to use
        dpi: the final dpi for the image
        size: scalar - the width of the plot. The height is determined
            by the number of factors
        thresh: proportion of factor loadings to remove
        ext: the extension for the saved figure
        plot_dir: the directory to save the figure. If none, do not save
    """
    # set up plot variables
    
    EFA = results.EFA
    loadings = EFA.reorder_factors(EFA.get_loading(c, rotate=rotate))           
    grouping = get_factor_groups(loadings)
    flattened_factor_order = []
    for sublist in [i[1] for i in grouping]:
        flattened_factor_order += sublist
    loadings = loadings.loc[flattened_factor_order]
    # bootstrap CI
    bootstrap_CI = EFA.get_boot_stats(c, rotate=rotate)
    if bootstrap_CI is not None:
        bootstrap_CI = bootstrap_CI['sds'] * 1.96
        bootstrap_CI = bootstrap_CI.loc[flattened_factor_order]
    # get threshold for loadings
    if thresh>0:
        thresh_val = np.percentile(abs(loadings).values, thresh)
        print('Thresholding all loadings less than %s' % np.round(thresh_val, 3))
        loadings = loadings.mask(abs(loadings) <= thresh_val, 0)
        # remove variables that don't cross the threshold for any factor
        kept_vars = list(loadings.index[loadings.mean(1)!=0])
        print('%s Variables out of %s are kept after threshold' % (len(kept_vars), loadings.shape[0]))
        loadings = loadings.loc[kept_vars]
        if bootstrap_CI is not None:
            bootstrap_CI = bootstrap_CI.mask(abs(loadings) <= thresh_val, 0)
            bootstrap_CI = bootstrap_CI.loc[kept_vars]
        # remove masked variabled from grouping
        threshed_groups = []
        for factor, group in grouping:
            group = [x for x in group if x in kept_vars]
            threshed_groups.append([factor,group])
        grouping = threshed_groups
    # change variable names to make them more readable
    loadings.index = format_variable_names(loadings.index)
    if bootstrap_CI is not None:
        bootstrap_CI.index = format_variable_names(bootstrap_CI.index)
    # plot
    n_factors = len(loadings.columns)
    f, axes = plt.subplots(1, n_factors, figsize=(size, size*2))
    if bar_kws == None:
        bar_kws = {}
    for i, k in enumerate(loadings.columns):
        loading = loadings[k]
        ax = axes[i]
        if bootstrap_CI is not None:
            bootstrap_err = bootstrap_CI[k]
        else:
            bootstrap_err = None
        label_rows=False
        if i==0:
            label_rows = True
        plot_bar_factor(loading, 
                        ax,
                        bootstrap_err, 
                        width=size/n_factors,
                        height=size*2,
                        grouping=grouping,
                        label_rows=label_rows,
                        title=k,
                        **bar_kws
                        )
    if plot_dir:
        filename = 'factor_bars_EFA%s.%s' % (c, ext)
        save_figure(f, path.join(plot_dir, filename), 
                    {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
def plot_bar_factors(results,
                     c,
                     size=4.6,
                     thresh=75,
                     rotate='oblimin',
                     bar_kws=None,
                     dpi=300,
                     ext='png',
                     plot_dir=None):
    """ Plots factor analytic results as bars
    
    Args:
        results: a dimensional structure results object
        c: the number of components to use
        dpi: the final dpi for the image
        size: scalar - the width of the plot. The height is determined
            by the number of factors
        thresh: proportion of factor loadings to remove
        ext: the extension for the saved figure
        plot_dir: the directory to save the figure. If none, do not save
    """
    # set up plot variables

    EFA = results.EFA
    loadings = EFA.reorder_factors(EFA.get_loading(c, rotate=rotate))
    grouping = get_factor_groups(loadings)
    flattened_factor_order = []
    for sublist in [i[1] for i in grouping]:
        flattened_factor_order += sublist
    loadings = loadings.loc[flattened_factor_order]
    # bootstrap CI
    bootstrap_CI = EFA.get_boot_stats(c, rotate=rotate)
    if bootstrap_CI is not None:
        bootstrap_CI = bootstrap_CI['sds'] * 1.96
        bootstrap_CI = bootstrap_CI.loc[flattened_factor_order]
    # get threshold for loadings
    if thresh > 0:
        thresh_val = np.percentile(abs(loadings).values, thresh)
        print('Thresholding all loadings less than %s' %
              np.round(thresh_val, 3))
        loadings = loadings.mask(abs(loadings) <= thresh_val, 0)
        # remove variables that don't cross the threshold for any factor
        kept_vars = list(loadings.index[loadings.mean(1) != 0])
        print('%s Variables out of %s are kept after threshold' %
              (len(kept_vars), loadings.shape[0]))
        loadings = loadings.loc[kept_vars]
        if bootstrap_CI is not None:
            bootstrap_CI = bootstrap_CI.mask(abs(loadings) <= thresh_val, 0)
            bootstrap_CI = bootstrap_CI.loc[kept_vars]
        # remove masked variabled from grouping
        threshed_groups = []
        for factor, group in grouping:
            group = [x for x in group if x in kept_vars]
            threshed_groups.append([factor, group])
        grouping = threshed_groups
    # change variable names to make them more readable
    loadings.index = format_variable_names(loadings.index)
    if bootstrap_CI is not None:
        bootstrap_CI.index = format_variable_names(bootstrap_CI.index)
    # plot
    n_factors = len(loadings.columns)
    f, axes = plt.subplots(1, n_factors, figsize=(size, size * 2))
    if bar_kws == None:
        bar_kws = {}
    for i, k in enumerate(loadings.columns):
        loading = loadings[k]
        ax = axes[i]
        if bootstrap_CI is not None:
            bootstrap_err = bootstrap_CI[k]
        else:
            bootstrap_err = None
        label_rows = False
        if i == 0:
            label_rows = True
        plot_bar_factor(loading,
                        ax,
                        bootstrap_err,
                        width=size / n_factors,
                        height=size * 2,
                        grouping=grouping,
                        label_rows=label_rows,
                        title=k,
                        **bar_kws)
    if plot_dir:
        filename = 'factor_bars_EFA%s.%s' % (c, ext)
        save_figure(f, path.join(plot_dir, filename), {
            'bbox_inches': 'tight',
            'dpi': dpi
        })
        plt.close()
def plot_subbranches(results, rotate='oblimin', EFA_clustering=True,
                     cluster_range=None, absolute_loading=False,
                     size=2.3, dpi=300, ext='png', plot_dir=None):
    """ Plots HCA results as dendrogram with loadings underneath
    
    Args:
        results: results object
        c: number of components to use for loadings
        orientation: horizontal or vertical, which determines the direction
            the dendrogram leaves should be spread out on
        plot_dir: if set, where to save the plot
        inp: which clustering solution to use
        titles: list of titles. Should correspond to number of clusters in
                results object if "inp" is not set. Otherwise should be a list of length 1.
    """
    HCA = results.HCA
    EFA = results.EFA
    loading = EFA.reorder_factors(EFA.get_loading(rotate=rotate), rotate=rotate)
    loading.index = format_variable_names(loading.index)
    if EFA_clustering:
        inp = 'EFA%s_%s' % (EFA.get_c(), rotate)
    else:
        inp = 'data'
    clustering = HCA.results[inp]
    name = inp
    
    # extract cluster vars
    link = clustering['linkage']
    labels = clustering['clustered_df'].columns
    labels = format_variable_names(labels)
    ordered_loading = loading.loc[labels]
    if absolute_loading:
        ordered_loading = abs(ordered_loading)
    # get cluster sizes
    cluster_labels, DVs= list(zip(*HCA.get_cluster_DVs(inp=name).items()))
    cluster_sizes = [len(i) for i in DVs]
    link_function, colors = get_dendrogram_color_fun(link, clustering['reorder_vec'],
                                                     clustering['labels'])
    tree = dendrogram(link,  link_color_func=link_function, no_plot=True,
                      no_labels=True)
    
    if plot_dir is not None:
        function_directory = 'subbranches_input-%s' % inp
        makedirs(path.join(plot_dir, function_directory), exist_ok=True)
        
    plot_loc = None
    if cluster_range is None:
        cluster_range = range(len(cluster_labels))
    # titles = 
    figs = []
    for cluster_i in cluster_range:
        if plot_dir:
            filey = 'cluster_%s.%s' % (str(cluster_i).zfill(2), ext)
            plot_loc = path.join(plot_dir, function_directory, filey)
        fig = plot_subbranch(colors[cluster_i], cluster_i, tree, 
                             ordered_loading, cluster_sizes,
                             title=cluster_labels[cluster_i], 
                             size=size, plot_loc=plot_loc)
        if fig:
            figs.append(fig)
    return figs
def display_cluster_DVs(consensus, results):
    nth = {
        1: "first",
        2: "second",
        3: "third",
        4: "fourth",
        5: "fifth",
        6: "sixth",
        7: "seventh",
        8: "eigth",
        9: "ninth",
        10: "tenth",
        11: "11th",
        12: "12th",
        13: "13th",
        14: "14th",
        15: "15th",
        16: "16th",
        17: "17th",
        18: "18th",
        19: "19th",
        20: "20th",
    }
    c = results.EFA.get_c()
    cluster_DVs = results.HCA.get_cluster_DVs(inp='EFA%s_oblimin' % c)
    df = consensus.get_consensus_cluster()['distance_df']
    sorted_df = pd.DataFrame(data=np.zeros((len(df), 20)), index=df.index)
    for name, row in sorted_df.iterrows():
        neighbors = [
            v for v in cluster_DVs.values() if name in format_variable_names(v)
        ][0]
        neighbors = format_variable_names(neighbors)
        closest = 1 - df.loc[name, neighbors].drop(name).sort_values()
        closest = [
            '%s: %s%%' % (i, int(b * 100)) for i, b in closest.iteritems()
        ]
        closest += ['' for _ in range(20 - len(closest))]
        sorted_df.loc[name] = closest

    def magnify():
        return [
            dict(selector="tr:hover",
                 props=[("border-top", "2pt solid black"),
                        ("border-bottom", "2pt solid black")]),
            dict(selector="th:hover", props=[("font-size", "10pt")]),
            dict(selector="td", props=[('padding', "0em 0em")]),
            # dict(selector="th:hover",
            #      props=[("font-size", "12pt")]),
            dict(selector="tr:hover td:hover",
                 props=[('max-width', '200px'), ('font-weight', 'bold'),
                        ('color', 'black'), ('font-size', '9pt')])
        ]

    cm = sns.diverging_palette(220, 15, n=161)

    def color_cell(val):
        if val == '':
            return 'background-color: None'
        num = val[val.rindex(': ') + 2:val.rindex('%')]
        color = to_hex(cm[int(num) + 30])
        return 'background-color: %s' % color

    styler = sorted_df.style
    styler \
        .applymap(color_cell) \
        .set_properties(**{'max-width': '100px',  'font-size': '10pt', 'border-color': 'white'})\
        .set_precision(2)\
        .set_table_styles(magnify())
    return styler
def plot_task_factors(results,
                      c,
                      rotate='oblimin',
                      task_sublists=None,
                      normalize_loadings=False,
                      figsize=10,
                      dpi=300,
                      ext='png',
                      plot_dir=None):
    """ Plots task factors as polar plots
    
    Args:
        results: a dimensional structure results object
        c: the number of components to use
        task_sublists: a dictionary whose values are sets of tasks, and 
                        whose keywords are labels for those lists
        dpi: the final dpi for the image
        figsize: scalar - a width multiplier for the plot
        ext: the extension for the saved figure
        plot_dir: the directory to save the figure. If none, do not save
    """
    EFA = results.EFA
    # plot task factor loading
    loadings = EFA.get_loading(c, rotate=rotate)
    max_loading = abs(loadings).max().max()
    tasks = np.unique([i.split('.')[0] for i in loadings.index])

    if task_sublists is None:
        task_sublists = {
            'surveys': [t for t in tasks if 'survey' in t],
            'tasks': [t for t in tasks if 'survey' not in t]
        }

    for sublist_name, task_sublist in task_sublists.items():
        for i, task in enumerate(task_sublist):
            # plot loading distributions. Each measure is scaled so absolute
            # comparisons are impossible. Only the distributions can be compared
            f, ax = plt.subplots(1,
                                 1,
                                 figsize=(figsize, figsize),
                                 subplot_kw={'projection': 'polar'})
            task_loadings = loadings.filter(regex='^%s' % task, axis=0)
            task_loadings.index = format_variable_names(task_loadings.index)
            if normalize_loadings:
                task_loadings = task_loadings = (task_loadings.T /
                                                 abs(task_loadings).max(1)).T
            # format variable names
            task_loadings.index = format_variable_names(task_loadings.index)
            # plot
            visualize_task_factors(task_loadings,
                                   ax,
                                   ymax=max_loading,
                                   xticklabels=True,
                                   label_size=figsize * 2)
            ax.set_title(' '.join(task.split('_')), y=1.14, fontsize=25)

            if plot_dir is not None:
                if normalize_loadings:
                    function_directory = 'factor_DVnormdist_EFA%s_subset-%s' % (
                        c, sublist_name)
                else:
                    function_directory = 'factor_DVdist_EFA%s_subset-%s' % (
                        c, sublist_name)
                makedirs(path.join(plot_dir, function_directory),
                         exist_ok=True)
                filename = '%s.%s' % (task, ext)
                save_figure(f, path.join(plot_dir, function_directory,
                                         filename), {
                                             'bbox_inches': 'tight',
                                             'dpi': dpi
                                         })
                plt.close()
def plot_heatmap_factors(results,
                         c,
                         size=4.6,
                         thresh=75,
                         rotate='oblimin',
                         DA=False,
                         dpi=300,
                         ext='png',
                         plot_dir=None):
    """ Plots factor analytic results as bars
    
    Args:
        results: a dimensional structure results object
        c: the number of components to use
        dpi: the final dpi for the image
        size: scalar - the width of the plot. The height is determined
            by the number of factors
        thresh: proportion of factor loadings to remove
        ext: the extension for the saved figure
        plot_dir: the directory to save the figure. If none, do not save
    """
    if DA:
        EFA = results.DA
    else:
        EFA = results.EFA
    loadings = EFA.get_loading(c, rotate=rotate)
    loadings = EFA.reorder_factors(loadings, rotate=rotate)
    grouping = get_factor_groups(loadings)
    flattened_factor_order = []
    for sublist in [i[1] for i in grouping]:
        flattened_factor_order += sublist
    loadings = loadings.loc[flattened_factor_order]
    # get threshold for loadings
    if thresh > 0:
        thresh_val = np.percentile(abs(loadings).values, thresh)
        print('Thresholding all loadings less than %s' %
              np.round(thresh_val, 3))
        loadings = loadings.mask(abs(loadings) <= thresh_val, 0)
        # remove variables that don't cross the threshold for any factor
        kept_vars = list(loadings.index[loadings.mean(1) != 0])
        print('%s Variables out of %s are kept after threshold' %
              (len(kept_vars), loadings.shape[0]))
        loadings = loadings.loc[kept_vars]
        # remove masked variabled from grouping
        threshed_groups = []
        for factor, group in grouping:
            group = [x for x in group if x in kept_vars]
            threshed_groups.append([factor, group])
        grouping = threshed_groups
    # change variable names to make them more readable
    loadings.index = format_variable_names(loadings.index)
    # set up plot variables
    DV_fontsize = size * 2 / (loadings.shape[0] // 2) * 30
    figsize = (size, size * 2)

    f = plt.figure(figsize=figsize)
    ax = f.add_axes([0, 0, .08 * loadings.shape[1], 1])
    cbar_ax = f.add_axes([.08 * loadings.shape[1] + .02, 0, .04, 1])

    max_val = abs(loadings).max().max()
    sns.heatmap(
        loadings,
        ax=ax,
        cbar_ax=cbar_ax,
        vmax=max_val,
        vmin=-max_val,
        cbar_kws={'ticks': [-max_val, -max_val / 2, 0, max_val / 2, max_val]},
        linecolor='white',
        linewidth=.01,
        cmap=sns.diverging_palette(220, 15, n=100, as_cmap=True))
    ax.set_yticks(np.arange(.5, loadings.shape[0] + .5, 1))
    ax.set_yticklabels(loadings.index, fontsize=DV_fontsize, rotation=0)
    ax.set_xticklabels(loadings.columns,
                       fontsize=min(size * 3, DV_fontsize * 1.5),
                       ha='center',
                       rotation=90)
    ax.tick_params(length=size * .5, width=size / 10)
    # format cbar
    cbar_ax.set_yticklabels([
        format_num(-max_val, 2),
        format_num(-max_val / 2, 2), 0,
        format_num(-max_val / 2, 2),
        format_num(max_val, 2)
    ])
    cbar_ax.tick_params(axis='y', length=0)
    cbar_ax.tick_params(labelsize=DV_fontsize * 1.5)
    cbar_ax.set_ylabel('Factor Loading',
                       rotation=-90,
                       fontsize=DV_fontsize * 2)

    # draw lines separating groups
    if grouping is not None:
        factor_breaks = np.cumsum([len(i[1]) for i in grouping])[:-1]
        for y_val in factor_breaks:
            ax.hlines(y_val,
                      0,
                      loadings.shape[1],
                      lw=size / 5,
                      color='grey',
                      linestyle='dashed')

    if plot_dir:
        filename = 'factor_heatmap_EFA%s.%s' % (c, ext)
        save_figure(f, path.join(plot_dir, filename), {
            'bbox_inches': 'tight',
            'dpi': dpi
        })
        plt.close()
示例#15
0
def plot_subbranches(results,
                     rotate='oblimin',
                     EFA_clustering=True,
                     cluster_range=None,
                     absolute_loading=False,
                     size=2.3,
                     dpi=300,
                     ext='png',
                     plot_dir=None):
    """ Plots HCA results as dendrogram with loadings underneath
    
    Args:
        results: results object
        c: number of components to use for loadings
        orientation: horizontal or vertical, which determines the direction
            the dendrogram leaves should be spread out on
        plot_dir: if set, where to save the plot
        inp: which clustering solution to use
        titles: list of titles. Should correspond to number of clusters in
                results object if "inp" is not set. Otherwise should be a list of length 1.
    """
    HCA = results.HCA
    EFA = results.EFA
    loading = EFA.reorder_factors(EFA.get_loading(rotate=rotate),
                                  rotate=rotate)
    loading.index = format_variable_names(loading.index)
    if EFA_clustering:
        inp = 'EFA%s_%s' % (EFA.get_c(), rotate)
    else:
        inp = 'data'
    clustering = HCA.results[inp]
    name = inp

    # extract cluster vars
    link = clustering['linkage']
    labels = clustering['clustered_df'].columns
    labels = format_variable_names(labels)
    ordered_loading = loading.loc[labels]
    if absolute_loading:
        ordered_loading = abs(ordered_loading)
    # get cluster sizes
    cluster_labels, DVs = list(zip(*HCA.get_cluster_DVs(inp=name).items()))
    cluster_sizes = [len(i) for i in DVs]
    link_function, colors = get_dendrogram_color_fun(link,
                                                     clustering['reorder_vec'],
                                                     clustering['labels'])
    tree = dendrogram(link,
                      link_color_func=link_function,
                      no_plot=True,
                      no_labels=True)

    if plot_dir is not None:
        function_directory = 'subbranches_input-%s' % inp
        makedirs(path.join(plot_dir, function_directory), exist_ok=True)

    plot_loc = None
    if cluster_range is None:
        cluster_range = range(len(cluster_labels))
    # titles =
    figs = []
    for cluster_i in cluster_range:
        print(cluster_i)
        if plot_dir:
            filey = 'cluster_%s.%s' % (str(cluster_i).zfill(2), ext)
            print('filey: ' + filey)
            plot_loc = path.join(plot_dir, function_directory, filey)
            print('plot_loc: ' + plot_loc)
        fig = plot_subbranch(colors[cluster_i],
                             cluster_i,
                             tree,
                             ordered_loading,
                             cluster_sizes,
                             title=cluster_labels[cluster_i],
                             size=size,
                             plot_loc=plot_loc)
        if fig:
            figs.append(fig)
    return figs
def plot_cross_communality(all_results, rotate='oblimin', retest_threshold=.2,
                           size=4.6, dpi=300, ext='png', plot_dir=None):
    
    retest_data = None
    num_cols = 2
    num_rows = math.ceil(len(all_results.keys())/2)
    with sns.axes_style('white'):
        f, axes = plt.subplots(num_rows, num_cols, figsize=(size, size/2*num_rows))
    max_y = 0
    for i, (name, results) in enumerate(all_results.items()):
        if retest_data is None:
            # load retest data
            retest_data = get_retest_data(dataset=results.dataset.replace('Complete','Retest'))
            if retest_data is None:
                print('No retest data found for datafile: %s' % results.dataset)
        c = results.EFA.get_c()
        EFA = results.EFA
        loading = EFA.get_loading(c, rotate=rotate)
        # get communality from psych out
        fa = EFA.results['factor_tree_Rout_%s' % rotate][c]
        communality = get_attr(fa, 'communalities')
        communality = pd.Series(communality, index=loading.index)
        # alternative calculation
        #communality = (loading**2).sum(1).sort_values()
        communality.index = [i.replace('.logTr','') for i in communality.index]
        
        # reorder data in line with communality
        retest_subset= retest_data.loc[communality.index]
        # reformat variable names
        communality.index = format_variable_names(communality.index)
        retest_subset.index = format_variable_names(retest_subset.index)
        if len(retest_subset) > 0:
            # noise ceiling
            noise_ceiling = retest_subset.pearson
            # remove very low reliabilities
            if retest_threshold:
                noise_ceiling[noise_ceiling<retest_threshold]= np.nan
            # adjust
            adjusted_communality = communality/noise_ceiling
            
        # plot communality histogram
        if len(retest_subset) > 0:
            ax = axes[i]
            ax.set_title(name.title(), fontweight='bold', fontsize=size*2)
            colors = sns.color_palette(n_colors=2, desat=.75)
            sns.kdeplot(communality, linewidth=size/4, ax=ax, vertical=True,
                        shade=True, label='Communality', color=colors[0])
            sns.kdeplot(adjusted_communality, linewidth=size/4, ax=ax, vertical=True,
                        shade=True, label='Adjusted Communality', color=colors[1])
            xlim = ax.get_xlim()
            ax.hlines(np.mean(communality), xlim[0], xlim[1],
                      color=colors[0], linewidth=size/4, linestyle='--')
            ax.hlines(np.mean(adjusted_communality), xlim[0], xlim[1],
                      color=colors[1], linewidth=size/4, linestyle='--')
            ax.set_xticks([])
            ax.tick_params(labelsize=size*1.2)
            ax.set_ylim(0, ax.get_ylim()[1])
            ax.set_xlim(0, ax.get_xlim()[1])
            ax.spines['right'].set_visible(False)
            ax.spines['top'].set_visible(False)
            if (i+1) == len(all_results):
                ax.set_xlabel('Normalized Density', fontsize=size*2)
                leg=ax.legend(fontsize=size*1.5, loc='upper right',
                              bbox_to_anchor=(1.2, 1.0), 
                              handlelength=0, handletextpad=0)
                beautify_legend(leg, colors)
            elif i>=len(all_results)-2:
                ax.set_xlabel('Normalized Density', fontsize=size*2)
                ax.legend().set_visible(False)
            else:
                ax.legend().set_visible(False)
            if i%2==0:
                ax.set_ylabel('Communality', fontsize=size*2)
                ax.tick_params(labelleft=True, left=True, 
                               length=size/4, width=size/8)
            else:
                ax.tick_params(labelleft=False, left=True, 
                               length=0, width=size/8)
            # update max_x
            if ax.get_ylim()[1] > max_y:
                max_y = ax.get_ylim()[1]
            ax.grid(False)
            [i.set_linewidth(size*.1) for i in ax.spines.values()]
        for ax in axes:
            ax.set_ylim((0, max_y))
        plt.subplots_adjust(wspace=0)
                    
        if plot_dir:
            filename = 'communality_adjustment.%s' % ext
            save_figure(f, path.join(plot_dir, rotate, filename), 
                        {'bbox_inches': 'tight', 'dpi': dpi})
            plt.close()
def display_cluster_DVs(consensus, results):
    nth = {
        1: "first",
        2: "second",
        3: "third",
        4: "fourth",
        5: "fifth",
        6: "sixth",
        7: "seventh",
        8: "eigth",
        9: "ninth",
        10: "tenth",
        11: "11th",
        12: "12th",
        13: "13th",
        14: "14th",
        15: "15th",
        16: "16th",
        17: "17th",
        18: "18th",
        19: "19th",
        20: "20th",
    }
    c = results.EFA.get_c()
    cluster_DVs = results.HCA.get_cluster_DVs(inp='EFA%s_oblimin' % c)
    df = consensus.get_consensus_cluster()['distance_df']
    sorted_df = pd.DataFrame(data=np.zeros((len(df),20)), index=df.index)
    for name, row in sorted_df.iterrows():
        neighbors = [v for v in cluster_DVs.values() if name in format_variable_names(v)][0]
        neighbors = format_variable_names(neighbors)
        closest = 1-df.loc[name, neighbors].drop(name).sort_values()
        closest = ['%s: %s%%' % (i,int(b*100)) for i,b in closest.iteritems()]
        closest += ['' for _ in range(20-len(closest))]
        sorted_df.loc[name] = closest

    def magnify():
        return [dict(selector="tr:hover",
                    props=[("border-top", "2pt solid black"),
                           ("border-bottom", "2pt solid black")]),
                dict(selector="th:hover",
                     props=[("font-size", "10pt")]),
                dict(selector="td",
                     props=[('padding', "0em 0em")]),
               # dict(selector="th:hover",
               #      props=[("font-size", "12pt")]),
                dict(selector="tr:hover td:hover",
                     props=[('max-width', '200px'),
                            ('font-weight', 'bold'),
                            ('color', 'black'),
                           ('font-size', '9pt')])
    ]

    cm =sns.diverging_palette(220,15,n=161)
    def color_cell(val):
        if val =='':
            return 'background-color: None'
        num = val[val.rindex(': ')+2:val.rindex('%')]
        color = to_hex(cm[int(num)+30])
        return 'background-color: %s' % color


    styler = sorted_df.style
    styler \
        .applymap(color_cell) \
        .set_properties(**{'max-width': '100px',  'font-size': '10pt', 'border-color': 'white'})\
        .set_precision(2)\
        .set_table_styles(magnify())
    return styler
def plot_cross_communality(all_results,
                           rotate='oblimin',
                           retest_threshold=.2,
                           size=4.6,
                           dpi=300,
                           ext='png',
                           plot_dir=None):

    retest_data = None
    num_cols = 2
    num_rows = math.ceil(len(all_results.keys()) / 2)
    with sns.axes_style('white'):
        f, axes = plt.subplots(num_rows,
                               num_cols,
                               figsize=(size, size / 2 * num_rows))
    max_y = 0
    for i, (name, results) in enumerate(all_results.items()):
        if retest_data is None:
            # load retest data
            retest_data = get_retest_data(
                dataset=results.dataset.replace('Complete', 'Retest'))
            if retest_data is None:
                print('No retest data found for datafile: %s' %
                      results.dataset)
        c = results.EFA.get_c()
        EFA = results.EFA
        loading = EFA.get_loading(c, rotate=rotate)
        # get communality from psych out
        fa = EFA.results['factor_tree_Rout_%s' % rotate][c]
        communality = get_attr(fa, 'communalities')
        communality = pd.Series(communality, index=loading.index)
        # alternative calculation
        #communality = (loading**2).sum(1).sort_values()
        communality.index = [
            i.replace('.logTr', '') for i in communality.index
        ]

        # reorder data in line with communality
        retest_subset = retest_data.loc[communality.index]
        # reformat variable names
        communality.index = format_variable_names(communality.index)
        retest_subset.index = format_variable_names(retest_subset.index)
        if len(retest_subset) > 0:
            # noise ceiling
            noise_ceiling = retest_subset.pearson
            # remove very low reliabilities
            if retest_threshold:
                noise_ceiling[noise_ceiling < retest_threshold] = np.nan
            # adjust
            adjusted_communality = communality / noise_ceiling

        # plot communality histogram
        if len(retest_subset) > 0:
            ax = axes[i]
            ax.set_title(name.title(), fontweight='bold', fontsize=size * 2)
            colors = sns.color_palette(n_colors=2, desat=.75)
            sns.kdeplot(communality,
                        linewidth=size / 4,
                        ax=ax,
                        vertical=True,
                        shade=True,
                        label='Communality',
                        color=colors[0])
            sns.kdeplot(adjusted_communality,
                        linewidth=size / 4,
                        ax=ax,
                        vertical=True,
                        shade=True,
                        label='Adjusted Communality',
                        color=colors[1])
            xlim = ax.get_xlim()
            ax.hlines(np.mean(communality),
                      xlim[0],
                      xlim[1],
                      color=colors[0],
                      linewidth=size / 4,
                      linestyle='--')
            ax.hlines(np.mean(adjusted_communality),
                      xlim[0],
                      xlim[1],
                      color=colors[1],
                      linewidth=size / 4,
                      linestyle='--')
            ax.set_xticks([])
            ax.tick_params(labelsize=size * 1.2)
            ax.set_ylim(0, ax.get_ylim()[1])
            ax.set_xlim(0, ax.get_xlim()[1])
            ax.spines['right'].set_visible(False)
            ax.spines['top'].set_visible(False)
            if (i + 1) == len(all_results):
                ax.set_xlabel('Normalized Density', fontsize=size * 2)
                leg = ax.legend(fontsize=size * 1.5,
                                loc='upper right',
                                bbox_to_anchor=(1.2, 1.0),
                                handlelength=0,
                                handletextpad=0)
                beautify_legend(leg, colors)
            elif i >= len(all_results) - 2:
                ax.set_xlabel('Normalized Density', fontsize=size * 2)
                ax.legend().set_visible(False)
            else:
                ax.legend().set_visible(False)
            if i % 2 == 0:
                ax.set_ylabel('Communality', fontsize=size * 2)
                ax.tick_params(labelleft=True,
                               left=True,
                               length=size / 4,
                               width=size / 8)
            else:
                ax.tick_params(labelleft=False,
                               left=True,
                               length=0,
                               width=size / 8)
            # update max_x
            if ax.get_ylim()[1] > max_y:
                max_y = ax.get_ylim()[1]
            ax.grid(False)
            [i.set_linewidth(size * .1) for i in ax.spines.values()]
        for ax in axes:
            ax.set_ylim((0, max_y))
        plt.subplots_adjust(wspace=0)

        if plot_dir:
            filename = 'communality_adjustment.%s' % ext
            save_figure(f, path.join(plot_dir, rotate, filename), {
                'bbox_inches': 'tight',
                'dpi': dpi
            })
            plt.close()
def plot_communality(results,
                     c,
                     rotate='oblimin',
                     retest_threshold=.2,
                     size=4.6,
                     dpi=300,
                     ext='png',
                     plot_dir=None):
    EFA = results.EFA
    communality = get_communality(EFA, rotate, c)
    # load retest data
    retest_data = get_retest_data(
        dataset=results.dataset.replace('Complete', 'Retest'))
    if retest_data is None:
        print('No retest data found for datafile: %s' % results.dataset)
        return

    # reorder data in line with communality
    retest_data = retest_data.loc[communality.index]
    # reformat variable names
    communality.index = format_variable_names(communality.index)
    retest_data.index = format_variable_names(retest_data.index)
    if len(retest_data) > 0:
        adjusted_communality,correlation, noise_ceiling = \
                get_adjusted_communality(communality,
                                         retest_data,
                                         retest_threshold)

    # plot communality bars woo!
    if len(retest_data) > 0:
        f, axes = plt.subplots(1, 3, figsize=(3 * (size / 10), size))

        plot_bar_factor(communality,
                        axes[0],
                        width=size / 10,
                        height=size,
                        label_rows=True,
                        title='Communality')
        plot_bar_factor(noise_ceiling,
                        axes[1],
                        width=size / 10,
                        height=size,
                        label_rows=False,
                        title='Test-Retest')
        plot_bar_factor(adjusted_communality,
                        axes[2],
                        width=size / 10,
                        height=size,
                        label_rows=False,
                        title='Adjusted Communality')
    else:
        f = plot_bar_factor(communality,
                            label_rows=True,
                            width=size / 3,
                            height=size * 2,
                            title='Communality')
    if plot_dir:
        filename = 'communality_bars-EFA%s.%s' % (c, ext)
        save_figure(f, path.join(plot_dir, filename), {
            'bbox_inches': 'tight',
            'dpi': dpi
        })
        plt.close()

    # plot communality histogram
    if len(retest_data) > 0:
        with sns.axes_style('white'):
            colors = sns.color_palette(n_colors=2, desat=.75)
            f, ax = plt.subplots(1, 1, figsize=(size, size))
            sns.kdeplot(communality,
                        linewidth=size / 4,
                        shade=True,
                        label='Communality',
                        color=colors[0])
            sns.kdeplot(adjusted_communality,
                        linewidth=size / 4,
                        shade=True,
                        label='Adjusted Communality',
                        color=colors[1])
            ylim = ax.get_ylim()
            ax.vlines(np.mean(communality),
                      ylim[0],
                      ylim[1],
                      color=colors[0],
                      linewidth=size / 4,
                      linestyle='--')
            ax.vlines(np.mean(adjusted_communality),
                      ylim[0],
                      ylim[1],
                      color=colors[1],
                      linewidth=size / 4,
                      linestyle='--')
            leg = ax.legend(fontsize=size * 2, loc='upper right')
            beautify_legend(leg, colors)
            plt.xlabel('Communality', fontsize=size * 2)
            plt.ylabel('Normalized Density', fontsize=size * 2)
            ax.set_yticks([])
            ax.tick_params(labelsize=size)
            ax.set_ylim(0, ax.get_ylim()[1])
            ax.set_xlim(0, ax.get_xlim()[1])
            ax.spines['right'].set_visible(False)
            #ax.spines['left'].set_visible(False)
            ax.spines['top'].set_visible(False)
            # add correlation
            correlation = format_num(np.mean(correlation))
            ax.text(1.1,
                    1.25,
                    'Correlation Between Communality \nand Test-Retest: %s' %
                    correlation,
                    size=size * 2)

        if plot_dir:
            filename = 'communality_dist-EFA%s.%s' % (c, ext)
            save_figure(f, path.join(plot_dir, filename), {
                'bbox_inches': 'tight',
                'dpi': dpi
            })
            plt.close()