def plot_EFA_relationships(all_results):
    EFA_all_results = {k:v.EFA for k,v in all_results.items()}
    scores = {k:v.get_scores() for k,v in EFA_all_results.items()}
    # quantify relationships using linear regression
    for name1, name2 in combinations(scores.keys(), 2):
        scores1 = scores[name1]
        scores2 = scores[name2]
        lr = LinearRegression()  
        cv_score = np.mean(cross_val_score(lr, scores1, scores2, cv=10))
        print(name1, name2, cv_score)
    # plot
    # plot task factors in task PCA space
    pca = PCA(2)
    task_pca = pca.fit_transform(scores['task'])
    palettes = ['Reds', 'Blues', 'Greens']
    all_colors = []
    # plot scores in task PCA space
    f, ax = plt.subplots(figsize=[12,8])
    ax.set_facecolor('white')

    for k,v in scores.items():
        palette = sns.color_palette(palettes.pop(), n_colors = len(v.columns))
        all_colors += palette
        lr = LinearRegression()
        lr.fit(task_pca, v)
        for i, coef in enumerate(lr.coef_):
            plt.plot([0,coef[0]], [0, coef[1]], linewidth=3, 
                     c=palette[i], label=k+'_'+str(v.columns[i]))
    leg = plt.legend(bbox_to_anchor=(.8, .5))
    frame = leg.get_frame()
    frame.set_color('black')
    beautify_legend(leg, all_colors)
def plot_EFA_relationships(all_results):
    EFA_all_results = {k: v.EFA for k, v in all_results.items()}
    scores = {k: v.get_scores() for k, v in EFA_all_results.items()}
    # quantify relationships using linear regression
    for name1, name2 in combinations(scores.keys(), 2):
        scores1 = scores[name1]
        scores2 = scores[name2]
        lr = LinearRegression()
        cv_score = np.mean(cross_val_score(lr, scores1, scores2, cv=10))
        print(name1, name2, cv_score)
    # plot
    # plot task factors in task PCA space
    pca = PCA(2)
    task_pca = pca.fit_transform(scores['task'])
    palettes = ['Reds', 'Blues', 'Greens']
    all_colors = []
    # plot scores in task PCA space
    f, ax = plt.subplots(figsize=[12, 8])
    ax.set_facecolor('white')

    for k, v in scores.items():
        palette = sns.color_palette(palettes.pop(), n_colors=len(v.columns))
        all_colors += palette
        lr = LinearRegression()
        lr.fit(task_pca, v)
        for i, coef in enumerate(lr.coef_):
            plt.plot([0, coef[0]], [0, coef[1]],
                     linewidth=3,
                     c=palette[i],
                     label=k + '_' + str(v.columns[i]))
    leg = plt.legend(bbox_to_anchor=(.8, .5))
    frame = leg.get_frame()
    frame.set_color('black')
    beautify_legend(leg, all_colors)
def visualize_task_factors(task_loadings, ax, xticklabels=True, label_size=12,
                           yticklabels=False, pad=0, ymax=None, legend=True):
    """Plot task loadings on one axis"""
    n_measures = len(task_loadings)
    colors = sns.hls_palette(len(task_loadings), l=.4, s=.8)
    for i, (name, DV) in enumerate(task_loadings.iterrows()):
        plot_loadings(ax, abs(DV)+pad, width_scale=1/(n_measures), 
                      colors = [colors[i]], offset=i+.5,
                      kind='line',
                      plot_kws={'label': name, 'alpha': .8})
    # set up yticks
    if ymax:
        ax.set_ylim(top=ymax)
    ytick_locs = ax.yaxis.get_ticklocs()
    new_yticks = np.linspace(0, ytick_locs[-1], 7)
    ax.set_yticks(new_yticks)
    if yticklabels:
        labels = np.round(new_yticks,2)
        replace_dict = {i:'' for i in labels[::2]}
        labels = [replace_dict.get(i, i) for i in labels]
        ax.set_yticklabels(labels)
    # set up x ticks
    xtick_locs = np.arange(0.0, 2*np.pi, 2*np.pi/len(DV))
    ax.set_xticks(xtick_locs)
    ax.set_xticks(xtick_locs+np.pi/len(DV), minor=True)
    if xticklabels:
        labels = task_loadings.columns
        if type(labels[0]) != str:
            labels = ['Fac %s' % str(i) for i in labels]
        scale = 1.2
        size = ax.get_position().expanded(scale, scale)
        ax2=ax.get_figure().add_axes(size,zorder=2)
        max_var_length = max([len(v) for v in labels])
        for i, var in enumerate(labels):
            offset=.3*25/len(labels)**2
            start = (i-offset)*2*np.pi/len(labels)
            end = (i+(1-offset))*2*np.pi/len(labels)
            curve = [
                np.cos(np.linspace(start,end,100)),
                np.sin(np.linspace(start,end,100))
            ]  
            plt.plot(*curve, alpha=0)
            # pad strings to longest length
            num_spaces = (max_var_length-len(var))
            var = ' '*(num_spaces//2) + var + ' '*(num_spaces-num_spaces//2)
            curvetext = CurvedText(
                x = curve[0][::-1],
                y = curve[1][::-1],
                text=var, #'this this is a very, very long text',
                va = 'top',
                axes = ax2,
                fontsize=label_size##calls ax.add_artist in __init__
            )
            ax2.axis('off')
    if legend:
        leg = ax.legend(loc='upper center', bbox_to_anchor=(.5,-.15), frameon=False)
        beautify_legend(leg, colors[:len(task_loadings)])
def importance_bar_plots(predictions,
                         target_order=None,
                         show_sign=True,
                         colorbar=True,
                         size=5,
                         dpi=300,
                         filename=None):
    #palette = sns.cubehelix_palette(100)
    # plot
    if target_order is None:
        target_order = predictions.keys()
    n_predictors = len(predictions[list(target_order)[0]]['importances'][0])
    #set up color styling
    palette = sns.color_palette('Blues_d', n_predictors)
    # get max r2
    max_r2 = 0
    vals = [predictions[i] for i in target_order]
    max_r2 = max(max_r2, max([i['scores_cv'][0]['R2'] for i in vals]))
    importances = [(i['predvars'], i['importances'][0]) for i in vals]
    prediction_df = pd.DataFrame([i[1] for i in importances],
                                 columns=importances[0][0],
                                 index=target_order)
    prediction_df.sort_values(axis=1,
                              by=prediction_df.index[0],
                              inplace=True,
                              ascending=False)

    # plot
    sns.set_style('white')
    ax = prediction_df.plot(kind='bar',
                            edgecolor=None,
                            linewidth=0,
                            figsize=(size, size * .67),
                            color=palette)
    fig = ax.get_figure()
    ax.tick_params(labelsize=size)
    #ax.tick_params(axis='x', rotation=0)
    ax.set_ylabel(r'Standardized $\beta$', fontsize=size * 1.5)
    # set up legend and other aesthetic
    ax.grid(axis='y', linewidth=size / 10)
    leg = ax.legend(frameon=False,
                    fontsize=size * 1.5,
                    bbox_to_anchor=(1.25, .8),
                    handlelength=0,
                    handletextpad=0,
                    framealpha=1)
    beautify_legend(leg, colors=palette)
    for name, spine in ax.spines.items():
        spine.set_visible(False)
    if filename is not None:
        save_figure(fig, filename, {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
    else:
        return fig
def plot_BIC(all_results, size=4.6, dpi=300, ext='png', plot_dir=None):
    """ Plots BIC and SABIC curves
    
    Args:
        all_results: a dimensional structure all_results object
        dpi: the final dpi for the image
        ext: the extension for the saved figure
        plot_dir: the directory to save the figure. If none, do not save
    """
    all_colors = [sns.color_palette('Blues_d',3)[0:3],
              sns.color_palette('Reds_d',3)[0:3],
              sns.color_palette('Greens_d',3)[0:3],
              sns.color_palette('Oranges_d',3)[0:3]]
    height= size*.75/len(all_results)
    with sns.axes_style('white'):
        fig, axes = plt.subplots(1, len(all_results), figsize=(size, height))
    for i, results in enumerate([all_results[key] for key in ['task','survey']]):
        ax1 = axes[i]
        name = results.ID.split('_')[0].title()
        EFA = results.EFA
        # Plot BIC and SABIC curves
        colors = all_colors[i]
        with sns.axes_style('white'):
            x = list(EFA.results['cscores_metric-BIC'].keys())
            # score keys
            keys = [k for k in EFA.results.keys() if 'cscores' in k]
            for key in keys:
                metric = key.split('-')[-1]
                BIC_scores = [EFA.results[key][i] for i in x]
                BIC_c = EFA.results['c_metric-%s' % metric]
                ax1.plot(x, BIC_scores,  'o-', c=colors[0], lw=size/6, label=metric,
                         markersize=height*2)
                ax1.plot(BIC_c, BIC_scores[BIC_c-1], '.', color='white',
                         markeredgecolor=colors[0], markeredgewidth=height/2, 
                         markersize=height*4)
            if i==0:
                if len(keys)>1:
                    ax1.set_ylabel('Score', fontsize=height*3)
                    leg = ax1.legend(loc='center right',
                                     fontsize=height*3, markerscale=0)
                    beautify_legend(leg, colors=colors)
                else:
                    ax1.set_ylabel(metric, fontsize=height*4)
            ax1.set_xlabel('# Factors', fontsize=height*4)
            ax1.set_xticks(x)
            ax1.set_xticklabels(x)
            ax1.tick_params(labelsize=height*2, pad=size/4, length=0)
            ax1.set_title(name, fontsize=height*4, y=1.01)
            ax1.grid(linewidth=size/8)
            [i.set_linewidth(size*.1) for i in ax1.spines.values()]
    if plot_dir is not None:
        save_figure(fig, path.join(plot_dir, 'BIC_curves.%s' % ext),
                    {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
def plot_prediction_comparison(results,
                               size=4.6,
                               change=False,
                               dpi=300,
                               ext='png',
                               plot_dir=None):
    colors = ref_colors[results.ID.split('_')[0]]
    R2s = {}
    for EFA in [False, True]:
        predictions = results.get_prediction_files(EFA=EFA,
                                                   change=change,
                                                   shuffle=False)
        predictions = sorted(predictions, key=path.getmtime)
        classifiers = np.unique([i.split('_')[-2] for i in predictions])
        # get last prediction file of each type
        for classifier in classifiers:
            filey = [i for i in predictions if classifier in i][-1]
            prediction_object = pickle.load(open(filey, 'rb'))['data']
            R2 = [i['scores_cv'][0]['R2'] for i in prediction_object.values()]
            R2 = np.nan_to_num(R2)
            feature = 'EFA' if EFA else 'IDM'
            R2s[feature + '_' + classifier] = R2
    if len(R2s) == 0:
        print('No prediction objects found')
        return
    R2s = pd.DataFrame(R2s).melt(var_name='Classifier', value_name='R2')
    R2s['Feature'], R2s['Classifier'] = R2s.Classifier.str.split('_', 1).str
    f = plt.figure(figsize=(size, size * .62))
    sns.barplot(x='Classifier',
                y='R2',
                data=R2s,
                hue='Feature',
                palette=colors[:2],
                errwidth=size / 5)
    ax = plt.gca()
    ax.tick_params(axis='y', labelsize=size * 1.8)
    ax.tick_params(axis='x', labelsize=size * 1.8)
    leg = ax.legend(fontsize=size * 2, loc='upper right')
    beautify_legend(leg, colors[:2])
    plt.xlabel('Classifier', fontsize=size * 2.2, labelpad=size / 2)
    plt.ylabel('R2', fontsize=size * 2.2, labelpad=size / 2)
    plt.title('Comparison of Prediction Methods', fontsize=size * 2.5, y=1.05)

    if plot_dir is not None:
        filename = 'prediction_comparison.%s' % ext
        save_figure(f, path.join(plot_dir, filename), {
            'bbox_inches': 'tight',
            'dpi': dpi
        })
        plt.close()
예제 #7
0
def plot_vars(tasks, contrasts, axes=None, xlabel='Value', standardize=False):
    colors = sns.hls_palette(4)
    desat_colors = [sns.desaturate(c, .5) for c in colors]
    for i, task in enumerate(tasks):
        subset = contrasts.filter(regex='^' + task)
        if subset.shape[1] != 0:
            if standardize:
                subset = subset / subset.std()
            subset.columns = [c.split('.')[1] for c in subset.columns]
            subset.columns = format_variable_names(subset.columns)
            # add mean value to columns
            means = subset.mean()
            subset.columns = [
                subset.columns[i] + ': %s' % format_num(means.iloc[i])
                for i in range(len(means))
            ]
            subset = subset.melt(var_name='Variable', value_name='Value')

            sns.stripplot(x='Value',
                          y='Variable',
                          hue='Variable',
                          ax=axes[i],
                          data=subset,
                          palette=desat_colors,
                          jitter=True,
                          alpha=.75)
            # plot central tendency
            N = len(means)
            axes[i].scatter(means,
                            range(N),
                            s=200,
                            c=colors[:N],
                            edgecolors='white',
                            linewidths=2,
                            zorder=3)

            # add legend
            leg = axes[i].get_legend()
            leg.set_title('')
            beautify_legend(leg, colors=colors, fontsize=14)
            # change axes
            max_val = subset.Value.abs().max()
            axes[i].set_xlim(-max_val, max_val)
            axes[i].set_xlabel(xlabel, fontsize=16)
            axes[i].set_ylabel('')
            axes[i].set_yticklabels('')
        axes[i].set_title(format_variable_names([task])[0].title(),
                          fontsize=20)
    plt.subplots_adjust(hspace=.3)
def plot_DDM(results,
             c,
             rotate='oblimin',
             dpi=300,
             figsize=(20, 8),
             ext='png',
             plot_dir=None):
    EFA = results.EFA
    loading = abs(EFA.get_loading(c, rotate=rotate))
    cats = []
    for i in loading.index:
        if 'drift' in i:
            cats.append('Drift')
        elif 'thresh' in i:
            cats.append('Thresh')
        elif 'non_decision' in i:
            cats.append('Non-Decision')
        else:
            cats.append('Misc')
    loading.insert(0, 'category', cats)
    # plotting
    colors = sns.color_palette("Set1", 8, .75)
    color_map = {v: i for i, v in enumerate(loading.category.unique())}

    fig = plt.figure(figsize=(12, 12))
    ax = fig.add_subplot(111, projection='3d')
    for name, group in loading.groupby('category'):
        ax.scatter(group['Speeded IP'],
                   group['Caution'],
                   group['Perc / Resp'],
                   marker='o',
                   s=150,
                   c=colors[color_map[name]],
                   label=name)
    ax.tick_params(labelsize=0, length=0)
    ax.set_xlabel('Speeded IP', fontsize=20)
    ax.set_ylabel('Caution', fontsize=20)
    ax.set_zlabel('Perc / Resp', fontsize=20)
    ax.view_init(30, 30)
    leg = plt.legend(fontsize=20)
    beautify_legend(leg, colors)
    if plot_dir is not None:
        fig.savefig(path.join(plot_dir, 'DDM_factors.%s' % ext),
                    bbox_inches='tight',
                    dpi=dpi)
        plt.close()
def plot_BIC_SABIC(results, size=2.3, dpi=300, ext='png', plot_dir=None):
    """ Plots BIC and SABIC curves
    
    Args:
        results: a dimensional structure results object
        dpi: the final dpi for the image
        ext: the extension for the saved figure
        plot_dir: the directory to save the figure. If none, do not save
    """
    EFA = results.EFA
    # Plot BIC and SABIC curves
    colors = ['c', 'm']
    with sns.axes_style('white'):
        fig, ax1 = plt.subplots(1,1, figsize=(size, size*.75))
        x = sorted(list(EFA.results['cscores_metric-BIC'].keys()))
        # BIC
        BIC_scores = [EFA.results['cscores_metric-BIC'][i] for i in x]
        BIC_c = EFA.results['c_metric-BIC']
        ax1.plot(x, BIC_scores,  'o-', c=colors[0], lw=3, label='BIC',
                 markersize=size*2)
        ax1.set_xlabel('# Factors', fontsize=size*3)
        ax1.set_ylabel('BIC', fontsize=size*3)
        ax1.plot(BIC_c, BIC_scores[BIC_c-1], '.', color='white',
                 markeredgecolor=colors[0], markeredgewidth=size/2, 
                 markersize=size*4)
        ax1.tick_params(labelsize=size*2)
        if 'cscores_metric-SABIC' in EFA.results.keys():
            # SABIC
            ax2 = ax1.twinx()
            SABIC_scores = list(EFA.results['cscores_metric-SABIC'].values())
            SABIC_c = EFA.results['c_metric-SABIC']
            ax2.plot(x, SABIC_scores, c=colors[1], lw=3, label='SABIC',
                     markersize=size*2)
            ax2.set_ylabel('SABIC', fontsize=size*4)
            ax2.plot(SABIC_c, SABIC_scores[SABIC_c],'k.',
                 markeredgecolor=colors[0], markeredgewidth=size/2, 
                 markersize=size*4)
            # set up legend
            ax1.plot(np.nan, c='m', lw=3, label='SABIC')
            leg = ax1.legend(loc='right center')
            beautify_legend(leg, colors=colors)
        if plot_dir is not None:
            save_figure(fig, path.join(plot_dir, 'BIC_SABIC_curves.%s' % ext),
                        {'bbox_inches': 'tight', 'dpi': dpi})
            plt.close()
def plot_DDM(results, c, rotate='oblimin', 
             dpi=300, figsize=(20,8), ext='png', plot_dir=None): 
    EFA = results.EFA
    loading = abs(EFA.get_loading(c, rotate=rotate))
    cats = []
    for i in loading.index:
        if 'drift' in i:
            cats.append('Drift')
        elif 'thresh' in i:
            cats.append('Thresh')
        elif 'non_decision' in i:
            cats.append('Non-Decision')
        else:
            cats.append('Misc')
    loading.insert(0,'category', cats)
    # plotting
    colors = sns.color_palette("Set1", 8, .75)
    color_map = {v:i for i,v in enumerate(loading.category.unique())}
    
    fig = plt.figure(figsize=(12,12))
    ax = fig.add_subplot(111, projection='3d')
    for name, group in loading.groupby('category'):
        ax.scatter(group['Speeded IP'],
                   group['Caution'],
                   group['Perc / Resp'],
                   marker='o',
                   s=150,
                   c=colors[color_map[name]],
                   label=name)
    ax.tick_params(labelsize=0, length=0)
    ax.set_xlabel('Speeded IP', fontsize=20)
    ax.set_ylabel('Caution', fontsize=20)
    ax.set_zlabel('Perc / Resp', fontsize=20)
    ax.view_init(30, 30)
    leg = plt.legend(fontsize=20)
    beautify_legend(leg, colors)      
    if plot_dir is not None:
        fig.savefig(path.join(plot_dir, 'DDM_factors.%s' % ext), 
                  bbox_inches='tight', dpi=dpi)
        plt.close()
def plot_prediction_comparison(results, size=4.6, change=False,
                               dpi=300, ext='png', plot_dir=None):
    colors = ref_colors[results.ID.split('_')[0]]
    R2s = {}
    for EFA in [False, True]:
        predictions = results.get_prediction_files(EFA=EFA, change=change, 
                                                   shuffle=False)
        predictions = sorted(predictions, key = path.getmtime)
        classifiers = np.unique([i.split('_')[-2] for i in predictions])
        # get last prediction file of each type
        for classifier in classifiers:
            filey = [i for i in predictions if classifier in i][-1]
            prediction_object = pickle.load(open(filey, 'rb'))['data']
            R2 = [i['scores_cv'][0]['R2'] for i in prediction_object.values()]
            R2 = np.nan_to_num(R2)
            feature = 'EFA' if EFA else 'IDM'
            R2s[feature+'_'+classifier] = R2
    if len(R2s) == 0:
        print('No prediction objects found')
        return
    R2s = pd.DataFrame(R2s).melt(var_name='Classifier', value_name='R2')
    R2s['Feature'], R2s['Classifier'] = R2s.Classifier.str.split('_', 1).str
    f = plt.figure(figsize=(size, size*.62))
    sns.barplot(x='Classifier', y='R2', data=R2s, hue='Feature',
                palette=colors[:2], errwidth=size/5)
    ax = plt.gca()
    ax.tick_params(axis='y', labelsize=size*1.8)
    ax.tick_params(axis='x', labelsize=size*1.8)
    leg = ax.legend(fontsize=size*2, loc='upper right')
    beautify_legend(leg, colors[:2])
    plt.xlabel('Classifier', fontsize=size*2.2, labelpad=size/2)
    plt.ylabel('R2', fontsize=size*2.2, labelpad=size/2)
    plt.title('Comparison of Prediction Methods', fontsize=size*2.5, y=1.05)
    
    if plot_dir is not None:
        filename = 'prediction_comparison.%s' % ext
        save_figure(f, path.join(plot_dir, filename), 
                    {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
def plot_prediction(predictions,
                    shuffled_predictions,
                    target_order=None,
                    metric='R2',
                    size=4.6,
                    dpi=300,
                    filename=None):
    """ Plots predictions resulting from "run_prediction" function
    
    Args:
        predictions: dictionary of run_prediction results
        shuffled_predictions: dictionary of run_prediction shuffled results
        target_order: (optional) a list of targets to order the plot
        metric: which metric from the output of run_prediction to use
        size: figure size
        dpi: dpi to use for saving
        ext: extension to use for saving (e.g., pdf)
        filename: if provided, save to this location
    """
    colors = sns.color_palette('Blues_d', 5)
    basefont = max(size, 5)
    sns.set_style('white')
    if target_order is None:
        target_order = predictions.keys()
    prediction_keys = predictions.keys()
    # get prediction success
    # plot
    shuffled_grey = [.3, .3, .3, .3]
    # plot variables
    figsize = (size, size * .75)
    fig = plt.figure(figsize=figsize)
    # plot bars
    width = 1 / (len(prediction_keys) + 1)
    ax1 = fig.add_axes([0, 0, 1, .5])
    for predictor_i, key in enumerate(prediction_keys):
        prediction = predictions[key]
        shuffled_prediction = shuffled_predictions[key]
        r2s = [[k, prediction[k]['scores_cv'][0][metric]]
               for k in target_order]
        # get shuffled values
        shuffled_r2s = []
        for i, k in enumerate(target_order):
            # normalize r2s to significance
            R2s = [i[metric] for i in shuffled_prediction[k]['scores_cv']]
            R2_95 = np.percentile(R2s, 95)
            shuffled_r2s.append((k, R2_95))
        # convert nans to 0
        r2s = [(i, k) if k == k else (i, 0) for i, k in r2s]
        shuffled_r2s = [(i, k) if k == k else (i, 0) for i, k in shuffled_r2s]

        ind = np.arange(len(r2s)) - (width * (len(prediction_keys) / 2 - 1))
        ax1.bar(ind + width * predictor_i, [i[1] for i in r2s],
                width,
                label='%s Prediction' % ' '.join(key.title().split('_')),
                linewidth=0,
                color=colors[predictor_i])
        # plot shuffled values above
        if predictor_i == len(prediction_keys) - 1:
            shuffled_label = '95% shuffled prediction'
        else:
            shuffled_label = None
        ax1.bar(ind + width * predictor_i, [i[1] for i in shuffled_r2s],
                width,
                color=shuffled_grey,
                linewidth=0,
                label=shuffled_label)

    ax1.set_xticks(np.arange(0, len(r2s)) + width / 2)
    ax1.set_xticklabels(['\n'.join(i[0].split()) for i in r2s],
                        rotation=90,
                        fontsize=basefont * .75,
                        ha='center')
    ax1.tick_params(axis='y', labelsize=size * 1.2)
    ax1.tick_params(length=size / 2,
                    width=size / 10,
                    pad=size / 2,
                    bottom=True,
                    left=True)
    xlow, xhigh = ax1.get_xlim()
    if metric == 'R2':
        ax1.set_ylabel(r'$R^2$', fontsize=basefont * 1.5, labelpad=size * 1.5)
    else:
        ax1.set_ylabel(metric, fontsize=basefont * 1.5, labelpad=size * 1.5)
    # add a legend
    leg = ax1.legend(fontsize=basefont * 1.4,
                     loc='upper right',
                     bbox_to_anchor=(1.3, 1.1),
                     frameon=True,
                     handlelength=0,
                     handletextpad=0,
                     framealpha=1)
    beautify_legend(leg, colors[:len(predictions)] + [shuffled_grey])
    # draw grid
    ax1.set_axisbelow(True)
    plt.grid(axis='y', linestyle='dotted', linewidth=size / 6)
    plt.setp(list(ax1.spines.values()), linewidth=size / 10)
    if filename is not None:
        save_figure(fig, filename, {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
    else:
        return fig
def plot_prediction(predictions, comparison_predictions, 
                    colors=None, EFA=None, comparison_label=None,
                    target_order=None,  metric='R2', size=4.6,  
                    dpi=300, filename=None):
    if colors is None:
        colors = [sns.color_palette('Purples_d', 4)[i] for i in [1,3]]
    if comparison_label is None:
        comparison_label = '95% shuffled prediction'
    basefont = max(size, 5)
    sns.set_style('white')
    if target_order is None:
        target_order = predictions.keys()
    # get prediction success
    r2s = [[k,predictions[k]['scores_cv'][0][metric]] for k in target_order]
    insample_r2s = [[k, predictions[k]['scores_insample'][0][metric]] for k in target_order]
    # get shuffled values
    shuffled_r2s = []
    insample_shuffled_r2s = []
    for i, k in enumerate(target_order):
        # normalize r2s to significance
        R2s = [i[metric] for i in comparison_predictions[k]['scores_cv']]
        R2_95 = np.percentile(R2s, 95)
        shuffled_r2s.append((k,R2_95))
        # and insample
        R2s = [i[metric] for i in comparison_predictions[k]['scores_insample']]
        R2_95 = np.percentile(R2s, 95)
        insample_shuffled_r2s.append((k,R2_95))
        
    # convert nans to 0
    r2s = [(i, k) if k==k else (i,0) for i, k in r2s]
    insample_r2s = [(i, k) if k==k else (i,0) for i, k in insample_r2s]
    shuffled_r2s = [(i, k) if k==k else (i,0) for i, k in shuffled_r2s]
    
    # plot
    shuffled_grey = [.3,.3,.3]
    # plot variables
    figsize = (size, size*.75)
    fig = plt.figure(figsize=figsize)
    # plot bars
    ind = np.arange(len(r2s))
    width=.25
    ax1 = fig.add_axes([0,0,1,.5]) 
    ax1.bar(ind, [i[1] for i in r2s], width, 
            label='Cross-validated prediction', color=colors[0])
    ax1.bar(ind+width, [i[1] for i in insample_r2s], width, 
            label='Insample prediction', color=colors[1])
    # plot shuffled values above
    ax1.bar(ind, [i[1] for i in shuffled_r2s], width, 
             color='none', edgecolor=shuffled_grey, 
            linewidth=size/10, linestyle='--', label=comparison_label)
    ax1.bar(ind+width, [i[1] for i in insample_shuffled_r2s], width, 
            color='none', edgecolor=shuffled_grey, 
            linewidth=size/10, linestyle='--')
    
    ax1.set_xticks(np.arange(0,len(r2s))+width/2)
    ax1.set_xticklabels([i[0] for i in r2s], rotation=15, fontsize=basefont*1.4)
    ax1.tick_params(axis='y', labelsize=size*1.2)
    ax1.tick_params(length=size/4, width=size/10, pad=size/2, left=True, bottom=True)
    xlow, xhigh = ax1.get_xlim()
    if metric == 'R2':
        ax1.set_ylabel(r'$R^2$', fontsize=basefont*1.5, labelpad=size*1.5)
    else:
        ax1.set_ylabel(metric, fontsize=basefont*1.5, labelpad=size*1.5)
    # add a legend
    leg = ax1.legend(fontsize=basefont*1.4, loc='upper left', framealpha=1,
                     frameon=True, handlelength=0, handletextpad=0)
    leg.get_frame().set_linewidth(size/10)
    beautify_legend(leg, colors[:2]+[shuffled_grey])
    # change y extents
    ylim = ax1.get_ylim()
    r2_max = max(max(r2s, key=lambda x: x[1])[1],
                 max(insample_r2s, key=lambda x: x[1])[1])
    ymax = r2_max*1.5
    ax1.set_ylim(ylim[0], ymax)
    # change yticks
    if ymax<.15:
        ax1.set_ylim(ylim[0], .15)
        ax1.yaxis.set_major_locator(ticker.MultipleLocator(.025))
    else:
        ax1.yaxis.set_major_locator(ticker.MultipleLocator(.05))
        ax1.set_yticks(np.append([0, .025, .05, .075, .1, .125], np.arange(.15, .45, .05)))
    # draw grid
    ax1.set_axisbelow(True)
    plt.grid(axis='y', linestyle='dotted', linewidth=size/6)
    plt.setp(list(ax1.spines.values()), linewidth=size/10)
    # Plot Polar Plots for importances
    if EFA is not None:
        reorder_vec = EFA.get_factor_reorder(EFA.results['num_factors'])
        reorder_fun = lambda x: [x[i] for i in reorder_vec]
        # get importances
        vals = [predictions[i] for i in target_order]
        importances = [(reorder_fun(i['predvars']), 
                        reorder_fun(i['importances'][0])) for i in vals]
        # plot
        axes=[]
        N = len(importances)
        best_predictors = sorted(enumerate(r2s), key = lambda x: x[1][1])
        #if plot_heights is None:
        ylim = ax1.get_ylim(); yrange = np.sum(np.abs(ylim))
        zero_place = abs(ylim[0])/yrange
        plot_heights = [int(r2s[i][1]>0)
                        *(max(r2s[i][1],
                              insample_r2s[i][1],
                              shuffled_r2s[i][1],
                              insample_shuffled_r2s[i][1])/yrange)
                        for i, k in enumerate(target_order)]
        plot_heights = [(h+zero_place+.02)*.5 for h in plot_heights]
        # mask heights
        plot_heights = [plot_heights[i] if r2s[i][1]>max(shuffled_r2s[i][1],0) else np.nan
                        for i in range(len(plot_heights))]
        plot_x = (ax1.get_xticks()-xlow)/(xhigh-xlow)-(1/N/2)
        for i, importance in enumerate(importances):
            if pd.isnull(plot_heights[i]):
                continue
            axes.append(fig.add_axes([plot_x[i], plot_heights[i], 1/N,1/N], projection='polar'))
            color = colors[0]
            visualize_importance(importance, axes[-1],
                                 yticklabels=False, xticklabels=False,
                                 label_size=figsize[1]*1,
                                 color=color,
                                 axes_linewidth=size/10)
        # plot top 2 predictions, labeled  
        if best_predictors[-1][0] < best_predictors[-2][0]:
            locs = [.32, .68]
        else:
            locs = [.68, .32]
        label_importance = importances[best_predictors[-1][0]]
        # write abbreviation key
        pad = 0
        text = [(l, shortened_factors.get(l, None)) for l in label_importance[0]] # for abbeviations text
        if len([True for t in text if t[1] is not None]) > 0:
            pad = .05
            text_ax = fig.add_axes([.8,.56,.1,.34]) 
            text_ax.tick_params(labelleft=False, left=False, 
                                labelbottom=False, bottom=False)
            for spine in ['top','right','bottom','left']:
                text_ax.spines[spine].set_visible(False)
            for i, (val, abr) in enumerate(text):
                text_ax.text(0, i/len(text), abr+':', fontsize=size*1.2)
                text_ax.text(.5, i/len(text), val, fontsize=size*1.2)
                
        ratio = figsize[1]/figsize[0]
        axes.append(fig.add_axes([locs[0]-.2*ratio-pad,.56,.3*ratio,.3], projection='polar'))
        visualize_importance(label_importance, axes[-1], yticklabels=False,
                             xticklabels=True,
                             label_size=max(figsize[1]*1.5, 5),
                             label_scale=.22,
                             title=best_predictors[-1][1][0],
                             color=colors[0],
                             axes_linewidth=size/10)
        # 2nd top
        label_importance = importances[best_predictors[-2][0]]
        ratio = figsize[1]/figsize[0]
        axes.append(fig.add_axes([locs[1]-.2*ratio-pad,.56,.3*ratio,.3], projection='polar'))
        visualize_importance(label_importance, axes[-1], yticklabels=False,
                             xticklabels=True,
                             label_size=max(figsize[1]*1.5, 5),
                             label_scale=.22,
                             title=best_predictors[-2][1][0],
                             color=colors[0],
                             axes_linewidth=size/10)
    if filename is not None:
        save_figure(fig, filename, 
            {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
예제 #14
0
sns.barplot(x='feedback_last',
            y='switch',
            hue='stage_transition_last',
            data=plot_df,
            order=['Rewarded', 'Unrewarded'],
            hue_order=['Common', 'Rare'],
            palette=colors,
            ax=axes[0])
axes[0].set_xlabel('')
axes[0].set_ylabel('Stay Probability', fontsize=24)
axes[0].set_title('Two Step Task', y=1.04, fontsize=30)
axes[0].set_ylim([.5, 1])
axes[0].tick_params(labelsize=20)
leg = axes[0].get_legend()
leg.set_title('')
beautify_legend(leg, colors=colors, fontsize=20)

#shift
sns.pointplot('trials_since_switch', 'correct', data=shift_df, ax=axes[1])
axes[1].set_xticks(range(0, 25, 5))
axes[1].set_xticklabels(range(0, 25, 5))
axes[1].set_xlabel('Trials After Change-Point', fontsize=24)
axes[1].set_ylabel('Percent Correct', fontsize=24)
axes[1].set_title('Shift Task', y=1.04, fontsize=30)
axes[1].tick_params(labelsize=20)
save_dir = path.join(base_dir, 'Results', 'replication', 'Plots',
                     'successful_learning_tasks.%s' % ext)
f.savefig(save_dir, dpi=300, bbox_inches='tight')
plt.close()

# *************************************************************************
def plot_corr_hist(all_results, reps=100, size=4.6, 
                   dpi=300, ext='png', plot_dir=None):
    colors = sns.color_palette('Blues_d',3)[0:2] + sns.color_palette('Reds_d',2)[:1]
    survey_corr = abs(all_results['survey'].data.corr())
    task_corr = abs(all_results['task'].data.corr())
    all_data = pd.concat([all_results['task'].data, all_results['survey'].data], axis=1)
    datasets = [('survey', all_results['survey'].data), 
                ('task', all_results['task'].data), 
                ('all', all_data)]
    # get cross corr
    cross_corr = abs(all_data.corr()).loc[survey_corr.columns,
                                                    task_corr.columns]
    
    plot_elements = [(extract_tril(survey_corr.values,-1), 'Within Surveys'),
                     (extract_tril(task_corr.values,-1), 'Within Tasks'),
                     (cross_corr.values.flatten(), 'Surveys x Tasks')]
    
    # get shuffled 95% correlation
    shuffled_95 = []
    for label, df in datasets:
        shuffled_corr = np.array([])
        for _ in range(reps):
            # create shuffled
            shuffled = df.copy()
            for i in shuffled:
                shuffle_vec = shuffled[i].sample(len(shuffled)).tolist()
                shuffled.loc[:,i] = shuffle_vec
            if label == 'all':
                shuffled_corr = abs(shuffled.corr()).loc[survey_corr.columns,
                                                    task_corr.columns]
            else:
                shuffled_corr = abs(shuffled.corr())
            np.append(shuffled_corr, extract_tril(shuffled_corr.values,-1))
        shuffled_95.append(np.percentile(shuffled_corr,95))
    
    # get cross_validated r2
    average_r2 = {}
    for (slabel, source), (tlabel, target) in product(datasets[:-1], repeat=2):
        scores = []
        for var, values in target.iteritems():
            if var in source.columns:
                predictors = source.drop(var, axis=1)
            else:
                predictors = source
            lr = RidgeCV()  
            cv_score = np.mean(cross_val_score(lr, predictors, values, cv=10))
            scores.append(cv_score)
        average_r2[(slabel, tlabel)] = np.mean(scores)

                
    # bring everything together
    plot_elements = [(extract_tril(survey_corr.values,-1), 'Within Surveys', 
                      average_r2[('survey','survey')]),
                     (extract_tril(task_corr.values,-1), 'Within Tasks',
                      average_r2[('task','task')]),
                     (cross_corr.values.flatten(), 'Surveys x Tasks',
                      average_r2[('survey', 'task')])]
    
    with sns.axes_style('white'):
        f, axes = plt.subplots(1,3, figsize=(10,4))
        plt.subplots_adjust(wspace=.3)
        for i, (corr, label, r2) in enumerate(plot_elements):
            #h = axes[i].hist(corr, normed=True, color=colors[i], 
            #         bins=12, label=label, rwidth=1, alpha=.4)
            sns.kdeplot(corr, ax=axes[i], color=colors[i], shade=True,
                        label=label, linewidth=3)
            axes[i].text(.4, axes[i].get_ylim()[1]*.5, 'CV-R2: {0:.2f}'.format(r2))
        for i, ax in enumerate(axes):
            ax.vlines(shuffled_95[i], *ax.get_ylim(), color=[.2,.2,.2], 
                      linewidth=2, linestyle='dashed', zorder=10)
            ax.set_xlim(0,1)
            ax.set_ylim(0, ax.get_ylim()[1])
            ax.set_xticks([0,.5,1])
            ax.set_xticklabels([0,.5,1], fontsize=16)
            ax.set_yticks([])
            ax.spines['right'].set_visible(False)
            #ax.spines['left'].set_visible(False)
            ax.spines['top'].set_visible(False)
            leg=ax.legend(fontsize=14, loc='upper center')
            beautify_legend(leg, [colors[i]])
        axes[1].set_xlabel('Pearson Correlation', fontsize=20, labelpad=10)
        axes[0].set_ylabel('Normalized Density', fontsize=20, labelpad=10)
    
    # save
    if plot_dir is not None:
        # make histogram plot
        save_figure(f, path.join(plot_dir, 'within-across_correlations.%s' % ext),
                                {'bbox_inches': 'tight', 'dpi': dpi})
예제 #16
0
    int)

sns.pointplot(x='num_available_measures',
              y='corr_score',
              hue='pop_size',
              data=KNNRpartial_var_summary,
              palette=colors,
              ax=axes[0],
              ci=None,
              scale=1.4)
leg = axes[0].legend(loc='best',
                     frameon=False,
                     handlelength=0,
                     handletextpad=0,
                     fontsize=size * 1.5)
beautify_legend(leg, colors=colors)
leg.get_title().set_fontsize(size * 1.5)

axes[0].set_ylabel('Reconstruction Score', fontsize=size * 3)
axes[0].set_xlabel('# of Measures', fontsize=size * 2)
axes[0].set_title('KNNR with Random Subset', fontsize=size * 3)
axes[0].tick_params(width=2, length=2, labelsize=size * 1.8)

# efficiency subset
closest_files = glob(
    path.join(ontology_results_dir, 'KNNRclosest_correlation_summary.pkl'))
closest_summary = pd.read_pickle(closest_files[0])
sns.pointplot(x='num_available_measures',
              y='mean',
              hue='pop_size',
              data=plot_df,
def plot_prediction(predictions,
                    comparison_predictions,
                    colors=None,
                    EFA=None,
                    comparison_label=None,
                    target_order=None,
                    metric='R2',
                    size=4.6,
                    dpi=300,
                    filename=None):
    if colors is None:
        colors = [sns.color_palette('Purples_d', 4)[i] for i in [1, 3]]
    if comparison_label is None:
        comparison_label = '95% shuffled prediction'
    basefont = max(size, 5)
    sns.set_style('white')
    if target_order is None:
        target_order = predictions.keys()
    # get prediction success
    r2s = [[k, predictions[k]['scores_cv'][0][metric]] for k in target_order]
    insample_r2s = [[k, predictions[k]['scores_insample'][0][metric]]
                    for k in target_order]
    # get shuffled values
    shuffled_r2s = []
    insample_shuffled_r2s = []
    for i, k in enumerate(target_order):
        # normalize r2s to significance
        R2s = [i[metric] for i in comparison_predictions[k]['scores_cv']]
        R2_95 = np.percentile(R2s, 95)
        shuffled_r2s.append((k, R2_95))
        # and insample
        R2s = [i[metric] for i in comparison_predictions[k]['scores_insample']]
        R2_95 = np.percentile(R2s, 95)
        insample_shuffled_r2s.append((k, R2_95))

    # convert nans to 0
    r2s = [(i, k) if k == k else (i, 0) for i, k in r2s]
    insample_r2s = [(i, k) if k == k else (i, 0) for i, k in insample_r2s]
    shuffled_r2s = [(i, k) if k == k else (i, 0) for i, k in shuffled_r2s]

    # plot
    shuffled_grey = [.3, .3, .3]
    # plot variables
    figsize = (size, size * .75)
    fig = plt.figure(figsize=figsize)
    # plot bars
    ind = np.arange(len(r2s))
    width = .25
    ax1 = fig.add_axes([0, 0, 1, .5])
    ax1.bar(ind, [i[1] for i in r2s],
            width,
            label='Cross-validated prediction',
            color=colors[0])
    ax1.bar(ind + width, [i[1] for i in insample_r2s],
            width,
            label='Insample prediction',
            color=colors[1])
    # plot shuffled values above
    ax1.bar(ind, [i[1] for i in shuffled_r2s],
            width,
            color='none',
            edgecolor=shuffled_grey,
            linewidth=size / 10,
            linestyle='--',
            label=comparison_label)
    ax1.bar(ind + width, [i[1] for i in insample_shuffled_r2s],
            width,
            color='none',
            edgecolor=shuffled_grey,
            linewidth=size / 10,
            linestyle='--')

    ax1.set_xticks(np.arange(0, len(r2s)) + width / 2)
    ax1.set_xticklabels([i[0] for i in r2s],
                        rotation=15,
                        fontsize=basefont * 1.4)
    ax1.tick_params(axis='y', labelsize=size * 1.2)
    ax1.tick_params(length=size / 4,
                    width=size / 10,
                    pad=size / 2,
                    left=True,
                    bottom=True)
    xlow, xhigh = ax1.get_xlim()
    if metric == 'R2':
        ax1.set_ylabel(r'$R^2$', fontsize=basefont * 1.5, labelpad=size * 1.5)
    else:
        ax1.set_ylabel(metric, fontsize=basefont * 1.5, labelpad=size * 1.5)
    # add a legend
    leg = ax1.legend(fontsize=basefont * 1.4,
                     loc='upper left',
                     framealpha=1,
                     frameon=True,
                     handlelength=0,
                     handletextpad=0)
    leg.get_frame().set_linewidth(size / 10)
    beautify_legend(leg, colors[:2] + [shuffled_grey])
    # change y extents
    ylim = ax1.get_ylim()
    r2_max = max(
        max(r2s, key=lambda x: x[1])[1],
        max(insample_r2s, key=lambda x: x[1])[1])
    ymax = r2_max * 1.5
    ax1.set_ylim(ylim[0], ymax)
    # change yticks
    if ymax < .15:
        ax1.set_ylim(ylim[0], .15)
        ax1.yaxis.set_major_locator(ticker.MultipleLocator(.025))
    else:
        ax1.yaxis.set_major_locator(ticker.MultipleLocator(.05))
        ax1.set_yticks(
            np.append([0, .025, .05, .075, .1, .125], np.arange(.15, .45,
                                                                .05)))
    # draw grid
    ax1.set_axisbelow(True)
    plt.grid(axis='y', linestyle='dotted', linewidth=size / 6)
    plt.setp(list(ax1.spines.values()), linewidth=size / 10)
    # Plot Polar Plots for importances
    if EFA is not None:
        reorder_vec = EFA.get_factor_reorder(EFA.results['num_factors'])
        reorder_fun = lambda x: [x[i] for i in reorder_vec]
        # get importances
        vals = [predictions[i] for i in target_order]
        importances = [(reorder_fun(i['predvars']),
                        reorder_fun(i['importances'][0])) for i in vals]
        # plot
        axes = []
        N = len(importances)
        best_predictors = sorted(enumerate(r2s), key=lambda x: x[1][1])
        #if plot_heights is None:
        ylim = ax1.get_ylim()
        yrange = np.sum(np.abs(ylim))
        zero_place = abs(ylim[0]) / yrange
        plot_heights = [
            int(r2s[i][1] > 0) *
            (max(r2s[i][1], insample_r2s[i][1], shuffled_r2s[i][1],
                 insample_shuffled_r2s[i][1]) / yrange)
            for i, k in enumerate(target_order)
        ]
        plot_heights = [(h + zero_place + .02) * .5 for h in plot_heights]
        # mask heights
        plot_heights = [
            plot_heights[i]
            if r2s[i][1] > max(shuffled_r2s[i][1], 0) else np.nan
            for i in range(len(plot_heights))
        ]
        plot_x = (ax1.get_xticks() - xlow) / (xhigh - xlow) - (1 / N / 2)
        for i, importance in enumerate(importances):
            if pd.isnull(plot_heights[i]):
                continue
            axes.append(
                fig.add_axes([plot_x[i], plot_heights[i], 1 / N, 1 / N],
                             projection='polar'))
            color = colors[0]
            visualize_importance(importance,
                                 axes[-1],
                                 yticklabels=False,
                                 xticklabels=False,
                                 label_size=figsize[1] * 1,
                                 color=color,
                                 axes_linewidth=size / 10)
        # plot top 2 predictions, labeled
        if best_predictors[-1][0] < best_predictors[-2][0]:
            locs = [.32, .68]
        else:
            locs = [.68, .32]
        label_importance = importances[best_predictors[-1][0]]
        # write abbreviation key
        pad = 0
        text = [(l, shortened_factors.get(l, None))
                for l in label_importance[0]]  # for abbeviations text
        if len([True for t in text if t[1] is not None]) > 0:
            pad = .05
            text_ax = fig.add_axes([.8, .56, .1, .34])
            text_ax.tick_params(labelleft=False,
                                left=False,
                                labelbottom=False,
                                bottom=False)
            for spine in ['top', 'right', 'bottom', 'left']:
                text_ax.spines[spine].set_visible(False)
            for i, (val, abr) in enumerate(text):
                text_ax.text(0, i / len(text), abr + ':', fontsize=size * 1.2)
                text_ax.text(.5, i / len(text), val, fontsize=size * 1.2)

        ratio = figsize[1] / figsize[0]
        axes.append(
            fig.add_axes([locs[0] - .2 * ratio - pad, .56, .3 * ratio, .3],
                         projection='polar'))
        visualize_importance(label_importance,
                             axes[-1],
                             yticklabels=False,
                             xticklabels=True,
                             label_size=max(figsize[1] * 1.5, 5),
                             label_scale=.22,
                             title=best_predictors[-1][1][0],
                             color=colors[0],
                             axes_linewidth=size / 10)
        # 2nd top
        label_importance = importances[best_predictors[-2][0]]
        ratio = figsize[1] / figsize[0]
        axes.append(
            fig.add_axes([locs[1] - .2 * ratio - pad, .56, .3 * ratio, .3],
                         projection='polar'))
        visualize_importance(label_importance,
                             axes[-1],
                             yticklabels=False,
                             xticklabels=True,
                             label_size=max(figsize[1] * 1.5, 5),
                             label_scale=.22,
                             title=best_predictors[-2][1][0],
                             color=colors[0],
                             axes_linewidth=size / 10)
    if filename is not None:
        save_figure(fig, filename, {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
def plot_communality(results, c, rotate='oblimin', retest_threshold=.2,
                     size=4.6, dpi=300, ext='png', plot_dir=None):
    EFA = results.EFA
    communality = get_communality(EFA, rotate, c)
    # load retest data
    retest_data = get_retest_data(dataset=results.dataset.replace('Complete','Retest'))
    if retest_data is None:
        print('No retest data found for datafile: %s' % results.dataset)
        return
    
    # reorder data in line with communality
    retest_data = retest_data.loc[communality.index]
    # reformat variable names
    communality.index = format_variable_names(communality.index)
    retest_data.index = format_variable_names(retest_data.index)
    if len(retest_data) > 0:
        adjusted_communality,correlation, noise_ceiling = \
                get_adjusted_communality(communality, 
                                         retest_data,
                                         retest_threshold)
        
    # plot communality bars woo!
    if len(retest_data)>0:
        f, axes = plt.subplots(1, 3, figsize=(3*(size/10), size))
    
        plot_bar_factor(communality, axes[0], width=size/10, height=size,
                        label_rows=True,  title='Communality')
        plot_bar_factor(noise_ceiling, axes[1], width=size/10, height=size,
                        label_rows=False,  title='Test-Retest')
        plot_bar_factor(adjusted_communality, axes[2], width=size/10, height=size,
                        label_rows=False,  title='Adjusted Communality')
    else:
        f = plot_bar_factor(communality, label_rows=True, 
                            width=size/3, height=size*2, title='Communality')
    if plot_dir:
        filename = 'communality_bars-EFA%s.%s' % (c, ext)
        save_figure(f, path.join(plot_dir, filename), 
                    {'bbox_inches': 'tight', 'dpi': dpi})
        plt.close()
    
    # plot communality histogram
    if len(retest_data) > 0:
        with sns.axes_style('white'):
            colors = sns.color_palette(n_colors=2, desat=.75)
            f, ax = plt.subplots(1,1,figsize=(size,size))
            sns.kdeplot(communality, linewidth=size/4, 
                        shade=True, label='Communality', color=colors[0])
            sns.kdeplot(adjusted_communality, linewidth=size/4, 
                        shade=True, label='Adjusted Communality', color=colors[1])
            ylim = ax.get_ylim()
            ax.vlines(np.mean(communality), ylim[0], ylim[1],
                      color=colors[0], linewidth=size/4, linestyle='--')
            ax.vlines(np.mean(adjusted_communality), ylim[0], ylim[1],
                      color=colors[1], linewidth=size/4, linestyle='--')
            leg=ax.legend(fontsize=size*2, loc='upper right')
            beautify_legend(leg, colors)
            plt.xlabel('Communality', fontsize=size*2)
            plt.ylabel('Normalized Density', fontsize=size*2)
            ax.set_yticks([])
            ax.tick_params(labelsize=size)
            ax.set_ylim(0, ax.get_ylim()[1])
            ax.set_xlim(0, ax.get_xlim()[1])
            ax.spines['right'].set_visible(False)
            #ax.spines['left'].set_visible(False)
            ax.spines['top'].set_visible(False)
            # add correlation
            correlation = format_num(np.mean(correlation))
            ax.text(1.1, 1.25, 'Correlation Between Communality \nand Test-Retest: %s' % correlation,
                    size=size*2)

        if plot_dir:
            filename = 'communality_dist-EFA%s.%s' % (c, ext)
            save_figure(f, path.join(plot_dir, filename), 
                        {'bbox_inches': 'tight', 'dpi': dpi})
            plt.close()
def plot_corr_hist(all_results,
                   reps=100,
                   size=4.6,
                   dpi=300,
                   ext='png',
                   plot_dir=None):
    colors = sns.color_palette('Blues_d', 3)[0:2] + sns.color_palette(
        'Reds_d', 2)[:1]
    survey_corr = abs(all_results['survey'].data.corr())
    task_corr = abs(all_results['task'].data.corr())
    all_data = pd.concat(
        [all_results['task'].data, all_results['survey'].data], axis=1)
    datasets = [('survey', all_results['survey'].data),
                ('task', all_results['task'].data), ('all', all_data)]
    # get cross corr
    cross_corr = abs(all_data.corr()).loc[survey_corr.columns,
                                          task_corr.columns]

    plot_elements = [(extract_tril(survey_corr.values, -1), 'Within Surveys'),
                     (extract_tril(task_corr.values, -1), 'Within Tasks'),
                     (cross_corr.values.flatten(), 'Surveys x Tasks')]

    # get shuffled 95% correlation
    shuffled_95 = []
    for label, df in datasets:
        shuffled_corr = np.array([])
        for _ in range(reps):
            # create shuffled
            shuffled = df.copy()
            for i in shuffled:
                shuffle_vec = shuffled[i].sample(len(shuffled)).tolist()
                shuffled.loc[:, i] = shuffle_vec
            if label == 'all':
                shuffled_corr = abs(shuffled.corr()).loc[survey_corr.columns,
                                                         task_corr.columns]
            else:
                shuffled_corr = abs(shuffled.corr())
            np.append(shuffled_corr, extract_tril(shuffled_corr.values, -1))
        shuffled_95.append(np.percentile(shuffled_corr, 95))

    # get cross_validated r2
    average_r2 = {}
    for (slabel, source), (tlabel, target) in product(datasets[:-1], repeat=2):
        scores = []
        for var, values in target.iteritems():
            if var in source.columns:
                predictors = source.drop(var, axis=1)
            else:
                predictors = source
            lr = RidgeCV()
            cv_score = np.mean(cross_val_score(lr, predictors, values, cv=10))
            scores.append(cv_score)
        average_r2[(slabel, tlabel)] = np.mean(scores)

    # bring everything together
    plot_elements = [
        (extract_tril(survey_corr.values,
                      -1), 'Within Surveys', average_r2[('survey', 'survey')]),
        (extract_tril(task_corr.values,
                      -1), 'Within Tasks', average_r2[('task', 'task')]),
        (cross_corr.values.flatten(), 'Surveys x Tasks', average_r2[('survey',
                                                                     'task')])
    ]

    with sns.axes_style('white'):
        f, axes = plt.subplots(1, 3, figsize=(10, 4))
        plt.subplots_adjust(wspace=.3)
        for i, (corr, label, r2) in enumerate(plot_elements):
            #h = axes[i].hist(corr, normed=True, color=colors[i],
            #         bins=12, label=label, rwidth=1, alpha=.4)
            sns.kdeplot(corr,
                        ax=axes[i],
                        color=colors[i],
                        shade=True,
                        label=label,
                        linewidth=3)
            axes[i].text(.4, axes[i].get_ylim()[1] * .5,
                         'CV-R2: {0:.2f}'.format(r2))
        for i, ax in enumerate(axes):
            ax.vlines(shuffled_95[i],
                      *ax.get_ylim(),
                      color=[.2, .2, .2],
                      linewidth=2,
                      linestyle='dashed',
                      zorder=10)
            ax.set_xlim(0, 1)
            ax.set_ylim(0, ax.get_ylim()[1])
            ax.set_xticks([0, .5, 1])
            ax.set_xticklabels([0, .5, 1], fontsize=16)
            ax.set_yticks([])
            ax.spines['right'].set_visible(False)
            #ax.spines['left'].set_visible(False)
            ax.spines['top'].set_visible(False)
            leg = ax.legend(fontsize=14, loc='upper center')
            beautify_legend(leg, [colors[i]])
        axes[1].set_xlabel('Pearson Correlation', fontsize=20, labelpad=10)
        axes[0].set_ylabel('Normalized Density', fontsize=20, labelpad=10)

    # save
    if plot_dir is not None:
        # make histogram plot
        save_figure(f,
                    path.join(plot_dir, 'within-across_correlations.%s' % ext),
                    {
                        'bbox_inches': 'tight',
                        'dpi': dpi
                    })
def plot_cross_communality(all_results, rotate='oblimin', retest_threshold=.2,
                           size=4.6, dpi=300, ext='png', plot_dir=None):
    
    retest_data = None
    num_cols = 2
    num_rows = math.ceil(len(all_results.keys())/2)
    with sns.axes_style('white'):
        f, axes = plt.subplots(num_rows, num_cols, figsize=(size, size/2*num_rows))
    max_y = 0
    for i, (name, results) in enumerate(all_results.items()):
        if retest_data is None:
            # load retest data
            retest_data = get_retest_data(dataset=results.dataset.replace('Complete','Retest'))
            if retest_data is None:
                print('No retest data found for datafile: %s' % results.dataset)
        c = results.EFA.get_c()
        EFA = results.EFA
        loading = EFA.get_loading(c, rotate=rotate)
        # get communality from psych out
        fa = EFA.results['factor_tree_Rout_%s' % rotate][c]
        communality = get_attr(fa, 'communalities')
        communality = pd.Series(communality, index=loading.index)
        # alternative calculation
        #communality = (loading**2).sum(1).sort_values()
        communality.index = [i.replace('.logTr','') for i in communality.index]
        
        # reorder data in line with communality
        retest_subset= retest_data.loc[communality.index]
        # reformat variable names
        communality.index = format_variable_names(communality.index)
        retest_subset.index = format_variable_names(retest_subset.index)
        if len(retest_subset) > 0:
            # noise ceiling
            noise_ceiling = retest_subset.pearson
            # remove very low reliabilities
            if retest_threshold:
                noise_ceiling[noise_ceiling<retest_threshold]= np.nan
            # adjust
            adjusted_communality = communality/noise_ceiling
            
        # plot communality histogram
        if len(retest_subset) > 0:
            ax = axes[i]
            ax.set_title(name.title(), fontweight='bold', fontsize=size*2)
            colors = sns.color_palette(n_colors=2, desat=.75)
            sns.kdeplot(communality, linewidth=size/4, ax=ax, vertical=True,
                        shade=True, label='Communality', color=colors[0])
            sns.kdeplot(adjusted_communality, linewidth=size/4, ax=ax, vertical=True,
                        shade=True, label='Adjusted Communality', color=colors[1])
            xlim = ax.get_xlim()
            ax.hlines(np.mean(communality), xlim[0], xlim[1],
                      color=colors[0], linewidth=size/4, linestyle='--')
            ax.hlines(np.mean(adjusted_communality), xlim[0], xlim[1],
                      color=colors[1], linewidth=size/4, linestyle='--')
            ax.set_xticks([])
            ax.tick_params(labelsize=size*1.2)
            ax.set_ylim(0, ax.get_ylim()[1])
            ax.set_xlim(0, ax.get_xlim()[1])
            ax.spines['right'].set_visible(False)
            ax.spines['top'].set_visible(False)
            if (i+1) == len(all_results):
                ax.set_xlabel('Normalized Density', fontsize=size*2)
                leg=ax.legend(fontsize=size*1.5, loc='upper right',
                              bbox_to_anchor=(1.2, 1.0), 
                              handlelength=0, handletextpad=0)
                beautify_legend(leg, colors)
            elif i>=len(all_results)-2:
                ax.set_xlabel('Normalized Density', fontsize=size*2)
                ax.legend().set_visible(False)
            else:
                ax.legend().set_visible(False)
            if i%2==0:
                ax.set_ylabel('Communality', fontsize=size*2)
                ax.tick_params(labelleft=True, left=True, 
                               length=size/4, width=size/8)
            else:
                ax.tick_params(labelleft=False, left=True, 
                               length=0, width=size/8)
            # update max_x
            if ax.get_ylim()[1] > max_y:
                max_y = ax.get_ylim()[1]
            ax.grid(False)
            [i.set_linewidth(size*.1) for i in ax.spines.values()]
        for ax in axes:
            ax.set_ylim((0, max_y))
        plt.subplots_adjust(wspace=0)
                    
        if plot_dir:
            filename = 'communality_adjustment.%s' % ext
            save_figure(f, path.join(plot_dir, rotate, filename), 
                        {'bbox_inches': 'tight', 'dpi': dpi})
            plt.close()
axes[2].set_ylabel(r'Mean $\pm$ SEM reaction time', fontsize=20)
axes[2].set_xticks(range(rt_stats.shape[0]))
axes[2].set_xticklabels(['AX', 'AY', 'BX', 'BY'])
axes[2].set_xlabel('Trial Type', fontsize=20)
axes[2].grid(axis='y')

# plot literature
axes[0].errorbar(range(lit_rt_stats.shape[0]),
                 lit_acc_stats.loc[:, 'mean'],
                 yerr=lit_acc_stats.loc[:, 'std'] / (lit_N**.5),
                 color='#29A6F0',
                 linewidth=5,
                 elinewidth=3,
                 label='Lopez-Garcia et al')
leg = axes[0].legend(handlelength=0)
beautify_legend(leg, colors=['#D3244F', '#29A6F0'])
# plot reaction time
axes[2].errorbar(range(lit_rt_stats.shape[0]),
                 lit_rt_stats.loc[:, 'mean'],
                 yerr=lit_rt_stats.loc[:, 'std'] / (lit_N**.5),
                 color='#29A6F0',
                 linewidth=5,
                 elinewidth=3)
# plot comparison to literature
axes[1].scatter(lit_acc_stats.loc[:, 'mean'],
                acc_stats.loc[:, 'mean'],
                color='k')
max_val = max(max(axes[1].get_xlim()), max(axes[1].get_ylim()))
axes[1].plot([0, max_val], [0, max_val], linestyle='--', color='k')
axes[1].set_ylabel('Our Values', fontsize=20)
def plot_BIC(all_results, size=4.6, dpi=300, ext='png', plot_dir=None):
    """ Plots BIC and SABIC curves
    
    Args:
        all_results: a dimensional structure all_results object
        dpi: the final dpi for the image
        ext: the extension for the saved figure
        plot_dir: the directory to save the figure. If none, do not save
    """
    all_colors = [
        sns.color_palette('Blues_d', 3)[0:3],
        sns.color_palette('Reds_d', 3)[0:3],
        sns.color_palette('Greens_d', 3)[0:3],
        sns.color_palette('Oranges_d', 3)[0:3]
    ]
    height = size * .75 / len(all_results)
    with sns.axes_style('white'):
        fig, axes = plt.subplots(1, len(all_results), figsize=(size, height))
    for i, results in enumerate(
        [all_results[key] for key in ['task', 'survey']]):
        ax1 = axes[i]
        name = results.ID.split('_')[0].title()
        EFA = results.EFA
        # Plot BIC and SABIC curves
        colors = all_colors[i]
        with sns.axes_style('white'):
            x = list(EFA.results['cscores_metric-BIC'].keys())
            # score keys
            keys = [k for k in EFA.results.keys() if 'cscores' in k]
            for key in keys:
                metric = key.split('-')[-1]
                BIC_scores = [EFA.results[key][i] for i in x]
                BIC_c = EFA.results['c_metric-%s' % metric]
                ax1.plot(x,
                         BIC_scores,
                         'o-',
                         c=colors[0],
                         lw=size / 6,
                         label=metric,
                         markersize=height * 2)
                ax1.plot(BIC_c,
                         BIC_scores[BIC_c - 1],
                         '.',
                         color='white',
                         markeredgecolor=colors[0],
                         markeredgewidth=height / 2,
                         markersize=height * 4)
            if i == 0:
                if len(keys) > 1:
                    ax1.set_ylabel('Score', fontsize=height * 3)
                    leg = ax1.legend(loc='center right',
                                     fontsize=height * 3,
                                     markerscale=0)
                    beautify_legend(leg, colors=colors)
                else:
                    ax1.set_ylabel(metric, fontsize=height * 4)
            ax1.set_xlabel('# Factors', fontsize=height * 4)
            ax1.set_xticks(x)
            ax1.set_xticklabels(x)
            ax1.tick_params(labelsize=height * 2, pad=size / 4, length=0)
            ax1.set_title(name, fontsize=height * 4, y=1.01)
            ax1.grid(linewidth=size / 8)
            [i.set_linewidth(size * .1) for i in ax1.spines.values()]
    if plot_dir is not None:
        save_figure(fig, path.join(plot_dir, 'BIC_curves.%s' % ext), {
            'bbox_inches': 'tight',
            'dpi': dpi
        })
        plt.close()
예제 #23
0
def plot_BIC_SABIC(results, size=2.3, dpi=300, ext='png', plot_dir=None):
    """ Plots BIC and SABIC curves
    
    Args:
        results: a dimensional structure results object
        dpi: the final dpi for the image
        ext: the extension for the saved figure
        plot_dir: the directory to save the figure. If none, do not save
    """
    EFA = results.EFA
    # Plot BIC and SABIC curves
    colors = ['c', 'm']
    with sns.axes_style('white'):
        fig, ax1 = plt.subplots(1, 1, figsize=(size, size * .75))
        x = sorted(list(EFA.results['cscores_metric-BIC'].keys()))
        # BIC
        BIC_scores = [EFA.results['cscores_metric-BIC'][i] for i in x]
        BIC_c = EFA.results['c_metric-BIC']
        ax1.plot(x,
                 BIC_scores,
                 'o-',
                 c=colors[0],
                 lw=3,
                 label='BIC',
                 markersize=size * 2)
        ax1.set_xlabel('# Factors', fontsize=size * 3)
        ax1.set_ylabel('BIC', fontsize=size * 3)
        ax1.plot(BIC_c,
                 BIC_scores[BIC_c - 1],
                 '.',
                 color='white',
                 markeredgecolor=colors[0],
                 markeredgewidth=size / 2,
                 markersize=size * 4)
        ax1.tick_params(labelsize=size * 2)
        if 'cscores_metric-SABIC' in EFA.results.keys():
            # SABIC
            ax2 = ax1.twinx()
            SABIC_scores = list(EFA.results['cscores_metric-SABIC'].values())
            SABIC_c = EFA.results['c_metric-SABIC']
            ax2.plot(x,
                     SABIC_scores,
                     c=colors[1],
                     lw=3,
                     label='SABIC',
                     markersize=size * 2)
            ax2.set_ylabel('SABIC', fontsize=size * 4)
            ax2.plot(SABIC_c,
                     SABIC_scores[SABIC_c],
                     'k.',
                     markeredgecolor=colors[0],
                     markeredgewidth=size / 2,
                     markersize=size * 4)
            # set up legend
            ax1.plot(np.nan, c='m', lw=3, label='SABIC')
            leg = ax1.legend(loc='right center')
            beautify_legend(leg, colors=colors)
        if plot_dir is not None:
            save_figure(fig, path.join(plot_dir, 'BIC_SABIC_curves.%s' % ext),
                        {
                            'bbox_inches': 'tight',
                            'dpi': dpi
                        })
            plt.close()
f, axes = plt.subplots(1,2,figsize=(20,8))
# two stage
sns.barplot(x='feedback_last', y='switch', hue='stage_transition_last', 
            data=plot_df, 
            order=['Rewarded', 'Unrewarded'],
            hue_order=['Common', 'Rare'],
            palette=colors,
            ax=axes[0])
axes[0].set_xlabel('')
axes[0].set_ylabel('Stay Probability', fontsize=24)
axes[0].set_title('Two Step Task', y=1.04, fontsize=30)
axes[0].set_ylim([.5,1])
axes[0].tick_params(labelsize=20)
leg = axes[0].get_legend()
leg.set_title('')
beautify_legend(leg, colors=colors, fontsize=20)

#shift
sns.pointplot('trials_since_switch', 'correct', data=shift_df, ax=axes[1])
axes[1].set_xticks(range(0,25,5))
axes[1].set_xticklabels(range(0,25,5))
axes[1].set_xlabel('Trials After Change-Point', fontsize=24)
axes[1].set_ylabel('Percent Correct', fontsize= 24)
axes[1].set_title('Shift Task', y=1.04, fontsize=30)
axes[1].tick_params(labelsize=20)
save_dir = path.join(base_dir, 'Results', 'replication', 'Plots', 'successful_learning_tasks.%s' % ext)
f.savefig(save_dir, dpi=300, bbox_inches='tight')
plt.close()

# *************************************************************************
# Unsuccessful replications
def plot_cross_communality(all_results,
                           rotate='oblimin',
                           retest_threshold=.2,
                           size=4.6,
                           dpi=300,
                           ext='png',
                           plot_dir=None):

    retest_data = None
    num_cols = 2
    num_rows = math.ceil(len(all_results.keys()) / 2)
    with sns.axes_style('white'):
        f, axes = plt.subplots(num_rows,
                               num_cols,
                               figsize=(size, size / 2 * num_rows))
    max_y = 0
    for i, (name, results) in enumerate(all_results.items()):
        if retest_data is None:
            # load retest data
            retest_data = get_retest_data(
                dataset=results.dataset.replace('Complete', 'Retest'))
            if retest_data is None:
                print('No retest data found for datafile: %s' %
                      results.dataset)
        c = results.EFA.get_c()
        EFA = results.EFA
        loading = EFA.get_loading(c, rotate=rotate)
        # get communality from psych out
        fa = EFA.results['factor_tree_Rout_%s' % rotate][c]
        communality = get_attr(fa, 'communalities')
        communality = pd.Series(communality, index=loading.index)
        # alternative calculation
        #communality = (loading**2).sum(1).sort_values()
        communality.index = [
            i.replace('.logTr', '') for i in communality.index
        ]

        # reorder data in line with communality
        retest_subset = retest_data.loc[communality.index]
        # reformat variable names
        communality.index = format_variable_names(communality.index)
        retest_subset.index = format_variable_names(retest_subset.index)
        if len(retest_subset) > 0:
            # noise ceiling
            noise_ceiling = retest_subset.pearson
            # remove very low reliabilities
            if retest_threshold:
                noise_ceiling[noise_ceiling < retest_threshold] = np.nan
            # adjust
            adjusted_communality = communality / noise_ceiling

        # plot communality histogram
        if len(retest_subset) > 0:
            ax = axes[i]
            ax.set_title(name.title(), fontweight='bold', fontsize=size * 2)
            colors = sns.color_palette(n_colors=2, desat=.75)
            sns.kdeplot(communality,
                        linewidth=size / 4,
                        ax=ax,
                        vertical=True,
                        shade=True,
                        label='Communality',
                        color=colors[0])
            sns.kdeplot(adjusted_communality,
                        linewidth=size / 4,
                        ax=ax,
                        vertical=True,
                        shade=True,
                        label='Adjusted Communality',
                        color=colors[1])
            xlim = ax.get_xlim()
            ax.hlines(np.mean(communality),
                      xlim[0],
                      xlim[1],
                      color=colors[0],
                      linewidth=size / 4,
                      linestyle='--')
            ax.hlines(np.mean(adjusted_communality),
                      xlim[0],
                      xlim[1],
                      color=colors[1],
                      linewidth=size / 4,
                      linestyle='--')
            ax.set_xticks([])
            ax.tick_params(labelsize=size * 1.2)
            ax.set_ylim(0, ax.get_ylim()[1])
            ax.set_xlim(0, ax.get_xlim()[1])
            ax.spines['right'].set_visible(False)
            ax.spines['top'].set_visible(False)
            if (i + 1) == len(all_results):
                ax.set_xlabel('Normalized Density', fontsize=size * 2)
                leg = ax.legend(fontsize=size * 1.5,
                                loc='upper right',
                                bbox_to_anchor=(1.2, 1.0),
                                handlelength=0,
                                handletextpad=0)
                beautify_legend(leg, colors)
            elif i >= len(all_results) - 2:
                ax.set_xlabel('Normalized Density', fontsize=size * 2)
                ax.legend().set_visible(False)
            else:
                ax.legend().set_visible(False)
            if i % 2 == 0:
                ax.set_ylabel('Communality', fontsize=size * 2)
                ax.tick_params(labelleft=True,
                               left=True,
                               length=size / 4,
                               width=size / 8)
            else:
                ax.tick_params(labelleft=False,
                               left=True,
                               length=0,
                               width=size / 8)
            # update max_x
            if ax.get_ylim()[1] > max_y:
                max_y = ax.get_ylim()[1]
            ax.grid(False)
            [i.set_linewidth(size * .1) for i in ax.spines.values()]
        for ax in axes:
            ax.set_ylim((0, max_y))
        plt.subplots_adjust(wspace=0)

        if plot_dir:
            filename = 'communality_adjustment.%s' % ext
            save_figure(f, path.join(plot_dir, rotate, filename), {
                'bbox_inches': 'tight',
                'dpi': dpi
            })
            plt.close()
예제 #26
0
def plot_communality(results,
                     c,
                     rotate='oblimin',
                     retest_threshold=.2,
                     size=4.6,
                     dpi=300,
                     ext='png',
                     plot_dir=None):
    EFA = results.EFA
    communality = get_communality(EFA, rotate, c)
    # load retest data
    retest_data = get_retest_data(
        dataset=results.dataset.replace('Complete', 'Retest'))
    if retest_data is None:
        print('No retest data found for datafile: %s' % results.dataset)
        return

    # reorder data in line with communality
    retest_data = retest_data.loc[communality.index]
    # reformat variable names
    communality.index = format_variable_names(communality.index)
    retest_data.index = format_variable_names(retest_data.index)
    if len(retest_data) > 0:
        adjusted_communality,correlation, noise_ceiling = \
                get_adjusted_communality(communality,
                                         retest_data,
                                         retest_threshold)

    # plot communality bars woo!
    if len(retest_data) > 0:
        f, axes = plt.subplots(1, 3, figsize=(3 * (size / 10), size))

        plot_bar_factor(communality,
                        axes[0],
                        width=size / 10,
                        height=size,
                        label_rows=True,
                        title='Communality')
        plot_bar_factor(noise_ceiling,
                        axes[1],
                        width=size / 10,
                        height=size,
                        label_rows=False,
                        title='Test-Retest')
        plot_bar_factor(adjusted_communality,
                        axes[2],
                        width=size / 10,
                        height=size,
                        label_rows=False,
                        title='Adjusted Communality')
    else:
        f = plot_bar_factor(communality,
                            label_rows=True,
                            width=size / 3,
                            height=size * 2,
                            title='Communality')
    if plot_dir:
        filename = 'communality_bars-EFA%s.%s' % (c, ext)
        save_figure(f, path.join(plot_dir, filename), {
            'bbox_inches': 'tight',
            'dpi': dpi
        })
        plt.close()

    # plot communality histogram
    if len(retest_data) > 0:
        with sns.axes_style('white'):
            colors = sns.color_palette(n_colors=2, desat=.75)
            f, ax = plt.subplots(1, 1, figsize=(size, size))
            sns.kdeplot(communality,
                        linewidth=size / 4,
                        shade=True,
                        label='Communality',
                        color=colors[0])
            sns.kdeplot(adjusted_communality,
                        linewidth=size / 4,
                        shade=True,
                        label='Adjusted Communality',
                        color=colors[1])
            ylim = ax.get_ylim()
            ax.vlines(np.mean(communality),
                      ylim[0],
                      ylim[1],
                      color=colors[0],
                      linewidth=size / 4,
                      linestyle='--')
            ax.vlines(np.mean(adjusted_communality),
                      ylim[0],
                      ylim[1],
                      color=colors[1],
                      linewidth=size / 4,
                      linestyle='--')
            leg = ax.legend(fontsize=size * 2, loc='upper right')
            beautify_legend(leg, colors)
            plt.xlabel('Communality', fontsize=size * 2)
            plt.ylabel('Normalized Density', fontsize=size * 2)
            ax.set_yticks([])
            ax.tick_params(labelsize=size)
            ax.set_ylim(0, ax.get_ylim()[1])
            ax.set_xlim(0, ax.get_xlim()[1])
            ax.spines['right'].set_visible(False)
            #ax.spines['left'].set_visible(False)
            ax.spines['top'].set_visible(False)
            # add correlation
            correlation = format_num(np.mean(correlation))
            ax.text(1.1,
                    1.25,
                    'Correlation Between Communality \nand Test-Retest: %s' %
                    correlation,
                    size=size * 2)

        if plot_dir:
            filename = 'communality_dist-EFA%s.%s' % (c, ext)
            save_figure(f, path.join(plot_dir, filename), {
                'bbox_inches': 'tight',
                'dpi': dpi
            })
            plt.close()