def load_snr_and_cellids(a1, peg, a1_snr_path, peg_snr_path): a1_significant_cells = get_significant_cells(a1, SIG_TEST_MODELS, as_list=True) a1_results = nd.batch_comp(batch=a1, modelnames=SIG_TEST_MODELS, stat=PLOT_STAT, cellids=a1_significant_cells) a1_index = a1_results.index.values peg_significant_cells = get_significant_cells(peg, SIG_TEST_MODELS, as_list=True) peg_results = nd.batch_comp(batch=peg, modelnames=SIG_TEST_MODELS, stat=PLOT_STAT, cellids=peg_significant_cells) peg_index = peg_results.index.values a1_snr_df = snr_by_batch(a1, 'ozgf.fs100.ch18', load_path=a1_snr_path).loc[a1_index] peg_snr_df = snr_by_batch(peg, 'ozgf.fs100.ch18', load_path=peg_snr_path).loc[peg_index] a1_snr = a1_snr_df.values.flatten() a1_cellids = a1_snr_df.index peg_snr = peg_snr_df.values.flatten() peg_idx = np.argsort(peg_snr_df.values, axis=None) peg_snr_sample = peg_snr[peg_idx] peg_cellids = peg_snr_df.index[peg_idx].values return a1_snr, peg_snr_sample, a1_snr_df, peg_snr_df, a1_cellids, peg_cellids
def get_significant_cells(batch, models, as_list=False): df_r = nd.batch_comp(batch, models, stat='r_test') df_r.dropna(axis=0, how='any', inplace=True) df_r.sort_index(inplace=True) df_e = nd.batch_comp(batch, models, stat='se_test') df_e.dropna(axis=0, how='any', inplace=True) df_e.sort_index(inplace=True) df_f = nd.batch_comp(batch, models, stat='r_floor') df_f.dropna(axis=0, how='any', inplace=True) df_f.sort_index(inplace=True) masks = [] for m in models: mask1 = df_r[m] > df_e[m] * 2 mask2 = df_r[m] > df_f[m] mask = mask1 & mask2 masks.append(mask) all_significant = masks[0] for m in masks[1:]: all_significant &= m if as_list: all_significant = all_significant[all_significant].index.values.tolist( ) return all_significant
def plot_heldout_a1_vs_peg(a1_snr_path, peg_snr_path, ax=None): if ax is None: _, ax = plt.subplots(1, 1, figsize=(single_column_short)) a1_snr, peg_snr, a1_snr_df, peg_snr_df, a1_cellids, peg_cellids = load_snr_and_cellids( 322, 323, a1_snr_path, peg_snr_path) cell_map_df = get_matched_snr_mapping(a1_snr, peg_snr, a1_snr_df, peg_snr_df, a1_cellids, peg_cellids) #a1_matched_cellids = cell_map_df.A1_cellid peg_matched_cellids = cell_map_df.PEG_cellid #a1_r = nd.batch_comp(322, HELDOUT[:1], stat=PLOT_STAT).loc[a1_matched_cellids] peg_r = nd.batch_comp(323, [HELDOUT_CROSSBATCH], stat=PLOT_STAT).loc[peg_matched_cellids] peg_r2 = nd.batch_comp(323, HELDOUT[:1], stat=PLOT_STAT).loc[peg_matched_cellids] #test_crossbatch = st.wilcoxon(a1_r.values.flatten(), peg_r.values.flatten(), alternative='two-sided') test_within_peg = st.wilcoxon(peg_r.values.flatten(), peg_r2.values.flatten(), alternative='two-sided') #a1_r = a1_r.rename(columns={HELDOUT[0]: 'A1'}) peg_r = peg_r.rename(columns={HELDOUT_CROSSBATCH: 'cross-batch held'}) peg_r2 = peg_r2.rename(columns={HELDOUT[0]: 'PEG held'}) combined_r = peg_r.join(peg_r2, how='outer') #combined_df = combined_r.join(peg_snr_df, how='left') #combined_df = combined_df.fillna(peg_snr_df.loc[peg_matched_cellids]) #combined_df[PLOT_STAT] = np.nan #combined_df.r_ceiling.loc[a1_matched_cellids] = combined_df.A1.loc[a1_matched_cellids] #combined_df.r_ceiling.loc[peg_matched_cellids] = combined_df.PEG.loc[peg_matched_cellids] #combined_df['variable'] = 'A1' #combined_df['variable'].loc[peg_matched_cellids] = 'PEG' #combined_df = combined_df.drop(columns='A1') #combined_df = combined_df.drop(columns='PEG') #combined_df = combined_df.sort_values(by='snr') #combined_df = combined_df.reset_index(0) #sns.stripplot(x='variable', y=PLOT_STAT, data=combined_df, color=DOT_COLORS['1D CNNx2'], #hue='snr', palette=f'dark:{DOT_COLORS["1D CNNx2"]}', #size=2, ax=ax, zorder=0) sns.stripplot(data=combined_r, color=DOT_COLORS['1Dx2-CNN'], size=2, ax=ax, zorder=0) plt.xticks(rotation=45, fontsize=6, ha='right') sns.boxplot(data=combined_r, boxprops={ 'facecolor': 'None', 'linewidth': 1 }, showcaps=False, showfliers=False, whiskerprops={'linewidth': 0}, ax=ax) #ax.legend_.remove() # tried to color by snr but it's just really hard to see, and no pattern anyway plt.tight_layout() return peg_r, peg_r2, test_within_peg
def plot_pred_scatter(batch, modelnames, labels=None, colors=None, ax=None): if ax is None: fig, ax = plt.subplots() else: fig = ax.figure if labels is None: labels = ['model 1', 'model 2'] if colors is None: colors = ['black', 'black'] significant_cells = get_significant_cells(batch, SIG_TEST_MODELS, as_list=True) sig_scores = nd.batch_comp(batch, modelnames, cellids=significant_cells, stat='r_test') se_scores = nd.batch_comp(batch, modelnames, cellids=significant_cells, stat='se_test') ceiling_scores = nd.batch_comp(batch, modelnames, cellids=significant_cells, stat=PLOT_STAT) nonsig_cells = list(set(sig_scores.index) - set(significant_cells)) # figure out units with significant differences between models sig = (sig_scores[modelnames[1]] - se_scores[modelnames[1]] > sig_scores[modelnames[0]] + se_scores[modelnames[0]]) | \ (sig_scores[modelnames[0]] - se_scores[modelnames[0]] > sig_scores[modelnames[1]] + se_scores[modelnames[1]]) group1 = (ceiling_scores.loc[~sig, modelnames[0]].values, ceiling_scores.loc[~sig, modelnames[1]].values) group2 = (ceiling_scores.loc[sig, modelnames[0]].values, ceiling_scores.loc[sig, modelnames[1]].values) n_nonsig = group1[0].size n_sig = group2[0].size scatter_groups([group1, group2], ['lightgray', 'black'], ax=ax, labels=['N.S.', 'p < 0.5']) #ax.set_title(f'{batch_str[batch]} {PLOT_STAT}') ax.set_xlabel( f'{labels[0]}\n(median r={ceiling_scores[modelnames[0]].median():.3f})', color=colors[0]) ax.set_ylabel( f'{labels[1]}\n(median r={ceiling_scores[modelnames[1]].median():.3f})', color=colors[1]) return fig, n_sig, n_nonsig
def plot_conv_scatters(batch): significant_cells = get_significant_cells(batch, SIG_TEST_MODELS, as_list=True) sig_scores = nd.batch_comp(batch, ALL_FAMILY_MODELS, cellids=significant_cells, stat='r_test') se_scores = nd.batch_comp(batch, ALL_FAMILY_MODELS, cellids=significant_cells, stat='se_test') ceiling_scores = nd.batch_comp(batch, ALL_FAMILY_MODELS, cellids=significant_cells, stat=PLOT_STAT) nonsig_cells = list(set(sig_scores.index) - set(significant_cells)) fig, ax = plt.subplots(1, 3, figsize=(12, 4)) # LN vs DNN-single plot_pred_scatter(batch, [ALL_FAMILY_MODELS[3], ALL_FAMILY_MODELS[4]], labels=['1Dx2-CNN', 'pop-LN'], ax=ax[0]) # LN vs conv1dx2+d plot_pred_scatter(batch, [ALL_FAMILY_MODELS[3], ALL_FAMILY_MODELS[2]], labels=['1Dx2-CNN', 'pop-LN'], ax=ax[0]) # conv2d vs conv1dx2+d sig = (sig_scores[ALL_FAMILY_MODELS[0]] - se_scores[ALL_FAMILY_MODELS[0]] > sig_scores[ALL_FAMILY_MODELS[2]] + se_scores[ALL_FAMILY_MODELS[2]]) | \ (sig_scores[ALL_FAMILY_MODELS[2]] - se_scores[ALL_FAMILY_MODELS[2]] > sig_scores[ALL_FAMILY_MODELS[0]] + se_scores[ALL_FAMILY_MODELS[0]]) group3 = (ceiling_scores.loc[~sig, ALL_FAMILY_MODELS[0]].values, ceiling_scores.loc[~sig, ALL_FAMILY_MODELS[2]].values) group4 = (ceiling_scores.loc[sig, ALL_FAMILY_MODELS[0]].values, ceiling_scores.loc[sig, ALL_FAMILY_MODELS[2]].values) scatter_groups([group3, group4], ['lightgray', 'black'], ax=ax[2]) ax[2].set_title('batch %d, %s, conv2d vs conv1dx2+d' % (batch, PLOT_STAT)) ax[2].set_xlabel( f'DNN (2D conv) pred. correlation ({ceiling_scores[ALL_FAMILY_MODELS[0]].mean():.3f})', color=colors[0]) ax[2].set_ylabel( f'DNN (1D conv) pred. correlation ({ceiling_scores[ALL_FAMILY_MODELS[2]].mean():.3f})', color=colors[1]) plt.tight_layout() return fig
def pop_selector(recording_uri_list, batch=None, cellid=None, rand_match=False, cell_count=20, best_cells=False, **context): rec = load_recording(recording_uri_list[0]) all_cells = rec.meta['cellid'] this_site = cellid cellid = [c for c in all_cells if c.split("-")[0] == this_site] site_cellid = cellid.copy() pmodelname = "ozgf.fs100.ch18-ld-sev_dlog-wc.18x3.g-fir.3x15-lvl.1-dexp.1_init-basic" single_perf = nd.batch_comp(batch=batch, modelnames=[pmodelname], cellids=all_cells, stat='r_test') this_perf = np.array([ single_perf[single_perf.index == c][pmodelname].values[0] for c in cellid ]) sidx = np.argsort(this_perf) if best_cells: keepidx = (this_perf >= this_perf[sidx[-cell_count]]) cellid = list(np.array(cellid)[keepidx]) this_perf = this_perf[keepidx] else: cellid = cellid[:cell_count] this_perf = this_perf[:cell_count] if rand_match: out_cellid = [c for c in all_cells if c.split("-")[0] != this_site] out_perf = np.array([ single_perf[single_perf.index == c][pmodelname].values[0] for c in out_cellid ]) alt_cellid = [] alt_perf = [] for i, c in enumerate(cellid): d = np.abs(out_perf - this_perf[i]) w = np.argmin(d) alt_cellid.append(out_cellid[w]) alt_perf.append(out_perf[w]) out_perf[w] = 100 # prevent cell from getting picked again log.info("Rand matched cellids: %s", alt_cellid) log.info("Mean actual: %.3f", np.mean(this_perf)) print(this_perf) log.info("Mean rand: %.3f", np.mean(np.array(alt_perf))) print(np.array(alt_perf)) rec['resp'] = rec['resp'].extract_channels(alt_cellid) rec.meta['cellid'] = cellid else: rec['resp'] = rec['resp'].extract_channels(cellid) rec.meta['cellid'] = cellid return {'rec': rec}
def get_population_cellids(batch, modelname=None): if modelname is None: modelname = ref_modelname def get_siteid(s): return s.split("-")[0] all_siteids = {} rep_cellids = {} d = nd.batch_comp(batch=batch, modelnames=[modelname]) d['siteid'] = d.index.map(get_siteid) siteids = list(set(d['siteid'].tolist())) all_siteids[batch] = siteids site_cellids = [ d.loc[d.index.str.startswith(s)].index.values[0] for s in siteids ] site_cellids.sort() rep_cellids[batch] = site_cellids return all_siteids, rep_cellids
def get_dataframes(batch, gc, stp, LN, combined): df_r = nd.batch_comp(batch, [gc, stp, LN, combined], stat='r_test') df_c = nd.batch_comp(batch, [gc, stp, LN, combined], stat='r_ceiling') df_e = nd.batch_comp(batch, [gc, stp, LN, combined], stat='se_test') # Remove any cellids that have NaN for 1 or more models # and sort indexes, double check that all are equal df_r.dropna(axis=0, how='any', inplace=True) df_e.dropna(axis=0, how='any', inplace=True) df_c.dropna(axis=0, how='any', inplace=True) df_r.sort_index(inplace=True) df_e.sort_index(inplace=True) df_c.sort_index(inplace=True) if (not np.all(df_r.index.values == df_e.index.values)) \ or (not np.all(df_r.index.values == df_c.index.values)): raise ValueError('index mismatch in dataframes') return df_r, df_c, df_e
def get_rceiling_correction(batch): LN_model = MODELGROUPS['LN'][3] rceiling_ratios = [] significant_cells = get_significant_cells(batch, SIG_TEST_MODELS, as_list=True) rtest = nd.batch_comp(batch, [LN_model], cellids=significant_cells, stat='r_test') rceiling = nd.batch_comp(batch, [LN_model], cellids=significant_cells, stat='r_ceiling') rceiling_ratios = rceiling[LN_model] / rtest[LN_model] rceiling_ratios.loc[rceiling_ratios < 1] = 1 return rceiling_ratios
def get_filtered_cellids(batch, gc, stp, LN, combined, se_filter=True, LN_filter=False, as_lists=True): df_r, df_c, df_e = get_dataframes(batch, gc, stp, LN, combined) df_f = nd.batch_comp(batch, [gc, stp, LN, combined], stat='r_floor') df_f.dropna(axis=0, how='any', inplace=True) df_f.sort_index(inplace=True) if not np.all(df_f.index.values == df_r.index.values): raise ValueError('index mismatch in dataframes') cellids = df_r.index.values.tolist() gc_test, gc_se, gc_floor = [d[gc] for d in [df_r, df_e, df_f]] stp_test, stp_se, stp_floor = [d[stp] for d in [df_r, df_e, df_f]] ln_test, ln_se, ln_floor = [d[LN] for d in [df_r, df_e, df_f]] gc_stp_test, gc_stp_se, gc_stp_floor = [ d[combined] for d in [df_r, df_e, df_f] ] if se_filter: # Remove if performance not significant at all good_cells = ((gc_test > gc_se * 2) & (gc_test > gc_floor) & (stp_test > stp_se * 2) & (stp_test > stp_floor) & (ln_test > ln_se * 2) & (ln_test > ln_floor) & (gc_stp_test > gc_stp_se * 2) & (gc_stp_test > gc_stp_floor)) else: # Set to series w/ all True, so none are skipped good_cells = (gc_test != np.nan) if LN_filter: # Remove if performance significantly worse than LN bad_cells = ((gc_test + gc_se < ln_test - ln_se) | (stp_test + stp_se < ln_test - ln_se) | (gc_stp_test + gc_stp_se < ln_test - ln_se)) else: # Set to series w/ all False, so none are skipped bad_cells = (gc_test == np.nan) keep = good_cells & ~bad_cells if as_lists: cellids = df_r[keep].index.values.tolist() under_chance = df_r[~good_cells].index.values.tolist() less_LN = df_r[bad_cells].index.values.tolist() return cellids, under_chance, less_LN else: return keep, ~good_cells, bad_cells
def scatter_titan(batch): PLOT_STAT = 'r_test' TITAN_MODEL = 'ozgf.fs100.ch18.pop-loadpop-norm.l1-popev_wc.18x70.g-fir.1x15x70-relu.70.f-wc.70x80-fir.1x10x80-relu.80.f-wc.80x100-relu.100-wc.100xR-lvl.R-dexp.R_tfinit.n.mc50.lr1e3.es20-newtf.n.mc100.lr1e4.exa' REFERENCE_MODEL = 'ozgf.fs100.ch18.pop-loadpop-norm.l1-popev_wc.18x70.g-fir.1x15x70-relu.70.f-wc.70x80-fir.1x10x80-relu.80.f-wc.80x100-relu.100-wc.100xR-lvl.R-dexp.R_tfinit.n.lr1e3.et3.es20-newtf.n.lr1e4' MODELS = [REFERENCE_MODEL, TITAN_MODEL] significant_cells = get_significant_cells(batch, MODELS, as_list=True) sig_scores = nd.batch_comp(batch, MODELS, cellids=significant_cells, stat='r_test') se_scores = nd.batch_comp(batch, MODELS, cellids=significant_cells, stat='se_test') ceiling_scores = nd.batch_comp(batch, MODELS, cellids=significant_cells, stat=PLOT_STAT) nonsig_cells = list(set(sig_scores.index) - set(significant_cells)) fig, ax = plt.subplots(1, 1, figsize=(4, 4)) # LN vs DNN-single sig = (sig_scores[MODELS[1]] - se_scores[MODELS[1]] > sig_scores[MODELS[0]] + se_scores[MODELS[0]]) | \ (sig_scores[MODELS[0]] - se_scores[MODELS[0]] > sig_scores[MODELS[1]] + se_scores[MODELS[1]]) group1 = (ceiling_scores.loc[~sig, MODELS[0]].values, ceiling_scores.loc[~sig, MODELS[1]].values) group2 = (ceiling_scores.loc[sig, MODELS[0]].values, ceiling_scores.loc[sig, MODELS[1]].values) scatter_groups([group1, group2], ['lightgray', 'black'], ax=ax) ax.set_title('batch %d, %s, Conv1d vs Titan' % (batch, PLOT_STAT)) ax.set_xlabel(f'Single stim ({ceiling_scores[MODELS[0]].mean():.3f})', color='orange') ax.set_ylabel(f'Titan ({ceiling_scores[MODELS[1]].mean():.3f})', color='lightgreen') return fig
def get_heldout_results(batch, significant_cells, short_names): modelgroups = {} for i, name in enumerate(short_names): modelgroups[name + ' held'] = HELDOUT[i] modelgroups[name + ' match'] = MATCHED[i] r_ceilings = {} for n in short_names: # sum is just to collapse cols heldout_r_ceiling = nd.batch_comp(batch, [modelgroups[n + ' held']], cellids=significant_cells, stat=PLOT_STAT)[modelgroups[n + ' held']] matched_r_ceiling = nd.batch_comp( batch, [modelgroups[n + ' match']], cellids=significant_cells, stat=PLOT_STAT)[modelgroups[n + ' match']] r_ceilings[n + ' held'] = heldout_r_ceiling r_ceilings[n + ' match'] = matched_r_ceiling return r_ceilings
def equivalence_effect_size(batch, models, performance_stat='r_ceiling', manual_cellids=None): df = nd.batch_comp(batch=batch, modelnames=models, stat=performance_stat, cellids=manual_cellids) stats = [df[model].values for model in models] rel_1 = stats[0] - stats[2] rel_2 = stats[1] - stats[2] effect_sizes = 0.5*(rel_1 + rel_2) results = {'performance_effect': effect_sizes, 'cellid': df.index.values} df = pd.DataFrame.from_dict(results) df.set_index('cellid', inplace=True) return df
def count_fits(models, batch=None): if batch is None: batches = [322, 323] else: batches = [batch] for batch in batches: df_r = nd.batch_comp(batch, models, stat='r_test') #print(f"BATCH {batch}:") for i, c in enumerate(df_r.columns): parts = c.split("_") if i == 0: print(f"BATCH {batch} -- {parts[0]}_XX_{parts[2]}") print(f'{parts[1]}: {df_r[c].count()}')
def scatter_bar(batches, modelnames, stest=SIG_TEST_MODELS, axes=None): if axes is None: fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(3.5, 6)) else: ax1, ax2 = axes cellids = [ get_significant_cells(batch, stest, as_list=True) for batch in batches ] r_values = [ nd.batch_comp(batch, modelnames, cellids=cells, stat=PLOT_STAT) for batch, cells in zip(batches, cellids) ] all_r_values = pd.concat(r_values) # NOTE: if ALL_FAMILY_MODELS changes, the 3, 2 indices will need to be updated # Scatter Plot -- LN vs c1dx2+d ax1.scatter(all_r_values[modelnames[3]], all_r_values[modelnames[2]], s=2, c='black') ax1.plot([[0, 0], [1, 1]], c='black', linestyle='dashed', linewidth=1) ax_remove_box(ax1) #set_equal_axes(ax1) ax1.set_ylim(0, 1) ax1.set_xlim(0, 1) ax1.set_xlabel('LN_pop prediction accuracy') ax1.set_ylabel('conv1Dx2 prediction accuracy') # Bar Plot -- Median for each model # NOTE: ordering of names is assuming ALL_FAMILY_MODELS is being used and has not changed. short_names = ['conv2d', 'conv1d', 'conv1dx2+d', 'LN_pop', 'dnn1_single'] bar_colors = [DOT_COLORS[k] for k in short_names] ax2.bar(np.arange(0, len(modelnames)), all_r_values.median(axis=0).values, color=bar_colors, tick_label=short_names) ax_remove_box(ax2) ax2.set_ylabel('Median Prediction Accuracy') ax2.set_xticklabels(ax2.get_xticklabels(), rotation='45', ha='right') fig = plt.gcf() fig.tight_layout() return ax1, ax2
def bar_mean(batch, modelnames, stest=SIG_TEST_MODELS, ax=None): if ax is None: fig, ax = plt.subplots() else: fig = ax.figure cellids = get_significant_cells(batch, stest, as_list=True) r_values = nd.batch_comp(batch, modelnames, cellids=cellids, stat=PLOT_STAT) # Bar Plot -- Median for each model # NOTE: ordering of names is assuming ALL_FAMILY_MODELS is being used and has not changed. bar_colors = [DOT_COLORS[k] for k in shortnames] medians = r_values.median(axis=0).values ax.bar(np.arange(0, len(modelnames)), medians, color=bar_colors, edgecolor='black', linewidth=1, tick_label=shortnames) ax.set_ylabel('Median prediction\ncorrelation') ax.set_xticklabels(ax.get_xticklabels(), rotation='45', ha='right') # Test significance for all comparisons stats_results = {} reduced_modelnames = modelnames.copy() reduced_shortnames = shortnames.copy() for m1, s1 in zip(modelnames, shortnames): i = 0 reduced_modelnames.pop(i) reduced_shortnames.pop(i) i += 1 # compare each model to every other model for m2, s2 in zip(reduced_modelnames, reduced_shortnames): stats_test = st.wilcoxon(r_values[m1], r_values[m2], alternative='two-sided') #stats_test = arrays_to_p(r_values[m1], r_values[m2], cellids, twosided=True) key = f'{s1} vs {s2}' stats_results[key] = stats_test return ax, medians, stats_results
def _matching_cells(batch=289, siteid=None, alt_cells_available=None, cell_count=None, best_cells=False): pmodelname = "ozgf.fs100.ch18-ld-sev_dlog-wc.18x3.g-fir.3x15-lvl.1-dexp.1_init-basic" single_perf = nd.batch_comp(batch=batch, modelnames=[pmodelname], stat='r_test') if alt_cells_available is not None: all_cells = alt_cells_available else: all_cells = list(single_perf.index) cellid = [c for c in all_cells if c.split("-")[0]==siteid] this_perf=np.array([single_perf[single_perf.index==c][pmodelname].values[0] for c in cellid]) if cell_count is None: pass elif best_cells: sidx = np.argsort(this_perf) keepidx=(this_perf >= this_perf[sidx[-cell_count]]) cellid=list(np.array(cellid)[keepidx]) this_perf = this_perf[keepidx] else: cellid=cellid[:cell_count] this_perf = this_perf[:cell_count] out_cellid = [c for c in all_cells if c.split("-")[0]!=siteid] out_perf=np.array([single_perf[single_perf.index==c][pmodelname].values[0] if c in single_perf.index else 0.0 for c in out_cellid]) alt_cellid=[] alt_perf=[] for i, c in enumerate(cellid): d = np.abs(out_perf-this_perf[i]) w = np.argmin(d) alt_cellid.append(out_cellid[w]) alt_perf.append(out_perf[w]) out_perf[w]=100 # prevent cell from getting picked again log.info("Rand matched cellids: %s", alt_cellid) log.info("Mean actual: %.3f", np.mean(this_perf)) log.info(this_perf) log.info("Mean rand: %.3f", np.mean(np.array(alt_perf))) log.info(np.array(alt_perf)) return cellid, this_perf, alt_cellid, alt_perf
def run_fun(self): cellids = self._selected_cells() modelnames = self._selected_modelnames() batch = self.batch d = nd.batch_comp(batch=batch, modelnames=modelnames, cellids=cellids, stat='r_test') cellids = d.index siteids = [c.split("-")[0] for c in cellids] d['siteids'] = siteids mean_r_test = d[modelnames].mean().values shortened, prefix, suffix = find_common(modelnames) modelcount=len(modelnames) plt.figure() ax = plt.subplot(1,1,1) site_r_test = d.groupby(['siteids']).mean().values.T usiteids=d.groupby(['siteids']).groups.keys() ax.bar(np.arange(len(mean_r_test)), mean_r_test, color='lightgray') ax.plot(site_r_test) print('{} -- {}'.format(prefix, suffix)) for i,m in enumerate(modelnames): print("{} (n={}): {:.3f}".format(shortened[i], d[m].count(), d[m].mean())) r = d[m].values #plt.plot(np.random.uniform(low=-0.25, high=0.25, size=r.shape)+i, # r, '.', color='gray') s = shortened[i].replace("_","\n") + "\n{:.3f}".format(mean_r_test[i]) ax.text(i, 0, s, rotation=90, color='black', ha='left', va='bottom', fontsize=7) for i, s in enumerate(usiteids): ax.text(modelcount-0.2, site_r_test[-1,i], s) plt.title('{} -- {}'.format(prefix, suffix)) ax.set_xticks(np.arange(modelcount)) ax.set_xticklabels([]) nplt.ax_remove_box(ax)
def load_models(cell, batch, models, check_db=True, site=None, site_cache=None): '''Load standardized psth from each model, error if not all fits exist.''' if check_db: # Before trying to load, check database to see if a result exists. # Should set False if you know model results are not stored in DB, # but exist in file storage. df = nd.batch_comp(batch=batch, modelnames=models, cellids=[cell]) if np.sum(df.isna().values) > 0: # at least one cell wasn't fit (or at least not stored in db) # so skip trying to load any of them since all are required. raise ValueError('Not all results exist for: %s, %d' % (cell, batch)) # Load all models ctxs = [] for model in models: if site_cache is None: xf, ctx = xhelp.load_model_xform(cell, batch, model) elif model in site_cache[site]: log.info("Site %s is cached, skipping load...", site) ctx = site_cache[site][model] else: xf, ctx = xhelp.load_model_xform(cell, batch, model) site_cache[site][model] = ctx ctxs.append(ctx) for ctx in ctxs: if ctx['val']['pred'].chans is None: ctx['val']['pred'].chans = copy(ctx['val']['resp'].chans) # Pull out model predictions and remove times with nan for at least 1 model preds = [ctx['val'].apply_mask()['pred'].extract_channels([cell]).as_continuous() for ctx in ctxs] ff = np.isfinite(preds[0]) for pred in preds[1:]: ff &= np.isfinite(pred) no_nans = [pred[ff] for pred in preds] return no_nans
def stp_v_beh(): batch1 = 274 batch2 = 275 modelnames=["env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-fir.1x15-lvl.1-dexp.1_jk.nf5-init.st-basic", "env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-fir.1x15-lvl.1-rep.2-dexp.2-mrg_jk.nf5-init.st-basic", "env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-rep.2-fir.1x15x2-lvl.2-dexp.2-mrg_jk.nf5-init.st-basic", "env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-stp.1-fir.1x15-lvl.1-dexp.1_jk.nf5-init.st-basic", "env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-stp.1-fir.1x15-lvl.1-rep.2-dexp.2-mrg_jk.nf5-init.st-basic", "env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-stp.1-rep.2-fir.1x15x2-lvl.2-dexp.2-mrg_jk.nf5-init.st-basic", "env.fs100-ld-st.beh-ref_dlog.f-wc.2x1.c-rep.2-stp.2-fir.1x15x2-lvl.2-dexp.2-mrg_jk.nf5-init.st-basic"] fileprefix="fig8.stp_v_beh" n1=modelnames[0] n2=modelnames[-3] xc_range = [-0.05, 0.6] df1 = nd.batch_comp(batch1,modelnames,stat='r_test').reset_index() df1_e = nd.batch_comp(batch1,modelnames,stat='se_test').reset_index() df2 = nd.batch_comp(batch2,modelnames,stat='r_test').reset_index() df2_e = nd.batch_comp(batch2,modelnames,stat='se_test').reset_index() df = df1.append(df2) df_e = df1_e.append(df2_e) cellcount = len(df) beta1 = df[n1] beta2 = df[n2] beta1_test = df[n1] beta2_test = df[n2] se1 = df_e[n1] se2 = df_e[n2] beta1[beta1>1]=1 beta2[beta2>1]=1 # test for significant improvement improvedcells = (beta2_test-se2 > beta1_test+se1) # test for signficant prediction at all goodcells = ((beta2_test > se2*3) | (beta1_test > se1*3)) fh = plt.figure() ax = plt.subplot(2,2,1) stateplots.beta_comp(beta1[goodcells], beta2[goodcells], n1='LN STRF', n2='STP+BEH LN STRF', hist_range=xc_range, ax=ax, highlight=improvedcells[goodcells]) # LN vs. STP: beta1b = df[modelnames[3]] beta1a = df[modelnames[0]] beta1 = beta1b - beta1a se1a = df_e[modelnames[3]] se1b= df_e[modelnames[0]] b1=4 b0=3 beta2b = df[modelnames[b1]] beta2a = df[modelnames[b0]] beta2 = beta2b - beta2a se2a = df_e[modelnames[b1]] se2b= df_e[modelnames[b0]] stpgood = (beta1 > se1a+se1b) behgood = (beta2 > se2a+se2b) neither_good = np.logical_not(stpgood) & np.logical_not(behgood) both_good = stpgood & behgood stp_only_good = stpgood & np.logical_not(behgood) beh_only_good = np.logical_not(stpgood) & behgood xc_range = np.array([-0.05, 0.15]) beta1[beta1<xc_range[0]]=xc_range[0] beta2[beta2<xc_range[0]]=xc_range[0] zz = np.zeros(2) ax=plt.subplot(2,2,2) ax.plot(xc_range,zz,'k--',linewidth=0.5) ax.plot(zz,xc_range,'k--',linewidth=0.5) ax.plot(xc_range, xc_range, 'k--', linewidth=0.5) l = ax.plot(beta1[neither_good], beta2[neither_good], '.', color='lightgray') +\ ax.plot(beta1[beh_only_good], beta2[beh_only_good], '.', color='purple') +\ ax.plot(beta1[stp_only_good], beta2[stp_only_good], '.', color='orange') +\ ax.plot(beta1[both_good], beta2[both_good], '.', color='black') ax_remove_box(ax) ax.set_aspect('equal', 'box') #plt.axis('equal') ax.set_xlim(xc_range) ax.set_ylim(xc_range) ax.set_xlabel('delta(stp)') ax.set_ylabel('delta(beh)') olap=np.zeros(100) a = stpgood.values.copy() b = behgood.values.copy() for i in range(100): np.random.shuffle(a) olap[i] = np.sum(a & b) ll=[np.sum(neither_good), np.sum(beh_only_good), np.sum(stp_only_good), np.sum(both_good)] ax.legend(l, ll) ax=plt.subplot(2,2,3) m = np.array(df.loc[goodcells].mean()[modelnames]) xc_range = [-0.02, 0.2] plt.bar(np.arange(len(modelnames)), m, color='black') plt.plot(np.array([-1, len(modelnames)]), np.array([0, 0]), 'k--', linewidth=0.5) plt.ylim(xc_range) plt.title("batch {}, n={}/{} good cells".format( batch, np.sum(goodcells), len(goodcells))) plt.ylabel('median pred corr') plt.xlabel('model architecture') ax_remove_box(ax) for i in range(len(modelnames)-1): d1 = np.array(df[modelnames[i]]) d2 = np.array(df[modelnames[i+1]]) s, p = ss.wilcoxon(d1, d2) plt.text(i+0.5, m[i+1], "{:.1e}".format(p), ha='center', fontsize=6) plt.xticks(np.arange(len(m)),np.round(m,3)) return fh, df[stpgood]['cellid'].tolist()
pca_file = '/auto/users/svd/python/scripts/NAT_pop_models/dpc322.csv' waveform_labels = pd.read_csv('phototag_waveform_labels.csv', index_col=0) # waveform_labels = pd.read_csv(pl.Path('/auto/users/mateo/nems_db/nems_lbhb/projects/phototag/phototag_waveform_labels.csv'), index_col=0) runclass = "NAT" sql="select sCellFile.*,gSingleCell.siteid,gSingleCell.phototag from gSingleCell INNER JOIN sCellFile ON gSingleCell.id=sCellFile.singleid" +\ " INNER JOIN gRunClass on gRunClass.id=sCellFile.runclassid" +\ f" WHERE gRunClass.name='{runclass}' AND not(isnull(phototag))" d = nd.pd_query(sql) d['parmfile'] = d['stimpath'] + d['stimfile'] print(f'cell/file combos with phototag labels={len(d)}') dtag = d[['cellid', 'phototag']].groupby('cellid').max() dpred = nd.batch_comp(batch=batch, modelnames=modelnames, stat="r_ceiling") dpred.columns = shortnames dpred['siteid'] = dpred.index dpred['siteid'] = dpred['siteid'].apply(nd.get_siteid) dpred['diff'] = dpred[shortnames[1]] - dpred[shortnames[0]] dpred = dpred.merge(dtag, how='inner', left_index=True, right_index=True) dpred = dpred.merge(waveform_labels[['cellid', 'wshape']], how='left', left_index=True, right_on='cellid').set_index('cellid') #dpred['wshape']=dpred['wshape'].fillna("?") dpred['pw'] = dpred['phototag'] + " " + dpred['wshape'] dpc = pd.read_csv(pca_file, index_col=0) dpc = dpc.loc[dpc['type'] == 'avg']
d = nems_db.params.fitted_params_per_batch( batch, modelname, meta=['r_test', 'r_fit', 'se_test', 'r_ceiling'], stats_keys=[], multi='first') if modelname0 is not None: d0 = nems_db.params.fitted_params_per_batch( batch, modelname0, meta=['r_test', 'r_fit', 'se_test', 'r_ceiling'], stats_keys=[], multi='first') df_r = nd.batch_comp(batch, [modelname0, modelname], stat='r_test') df_e = nd.batch_comp(batch, [modelname0, modelname], stat='se_test') u_bounds = np.array([-0.6, 2.1]) tau_bounds = np.array([-0.1, 1.5]) str_bounds = np.array([-0.25, 0.55]) amp_bounds = np.array([-1, 1.5]) amp0_bounds = np.array([-0.5, 0.75]) indices = list(d.index) for ind in indices: if '--u' in ind: u_index = ind elif '--tau' in ind: tau_index = ind
import numpy as np import matplotlib.pyplot as plt import svdplots batch = 303 modelname1 = "psth20pup0prebeh_stategain4_basic-nf" modelname2 = "psth20pupprebeh_stategain4_basic-nf" bins = 20 range = [-1, 1] n1 = 'pre gain' n2 = 'active gain' i1 = 2 i2 = 3 res = nd.batch_comp(batch, [modelname0, modelname]) d1 = nems_db.params.fitted_params_per_batch(batch, modelname, stats_keys=[]) d2 = nems_db.params.fitted_params_per_batch(batch, modelname2, stats_keys=[]) # parse modelname kws = modelname.split("_") modname = kws[1] statecount = int(modname[-1]) g1 = np.zeros([0, statecount]) b1 = np.zeros([0, statecount]) g2 = np.zeros([0, statecount]) b2 = np.zeros([0, statecount]) sig = np.zeros(len(d1.columns)) i = 0
relative_changes_per_model = [ means[0][i] / means[1][i] for i, _ in enumerate(labels) ] relative_change_pareto = '\n'.join([ f'{n}: {changes.mean()}' for n, changes in zip(labels, relative_changes_per_model) ]) stats_tests.append('\n\nPareto plot, relative change a1/peg:') stats_tests.append(relative_change_pareto) LN_single = MODELGROUPS['LN'][4] pop_LN = ALL_FAMILY_MODELS[3] sig_cells_A1 = get_significant_cells(322, SIG_TEST_MODELS, as_list=True) r1 = nd.batch_comp(322, [LN_single, pop_LN], cellids=sig_cells_A1, stat=PLOT_STAT) LN_test1 = st.wilcoxon(getattr(r1, LN_single), getattr(r1, pop_LN), alternative='two-sided') sig_cells_PEG = get_significant_cells(323, SIG_TEST_MODELS, as_list=True) r2 = nd.batch_comp(323, [LN_single, pop_LN], cellids=sig_cells_PEG, stat=PLOT_STAT) LN_test2 = st.wilcoxon(getattr(r2, LN_single), getattr(r2, pop_LN), alternative='two-sided') stats_tests.append('\nLN single vs pop-LN:') stats_tests.append(f'A1: {LN_test1}')
modelname_filter = POP_MODELS[1] mfb = {322: modelname_filter, 323: modelname_filter.replace('.ver2','')} # ROUND 1, all families pop if 0: modelnames = ALL_FAMILY_POP[:-1] useGPU = True for batch in batches: if useGPU and (batch==322): c = ['NAT4v2'] elif useGPU: c = ['NAT4'] else: c = nd.batch_comp(modelnames=[modelname_filter], batch=batch).index.to_list() enqueue_exacloud_models( cellist=c, batch=batch, modellist=modelnames, user=lbhb_user, linux_user=user, force_rerun=force_rerun, priority=1, executable_path=executable_path_exa, script_path=script_path_exa, useGPU=useGPU) if 0: # dnn single, round 1 modelnames = DNN_SINGLE_MODELS[:5] # ln single, only 1 round modelnames = LN_SINGLE_MODELS[:10] # dnn single, round 2 modelnames = DNN_SINGLE_STAGE2[:5]
'pdf.fonttype': 42, 'ps.fonttype': 42 } plt.rcParams.update(params) batch = 259 # shrinkage modelname1 = "env.fs100-ld-sev_dlog.f-fir.2x15-lvl.1-dexp.1_init-mt.shr-basic" modelname2 = "env.fs100-ld-sev_dlog.f-wc.2x3.c.n-stp.3-fir.3x15-lvl.1-dexp.1_init-mt.shr-basic" # regular modelname1 = "env.fs100-ld-sev_dlog.f-fir.2x15-lvl.1-dexp.1_init-basic" modelname2 = "env.fs100-ld-sev_dlog.f-wc.2x3.c.n-stp.3-fir.3x15-lvl.1-dexp.1_init-basic" modelnames = [modelname1, modelname2] df = nd.batch_comp(batch, modelnames) df['diff'] = df[modelname2] - df[modelname1] df['cellid'] = df.index df.sort_values('cellid', inplace=True, ascending=True) m = df['cellid'].str.startswith('por07') & (df[modelname2] > 0.3) for index, c in df[m].iterrows(): print("{} {:.3f} - {:.3f} = {:.3f}".format(index, c[modelname2], c[modelname1], c['diff'])) plt.close('all') outpath = "/auto/users/svd/docs/current/two_band_spn/eps_rev/" if 0: #cellid="por077a-c1" cellid = "por074b-d2" cellid = "por020a-c1"
def gd_scatter(batch, model1, model2, se_filter=True, gd_threshold=0, param='kappa', log_gd=False): df_r = nd.batch_comp(batch, [model1, model2], stat='r_ceiling') df_e = nd.batch_comp(batch, [model1, model2], stat='se_test') # Remove any cellids that have NaN for 1 or more models df_r.dropna(axis=0, how='any', inplace=True) df_e.dropna(axis=0, how='any', inplace=True) cellids = df_r.index.values.tolist() gc_test = df_r[model1] gc_se = df_e[model1] ln_test = df_r[model2] ln_se = df_e[model2] if se_filter: # Remove if performance not significant at all good_cells = ((gc_test > gc_se * 2) & (ln_test > ln_se * 2)) else: # Set to series w/ all True, so none are skipped good_cells = (gc_test != np.nan) df1 = fitted_params_per_batch(batch, model1, stats_keys=[]) df2 = fitted_params_per_batch(batch, model2, stats_keys=[]) # fill in missing cellids w/ nan celldata = nd.get_batch_cells(batch=batch) cellids = celldata['cellid'].tolist() cellids = [c for c in cellids if c in good_cells] nrows = len(df1.index.values.tolist()) df1_cells = df1.loc['meta--r_test'].index.values.tolist()[5:] df2_cells = df2.loc['meta--r_test'].index.values.tolist()[5:] nan_series = pd.Series(np.full((nrows), np.nan)) df1_nans = 0 df2_nans = 0 for c in cellids: if c not in df1_cells: df1[c] = nan_series df1_nans += 1 if c not in df2_cells: df2[c] = nan_series df2_nans += 1 print("# missing cells: %d, %d" % (df1_nans, df2_nans)) # Force same cellid order now that missing cols are filled in df1 = df1[cellids] df2 = df2[cellids] gc_vs_ln = df1.loc['meta--r_test'].values / df2.loc['meta--r_test'].values gc_vs_ln = gc_vs_ln.astype('float32') kappa_mod = df1[df1.index.str.contains('%s_mod' % param)] kappa = df1[df1.index.str.contains('%s$' % param)] gd_ratio = (np.abs(kappa_mod.values / kappa.values)).astype('float32').flatten() ff = np.isfinite(gc_vs_ln) & np.isfinite(gd_ratio) gc_vs_ln = gc_vs_ln[ff] gd_ratio = gd_ratio[ff] if log_gd: gd_ratio = np.log(gd_ratio) # drop cells with excessively large/small gd_ratio or gc_vs_ln gcd_big = gd_ratio > 10 gc_vs_ln_big = gc_vs_ln > 10 gc_vs_ln_small = gc_vs_ln < 0.1 keep = ~gcd_big & ~gc_vs_ln_big & ~gc_vs_ln_small gd_ratio = gd_ratio[keep] gc_vs_ln = gc_vs_ln[keep] r = np.corrcoef(gc_vs_ln, gd_ratio)[0, 1] n = gc_vs_ln.size # Separately do the same comparison but only with cells that had a # Gd ratio at least a little greater than 1 (i.e. had *some* GC effect) gd_ratio2 = copy.deepcopy(gd_ratio) gc_vs_ln2 = copy.deepcopy(gc_vs_ln) if log_gd: gd_threshold = np.log(gd_threshold) thresholded = (gd_ratio2 > gd_threshold) gd_ratio2 = gd_ratio2[thresholded] gc_vs_ln2 = gc_vs_ln2[thresholded] r2 = np.corrcoef(gc_vs_ln2, gd_ratio2)[0, 1] n2 = gc_vs_ln2.size fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(8, 9)) ax1.scatter(gd_ratio, gc_vs_ln, c='black', s=1) ax1.set_ylabel("GC/LN R") ax1.set_xlabel("Gd ratio") ax1.set_title("Performance Improvement vs Gd ratio\nr: %.02f, n: %d" % (r, n)) ax2.hist(gd_ratio, bins=30, histtype='bar', color=['gray']) ax2.set_title('Gd ratio distribution') ax2.set_xlabel('Gd ratio') ax2.set_ylabel('Count') ax3.scatter(gd_ratio2, gc_vs_ln2, c='black', s=1) ax3.set_ylabel("GC/LN R") ax3.set_xlabel("Gd ratio") ax3.set_title("Same, only cells w/ Gd > %.02f\nr: %.02f, n: %d" % (gd_threshold, r2, n2)) ax4.hist(gd_ratio2, bins=30, histtype='bar', color=['gray']) ax4.set_title('Gd ratio distribution, only Gd > %.02f' % gd_threshold) ax4.set_xlabel('Gd ratio') ax4.set_ylabel('Count') fig.suptitle('param: %s' % param) fig.tight_layout()
def model_comp_pareto(batch, modelgroups, ax, cellids, nparms_modelgroups=None, dot_colors=None, dot_markers=None, fill_styles=None, plot_stat='r_test', plot_medians=False, labeled_models=None, show_legend=True): if labeled_models is None: labeled_models = [] if nparms_modelgroups is None: nparms_modelgroups = copy.copy(modelgroups) if fill_styles is None: fill_styles = {k: 'full' for (k, v) in dot_colors.items()} mean_cells_per_site = len(cellids) # NAT4 dataset, so all cellids are used overall_min = 100 overall_max = -100 all_model_means = [] labeled_data = [] for k, modelnames in modelgroups.items(): np_modelnames = nparms_modelgroups[k] b_ceiling = nd.batch_comp(batch, modelnames, cellids=cellids, stat=plot_stat) b_n = nd.batch_comp(batch, np_modelnames, cellids=cellids, stat='n_parms') if not plot_medians: model_mean = b_ceiling.mean() else: model_mean = b_ceiling.median() b_m = np.array(model_mean) print( f"{k} modelcount {len(modelnames)} fits per model: {b_n.count().values}" ) n_parms = np.array([np.mean(b_n[m]) for m in np_modelnames]) # don't divide by cells per site if only one cell was fit if ('single' not in k) and (k != 'LN') and (k != 'stp'): n_parms = n_parms / mean_cells_per_site y_max = b_m.max() * 1.05 y_min = b_m.min() * 0.9 overall_max = max(overall_max, y_max) overall_min = min(overall_min, y_min) ax.plot(n_parms, b_m, color=dot_colors[k], marker=dot_markers[k], label=k, markersize=4.5, fillstyle=fill_styles[k]) for m in labeled_models: if m in modelnames: i = modelnames.index(m) labeled_data.append([n_parms[i], b_m[i]]) all_model_means.append(model_mean) ax.plot(*list(zip(*labeled_data)), 's', color='black', marker='o', fillstyle='none', markersize=10) handles, labels = ax.get_legend_handles_labels() # reverse the order if show_legend: ax.legend(handles, labels, loc='lower right', fontsize=7, frameon=False) ax.set_xlabel('Free parameters per neuron') ax.set_ylabel('Median prediction correlation') ax.set_ylim((overall_min, overall_max)) ax_remove_box(ax) plt.tight_layout() return ax, b_ceiling, all_model_means, labels
n2 = modelnames[-2] elif 1: batch = 289 modelnames = [ "ozgf.fs100.ch18-ld-sev_dlog-wc.18x3-fir.3x15-lvl.1-dexp.1_init-basic", "ozgf.fs100.ch18-ld-sev_dlog-wc.18x3-stp.3-fir.3x15-lvl.1-dexp.1_init-basic" ] #modelnames = ["ozgf.fs100.ch18-ld-sev_dlog-wc.18x3.g-fir.3x15-lvl.1-dexp.1_init-basic", # "ozgf.fs100.ch18-ld-sev_dlog-wc.18x3.g-stp.3-fir.3x15-lvl.1-dexp.1_init-basic"] n1 = modelnames[0] n2 = modelnames[1] fileprefix = "fig9.NAT" xc_range = [-0.05, 1.1] df = nd.batch_comp(batch, modelnames, stat='r_ceiling') df_r = nd.batch_comp(batch, modelnames, stat='r_test') df_e = nd.batch_comp(batch, modelnames, stat='se_test') cellcount = len(df) beta1 = df[n1] beta2 = df[n2] beta1_test = df_r[n1] beta2_test = df_r[n2] se1 = df_e[n1] se2 = df_e[n2] beta1[beta1 > 1] = 1 beta2[beta2 > 1] = 1
def sparseness_by_batch(batch, modelnames=None, pop_reference_model=ALL_FAMILY_POP[2], save_path=None, force_regenerate=False, rec=None): if (save_path is not None) & (not force_regenerate): if os.path.exists(save_path): sparseness_data = pd.read_csv(save_path, index_col=0) return sparseness_data if modelnames is None: modelnames = [ ALL_FAMILY_MODELS[0], ALL_FAMILY_MODELS[2], ALL_FAMILY_MODELS[3] ] #modelnames=[ALL_FAMILY_MODELS[2],ALL_FAMILY_MODELS[3] ] d = nd.batch_comp(batch, modelnames) #cellids = d.index cellids = get_significant_cells(batch, SIG_TEST_MODELS, as_list=True) if rec is None: xf0, ctx0 = load_model_xform(cellids[0], batch, pop_reference_model, eval_model=True) val = ctx0['val'] val = val.apply_mask() del ctx0 else: val = rec sparseness_data = pd.DataFrame() for j, cellid in enumerate(cellids): r_test_all = d.loc[cellid].values.max() if r_test_all > 0: for i, m in enumerate(modelnames): xf, ctx = load_model_xform(cellid, batch, m, eval_model=False) val_ = ctx['modelspec'].evaluate(val) fs = val['resp'].fs this_resp = val['resp'].extract_channels( chans=[cellid])._data[0, :] * fs this_pred = val_['pred']._data[0, :] * fs c = np.corrcoef(this_resp, this_pred)[0, 1] #c2 = np.corrcoef(this_resp._data[0,:], original_pred._data[0,:])[0,1] #print(c2, d.loc[cellid].values) S_r, S_p, c = sparseness(this_resp, this_pred, verbose=False) print( f"{j} {cellid} r_test={c:.3f} orig r_test={d.loc[cellid].values[i]:.3f} S_r={S_r:.3f} S_p={S_p:.3f}" ) sparseness_data = sparseness_data.append( { 'cellid': cellid, 'S_r': S_r, 'S_p': S_p, 'r_test': c, 'r_test_all': r_test_all, 'model': i }, ignore_index=True) if save_path is not None: print(f"Saving sparseness_data to {save_path}") sparseness_data.to_csv(save_path) return sparseness_data