def filtered_strf_vs_resp_batch(batch, gc, stp, LN, combined, strf, save_path, good_ln=0.0, test_limit=None, stat='r_ceiling', bin_count=40): e, a, g, s, c = improved_cells_to_list(batch, gc, stp, LN, combined, good_ln=good_ln) df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined) if stat == 'r_ceiling': df = df_c else: df = df_r tags = ['either', 'neither', 'gc', 'stp', 'combined'] for cells, tag in zip([e, a, g, s, c], tags): _strf_resp_sub_batch(cells[:test_limit], df, tag, stat, batch, gc, stp, LN, combined, strf, save_path, bin_count=bin_count)
def filtered_pred_matched_batch(batch, gc, stp, LN, combined, save_path, good_ln=0.0, test_limit=None, stat='r_ceiling', replace_existing=True): e, a, g, s, c = improved_cells_to_list(batch, gc, stp, LN, combined, good_ln=good_ln) df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined) if stat == 'r_ceiling': df = df_c else: df = df_r tags = ['either', 'neither', 'gc', 'stp', 'combined'] a = list(set(a) - set(e)) for cells, tag in zip([e, a, g, s, c], tags): _sigmoid_sub_batch(cells[:test_limit], df, tag, stat, batch, gc, stp, LN, combined, save_path, replace_existing=replace_existing)
def performance_table(batch1, gc, stp, LN, combined, batch2, plot_stat='r_ceiling', height_scaling=3): # 4 tables: batch 289 and 263 all cells / improved cells df_r1, df_c1, df_e1 = get_dataframes(batch1, gc, stp, LN, combined) cellids1, under_chance1, less_LN1 = get_filtered_cellids( df_r1, df_e1, gc, stp, LN, combined) df_r2, df_c2, df_e2 = get_dataframes(batch2, gc, stp, LN, combined) cellids2, under_chance2, less_LN2 = get_filtered_cellids( df_r2, df_e2, gc, stp, LN, combined) e1, a1, _, _, _ = improved_cells_to_list(batch1, gc, stp, LN, combined) e2, a2, _, _, _ = improved_cells_to_list(batch2, gc, stp, LN, combined) if plot_stat == 'r_ceiling': df1 = df_c1 df2 = df_c2 else: df1 = df_r1 df2 = df_c2 models = [LN, gc, stp, combined] a289, a289_stats = _make_table(df1, df_e1, models, a1) a263, a263_stats = _make_table(df2, df_e2, models, a2) i289, i289_stats = _make_table(df1, df_e1, models, e1) i263, i263_stats = _make_table(df2, df_e2, models, e2) model_names = ['LN', 'GC', 'STP', 'GC+STP', 'Max(GC,STP)'] fig1, ((a1, a2), (a3, a4)) = plt.subplots(2, 2) fig2, ((a5, a6), (a7, a8)) = plt.subplots(2, 2) fig1.patch.set_visible(False) fig2.patch.set_visible(False) iters = zip([a1, a2, a3, a4], [a5, a6, a7, a8], [a289, a263, i289, i263], [a289_stats, a263_stats, i289_stats, i263_stats], [ 'Natural stimuli, all cells', 'Voc. in noise, all cells', 'Natural stimuli, nonlinear cells', 'Voc. in noise, nonlinear cells' ]) for ax1, ax2, table, stats, title in iters: ax1.axis('off') ax1.axis('tight') table1 = ax1.table(cellText=table, colLabels=model_names, rowLabels=model_names, loc='center', cellLoc='center', rowLoc='center') table1_cells = table1.properties()['child_artists'] for c1 in table1_cells: current_height1 = c1.get_height() c1.set_height(current_height1 * height_scaling) ax1.set_title(title) ax2.axis('off') ax2.axis('tight') row_labels = ['mean', 'median', 'std err'] table_text = np.empty((len(row_labels), len(model_names)), dtype='U7') for i, _ in enumerate(row_labels): for j, _ in enumerate(model_names): s = stats[j][i] text = '%.5f' % s table_text[i][j] = text table2 = ax2.table(cellText=table_text, colLabels=model_names, rowLabels=row_labels, loc='center', cellLoc='center', rowLoc='center') table2_cells = table2.properties()['child_artists'] for c2 in table2_cells: current_height2 = c2.get_height() c2.set_height(current_height2 * height_scaling) ax2.set_title(title) fig1.tight_layout() fig2.tight_layout() return fig2, fig1
def performance_bar(batch, gc, stp, LN, combined, se_filter=True, LN_filter=False, manual_cellids=None, abbr_yaxis=False, plot_stat='r_ceiling', y_adjust=0.05, manual_y=None, only_improvements=False, show_text_labels=False): ''' model1: GC model2: STP model3: LN model4: GC+STP ''' df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined) # cellids, under_chance, less_LN = get_filtered_cellids(batch, gc, stp, # LN, combined, # se_filter, # LN_filter) e, a, g, s, c = improved_cells_to_list(batch, gc, stp, LN, combined, se_filter=se_filter, LN_filter=LN_filter) cellids = a if manual_cellids is not None: # WARNING: Will override se and ratio filters even if they are set cellids = manual_cellids elif only_improvements: e, a, g, s, c = improved_cells_to_list(batch, gc, stp, LN, combined, as_lists=True) cellids = e if plot_stat == 'r_ceiling': plot_df = df_c else: plot_df = df_r n_cells = len(cellids) gc_test = plot_df[gc][cellids] stp_test = plot_df[stp][cellids] ln_test = plot_df[LN][cellids] gc_stp_test = plot_df[combined][cellids] #max_test = np.maximum(gc_test, stp_test) gc = np.median(gc_test.values) stp = np.median(stp_test.values) ln = np.median(ln_test.values) gc_stp = np.median(gc_stp_test.values) #maximum = np.median(max_test) largest = max(gc, stp, ln, gc_stp) #, maximum) colors = [model_colors[k] for k in ['LN', 'gc', 'stp', 'combined']] #, 'max']] #fig = plt.figure(figsize=(15, 12)) fig = plt.figure() plt.bar( [1, 2, 3, 4], [ln, gc, stp, gc_stp], # maximum], color=colors, edgecolor="black", linewidth=1) plt.xticks([1, 2, 3, 4, 5], ['LN', 'GC', 'STP', 'GC+STP']) #, 'Max(GC,STP)']) if abbr_yaxis: if manual_y: lower, upper = manual_y else: lower = np.floor(10 * min(gc, stp, ln, gc_stp)) / 10 upper = np.ceil(10 * max(gc, stp, ln, gc_stp)) / 10 + y_adjust plt.ylim(ymin=lower, ymax=upper) else: plt.ylim(ymax=largest * 1.4) common_kwargs = {'color': 'white', 'horizontalalignment': 'center'} if abbr_yaxis: y_text = 0.5 * (lower + min(gc, stp, ln, gc_stp)) else: y_text = 0.2 if show_text_labels: for i, m in enumerate([ln, gc, stp, gc_stp]): #, maximum]): t = plt.text(i + 1, y_text, "%0.04f" % m, **common_kwargs) #t.set_path_effects([pe.withStroke(linewidth=3, foreground='black')]) xmin, xmax = plt.xlim() plt.xlim(xmin, xmax - 0.35) ax_remove_box() plt.tight_layout() fig2 = plt.figure(figsize=text_fig) text = "Median Performance, batch: %d\, n:%d" % (batch, n_cells) plt.text(0.1, 0.5, text) return fig, fig2
def combined_vs_max(batch, gc, stp, LN, combined, se_filter=True, LN_filter=False, plot_stat='r_ceiling', legend=False, improved_only=True, exclude_low_snr=False, snr_path=None): df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined) # cellids, under_chance, less_LN = get_filtered_cellids(df_r, df_e, gc, stp, # LN, combined, # se_filter, # LN_filter) e, a, g, s, c = improved_cells_to_list(batch, gc, stp, LN, combined, se_filter=se_filter, LN_filter=LN_filter) improved = c not_improved = list(set(a) - set(c)) if exclude_low_snr: snr_df = pd.read_pickle(snr_path) med_snr = snr_df['snr'].median() high_snr = snr_df.loc[snr_df['snr'] >= med_snr] high_snr_cells = high_snr.index.values.tolist() improved = list(set(improved) & set(high_snr_cells)) not_improved = list(set(not_improved) & set(high_snr_cells)) if plot_stat == 'r_ceiling': plot_df = df_c else: plot_df = df_r gc_not = plot_df[gc][not_improved] gc_imp = plot_df[gc][improved] #gc_test_under_chance = plot_df[gc][under_chance] stp_not = plot_df[stp][not_improved] stp_imp = plot_df[stp][improved] #stp_test_under_chance = plot_df[stp][under_chance] #ln_test = plot_df[LN][cellids] gc_stp_not = plot_df[combined][not_improved] gc_stp_imp = plot_df[combined][improved] max_not = np.maximum(gc_not, stp_not) max_imp = np.maximum(gc_imp, stp_imp) #gc_stp_test_rel = gc_stp_test - ln_test #max_test_rel = np.maximum(gc_test, stp_test) - ln_test imp_T, imp_p = st.wilcoxon(gc_stp_imp, max_imp) med_combined = np.nanmedian(gc_stp_imp) med_max = np.nanmedian(max_imp) import pdb pdb.set_trace() fig1 = plt.figure() c_not = model_colors['LN'] c_imp = model_colors['max'] if not improved_only: plt.scatter(max_not, gc_stp_not, c=c_not, s=small_scatter, label='no imp.') plt.scatter(max_imp, gc_stp_imp, c=c_imp, s=big_scatter, label='sig. imp.') ax = fig1.axes[0] plt.plot(ax.get_xlim(), ax.get_xlim(), 'k--', linewidth=1, dashes=dash_spacing) if legend: plt.legend() plt.xlim(0, 1) plt.ylim(0, 1) plt.tight_layout() plt.axes().set_aspect('equal') ax_remove_box() fig2 = plt.figure(figsize=text_fig) text = ("batch: %d\n" "x: Max(GC,STP)\n" "y: GC+STP\n" "wilcoxon: T: %.4E, p: %.4E\n" "combined median: %.4E\n" "max median: %.4E" % (batch, imp_T, imp_p, med_combined, med_max)) plt.text(0.1, 0.5, text) return fig1, fig2
def single_scatter(batch, gc, stp, LN, combined, compare, plot_stat='r_ceiling', legend=False): all_batch_cells = nd.get_batch_cells(batch, as_list=True) df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined) e, a, g, s, c = improved_cells_to_list(batch, gc, stp, LN, combined, se_filter=True, LN_filter=False, as_lists=True) if plot_stat == 'r_ceiling': plot_df = df_c else: plot_df = df_r improved = c not_improved = list(set(a) - set(c)) models = [gc, stp, LN, combined] names = ['gc', 'stp', 'LN', 'combined'] m1 = models[compare[0]] m2 = models[compare[1]] name1 = names[compare[0]] name2 = names[compare[1]] n_batch = len(all_batch_cells) n_all = len(a) n_imp = len(improved) n_not_imp = len(not_improved) m1_scores = plot_df[m1][not_improved] m1_scores_improved = plot_df[m1][improved] m2_scores = plot_df[m2][not_improved] m2_scores_improved = plot_df[m2][improved] fig = plt.figure() plt.plot([0, 1], [0, 1], color='black', linewidth=1, linestyle='dashed', dashes=dash_spacing) plt.scatter(m1_scores, m2_scores, s=small_scatter, label='no imp.', color=model_colors['LN']) #color='none', #edgecolors='black', linewidth=0.35) plt.scatter(m1_scores_improved, m2_scores_improved, s=big_scatter, label='sig. imp.', color=model_colors['max']) #color='none', #edgecolors='black', linewidth=0.35) ax_remove_box() if legend: plt.legend() plt.xlim(0, 1) plt.ylim(0, 1) plt.tight_layout() plt.axes().set_aspect('equal') fig2 = plt.figure(figsize=text_fig) plt.text( 0.1, 0.5, "batch %d\n" "%d/%d auditory/total cells\n" "%d no improvements\n" "%d at least one improvement\n" "stat: %s, x: %s, y: %s" % (batch, n_all, n_batch, n_not_imp, n_imp, plot_stat, name1, name2)) return fig, fig2
def gc_stp_scatter(batch, gc, stp, LN, combined, se_filter=True, LN_filter=False, manual_cellids=None, plot_stat='r_ceiling'): df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined) cellids, under_chance, less_LN = get_filtered_cellids( df_r, df_e, gc, stp, LN, combined, se_filter, LN_filter) if manual_cellids is not None: # WARNING: Will override se and ratio filters even if they are set cellids = manual_cellids if plot_stat == 'r_ceiling': plot_df = df_c else: plot_df = df_r n_cells = len(cellids) n_under_chance = len(under_chance) if under_chance != cellids else 0 n_less_LN = len(less_LN) if less_LN != cellids else 0 gc_test = plot_df[gc][cellids] gc_test_under_chance = plot_df[gc][under_chance] stp_test = plot_df[stp][cellids] stp_test_under_chance = plot_df[stp][under_chance] fig = plt.figure() plt.scatter(gc_test, stp_test, c=wsu_gray, s=20) ax = fig.axes[0] plt.plot(ax.get_xlim(), ax.get_ylim(), 'k--', linewidth=1, dashes=dash_spacing) plt.scatter(gc_test_under_chance, stp_test_under_chance, c=dropped_cell_color, s=20) plt.title('GC vs STP') plt.xlabel('GC') plt.ylabel('STP') if se_filter: plt.text(0.90, -0.05, 'all = %d' % (n_cells + n_under_chance + n_less_LN), ha='right', va='bottom') plt.text(0.90, 0.00, 'n = %d' % n_cells, ha='right', va='bottom') plt.text(0.90, 0.05, 'uc = %d' % n_under_chance, ha='right', va='bottom', color=dropped_cell_color)
def performance_scatters(batch, gc, stp, LN, combined, se_filter=True, LN_filter=False, manual_cellids=None, plot_stat='r_ceiling', show_dropped=True, color_improvements=False): ''' model1: GC model2: STP model3: LN model4: GC+STP ''' df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined) # cellids, under_chance, less_LN = get_filtered_cellids(batch, gc, stp, # LN, combined, # se_filter, # LN_filter) e, a, g, s, c = improved_cells_to_list(batch, gc, stp, LN, combined, se_filter=se_filter, LN_filter=LN_filter) cellids = a if manual_cellids is not None: # WARNING: Will override se and ratio filters even if they are set cellids = manual_cellids if plot_stat == 'r_ceiling': plot_df = df_c else: plot_df = df_r gc_test = plot_df[gc][cellids] gc_test_under_chance = plot_df[gc][under_chance] stp_test = plot_df[stp][cellids] stp_test_under_chance = plot_df[stp][under_chance] ln_test = plot_df[LN][cellids] ln_test_under_chance = plot_df[LN][under_chance] gc_stp_test = plot_df[combined][cellids] gc_stp_test_under_chance = plot_df[combined][under_chance] # Row 1 (vs LN) fig1, ax = plt.subplots(1, 1) ax.scatter(ln_test, gc_test, c=scatter_color, s=20) ax.plot(ax.get_xlim(), ax.get_ylim(), 'k--', linewidth=2, dashes=dash_spacing) ax.scatter(gc_test_under_chance, ln_test_under_chance, c=dropped_cell_color, s=20) #ax.scatter(gc_test_less_LN, ln_test_less_LN, c='blue', s=1) ax.set_title('LN vs GC') ax.set_ylabel('GC') ax.set_xlabel('LN') fig2, ax = plt.subplots(1, 1) ax.scatter(ln_test, stp_test, c=scatter_color, s=20) ax.plot(ax.get_xlim(), ax.get_ylim(), 'k--', linewidth=2, dashes=dash_spacing) ax.scatter(stp_test_under_chance, ln_test_under_chance, c=dropped_cell_color, s=20) #ax.scatter(stp_test_less_LN, ln_test_less_LN, c='blue', s=1) ax.set_title('LN vs STP') ax.set_ylabel('STP') ax.set_xlabel('LN') fig3, ax = plt.subplots(1, 1) ax.scatter(ln_test, gc_stp_test, c=scatter_color, s=20) ax.plot(ax.get_xlim(), ax.get_ylim(), 'k--', linewidth=2, dashes=dash_spacing) ax.scatter(ln_test_under_chance, gc_stp_test_under_chance, c=dropped_cell_color, s=20) #ax.scatter(gc_stp_test_less_LN, ln_test_less_LN, c='blue', s=1) ax.set_title('LN vs GC + STP') ax.set_ylabel('GC + STP') ax.set_xlabel('LN') # Row 2 (head-to-head) fig4, ax = plt.subplots(1, 1) ax.scatter(gc_test, stp_test, c=scatter_color, s=20) ax.plot(ax.get_xlim(), ax.get_ylim(), 'k--', linewidth=2, dashes=dash_spacing) ax.scatter(gc_test_under_chance, stp_test_under_chance, c=dropped_cell_color, s=20) #ax.scatter(gc_test_less_LN, stp_test_less_LN, c='blue', s=20) ax.set_title('GC vs STP') ax.set_xlabel('GC') ax.set_ylabel('STP') fig5, ax = plt.subplots(1, 1) ax.scatter(gc_test, gc_stp_test, c=scatter_color, s=20) ax.plot(ax.get_xlim(), ax.get_ylim(), 'k--', linewidth=2, dashes=dash_spacing) ax.scatter(gc_test_under_chance, gc_stp_test_under_chance, c=dropped_cell_color, s=20) #ax.scatter(gc_test_less_LN, gc_stp_test_less_LN, c='blue', s=20) ax.set_title('GC vs GC + STP') ax.set_xlabel('GC') ax.set_ylabel('GC + STP') fig6, ax = plt.subplots(1, 1) ax.scatter(stp_test, gc_stp_test, c=scatter_color, s=20) ax.plot(ax.get_xlim(), ax.get_ylim(), 'k--', linewidth=2, dashes=dash_spacing) ax.scatter(stp_test_under_chance, gc_stp_test_under_chance, c=dropped_cell_color, s=20) #ax.scatter(stp_test_less_LN, gc_stp_test_less_LN, c='blue', s=20) ax.set_title('STP vs GC + STP') ax.set_xlabel('STP') ax.set_ylabel('GC + STP') #plt.tight_layout() return fig1, fig2, fig3, fig4, fig5, fig6
def cf_vs_model_performance(batch, gc, stp, LN, combined, cf_load_path=None, cf_kwargs={}, se_filter=True, LN_filter=False, plot_stat='r_ceiling', include_LN=False, include_combined=False, only_improvements=False): df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined) cellids, under_chance, less_LN = get_filtered_cellids( df_r, df_e, gc, stp, LN, combined, se_filter, LN_filter) cellids = cellids if only_improvements: e, a, g, s, c = improved_cells_to_list(batch, gc, stp, LN, combined) gc_cells = list((set(cellids) & (set(e) | set(g))) - set(c) - set(s)) stp_cells = list((set(cellids) & (set(e) | set(s))) - set(c) - set(g)) n_gc = len(gc_cells) n_stp = len(stp_cells) else: gc_cells = stp_cells = cellids n_gc = n_stp = len(cellids) if plot_stat == 'r_ceiling': plot_df = df_c else: plot_df = df_r if cf_load_path is None: df = cf_batch(batch, LN, load_path=cf_load_path, **cf_kwargs) else: df = pd.read_pickle(cf_load_path) gc_cfs = df['cf'][gc_cells].values.astype('float32') stp_cfs = df['cf'][stp_cells].values.astype('float32') #ln_test = plot_df[LN][cellids].values.astype('float32') gc_test = plot_df[gc][gc_cells].values.astype('float32') stp_test = plot_df[stp][stp_cells].values.astype('float32') #combined_test = plot_df[combined][cellids].values.astype('float32') r_gc, p_gc = st.spearmanr(gc_cfs, gc_test) r_stp, p_stp = st.spearmanr(stp_cfs, stp_test) plt.figure() # if include_LN: # plt.scatter(cfs, ln_test, color='gray', alpha=0.5) plt.scatter(gc_cfs, gc_test, color='goldenrod', alpha=0.5) plt.scatter(stp_cfs, stp_test, color=wsu_crimson, alpha=0.5) # if include_combined: # plt.scatter(cfs, combined_test, color='purple', alpha=0.5) plt.xscale('log', basex=np.e) title = ("CF vs model performance\n" "gc -- rho: %.4f, p: %.4E, n: %d\n" "stp -- rho: %.4f, p: %.4E, n: %d" % (r_gc, p_gc, n_gc, r_stp, p_stp, n_stp)) plt.title(title)
def equivalence_histogram(batch, gc, stp, LN, combined, se_filter=True, LN_filter=False, test_limit=None, alpha=0.05, save_path=None, load_path=None, equiv_key='partial_corr', effect_key='performance_effect', self_equiv=False, self_kwargs={}, eq_models=[], cross_kwargs={}, cross_models=[], use_median=True, exclude_low_snr=False, snr_path=None, adjust_scores=True, use_log_ratios=False): ''' model1: GC model2: STP model3: LN ''' e, a, g, s, c = improved_cells_to_list(batch, gc, stp, LN, combined) _, cellids, _, _, _ = improved_cells_to_list(batch, gc, stp, LN, combined, as_lists=False) improved = c not_improved = list(set(a) - set(c)) if load_path is None: df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined) rs = [] for c in a[:test_limit]: xf1, ctx1 = xhelp.load_model_xform(c, batch, gc) xf2, ctx2 = xhelp.load_model_xform(c, batch, stp) xf3, ctx3 = xhelp.load_model_xform(c, batch, LN) gc_pred = ctx1['val'].apply_mask()['pred'].as_continuous() stp_pred = ctx2['val'].apply_mask()['pred'].as_continuous() ln_pred = ctx3['val'].apply_mask()['pred'].as_continuous() ff = np.isfinite(gc_pred) & np.isfinite(stp_pred) & np.isfinite( ln_pred) gcff = gc_pred[ff] stpff = stp_pred[ff] lnff = ln_pred[ff] rs.append(np.corrcoef(gcff - lnff, stpff - lnff)[0, 1]) blank = np.full_like(rs, np.nan) results = { 'cellid': a[:test_limit], 'equivalence': rs, 'effect_size': blank, 'corr_gc_LN': blank, 'corr_stp_LN': blank } df = pd.DataFrame.from_dict(results) df.set_index('cellid', inplace=True) if save_path is not None: df.to_pickle(save_path) else: df = pd.read_pickle(load_path) df = df[cellids] rs = df[equiv_key].values if exclude_low_snr: snr_df = pd.read_pickle(snr_path) med_snr = snr_df['snr'].median() high_snr = snr_df.loc[snr_df['snr'] >= med_snr] high_snr_cells = high_snr.index.values.tolist() improved = list(set(improved) & set(high_snr_cells)) not_improved = list(set(not_improved) & set(high_snr_cells)) imp = np.array(improved) not_imp = np.array(not_improved) imp_mask = np.isin(a, imp) not_mask = np.isin(a, not_imp) rs_not = rs[not_mask] rs_imp = rs[imp_mask] md_not = np.nanmedian(rs_not) md_imp = np.nanmedian(rs_imp) u, p = st.mannwhitneyu(rs_not, rs_imp, alternative='two-sided') n_not = len(not_improved) n_imp = len(improved) if self_equiv: stp1, stp2, gc1, gc2, LN1, LN2 = eq_models g1, s2, s1, g2, L1, L2 = cross_models _, ga, _, _, _ = improved_cells_to_list(batch, gc1, gc2, LN1, LN2, as_lists=True) _, sa, _, _, _ = improved_cells_to_list(batch, stp1, stp2, LN1, LN2, as_lists=True) aa = list(set(ga) & set(sa)) if exclude_low_snr: snr_df = pd.read_pickle(snr_path) med_snr = snr_df['snr'].median() high_snr = snr_df.loc[snr_df['snr'] >= med_snr] high_snr_cells = high_snr.index.values.tolist() aa = list(set(aa) & set(high_snr_cells)) eqs_stp, eqs_gc = _get_self_equivs(**self_kwargs, cellids=aa) md_stpeq = np.nanmedian(eqs_stp) md_gceq = np.nanmedian(eqs_gc) n_eq = eqs_stp.size u_gceq, p_gceq = st.mannwhitneyu(eqs_gc, rs_imp, alternative='two-sided') u_stpeq, p_stpeq = st.mannwhitneyu(eqs_stp, rs_imp, alternative='two-sided') eqs_x1, eqs_x2 = _get_self_equivs(**cross_kwargs, cellids=aa) md_x1 = np.nanmedian(eqs_x1) md_x2 = np.nanmedian(eqs_x2) sub1 = df.index.isin(ga) & df.index.isin(aa) sub2 = df.index.isin(sa) & df.index.isin(aa) eqs_sub1 = df[equiv_key][sub1].values eqs_sub2 = df[equiv_key][sub2].values md_sub1 = np.nanmedian(eqs_sub1) md_sub2 = np.nanmedian(eqs_sub2) if adjust_scores: if use_median: md_avg1 = 0.5 * (md_x1 + md_x2) md_avg2 = 0.5 * (md_sub1 + md_sub2) ratio = md_avg2 / md_avg1 md_stpeq *= ratio md_gceq *= ratio else: eqs_avg1 = 0.5 * (eqs_x1 + eqs_x2) # between-model, halved est eqs_avg2 = 0.5 * (eqs_sub1 + eqs_sub2 ) # between-model, full data ratios = np.abs(eqs_avg2 / eqs_avg1) if use_log_ratios: log_ratios = np.abs(np.log(ratios)) log_ratios /= log_ratios.max() stp_scaled = eqs_stp + (1 - eqs_stp) * log_ratios gc_scaled = eqs_gc + (1 - eqs_gc) * log_ratios else: stp_scaled = eqs_stp * ratios gc_scaled = eqs_gc * ratios md_stpeq = np.nanmedian(stp_scaled) md_gceq = np.nanmedian(gc_scaled) not_color = model_colors['LN'] imp_color = model_colors['max'] weights1 = [np.ones(len(rs_not)) / len(rs_not)] weights2 = [np.ones(len(rs_imp)) / len(rs_imp)] #n_cells = rs.shape[0] fig1, (a1, a2) = plt.subplots(2, 1) a1.hist(rs_not, bins=30, range=[-0.5, 1], weights=weights1, fc=faded_LN, edgecolor=dark_LN, linewidth=1) a2.hist(rs_imp, bins=30, range=[-0.5, 1], weights=weights2, fc=faded_max, edgecolor=dark_max, linewidth=1) a1.axes.axvline(md_not, color=dark_LN, linewidth=2, linestyle='dashed', dashes=dash_spacing) a1.axes.axvline(md_imp, color=dark_max, linewidth=2, linestyle='dashed', dashes=dash_spacing) a2.axes.axvline(md_not, color=dark_LN, linewidth=2, linestyle='dashed', dashes=dash_spacing) a2.axes.axvline(md_imp, color=dark_max, linewidth=2, linestyle='dashed', dashes=dash_spacing) if self_equiv: a1.axes.annotate('', xy=(md_gceq, 0), xycoords='data', xytext=(md_gceq, 0.07), textcoords='data', arrowprops=dict(arrowstyle="->", connectionstyle="arc3")) a1.axes.annotate('', xy=(md_stpeq, 0), xycoords='data', xytext=(md_stpeq, 0.07), textcoords='data', arrowprops=dict(arrowstyle="->", connectionstyle="arc3")) a2.axes.annotate('', xy=(md_gceq, 0), xycoords='data', xytext=(md_gceq, 0.07), textcoords='data', arrowprops=dict(arrowstyle="->", connectionstyle="arc3")) a2.axes.annotate('', xy=(md_stpeq, 0), xycoords='data', xytext=(md_stpeq, 0.07), textcoords='data', arrowprops=dict(arrowstyle="->", connectionstyle="arc3")) ymin1, ymax1 = a1.get_ylim() ymin2, ymax2 = a2.get_ylim() ymax = max(ymax1, ymax2) a1.set_ylim(0, ymax) a2.set_ylim(0, ymax) ax_remove_box(a1) ax_remove_box(a2) plt.tight_layout() if equiv_key == 'equivalence': x_text = 'equivalence, CC(GC-LN, STP-LN)\n' elif equiv_key == 'partial_corr': x_text = 'equivalence, partial correlation\n' else: x_text = 'unknown equivalence key' fig3 = plt.figure(figsize=text_fig) text2 = ("hist: equivalence of changein prediction relative to LN model\n" "batch: %d\n" "x: %s" "y: cell fraction\n" "n not imp: %d, md: %.2f\n" "n sig. imp: %d, md: %.2f\n" "st.mannwhitneyu: u: %.4E p: %.4E" % (batch, x_text, n_not, md_not, n_imp, md_imp, u, p)) if self_equiv: text2 += ("\n\nSelf equivalence, n: %d\n" "stp: md: %.2f, u: %.4E p %.4E\n" "gc: md: %.2f, u: %.4E p %.4E\n" "md sub1: %.2E\n" "md sub2: %.2E\n" "md x1: %.2E\n" "md x2: %.2E" % (n_eq, md_stpeq, u_stpeq, p_stpeq, md_gceq, u_gceq, p_gceq, md_sub1, md_sub2, md_x1, md_x2)) plt.text(0.1, 0.5, text2) return fig1, fig3
def equivalence_scatter(batch, gc, stp, LN, combined, se_filter=True, LN_filter=False, plot_stat='r_ceiling', enable_hover=False, manual_lims=None, drop_outliers=False, color_improvements=True, xmodel='GC', ymodel='STP', legend=False, self_equiv=False, self_eq_models=[], show_highlights=False, exclude_low_snr=False, snr_path=None): ''' model1: GC model2: STP model3: LN ''' df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined) if plot_stat == 'r_ceiling': plot_df = df_c else: plot_df = df_r e, a, g, s, c = improved_cells_to_list(batch, gc, stp, LN, combined, se_filter=se_filter, LN_filter=LN_filter) improved = c not_improved = list(set(a) - set(c)) if exclude_low_snr: snr_df = pd.read_pickle(snr_path) med_snr = snr_df['snr'].median() high_snr = snr_df.loc[snr_df['snr'] >= med_snr] high_snr_cells = high_snr.index.values.tolist() improved = list(set(improved) & set(high_snr_cells)) not_improved = list(set(not_improved) & set(high_snr_cells)) a = list(set(a) & set(high_snr_cells)) # check number of animals included prefixes = [s[:3] for s in a] n_animals = len(list(set(prefixes))) print('n_animals: %d' % n_animals) models = [gc, stp, LN] gc_rel_imp, stp_rel_imp = _relative_score(plot_df, models, improved) gc_rel_not, stp_rel_not = _relative_score(plot_df, models, not_improved) gc_rel_all, stp_rel_all = _relative_score(plot_df, models, a) # LN, STP, GC cells_to_highlight = ['TAR010c-40-1', 'AMT005c-20-1', 'TAR009d-22-1'] cells_to_plot_gc_rel = [] cells_to_plot_stp_rel = [] for c in cells_to_highlight: stp_rel = plot_df[stp][c] - plot_df[LN][c] gc_rel = plot_df[gc][c] - plot_df[LN][c] cells_to_plot_gc_rel.append(gc_rel) cells_to_plot_stp_rel.append(stp_rel) # compute corr. before dropping outliers (only dropping for visualization) r_imp, p_imp = st.pearsonr(gc_rel_imp, stp_rel_imp) r_not, p_not = st.pearsonr(gc_rel_not, stp_rel_not) r_all, p_all = st.pearsonr(gc_rel_all, stp_rel_all) if self_equiv: stp1, stp2, gc1, gc2, LN1, LN2 = self_eq_models _, ga, _, _, _ = improved_cells_to_list(batch, gc1, gc2, LN1, LN2, as_lists=True) _, sa, _, _, _ = improved_cells_to_list(batch, stp1, stp2, LN1, LN2, as_lists=True) aa = list(set(ga) & set(sa)) if exclude_low_snr: snr_df = pd.read_pickle(snr_path) med_snr = snr_df['snr'].median() high_snr = snr_df.loc[snr_df['snr'] >= med_snr] high_snr_cells = high_snr.index.values.tolist() aa = list(set(aa) & set(high_snr_cells)) df_r_eq = nd.batch_comp(batch, [gc1, gc2, stp1, stp2, LN1, LN2], stat=plot_stat) df_r_eq.dropna(axis=0, how='any', inplace=True) df_r_eq.sort_index(inplace=True) df_r_eq = df_r_eq[df_r_eq.index.isin(aa)] gc1_rel_imp = df_r_eq[gc1].values - df_r_eq[LN1].values gc2_rel_imp = df_r_eq[gc2].values - df_r_eq[LN2].values stp1_rel_imp = df_r_eq[stp1].values - df_r_eq[LN1].values stp2_rel_imp = df_r_eq[stp2].values - df_r_eq[LN2].values r_gceq, p_gceq = st.pearsonr(gc1_rel_imp, gc2_rel_imp) r_stpeq, p_stpeq = st.pearsonr(stp1_rel_imp, stp2_rel_imp) n_eq = gc1_rel_imp.size # compute on same subset for full estimation data # to compare to cross-set gc_subset1 = df_r[gc][ga].values gc_subset2 = df_r[gc][sa].values stp_subset1 = df_r[stp][ga].values stp_subset2 = df_r[stp][sa].values LN_subset1 = df_r[LN][ga].values LN_subset2 = df_r[LN][sa].values gc_rel1 = gc_subset1 - LN_subset1 gc_rel2 = gc_subset2 - LN_subset2 stp_rel1 = stp_subset1 - LN_subset1 stp_rel2 = stp_subset2 - LN_subset2 r_sub1, p_sub1 = st.pearsonr(gc_rel1, stp_rel1) r_sub2, p_sub2 = st.pearsonr(gc_rel2, stp_rel2) gc_rel_imp, stp_rel_imp = drop_common_outliers(gc_rel_imp, stp_rel_imp) gc_rel_not, stp_rel_not = drop_common_outliers(gc_rel_not, stp_rel_not) gc_rel_all, stp_rel_all = drop_common_outliers(gc_rel_all, stp_rel_all) n_imp = len(improved) n_not = len(not_improved) n_all = len(a) y_max = np.max(stp_rel_all) y_min = np.min(stp_rel_all) x_max = np.max(gc_rel_all) x_min = np.min(gc_rel_all) fig = plt.figure() ax = plt.gca() ax.axes.axhline(0, color='black', linewidth=1, linestyle='dashed', dashes=dash_spacing) ax.axes.axvline(0, color='black', linewidth=1, linestyle='dashed', dashes=dash_spacing) if color_improvements: ax.scatter(gc_rel_not, stp_rel_not, c=model_colors['LN'], s=small_scatter) ax.scatter(gc_rel_imp, stp_rel_imp, c=model_colors['max'], s=big_scatter) if show_highlights: for i, (g, s) in enumerate( zip(cells_to_plot_gc_rel, cells_to_plot_stp_rel)): plt.text(g, s, str(i + 1), fontsize=12, color='black') else: ax.scatter(gc_rel_all, stp_rel_all, c='black', s=big_scatter) if legend: plt.legend() if manual_lims is not None: ax.set_ylim(*manual_lims) ax.set_xlim(*manual_lims) else: upper = max(y_max, x_max) lower = min(y_min, x_min) upper_lim = np.ceil(10 * upper) / 10 lower_lim = np.floor(10 * lower) / 10 ax.set_ylim(lower_lim, upper_lim) ax.set_xlim(lower_lim, upper_lim) plt.axes().set_aspect('equal') if enable_hover: mplcursors.cursor(ax, hover=True).connect( "add", lambda sel: sel.annotation.set_text(a[sel.target.index])) plt.tight_layout() fig2 = plt.figure(figsize=text_fig) text = ("Performance Improvements over LN\n" "batch: %d\n" "dropped outliers?: %s\n" "all cells: r: %.2E, p: %.2E, n: %d\n" "improved: r: %.2E, p: %.2E, n: %d\n" "not imp: r: %.2E, p: %.2E, n: %d\n" "x: %s - LN\n" "y: %s - LN" % (batch, drop_outliers, r_all, p_all, n_all, r_imp, p_imp, n_imp, r_not, p_not, n_not, xmodel, ymodel)) if self_equiv: text += ("\n\nSelf Equivalence, n: %d\n" "gc, r: %.2Ef, p: %.2E\n" "stp, r: %.2Ef, p: %.2E\n" "sub1, r: %.2E, p: %.2E\n" "sub2, r: %.2E, p: %.2E\n" % (n_eq, r_gceq, p_gceq, r_stpeq, p_stpeq, r_sub1, p_sub1, r_sub2, p_sub2)) plt.text(0.1, 0.5, text) ax_remove_box(ax) return fig, fig2
def gain_by_contrast_slopes(batch, gc, stp, LN, combined, se_filter=True, good_LN=0, bins=30, use_exp=True): df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined) #cellids = df_r[LN] > good_LN cellids = df_r[LN] > df_e[LN] * 2 gc_LN_SE = (df_e[gc] + df_e[LN]) # stp_LN_SE = (df_e[stp] + df_e[LN]) gc_cells = (cellids) & ((df_r[gc] - df_r[LN]) > gc_LN_SE) # stp_cells = (df_r[LN] > good_LN) & ((df_r[stp] - df_r[LN]) > stp_LN_SE) # both_cells = gc_cells & stp_cells # gc_cells = gc_cells & np.logical_not(both_cells) # stp_cells = stp_cells & np.logical_not(both_cells) LN_cells = cellids & np.logical_not(gc_cells) # | stp_cells | both_cells) meta = ['r_test', 'ctmax_val', 'ctmax_est', 'ctmin_val', 'ctmin_est'] gc_params = fitted_params_per_batch(289, gc, stats_keys=[], meta=meta) # drop cellids that haven't been fit for all models gc_params_cells = gc_params.transpose().index.values.tolist() for c in gc_params_cells: if c not in LN_cells: LN_cells[c] = False if c not in gc_cells: gc_cells[c] = False # if c not in stp_cells: # stp_cells[c] = False # if c not in both_cells: # both_cells[c] = False # index keys are formatted like "4--dsig.d--kappa" mod_keys = gc.split('_')[1] for i, k in enumerate(mod_keys.split('-')): if 'dsig' in k: break k_key = f'{i}--{k}--kappa' ka_key = k_key + '_mod' meta_keys = ['meta--' + k for k in meta] all_keys = [k_key, ka_key] + meta_keys phi_dfs = [ gc_params[gc_params.index == k].transpose()[LN_cells].transpose() for k in all_keys ] sep_dfs = [df.values.flatten().astype(np.float64) for df in phi_dfs] gc_dfs = [ gc_params[gc_params.index == k].transpose()[gc_cells].transpose() for k in all_keys ] gc_sep_dfs = [df.values.flatten().astype(np.float64) for df in gc_dfs] # stp_dfs = [gc_params[gc_params.index==k].transpose()[stp_cells].transpose() # for k in all_keys] # stp_sep_dfs = [df.values.flatten().astype(np.float64) for df in stp_dfs] # both_dfs = [gc_params[gc_params.index==k].transpose()[both_cells].transpose() # for k in all_keys] # both_sep_dfs = [df.values.flatten().astype(np.float64) for df in both_dfs] low, high, r_test, ctmax_val, ctmax_est, ctmin_val, ctmin_est = sep_dfs gc_low, gc_high, gc_r, gc_ctmax_val, \ gc_ctmax_est, gc_ctmin_val, gc_ctmin_est = gc_sep_dfs # stp_low, stp_high, stp_r, stp_ctmax_val, \ # stp_ctmax_est, stp_ctmin_val, stp_ctmin_est = stp_sep_dfs # both_low, both_high, both_r, both_ctmax_val, \ # both_ctmax_est, both_ctmin_val, both_ctmin_est = both_sep_dfs ctmax = np.maximum(ctmax_val, ctmax_est) gc_ctmax = np.maximum(gc_ctmax_val, gc_ctmax_est) ctmin = np.minimum(ctmin_val, ctmin_est) gc_ctmin = np.minimum(gc_ctmin_val, gc_ctmin_est) # stp_ctmax = np.maximum(stp_ctmax_val, stp_ctmax_est) # stp_ctmin = np.minimum(stp_ctmin_val, stp_ctmin_est) # both_ctmax = np.maximum(both_ctmax_val, both_ctmax_est) # both_ctmin = np.minimum(both_ctmin_val, both_ctmin_est) ct_range = ctmax - ctmin gc_ct_range = gc_ctmax - gc_ctmin # stp_ct_range = stp_ctmax - stp_ctmin # both_ct_range = both_ctmax - both_ctmin gain = (high - low) * ct_range gc_gain = (gc_high - gc_low) * gc_ct_range # test hyp. that gc_gains are more negative than LN gc_LN_p = st.mannwhitneyu(gc_gain, gain, alternative='two-sided')[1] med_gain = np.median(gain) gc_med_gain = np.median(gc_gain) # stp_gain = (stp_high - stp_low)*stp_ct_range # both_gain = (both_high - both_low)*both_ct_range k_low = low + (high - low) * ctmin k_high = low + (high - low) * ctmax gc_k_low = gc_low + (gc_high - gc_low) * gc_ctmin gc_k_high = gc_low + (gc_high - gc_low) * gc_ctmax # stp_k_low = stp_low + (stp_high - stp_low)*stp_ctmin # stp_k_high = stp_low + (stp_high - stp_low)*stp_ctmax # both_k_low = both_low + (both_high - both_low)*both_ctmin # both_k_high = both_low + (both_high - both_low)*both_ctmax if use_exp: k_low = np.exp(k_low) k_high = np.exp(k_high) gc_k_low = np.exp(gc_k_low) gc_k_high = np.exp(gc_k_high) # stp_k_low = np.exp(stp_k_low) # stp_k_high = np.exp(stp_k_high) # both_k_low = np.exp(both_k_low) # both_k_high = np.exp(both_k_high) # fig = plt.figure()#, axes = plt.subplots(1, 2, ) # #axes[0].plot([ctmin, ctmax], [k_low, k_high], color='black', alpha=0.5) # plt.hist(high-low, bins=bins, color='black', alpha=0.5) # # #axes[0].plot([gc_ctmin, gc_ctmax], [gc_k_low, gc_k_high], color='red', # # alpha=0.3) # plt.hist(gc_high-gc_low, bins=bins, color='red', alpha=0.3) # # #axes[0].plot([stp_ctmin, stp_ctmax], [stp_k_low, stp_k_high], color='blue', # # alpha=0.3) # plt.hist(stp_high-stp_low, bins=bins, color='blue', alpha=0.3) # plt.xlabel('gain slope') # plt.ylabel('count') # plt.title(f'raw counts, LN > {good_LN}') # plt.legend([f'LN, {len(low)}', f'gc, {len(gc_low)}', f'stp, {len(stp_low)}', # f'Both, {len(both_low)}']) smallest_slope = min(np.min(gain), np.min(gc_gain)) #, np.min(stp_gain), #np.min(both_gain)) largest_slope = max(np.max(gain), np.max(gc_gain)) #, np.max(stp_gain), #np.max(both_gain)) slope_range = (smallest_slope, largest_slope) bins = np.linspace(smallest_slope, largest_slope, bins) bar_width = bins[1] - bins[0] axis_locs = bins[:-1] hist = np.histogram(gain, bins=bins, range=slope_range) gc_hist = np.histogram(gc_gain, bins=bins, range=slope_range) # stp_hist = np.histogram(stp_gain, bins=bins, range=slope_range) # both_hist = np.histogram(both_gain, bins=bins, range=slope_range) raw = hist[0] gc_raw = gc_hist[0] # stp_raw = stp_hist[0] # both_raw = both_hist[0] #prop_hist = hist[0] / np.sum(hist[0]) #prop_gc_hist = gc_hist[0] / np.sum(gc_hist[0]) # prop_stp_hist = stp_hist[0] / np.sum(stp_hist[0]) # prop_both_hist = both_hist[0] / np.sum(both_hist[0]) fig1 = plt.figure() plt.bar(axis_locs, raw, width=bar_width, color='gray', alpha=0.8) plt.bar(axis_locs, gc_raw, width=bar_width, color='maroon', alpha=0.8, bottom=raw) # plt.bar(axis_locs, stp_raw, width=bar_width, color='teal', alpha=0.8, # bottom=raw+gc_raw) # plt.bar(axis_locs, both_raw, width=bar_width, color='goldenrod', alpha=0.8, # bottom=raw+gc_raw+stp_raw) plt.xlabel('gain slope') plt.ylabel('count') plt.title(f'raw counts, LN > {good_LN}') plt.legend([ f'LN, {len(low)}, md={med_gain:.4f}', f'gc, {len(gc_low)}, md={gc_med_gain:.4f}, p={gc_LN_p:.4f}' ])
def gd_ratio(batch, gc, stp, LN, combined, se_filter=True, good_LN=0, bins=30, use_exp=True): df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined) #cellids = df_r[LN] > good_LN cellids = df_r[LN] > df_e[LN] * 2 gc_LN_SE = (df_e[gc] + df_e[LN]) #stp_LN_SE = (df_e[stp] + df_e[LN]) gc_cells = cellids & ((df_r[gc] - df_r[LN]) > gc_LN_SE) #stp_cells = (df_r[LN] > good_LN) & ((df_r[stp] - df_r[LN]) > stp_LN_SE) #both_cells = gc_cells & stp_cells LN_cells = cellids & np.logical_not(gc_cells) #stp_cells = stp_cells & np.logical_not(both_cells) meta = ['r_test', 'ctmax_val', 'ctmax_est', 'ctmin_val', 'ctmin_est'] gc_params = fitted_params_per_batch(289, gc, stats_keys=[], meta=meta) # drop cellids that haven't been fit for all models gc_params_cells = gc_params.transpose().index.values.tolist() for c in gc_params_cells: if c not in LN_cells: LN_cells[c] = False if c not in gc_cells: gc_cells[c] = False # if c not in stp_cells: # stp_cells[c] = False # if c not in both_cells: # both_cells[c] = False # index keys are formatted like "4--dsig.d--kappa" mod_keys = gc.split('_')[1] for i, k in enumerate(mod_keys.split('-')): if 'dsig' in k: break k_key = f'{i}--{k}--kappa' ka_key = k_key + '_mod' meta_keys = ['meta--' + k for k in meta] all_keys = [k_key, ka_key] + meta_keys phi_dfs = [ gc_params[gc_params.index == k].transpose()[LN_cells].transpose() for k in all_keys ] sep_dfs = [df.values.flatten().astype(np.float64) for df in phi_dfs] gc_dfs = [ gc_params[gc_params.index == k].transpose()[gc_cells].transpose() for k in all_keys ] gc_sep_dfs = [df.values.flatten().astype(np.float64) for df in gc_dfs] # stp_dfs = [gc_params[gc_params.index==k].transpose()[stp_cells].transpose() # for k in all_keys] # stp_sep_dfs = [df.values.flatten().astype(np.float64) for df in stp_dfs] # both_dfs = [gc_params[gc_params.index==k].transpose()[both_cells].transpose() # for k in all_keys] # both_sep_dfs = [df.values.flatten().astype(np.float64) for df in both_dfs] low, high, r_test, ctmax_val, ctmax_est, ctmin_val, ctmin_est = sep_dfs gc_low, gc_high, gc_r, gc_ctmax_val, \ gc_ctmax_est, gc_ctmin_val, gc_ctmin_est = gc_sep_dfs # stp_low, stp_high, stp_r, stp_ctmax_val, \ # stp_ctmax_est, stp_ctmin_val, stp_ctmin_est = stp_sep_dfs # both_low, both_high, both_r, both_ctmax_val, \ # both_ctmax_est, both_ctmin_val, both_ctmin_est = both_sep_dfs ctmax = np.maximum(ctmax_val, ctmax_est) gc_ctmax = np.maximum(gc_ctmax_val, gc_ctmax_est) ctmin = np.minimum(ctmin_val, ctmin_est) gc_ctmin = np.minimum(gc_ctmin_val, gc_ctmin_est) # stp_ctmax = np.maximum(stp_ctmax_val, stp_ctmax_est) # stp_ctmin = np.minimum(stp_ctmin_val, stp_ctmin_est) # both_ctmax = np.maximum(both_ctmax_val, both_ctmax_est) # both_ctmin = np.minimum(both_ctmin_val, both_ctmin_est) k_low = low + (high - low) * ctmin k_high = low + (high - low) * ctmax gc_k_low = gc_low + (gc_high - gc_low) * gc_ctmin gc_k_high = gc_low + (gc_high - gc_low) * gc_ctmax # stp_k_low = stp_low + (stp_high - stp_low)*stp_ctmin # stp_k_high = stp_low + (stp_high - stp_low)*stp_ctmax # both_k_low = both_low + (both_high - both_low)*both_ctmin # both_k_high = both_low + (both_high - both_low)*both_ctmax if use_exp: k_low = np.exp(k_low) k_high = np.exp(k_high) gc_k_low = np.exp(gc_k_low) gc_k_high = np.exp(gc_k_high) # stp_k_low = np.exp(stp_k_low) # stp_k_high = np.exp(stp_k_high) # both_k_low = np.exp(both_k_low) # both_k_high = np.exp(both_k_high) ratio = k_low / k_high gc_ratio = gc_k_low / gc_k_high # stp_ratio = stp_k_low / stp_k_high # both_ratio = both_k_low / both_k_high fig1, ((ax1), (ax2)) = plt.subplots( 1, 2, ) ax1.hist(ratio, bins=bins) ax1.set_title('all cells') ax2.hist(gc_ratio, bins=bins) ax2.set_title('gc') # ax3.hist(stp_ratio, bins=bins) # ax3.set_title('stp') if not use_exp: title = 'k_low / k_high' else: title = 'e^(k_low - k_high)' fig1.suptitle(title) fig3 = plt.figure() plt.scatter(ratio, r_test) plt.title('low/high vs r_test') fig4 = plt.figure() plt.scatter(gc_ratio, gc_r) plt.title('low/high vs r_test, gc improvements only')
def stp_distributions(batch, gc, stp, LN, combined, se_filter=True, good_ln=0, log_scale=False, legend=False, use_combined=False): df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined) cellids, under_chance, less_LN = get_filtered_cellids(batch, gc, stp, LN, combined, as_lists=False) _, _, _, _, c = improved_cells_to_list(batch, gc, stp, LN, combined, good_ln=good_ln) if use_combined: params_model = combined else: params_model = stp stp_params = fitted_params_per_batch(batch, params_model, stats_keys=[], meta=[]) stp_params_cells = stp_params.transpose().index.values.tolist() for cell in stp_params_cells: if cell not in cellids: cellids[cell] = False not_c = list(set(stp_params.transpose()[cellids].index.values) - set(c)) # index keys are formatted like "2--stp.2--tau" mod_keys = stp.split('_')[1] for i, k in enumerate(mod_keys.split('-')): if 'stp' in k: break tau_key = '%d--%s--tau' % (i, k) u_key = '%d--%s--u' % (i, k) all_taus = stp_params[stp_params.index == tau_key].transpose()[cellids].transpose() all_us = stp_params[stp_params.index == u_key].transpose()[cellids].transpose() dims = all_taus.values.flatten()[0].shape[0] # convert to dims x cells array instead of cells, array w/ multidim values #sep_taus = _df_to_array(all_taus, dims).mean(axis=0) #sep_us = _df_to_array(all_us, dims).mean(axis=0) #med_tau = np.median(sep_taus) #med_u = np.median(sep_u) sep_taus = _df_to_array(all_taus[not_c], dims).mean(axis=0) sep_us = _df_to_array(all_us[not_c], dims).mean(axis=0) med_tau = np.median(sep_taus) med_u = np.median(sep_us) stp_taus = all_taus[c] stp_us = all_us[c] stp_sep_taus = _df_to_array(stp_taus, dims).mean(axis=0) stp_sep_us = _df_to_array(stp_us, dims).mean(axis=0) stp_med_tau = np.median(stp_sep_taus) stp_med_u = np.median(stp_sep_us) #tau_t, tau_p = st.ttest_ind(sep_taus, stp_sep_taus) #u_t, u_p = st.ttest_ind(sep_us, stp_sep_us) # NOTE: not actually a t statistic now, it's mann-whitney U statistic, # just didn't want to change all of the var names incase i revert tau_t, tau_p = st.mannwhitneyu(sep_taus, stp_sep_taus, alternative='two-sided') u_t, u_p = st.mannwhitneyu(sep_us, stp_sep_us, alternative='two-sided') sep_taus, sep_us = drop_common_outliers(sep_taus, sep_us) stp_sep_taus, stp_sep_us = drop_common_outliers(stp_sep_taus, stp_sep_us) not_imp_outliers = len(sep_taus) imp_outliers = len(stp_sep_taus) fig1, (a1, a2) = plt.subplots(2, 1, sharex=True, sharey=True) color = model_colors['LN'] imp_color = model_colors['max'] stp_label = 'STP ++ (%d)' % len(c) total_cells = len(c) + len(not_c) bin_count = 30 hist_kwargs = {'linewidth': 1, 'label': ['not imp', 'stp imp']} plt.sca(a1) weights1 = [np.ones(len(sep_taus)) / len(sep_taus)] weights2 = [np.ones(len(stp_sep_taus)) / len(stp_sep_taus)] upper = max(sep_taus.max(), stp_sep_taus.max()) lower = min(sep_taus.min(), stp_sep_taus.min()) bins = np.linspace(lower, upper, bin_count + 1) # if log_scale: # lower_bound = min(sep_taus.min(), stp_sep_taus.min()) # upper_bound = max(sep_taus.max(), stp_sep_taus.max()) # bins = np.logspace(lower_bound, upper_bound, bin_count+1) # hist_kwargs['bins'] = bins # plt.hist([sep_taus, stp_sep_taus], weights=weights, **hist_kwargs) a1.hist(sep_taus, weights=weights1, fc=faded_LN, edgecolor=dark_LN, bins=bins, **hist_kwargs) a2.hist(stp_sep_taus, weights=weights2, fc=faded_max, edgecolor=dark_max, bins=bins, **hist_kwargs) a1.axes.axvline(med_tau, color=dark_LN, linewidth=2, linestyle='dashed', dashes=dash_spacing) a1.axes.axvline(stp_med_tau, color=dark_max, linewidth=2, linestyle='dashed', dashes=dash_spacing) a2.axes.axvline(med_tau, color=dark_LN, linewidth=2, linestyle='dashed', dashes=dash_spacing) a2.axes.axvline(stp_med_tau, color=dark_max, linewidth=2, linestyle='dashed', dashes=dash_spacing) ax_remove_box(a1) ax_remove_box(a2) #plt.title('tau, sig diff?: p=%.4E' % tau_p) #plt.xlabel('tau (ms)') fig2 = plt.figure(figsize=text_fig) text = ("tau distributions, n: %d\n" "n stp imp (bot): %d, med: %.4f\n" "n not imp (top): %d, med: %.4f\n" "yaxes: fraction of cells\n" "xaxis: tau(ms)\n" "st.mannwhitneyu u: %.4E,\np: %.4E\n" "not imp after outliers: %d\n" "imp after outliers: %d\n" % (total_cells, len(c), stp_med_tau, len(not_c), med_tau, tau_t, tau_p, not_imp_outliers, imp_outliers)) plt.text(0.1, 0.5, text) fig3, (a3, a4) = plt.subplots(2, 1, sharex=True, sharey=True) weights3 = [np.ones(len(sep_us)) / len(sep_us)] weights4 = [np.ones(len(stp_sep_us)) / len(stp_sep_us)] upper = max(sep_us.max(), stp_sep_us.max()) lower = min(sep_us.min(), stp_sep_us.min()) bins = np.linspace(lower, upper, bin_count + 1) # if log_scale: # lower_bound = min(sep_us.min(), stp_sep_us.min()) # upper_bound = max(sep_us.max(), stp_sep_us.max()) # bins = np.logspace(lower_bound, upper_bound, bin_count+1) # hist_kwargs['bins'] = bins # plt.hist([sep_us, stp_sep_us], weights=weights, **hist_kwargs) a3.hist(sep_us, weights=weights3, fc=faded_LN, edgecolor=dark_LN, bins=bins, **hist_kwargs) a4.hist(stp_sep_us, weights=weights4, fc=faded_max, edgecolor=dark_max, bins=bins, **hist_kwargs) a3.axes.axvline(med_u, color=dark_LN, linewidth=2, linestyle='dashed', dashes=dash_spacing) a3.axes.axvline(stp_med_u, color=dark_max, linewidth=2, linestyle='dashed', dashes=dash_spacing) a4.axes.axvline(med_u, color=dark_LN, linewidth=2, linestyle='dashed', dashes=dash_spacing) a4.axes.axvline(stp_med_u, color=dark_max, linewidth=2, linestyle='dashed', dashes=dash_spacing) ax_remove_box(a3) ax_remove_box(a4) #plt.title('u, sig diff?: p=%.4E' % u_p) #plt.xlabel('u (fractional change in gain \nper unit of stimulus amplitude)') #plt.ylabel('proportion within group') fig4 = plt.figure(figsize=text_fig) text = ("u distributions, n: %d\n" "n stp imp (bot): %d, med: %.4f\n" "n not imp (top): %d, med: %.4f\n" "yaxes: fraction of cells\n" "xaxis: u(fractional change in gain per unit stimulus amplitude)\n" "st.mannwhitneyu u: %.4E,\np: %.4E" % (total_cells, len(c), stp_med_u, len(not_c), med_u, u_t, u_p)) plt.text(0.1, 0.5, text) stp_mag, stp_yin, stp_out = stp_magnitude(np.array([[stp_med_tau]]), np.array([[stp_med_u]])) mag, yin, out = stp_magnitude(np.array([[med_tau]]), np.array([[med_u]])) fig5 = plt.figure(figsize=short_fig) plt.plot(stp_out.as_continuous().flatten(), color=imp_color, label='STP ++') plt.plot(out.as_continuous().flatten(), color=color) if legend: plt.legend() ax_remove_box() return fig1, fig2, fig3, fig4, fig5
def gc_distributions(batch, gc, stp, LN, combined, se_filter=True, good_ln=0, use_combined=False): df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined) cellids, under_chance, less_LN = get_filtered_cellids(batch, gc, stp, LN, combined, as_lists=False) _, _, _, _, c = improved_cells_to_list(batch, gc, stp, LN, combined, good_ln=good_ln) if use_combined: params_model = combined else: params_model = gc gc_params = fitted_params_per_batch(batch, params_model, stats_keys=[], meta=[]) gc_params_cells = gc_params.transpose().index.values.tolist() for cell in gc_params_cells: if cell not in cellids: cellids[cell] = False not_c = list(set(gc_params.transpose()[cellids].index.values) - set(c)) # index keys are formatted like "4--dsig.d--kappa" mod_keys = params_model.split('_')[1] for i, k in enumerate(mod_keys.split('-')): if 'dsig' in k: break b_key = f'{i}--{k}--base' a_key = f'{i}--{k}--amplitude' s_key = f'{i}--{k}--shift' k_key = f'{i}--{k}--kappa' ka_key = k_key + '_mod' ba_key = b_key + '_mod' aa_key = a_key + '_mod' sa_key = s_key + '_mod' all_keys = [b_key, a_key, s_key, k_key, ba_key, aa_key, sa_key, ka_key] phi_dfs = [ gc_params[gc_params.index == k].transpose()[cellids].transpose() for k in all_keys ] sep_dfs = [df[not_c].values.flatten().astype(np.float64) for df in phi_dfs] gc_sep_dfs = [df[c].values.flatten().astype(np.float64) for df in phi_dfs] # removing extreme outliers b/c kept getting one or two cells with # values that were multiple orders of magnitude different than all others # diffs = [sep_dfs[i+1] - sep_dfs[i] # for i, _ in enumerate(sep_dfs[:-1]) # if i % 2 == 0] #diffs = sep_dfs[1::2] - sep_dfs[::2] # gc_diffs = [gc_sep_dfs[i+1] - gc_sep_dfs[i] # for i, _ in enumerate(gc_sep_dfs[:-1]) # if i % 2 == 0] #gc_diffs = gc_sep_dfs[1::2] - gc_sep_dfs[::2] raw_low, raw_high = sep_dfs[:4], sep_dfs[4:] diffs = [high - low for low, high in zip(raw_low, raw_high)] medians = [np.median(d) for d in diffs] medians_low = [np.median(d) for d in raw_low] medians_high = [np.median(d) for d in raw_high] gc_raw_low, gc_raw_high = gc_sep_dfs[:4], gc_sep_dfs[4:] gc_diffs = [high - low for low, high in zip(gc_raw_low, gc_raw_high)] gc_medians = [np.median(d) for d in gc_diffs] gc_medians_low = [np.median(d) for d in gc_raw_low] gc_medians_high = [np.median(d) for d in gc_raw_high] ts, ps = zip(*[ st.mannwhitneyu(diff, gc_diff, alternative='two-sided') for diff, gc_diff in zip(diffs, gc_diffs) ]) diffs = drop_common_outliers(*diffs) gc_diffs = drop_common_outliers(*gc_diffs) not_imp_outliers = len(diffs[0]) imp_outliers = len(gc_diffs[0]) color = model_colors['LN'] c_color = model_colors['max'] gc_label = 'GC ++ (%d)' % len(c) total_cells = len(c) + len(not_c) hist_kwargs = {'label': ['no imp', 'sig imp'], 'linewidth': 1} figs = [] for i, name in zip([0, 1, 2, 3], ['base', 'amplitude', 'shift', 'kappa']): f1 = _stacked_hists(diffs[i], gc_diffs[i], medians[i], gc_medians[i], color, c_color, hist_kwargs=hist_kwargs) f2 = plt.figure(figsize=text_fig) text = ("%s distributions, n: %d\n" "n gc imp (bot): %d, med: %.4f\n" "n not imp (top): %d, med: %.4f\n" "yaxes: fraction of cells\n" "xaxis: 'fractional change in parameter per unit contrast'\n" "st.mannwhitneyu u: %.4E,\np: %.4E\n" "not imp w/o outliers: %d\n" "imp w/o outliers: %d" % (name, total_cells, len(c), gc_medians[i], len(not_c), medians[i], ts[i], ps[i], not_imp_outliers, imp_outliers)) plt.text(0.1, 0.5, text) figs.append(f1) figs.append(f2) f3 = plt.figure(figsize=small_fig) # median gc effect plots yin1, out1 = gc_dummy_sigmoid(*medians_low, low=0.0, high=0.3) yin2, out2 = gc_dummy_sigmoid(*medians_high, low=0.0, high=0.3) plt.scatter(yin1, out1, color=color, s=big_scatter, alpha=0.3) plt.scatter(yin2, out2, color=color, s=big_scatter * 2) figs.append(f3) plt.tight_layout() ax_remove_box() f3a = plt.figure(figsize=text_fig) text = ("non improved cells\n" "median low contrast:\n" "base: %.4f, amplitude: %.4f\n" "shift: %.4f, kappa: %.4f\n" "median high contrast:\n" "base: %.4f, amplitude: %.4f\n" "shift: %.4f, kappa: %.4f\n" % (*medians_low, *medians_high)) plt.text(0.1, 0.5, text) figs.append(f3a) f4 = plt.figure(figsize=small_fig) gc_yin1, gc_out1 = gc_dummy_sigmoid(*gc_medians_low, low=0.0, high=0.3) gc_yin2, gc_out2 = gc_dummy_sigmoid(*gc_medians_high, low=0.0, high=0.3) plt.scatter(gc_yin1, gc_out1, color=c_color, s=big_scatter, alpha=0.3) plt.scatter(gc_yin2, gc_out2, color=c_color, s=big_scatter * 2) figs.append(f4) plt.tight_layout() ax_remove_box() f4a = plt.figure(figsize=text_fig) text = ("improved cells\n" "median low contrast:\n" "base: %.4f, amplitude: %.4f\n" "shift: %.4f, kappa: %.4f\n" "median high contrast:\n" "base: %.4f, amplitude: %.4f\n" "shift: %.4f, kappa: %.4f\n" % (*gc_medians_low, *gc_medians_high)) plt.text(0.1, 0.5, text) figs.append(f4a) return figs
def relative_bar_comparison(batch1, batch2, gc, stp, LN, combined, se_filter=True, ln_filter=False, plot_stat='r_ceiling', only_improvements=False, good_ln=0.0): raise ValueError('fix cellid filters before using me') df_r1, df_c1, df_e1 = get_dataframes(batch1, gc, stp, LN, combined) cellids1, under_chance1, less_LN1 = get_filtered_cellids( df_r1, df_e1, gc, stp, LN, combined, se_filter, ln_filter) df_r2, df_c2, df_e2 = get_dataframes(batch2, gc, stp, LN, combined) cellids2, under_chance2, less_LN2 = get_filtered_cellids( df_r2, df_e2, gc, stp, LN, combined, se_filter, ln_filter) if only_improvements: # only use cells for which there was a significant improvement # to one or more models e1, n1, g1, s1, c1 = improved_cells_to_list(batch1, gc, stp, LN, combined, good_ln=good_ln) filter1 = list(set(e1) | set(g1) | set(s1) | set(c1)) e2, n2, g2, s2, c2 = improved_cells_to_list(batch2, gc, stp, LN, combined, good_ln=good_ln) filter2 = list(set(e2) | set(g2) | set(s2) | set(c2)) cellids1 = [c for c in cellids1 if c in filter1] cellids2 = [c for c in cellids2 if c in filter2] if plot_stat == 'r_ceiling': plot_df1 = df_c1 plot_df2 = df_c2 else: plot_df1 = df_r1 plot_df2 = df_r2 n_cells1 = len(cellids1) gc_test1 = plot_df1[gc][cellids1] stp_test1 = plot_df1[stp][cellids1] ln_test1 = plot_df1[LN][cellids1] gc_stp_test1 = plot_df1[combined][cellids1] n_cells2 = len(cellids2) gc_test2 = plot_df2[gc][cellids2] stp_test2 = plot_df2[stp][cellids2] ln_test2 = plot_df2[LN][cellids2] gc_stp_test2 = plot_df2[combined][cellids2] gc_rel1 = gc_test1 - ln_test1 stp_rel1 = stp_test1 - ln_test1 gc_stp_rel1 = gc_stp_test1 - ln_test1 gc1 = np.mean(gc_rel1.values) stp1 = np.mean(stp_rel1.values) gc_stp1 = np.mean(gc_stp_rel1.values) #largest1 = max(gc1, stp1, gc_stp1) gc_rel2 = gc_test2 - ln_test2 stp_rel2 = stp_test2 - ln_test2 gc_stp_rel2 = gc_stp_test2 - ln_test2 gc2 = np.mean(gc_rel2.values) stp2 = np.mean(stp_rel2.values) gc_stp2 = np.mean(gc_stp_rel2.values) #largest2 = max(gc2, stp2, gc_stp2) fig = plt.figure(figsize=(15, 12)) plt.bar( [1, 2, 3, 4, 5, 6], [gc1, stp1, gc_stp1, gc2, stp2, gc_stp2], #color=['purple', 'green', 'gray', 'blue']) color=[ gc_color, stp_color, gc_stp_color, gc_color, stp_color, gc_stp_color ], edgecolor="black", linewidth=2) plt.xticks( [1, 2, 3, 4, 5], ['GC %d' % batch1, 'STP', 'GC+STP', 'GC %d' % batch2, 'STP', 'GC+STP']) # if abbr_yaxis: # lower = np.floor(10*min(gc, stp, ln, gc_stp))/10 # upper = np.ceil(10*max(gc, stp, ln, gc_stp))/10 + y_adjust # plt.ylim(ymin=lower, ymax=upper) # else: # plt.ylim(ymax=largest*1.4) common_kwargs = {'color': 'white', 'horizontalalignment': 'center'} # if abbr_yaxis: # y_text = 0.5*(lower + min(gc, stp, ln, gc_stp)) # else: y_text = 0.2 plt.text(1, y_text, "%0.04f" % gc1, **common_kwargs) plt.text(2, y_text, "%0.04f" % stp1, **common_kwargs) plt.text(3, y_text, "%0.04f" % gc_stp1, **common_kwargs) plt.text(4, y_text, "%0.04f" % gc2, **common_kwargs) plt.text(5, y_text, "%0.04f" % stp2, **common_kwargs) plt.text(6, y_text, "%0.04f" % gc_stp2, **common_kwargs) plt.title("Mean Relative (to LN) Performance for GC, STP, and GC+STP\n" "batch %d, n: %d vs batch %d, n: %d" % (batch1, n_cells1, batch2, n_cells2)) return fig
def significance(batch, gc, stp, LN, combined, manual_cellids=None, plot_stat='r_ceiling', include_legend=True, only_improvements=False): ''' model1: GC model2: STP model3: LN model4: GC+STP ''' # NOTE: The comparison of max(gc, stp) to gc/stp should be # ignored. They're known to be different by definition # and there's a bug in the scipy code that causes the # W-statistic to be reported as 0. # This happens because all of the differences are one-sided # and the scipy code takes the minimum of either positive # or negative differences, one of which will always be 0 df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined) e, a, g, s, c = improved_cells_to_list(batch, gc, stp, LN, combined) cellids = a gc_test = df_r[gc][cellids] stp_test = df_r[stp][cellids] ln_test = df_r[LN][cellids] gc_stp_test = df_r[combined][cellids] max_test = np.maximum(gc_test, stp_test) modelnames = ['GC', 'STP', 'LN', 'GC + STP', 'Max(GC,STP)'] models = { 'GC': gc_test, 'STP': stp_test, 'LN': ln_test, 'GC + STP': gc_stp_test, 'Max(GC,STP)': max_test } array = np.ndarray(shape=(len(modelnames), len(modelnames)), dtype=float) for i, m_one in enumerate(modelnames): for j, m_two in enumerate(modelnames): # get series of values corresponding to selected measure # for each model series_one = models[m_one] series_two = models[m_two] # TODO: no reason to convert these to lists anymore? first = series_one.tolist() second = series_two.tolist() if j != i: w, p = st.wilcoxon(first, second) if j == i: # if indices equal, on diagonal so no comparison array[i][j] = 0.00 elif j > i: # if j is larger, below diagonal so get mean difference array[i][j] = w else: # if j is smaller, above diagonal so run t-test and # get p-value array[i][j] = p xticks = range(len(modelnames)) yticks = xticks minor_xticks = np.arange(-0.5, len(modelnames), 1) minor_yticks = np.arange(-0.5, len(modelnames), 1) fig = plt.figure(figsize=(12, 12)) ax = plt.gca() # ripped from stackoverflow. adds text labels to the grid # at positions i,j (model x model) with text z (value of array at i, j) for (i, j), z in np.ndenumerate(array): if j == i: color = "#EBEBEB" elif j > i: color = "#368DFF" else: if array[i][j] < 0.001: color = "#74E572" elif array[i][j] < 0.01: color = "#59AF57" elif array[i][j] < 0.05: color = "#397038" else: color = "#ABABAB" ax.add_patch( mpatch.Rectangle( xy=(j - 0.5, i - 0.5), width=1.0, height=1.0, angle=0.0, facecolor=color, edgecolor='black', )) if j == i: # don't draw text for diagonal continue # formatting = '{:.04f}' # if z <= 0.0001: # formatting = '{:.2E}' formatting = '{:.2E}' ax.text( j, i, formatting.format(z), ha='center', va='center', ) ax.set_ylabel('') ax.set_xlabel('') ax.set_yticks(yticks) ax.set_yticklabels(modelnames, fontsize=10) ax.set_xticks(xticks) ax.set_xticklabels(modelnames, fontsize=10, rotation="vertical") ax.set_yticks(minor_yticks, minor=True) ax.set_xticks(minor_xticks, minor=True) ax.grid(b=False) ax.grid(which='minor', color='b', linestyle='-', linewidth=0.75) title = "Wilcoxon Signed Test\nOnly improvements?: %s" % only_improvements ax.set_title(title, ha='center', fontsize=14) if include_legend: blue_patch = mpatch.Patch(color='#368DFF', label='W statistic', edgecolor='black') p001_patch = mpatch.Patch(color='#74E572', label='P < 0.001', edgecolor='black') p01_patch = mpatch.Patch(color='#59AF57', label='P < 0.01', edgecolor='black') p05_patch = mpatch.Patch(color='#397038', label='P < 0.05', edgecolor='black') nonsig_patch = mpatch.Patch( color='#ABABAB', label='Not Significant', edgecolor='black', ) plt.legend( #bbox_to_anchor=(0., 1.02, 1., .102), ncol=2, bbox_to_anchor=(1.05, 1), ncol=1, loc=2, handles=[ p05_patch, p01_patch, p001_patch, nonsig_patch, blue_patch, ]) plt.tight_layout() return fig
def equivalence_effect_size(batch, gc, stp, LN, combined, se_filter=True, LN_filter=False, save_path=None, load_path=None, test_limit=None, only_improvements=False, legend=False, effect_key='performance_effect', equiv_key='partial_corr', enable_hover=False, plot_stat='r_ceiling', highlight_cells=None): e, a, g, s, c = improved_cells_to_list(batch, gc, stp, LN, combined, se_filter=se_filter, LN_filter=LN_filter) _, cellids, _, _, _ = improved_cells_to_list(batch, gc, stp, LN, combined, se_filter=se_filter, LN_filter=LN_filter, as_lists=False) if load_path is None: equivs = [] partials = [] gcs = [] stps = [] effects = [] for cell in a[:test_limit]: xf1, ctx1 = xhelp.load_model_xform(cell, batch, gc) xf2, ctx2 = xhelp.load_model_xform(cell, batch, stp) xf3, ctx3 = xhelp.load_model_xform(cell, batch, LN) gc_pred = ctx1['val'].apply_mask()['pred'].as_continuous() stp_pred = ctx2['val'].apply_mask()['pred'].as_continuous() ln_pred = ctx3['val'].apply_mask()['pred'].as_continuous() ff = np.isfinite(gc_pred) & np.isfinite(stp_pred) & np.isfinite( ln_pred) gcff = gc_pred[ff] stpff = stp_pred[ff] lnff = ln_pred[ff] C = np.hstack((np.expand_dims(gcff, 0).transpose(), np.expand_dims(stpff, 0).transpose(), np.expand_dims(lnff, 0).transpose())) partials.append(partial_corr(C)[0, 1]) equivs.append(np.corrcoef(gcff - lnff, stpff - lnff)[0, 1]) this_gc = np.corrcoef(gcff, lnff)[0, 1] this_stp = np.corrcoef(stpff, lnff)[0, 1] gcs.append(this_gc) stps.append(this_stp) effects.append(1 - 0.5 * (this_gc + this_stp)) df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined) if plot_stat == 'r_ceiling': plot_df = df_c else: plot_df = df_r models = [gc, stp, LN] gc_rel_all, stp_rel_all = _relative_score(plot_df, models, a) results = { 'cellid': a[:test_limit], 'equivalence': equivs, 'effect_size': effects, 'corr_gc_LN': gcs, 'corr_stp_LN': stps, 'partial_corr': partials, 'performance_effect': 0.5 * (gc_rel_all[:test_limit] + stp_rel_all[:test_limit]) } df = pd.DataFrame.from_dict(results) df.set_index('cellid', inplace=True) if save_path is not None: df.to_pickle(save_path) else: df = pd.read_pickle(load_path) df = df[cellids] equivalence = df[equiv_key].values effect_size = df[effect_key].values r, p = st.pearsonr(effect_size, equivalence) improved = c not_improved = list(set(a) - set(c)) fig1, ax = plt.subplots(1, 1) ax.axes.axhline(0, color='black', linewidth=1, linestyle='dashed', dashes=dash_spacing) ax_remove_box(ax) extra_title_lines = [] if only_improvements: equivalence_imp = df[equiv_key][improved].values equivalence_not = df[equiv_key][not_improved].values effectsize_imp = df[effect_key][improved].values effectsize_not = df[effect_key][not_improved].values r_imp, p_imp = st.pearsonr(effectsize_imp, equivalence_imp) r_not, p_not = st.pearsonr(effectsize_not, equivalence_not) n_imp = len(improved) n_not = len(not_improved) lines = [ "improved cells, r: %.4f, p: %.4E" % (r_imp, p_imp), "not improved, r: %.4f, p: %.4E" % (r_not, p_not) ] extra_title_lines.extend(lines) # plt.scatter(effectsize_not, equivalence_not, s=small_scatter, # color=model_colors['LN'], label='no imp') plt.scatter(effectsize_imp, equivalence_imp, s=big_scatter, color=model_colors['max'], label='sig. imp.') if enable_hover: mplcursors.cursor(ax, hover=True).connect( "add", lambda sel: sel.annotation.set_text(improved[sel.target. index])) if legend: plt.legend() else: plt.scatter(effect_size, equivalence, s=big_scatter, color='black') if enable_hover: mplcursors.cursor(ax, hover=True).connect( "add", lambda sel: sel.annotation.set_text(a[sel.target.index])) if highlight_cells is not None: highlights = [df[df.index == h] for h in highlight_cells] equiv_highlights = [df[equiv_key].values for df in highlights] effect_highlights = [df[effect_key].values for df in highlights] for eq, eff in zip(equiv_highlights, effect_highlights): plt.scatter(eff, eq, s=big_scatter * 10, facecolors='none', edgecolors='black', linewidths=1) #plt.ylabel('Equivalence: CC(GC-LN, STP-LN)') #plt.xlabel('Effect size: 1 - 0.5*(CC(GC,LN) + CC(STP,LN))') if equiv_key == 'equivalence': y_text = 'equivalence, CC(GC-LN, STP-LN)\n' elif equiv_key == 'partial_corr': y_text = 'equivalence, partial correlation\n' else: y_text = 'unknown equivalence key' if effect_key == 'effect_size': x_text = 'effect size: 1 - 0.5*(CC(GC,LN) + CC(STP,LN))' elif effect_key == 'performance_effect': x_text = 'effect size: 0.5*(rGC-rLN + rSTP-rLN)' else: x_text = 'unknown effect key' text = ("scatter: Equivalence of Change to Predicted PSTH\n" "batch: %d\n" "vs Effect Size\n" "all cells, r: %.4f, p: %.4E\n" "y: %s" "x: %s" % (batch, r, p, y_text, x_text)) for ln in extra_title_lines: text += "\n%s" % ln plt.tight_layout() # fig2 = plt.figure() # md = np.nanmedian(equivalence) # n_cells = equivalence.shape[0] # plt.hist(equivalence, bins=30, range=[-0.5, 1], histtype='bar', # color=model_colors['combined'], edgecolor='black', linewidth=1) # plt.plot(np.array([0,0]), np.array(fig2.axes[0].get_ylim()), 'k--', # linewidth=2, dashes=dash_spacing) # plt.text(0.05, 0.95, 'n = %d\nmd = %.2f' % (n_cells, md), # ha='left', va='top', transform=fig2.axes[0].transAxes) #plt.xlabel('CC, GC-LN vs STP-LN') #plt.title('Equivalence of Change in Prediction Relative to LN Model') fig3 = plt.figure() plt.text(0.1, 0.75, text, va='top') #plt.text(0.1, 0.25, text2, va='top') return fig1, fig3