def on_view_model(self): """Event handler for view model button.""" self.statusbar.showMessage('Loading model...', 5000) xfspec, ctx = xform_helper.load_model_xform(self.cellid, self.batch, self.modelname, eval_model=True) self.launch_model_browser(ctx, xfspec)
def stp_sigmoid_pred_matched(cellid, batch, modelname, LN, include_phi=True): xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname) ln_spec, ln_ctx = xhelp.load_model_xform(cellid, batch, LN) modelspec = ctx['modelspec'] modelspec.recording = ctx['val'] val = ctx['val'].apply_mask() ln_modelspec = ln_ctx['modelspec'] ln_modelspec.recording = ln_ctx['val'] ln_val = ln_ctx['val'].apply_mask() pred_after_NL = val['pred'].as_continuous().flatten() # with stp val_before_NL = ms.evaluate(ln_val, ln_modelspec, stop=-1) pred_before_NL = val_before_NL['pred'].as_continuous().flatten() # no stp stp_idx = find_module('stp', modelspec) val_before_stp = ms.evaluate(val, modelspec, stop=stp_idx) val_after_stp = ms.evaluate(val, modelspec, stop=stp_idx + 1) pred_before_stp = val_before_stp['pred'].as_continuous().mean( axis=0).flatten() pred_after_stp = val_after_stp['pred'].as_continuous().mean( axis=0).flatten() stp_effect = (pred_after_stp - pred_before_stp) / (pred_after_stp + pred_before_stp) fig = plt.figure() plasma = plt.get_cmap('plasma') plt.scatter(pred_before_NL, pred_after_NL, c=stp_effect, s=2, alpha=0.75, cmap=plasma) plt.title(cellid) plt.xlabel('pred in (no stp)') plt.ylabel('pred out (with stp)') if include_phi: stp_phi = modelspec.phi[stp_idx] phi_string = '\n'.join( ['%s: %.4E' % (k, v) for k, v in stp_phi.items()]) fig.text(0.775, 0.9, phi_string, va='top', ha='left') plt.subplots_adjust(right=0.775, left=0.075) plt.colorbar() return fig
def load(self): batch = self.batchLE.text() cellid = self.cellLE.text() modelname = self.modelLE.text() print('Loading {}/{}/{}'.format(cellid,batch,modelname)) xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname) self.close() return xfspec, ctx
def get_model_preds(cellid, batch, modelname): xf, ctx = xhelp.load_model_xform(cellid, batch, modelname, eval_model=False) ctx, l = xforms.evaluate(xf, ctx, stop=-1) #ctx, l = oxf.evaluate(xf, ctx, stop=-1) return xf, ctx
def dynamic_sigmoid_differences(batch, modelname, hist_bins=60, test_limit=None, save_path=None, load_path=None, use_quartiles=False, avg_bin_count=20): if load_path is None: cellids = nd.get_batch_cells(batch, as_list=True) ratios = [] for cellid in cellids[:test_limit]: xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname) val = ctx['val'].apply_mask() ctpred = val['ctpred'].as_continuous().flatten() pred_after = val['pred'].as_continuous().flatten() val_before = ms.evaluate(val, ctx['modelspec'], stop=-1) pred_before = val_before['pred'].as_continuous().flatten() median_ct = np.nanmedian(ctpred) if use_quartiles: low = np.percentile(ctpred, 25) high = np.percentile(ctpred, 75) low_mask = (ctpred >= low) & (ctpred < median_ct) high_mask = ctpred >= high else: low_mask = ctpred < median_ct high_mask = ctpred >= median_ct # TODO: do some kind of binning here since the two vectors # don't actually overlap in x axis mean_before, bin_masks = _binned_xvar(pred_before, avg_bin_count) low = _binned_yavg(pred_after, low_mask, bin_masks) high = _binned_yavg(pred_after, high_mask, bin_masks) ratio = np.nanmean((low - high) / (np.abs(low) + np.abs(high))) ratios.append(ratio) ratios = np.array(ratios) if save_path is not None: np.save(save_path, ratios) else: ratios = np.load(load_path) plt.figure() plt.hist(ratios, bins=hist_bins, color=[wsu_gray_light], edgecolor='black', linewidth=1) #plt.rc('text', usetex=True) #plt.xlabel(r'\texit{\frac{low-high}{\left|high\right|+\left|low\right|}}') plt.xlabel('(low - high)/(|low| + |high|)') plt.ylabel('cell count') plt.title("difference of low-contrast output and high-contrast output\n" "positive means low-contrast has higher firing rate on average")
def dynamic_sigmoid_range(cellid, batch, modelname, plot=True, show_min_max=True): xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname) modelspec = ctx['modelspec'] modelspec.recording = ctx['val'] val = ctx['val'].apply_mask() ctpred = val['ctpred'].as_continuous().flatten() val_before_dsig = ms.evaluate(val, modelspec, stop=-1) pred_before_dsig = val_before_dsig['pred'].as_continuous().flatten() lows = {k: v for k, v in modelspec[-1]['phi'].items() if '_mod' not in k} highs = {k[:-4]: v for k, v in modelspec[-1]['phi'].items() if '_mod' in k} for k in lows: if k not in highs: highs[k] = lows[k] # re-sort keys to make sure they're in the same order lows = {k: lows[k] for k in sorted(lows)} highs = {k: highs[k] for k in sorted(highs)} ctmax_val = np.max(ctpred) ctmin_val = np.min(ctpred) ctmax_idx = np.argmax(ctpred) ctmin_idx = np.argmin(ctpred) thetas = list(lows.values()) theta_mods = list(highs.values()) for t, t_m, k in zip(thetas, theta_mods, list(lows.keys())): lows[k] = t + (t_m - t) * ctmin_val highs[k] = t + (t_m - t) * ctmax_val low_out = _double_exponential(pred_before_dsig, **lows).flatten() high_out = _double_exponential(pred_before_dsig, **highs).flatten() if plot: fig = plt.figure() plt.scatter(pred_before_dsig, low_out, color='blue', s=0.7, alpha=0.6) plt.scatter(pred_before_dsig, high_out, color='red', s=0.7, alpha=0.6) if show_min_max: max_pred = pred_before_dsig[ctmax_idx] min_pred = pred_before_dsig[ctmin_idx] plt.scatter(min_pred, low_out[ctmin_idx], facecolors=None, edgecolors='blue', s=60) plt.scatter(max_pred, high_out[ctmax_idx], facecolors=None, edgecolors='red', s=60) return pred_before_dsig, low_out, high_out
def _get_plot_contexts(cellid, batch, gc, stp, LN, combined): xfspec1, gc_ctx = xhelp.load_model_xform(cellid, batch, gc, eval_model=True) xfspec2, stp_ctx = xhelp.load_model_xform(cellid, batch, stp, eval_model=True) xfspec3, LN_ctx = xhelp.load_model_xform(cellid, batch, LN, eval_model=True) xfspec4, combined_ctx = xhelp.load_model_xform(cellid, batch, combined, eval_model=True) return gc_ctx, stp_ctx, LN_ctx, combined_ctx
def preview(self): batch = self.batchLE.text() cellid = self.cellLE.text() modelname = self.modelLE.text() print("Viewing {},{},{}".format(batch,cellid,modelname)) xf, ctx = xhelp.load_model_xform(cellid, batch, modelname, eval_model=False) figurefile = ctx['modelspec'].meta['figurefile'] print('Figure file: ' + figurefile) self.im_window.update_imagepath(imagepath=figurefile) self.im_window.show()
def generate_psth_correlations_pop(batch, modelnames, save_path=None, load_path=None): if load_path is not None: corrs = pd.read_pickle(load_path) return corrs else: cellids = [] c2d_c1d = [] c2d_LN = [] c1d_LN = [] significant_cells = get_significant_cells(batch, SIG_TEST_MODELS, as_list=True) # Load and evaluate each model, pull out validation pred signal for each one. contexts = [ xhelp.load_model_xform(significant_cells[0], batch, m, eval_model=True)[1] for m in modelnames ] preds = [c['val'].apply_mask()['pred'] for c in contexts] chans = preds[0].chans # not all the models load chans for some reason for i, _ in enumerate(preds[1:]): preds[i + 1].chans = chans preds = [ p.extract_channels(significant_cells).as_continuous() for p in preds ] for i, cellid in enumerate(significant_cells): # Compute correlation between eaceh pair of models, append to running list. # 0: conv2d, 1: conv1dx2+d, 2: LN_pop, # TODO: if EQUIVALENCE_MODELS_POP changes, this needs to change as well c2d_c1d.append(np.corrcoef( preds[0][i], preds[1][i])[0, 1]) # correlate conv2d with conv1dx2+d c2d_LN.append(np.corrcoef( preds[0][i], preds[2][i])[0, 1]) # correlate conv2d with LN_pop c1d_LN.append(np.corrcoef( preds[1][i], preds[2][i])[0, 1]) # correlate conv1dx2+d with LN_pop cellids.append(cellid) # Convert to dataframe and save after each cell, in case there's a crash. corrs = { 'cellid': cellids, 'c2d_c1d': c2d_c1d, 'c2d_LN': c2d_LN, 'c1d_LN': c1d_LN } corrs = pd.DataFrame.from_dict(corrs) corrs.set_index('cellid', inplace=True) if save_path is not None: corrs.to_pickle(save_path) return corrs
def pred_comparison(cellid, batch, models, subtract=None): ''' Convention for model order: gc, stp, ln, gc+stp subtract: None, or string modelname to subtract preds (like LN) ''' if subtract is not None: xf, ctx = xhelp.load_model_xform(cellid, batch, subtract) subtract_pred = ctx['pred'].as_continuous().T preds = [] for m in models: xf, ctx = xhelp.load_model_xform(cellid, batch, m) p = ctx['val']['pred'].as_continuous().T if subtract is not None: p = p - subtract_pred preds.append(p) for p in preds: plt.plot(p) plt.legend(*preds)
def get_strf(cellid, batch, modelname): xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname, eval_model=False) modelspec = ctx['modelspec'] wc_coefs = _get_wc_coefficients(modelspec, idx=0) fir_coefs = _get_fir_coefficients(modelspec, idx=0) strf = wc_coefs.T @ fir_coefs return strf
def test_comparison( cellid="TAR010c-15-5", batch=289, modelname1="ozgf.fs100.ch18.pop-loadpop.cc20.bth-norm.l1-popev_wc.18x30.g-fir.1x12x30-relu.30-wc.30x30.z-relu.30-wc.30xR.z-lvl.R-dexp.R_tfinit.n.lr1e3.et3-newtf.n.lr1e4-popspc", modelname2="ozgf.fs100.ch18.pop-loadpop.cc20.bth-norm.l1-popev_wc.18x40.g-stp.40.q.s-fir.1x12x40-relu.40-wc.40x30.z-relu.30-wc.30xR.z-lvl.R-dexp.R_tfinit.n.lr1e3.et3-newtf.n.lr1e4-popspc", modelname_ref="ozgf.fs100.ch18-ld-sev_dlog-wc.18x3.g-fir.3x15-lvl.1-dexp.1_init-basic", recname="val"): xf1, ctx1 = load_model_xform(cellid, batch, modelname1) xf2, ctx2 = load_model_xform(cellid, batch, modelname2) xfr, ctxr = load_model_xform(cellid, batch, modelname_ref) modelspec1 = ctx1['modelspec'] modelspec2 = ctx2['modelspec'] modelspec_ref = ctxr['modelspec'] rec1 = ctx1[recname] rec2 = ctx2[recname] rec_ref = ctxr[recname] return modelspec1, modelspec2, modelspec_ref, rec1, rec2, rec_ref
def gd_ratio(cellid, batch, modelname): xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname, eval_model=False) mspec = ctx['modelspec'] dsig_idx = find_module('dynamic_sigmoid', mspec) phi = mspec[dsig_idx]['phi'] return phi['kappa_mod'] / phi['kappa']
def sparseness_example(batch, cellid, modelname, rec=None, ax=None, tag="", colors=None): if rec is None: xf, ctx = load_model_xform(cellid, batch, modelname, eval_model=True) val = ctx['val'] val = val.apply_mask() else: xf, ctx = load_model_xform(cellid, batch, modelname, eval_model=False) val = rec val_ = ctx['modelspec'].evaluate(val) fs = val['resp'].fs this_resp = val['resp'].extract_channels(chans=[cellid])._data[0, :] * fs this_pred = val_['pred']._data[0, :] * fs c = np.corrcoef(this_resp, this_pred)[0, 1] if ax is None: f, ax = plt.subplots() S_r, S_p, c = sparseness(this_resp, this_pred, cellid=cellid, verbose=True, ax=ax, colors=colors, tag=tag) #ax.set_title(f"{cellid} {tag}") print( f"{cellid} r_test={c:.3f} orig r_test={ctx['modelspec'].meta['r_test'][0,0]:.3f} S_r={S_r:.3f} S_p={S_p:.3f}" ) return S_r, S_p
def load_models(cell, batch, models, check_db=True, site=None, site_cache=None): '''Load standardized psth from each model, error if not all fits exist.''' if check_db: # Before trying to load, check database to see if a result exists. # Should set False if you know model results are not stored in DB, # but exist in file storage. df = nd.batch_comp(batch=batch, modelnames=models, cellids=[cell]) if np.sum(df.isna().values) > 0: # at least one cell wasn't fit (or at least not stored in db) # so skip trying to load any of them since all are required. raise ValueError('Not all results exist for: %s, %d' % (cell, batch)) # Load all models ctxs = [] for model in models: if site_cache is None: xf, ctx = xhelp.load_model_xform(cell, batch, model) elif model in site_cache[site]: log.info("Site %s is cached, skipping load...", site) ctx = site_cache[site][model] else: xf, ctx = xhelp.load_model_xform(cell, batch, model) site_cache[site][model] = ctx ctxs.append(ctx) for ctx in ctxs: if ctx['val']['pred'].chans is None: ctx['val']['pred'].chans = copy(ctx['val']['resp'].chans) # Pull out model predictions and remove times with nan for at least 1 model preds = [ctx['val'].apply_mask()['pred'].extract_channels([cell]).as_continuous() for ctx in ctxs] ff = np.isfinite(preds[0]) for pred in preds[1:]: ff &= np.isfinite(pred) no_nans = [pred[ff] for pred in preds] return no_nans
def cf_batch_rank1(batch, modelname, save_path=None, load_path=None, f_low=0.2, f_high=20, nf=18, test_limit=None): if load_path is not None: df = pd.read_pickle(load_path) return df cells = nd.get_batch_cells(batch, as_list=True) cfs = [] cf_bins = [] skipped = [] for cellid in cells[:test_limit]: try: xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname, eval_model=False) except: # cell probably not fit for this model skipped.append(cellid) continue modelspec = ctx['modelspec'] # mult by nf b/c x vals in gaussian coeffs module are divided by # number of channels # max and min bounds b/c means outside of bin range are allowed mean = min(max(0, np.asscalar(modelspec.phi[1]['mean']) * nf), nf - 1) khz_freqs = np.logspace(np.log(f_low), np.log(f_high), num=nf, base=np.e) cf_bin = int(round(mean)) cf = khz_freqs[cf_bin] cfs.append(cf) cf_bins.append(cf_bin) cellid_index = [c for c in cells[:test_limit] if c not in skipped] results = {'cellid': cellid_index, 'cf': cfs, 'cf_bin': cf_bins} df = pd.DataFrame.from_dict(results) df.set_index('cellid', inplace=True) if save_path is not None: df.to_pickle(save_path) return df
def dynamic_sigmoid_distribution(cellid, batch, modelname, sample_every=10, alpha=0.1): xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname) modelspec = ctx['modelspec'] val = ctx['val'].apply_mask() modelspec.recording = val val_before_dsig = ms.evaluate(val, modelspec, stop=-1) pred_before_dsig = val_before_dsig['pred'].as_continuous().T ctpred = val_before_dsig['ctpred'].as_continuous().T lows = {k: v for k, v in modelspec[-1]['phi'].items() if '_mod' not in k} highs = {k[:-4]: v for k, v in modelspec[-1]['phi'].items() if '_mod' in k} for k in lows: if k not in highs: highs[k] = lows[k] # re-sort keys to make sure they're in the same order lows = {k: lows[k] for k in sorted(lows)} highs = {k: highs[k] for k in sorted(highs)} thetas = list(lows.values()) theta_mods = list(highs.values()) fig = plt.figure() for i in range(int(len(pred_before_dsig) / sample_every)): try: ts = {} for t, t_m, k in zip(thetas, theta_mods, list(lows.keys())): ts[k] = t + (t_m - t) * ctpred[i * sample_every] y = _double_exponential(pred_before_dsig, **ts) plt.scatter(pred_before_dsig, y, color='black', alpha=alpha, s=0.01) except IndexError: # Will happen on last attempt if array wasn't evenly divisible # by sample_every pass t_max = {} t_min = {} for t, t_m, k in zip(thetas, theta_mods, list(lows.keys())): t_max[k] = t + (t_m - t) * np.nanmax(ctpred) t_min[k] = t + (t_m - t) * np.nanmin(ctpred) max_out = _double_exponential(pred_before_dsig, **t_max) min_out = _double_exponential(pred_before_dsig, **t_min) plt.scatter(pred_before_dsig, max_out, color='red', s=0.1) plt.scatter(pred_before_dsig, min_out, color='blue', s=0.1)
def on_selection_changed(self, event=None): print('on_selection_changed') try: cellid = self.cells.currentItem().text() modelname = self.models.currentItem().text() batch = self.batch print('Selected cell(s): ' + cellid) print('Selected model(s): ' + modelname) print('Selected batch: ' + str(batch)) xf, ctx = xhelp.load_model_xform(cellid, batch, modelname, eval_model=False) figurefile = ctx['modelspec'].meta['figurefile'] print('Figure file: ' + figurefile) self.im_window.update_imagepath(imagepath=figurefile) except: print('error?')
def load_existing_pred(cellid=None, siteid=None, batch=None, modelname_existing=None, **kwargs): """ designed to be called by xforms keyword loadpred cellid/siteid - one or the other required batch - required default modelname_existing = "psth.fs4.pup-ld-st.pup-hrc-psthfr-aev_sdexp2.SxR_newtf.n.lr1e4.cont.et5.i50000" makes new signal 'pred0' from evaluated 'pred', returns in updated rec returns ctx-compatible dict {'rec': nems.Recording, 'input_name': 'pred0'} """ if (batch is None): raise ValueError("must specify cellid/siteid and batch") if cellid is None: if siteid is None: raise ValueError("must specify cellid/siteid and batch") d = nd.pd_query( "SELECT batch,cellid FROM Batches WHERE batch=%s AND cellid like %s", ( batch, siteid + "%", )) cellid = d['cellid'].values[0] elif type(cellid) is list: cellid = cellid[0] if modelname_existing is None: #modelname_existing = "psth.fs4.pup-ld-st.pup-hrc-psthfr-aev_sdexp2.SxR_newtf.n.lr1e4.cont" modelname_existing = "psth.fs4.pup-ld-st.pup-hrc-psthfr-aev_sdexp2.SxR_newtf.n.lr1e4.cont.et5.i50000" xf, ctx = xhelp.load_model_xform(cellid, batch, modelname_existing) for k in ctx['val'].signals.keys(): if k not in ctx['rec'].signals.keys(): ctx['rec'].signals[k] = ctx['val'].signals[k].copy() s = ctx['rec']['pred'].copy() s.name = 'pred0' ctx['rec'].add_signal(s) #return {'rec': ctx['rec'],'val': ctx['val'],'est': ctx['est']} return {'rec': ctx['rec'], 'input_name': 'pred0'}
def dynamic_sigmoid_pred_matched(cellid, batch, modelname, include_phi=True): xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname) modelspec = ctx['modelspec'] modelspec.recording = ctx['val'] val = ctx['val'].apply_mask() ctpred = val['ctpred'].as_continuous().flatten() # HACK # this really shouldn't happen.. but for some reason some of the # batch 263 cells are getting nans, so temporary fix. ctpred[np.isnan(ctpred)] = 0 pred_after_dsig = val['pred'].as_continuous().flatten() val_before_dsig = ms.evaluate(val, modelspec, stop=-1) pred_before_dsig = val_before_dsig['pred'].as_continuous().flatten() fig = plt.figure(figsize=(12, 7)) plasma = plt.get_cmap('plasma') plt.scatter(pred_before_dsig, pred_after_dsig, c=ctpred, s=2, alpha=0.75, cmap=plasma) plt.title(cellid) plt.xlabel('pred in') plt.ylabel('pred out') if include_phi: dsig_phi = modelspec.phi[-1] phi_string = '\n'.join( ['%s: %.4E' % (k, v) for k, v in dsig_phi.items()]) thetas = list(dsig_phi.keys())[0:-1:2] mods = list(dsig_phi.keys())[1::2] weights = { k: (dsig_phi[mods[i]] - dsig_phi[thetas[i]]) for i, k in enumerate(thetas) } weights_string = 'weights:\n' + '\n'.join( ['%s: %.4E' % (k, v) for k, v in weights.items()]) fig.text(0.775, 0.9, phi_string, va='top', ha='left') fig.text(0.775, 0.1, weights_string, va='bottom', ha='left') plt.subplots_adjust(right=0.775, left=0.075) plt.colorbar() return fig
def random_condition_convergence(cellid, batch, modelname, separate_figures=True): xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname) meta = ctx['modelspec'].meta rcs = meta['random_conditions'] best_idx = meta['best_random_idx'] keys = list(rcs[0][0].keys()) if not separate_figures: plt.figure() colors = [np.random.rand(3, ) for k in keys] for initial, final in rcs: starts = list(initial.values()) ends = list(final.values()) for i, k in enumerate(keys): plt.plot([0, 1], [np.asscalar(starts[i]), np.asscalar(ends[i])], c=colors[i]) plt.legend(keys) else: for k in keys: plt.figure() for i, (initial, final) in enumerate(rcs): start = initial[k] end = final[k] if i == 0: color = 'blue' label = 'mean' elif i == best_idx: color = 'red' label = 'best' else: color = 'black' label = None plt.plot([0, 1], np.concatenate((start, end)), color=color, label=label) plt.legend() plt.title("%s, best_idx: %d" % (k, best_idx))
def compare_sims(start=0, end=None): # TODO: set up to compare on synthetic stimuli xfspec, ctx = xhelp.load_model_xform(_DEFAULT_CELL, _DEFAULT_BATCH, _DEFAULT_MODEL) val = ctx['val'] gc_sim = build_toy_gc_cell(0, 0, 0, -0.5) #base, amp, shift, kappa gc_sim[-2]['fn_kwargs']['compute_contrast'] = True stp_sim = build_toy_stp_cell([0, 0.1], [0.08, 0.08]) #u, tau LN_sim = build_toy_LN_cell() stim = val['stim'].as_continuous() gc_val = gc_sim.evaluate(val) gc_sim.recording = gc_val gc_psth = gc_val['pred'].as_continuous().flatten() stp_val = stp_sim.evaluate(val) stp_sim.recording = stp_val stp_psth = stp_val['pred'].as_continuous().flatten() LN_val = LN_sim.evaluate(val) LN_sim.recording = LN_val LN_psth = LN_val['pred'].as_continuous().flatten() fig = plt.figure(figsize=wide_fig) if end is None: end = stim.shape[-1] plt.imshow(stim, aspect='auto', cmap=spectrogram_cmap, origin='lower', extent=(0, stim.shape[-1], 2.1, 3.4)) lw = 0.75 plt.plot(LN_psth, color=model_colors['LN'], linewidth=lw) plt.plot(gc_psth, color=model_colors['gc'], linewidth=lw*1.25) plt.plot(stp_psth, color=model_colors['stp'], alpha=0.75, linewidth=lw*1.25) plt.ylim(-0.1, 3.4) plt.xlim(start, end) ax = plt.gca() ax_remove_box(ax) return fig
def pop_model_example(figsize=None): tctx = load_high_res_stim() batch = 322 cellid = "DRX006b-128-2" modelname = ALL_FAMILY_POP[2] xf, ctx = load_model_xform(cellid, batch, modelname) modelspec = ctx['modelspec'].copy() val = ctx['val'].apply_mask() # extract a subset of channels, since 9xx is too many N = 50 rr = slice(N, N * 2, 1) modelspec.phi[8]['coefficients'] = modelspec.phi[8]['coefficients'][rr, :] modelspec.phi[9]['level'] = modelspec.phi[9]['level'][rr] modelspec.phi[10]['base'] = modelspec.phi[10]['base'][rr] modelspec.phi[10]['amplitude'] = modelspec.phi[10]['amplitude'][rr] modelspec.phi[10]['shift'] = modelspec.phi[10]['shift'][rr] modelspec.phi[10]['kappa'] = modelspec.phi[10]['kappa'][rr] val['resp'] = val['resp']._modified_copy(data=val['resp'][rr, :]) val['pred'] = val['pred']._modified_copy(data=val['pred'][rr, :]) modelspec.meta['cellid'] = modelspec.meta['cellids'][N] modelspec.meta['cellids'] = modelspec.meta['cellids'][rr] modelspec.meta['r_ceiling'] = modelspec.meta['r_test'][rr] * 1.1 print(val['stim'].shape, val['resp'].shape, tctx['val']['stim'].shape) f = pop_models.plot_layer_outputs(modelspec, val, index_range=np.arange(150, 600), example_idx=15, figsize=figsize, altstim=tctx['val']['stim']) return f
def mean_prior_used(batch, modelname): choices = [] cells = nd.get_batch_cells(batch, as_list=True) for i, c in enumerate(cells[400:500]): if 25 % (i + 1) == 0: print('cell %d/%d\n' % (i, len(cells))) try: xfspec, ctx = xhelp.load_model_xform(c, batch, modelname, eval_model=False) modelspec = ctx['modelspec'] choices.append(modelspec.meta.get('best_random_idx', 0)) except ValueError: # no result continue if choices: choices = np.array(choices).flatten() mean_count = np.sum(choices == 0) proportion = mean_count / len(choices) print('proportion mean prior used: %.4f' % proportion) else: print('no results found')
import nems.gui.editors as gui log = logging.getLogger(__name__) # NAT A1 SINGLE NEURON + PUPIL batch = 289 cellid = 'TAR009d-42-1' modelname = "ozgf.fs100.ch18-ld-sev_dlog.f-wc.18x3.g-stp.3-fir.3x15-lvl.1-dexp.1_init-basic" batch, cellid = 308, 'AMT018a-09-1' modelname = 'ozgf.fs100.ch18-ld-sev_dlog-wc.18x4.g-fir.2x15x2-relu.2-wc.2x1-lvl.1-dexp.1_tf.n.rb10' modelname2 = 'ozgf.fs100.ch18-ld-sev_dlog-wc.18x4.g-fir.2x15x2-relu.2-wc.2x1-lvl.1-dexp.1_tf.n.rb5' modelname2 = None GUI = True xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname) if GUI: # interactive model browser (matplotlib embedded Qt) ex = gui.browse_xform_fit(ctx, xfspec) if modelname2 is not None: xfspec2, ctx2 = xhelp.load_model_xform(cellid, batch, modelname2) ex2 = gui.browse_xform_fit(ctx2, xfspec2, control_widget=ex.editor.global_controls) #aw = browse_context(ctx, rec='val', signals=['stim', 'pred', 'resp']) #aw = browse_context(ctx, signals=['state', 'psth', 'pred', 'resp']) else:
#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Wed Aug 22 17:59:33 2018 @author: daniela """ #import nems_lbhb.xform_wrappers as nw import nems.xform_helper as xhelp cellid = "BRT036b-06-1" modelname = "psth.fs20.pup-ld-st.pup.beh-evs.tar.lic0_fir.Nx40-lvl.1-stategain.3_jk.nf10-init.st-basic" batch = 301 xf, ctx = xhelp.load_model_xform(cellid, batch, modelname)
def sparseness_figs(): batch = 322 cellid = 'ARM030a-40-2' modelname = ALL_FAMILY_MODELS[2] modelname_ln = ALL_FAMILY_MODELS[3] xf, ctx = load_model_xform(cellid, batch, modelname, eval_model=True) val = ctx['val'] val = val.apply_mask() f1, ax = plt.subplots( 1, 4, figsize=double_column_shorter) #, sharex=True, sharey=True) sparseness_example(batch, cellid, modelname, rec=val, ax=ax[0], tag="1D CNNx2", colors=['lightgray', DOT_COLORS['1D CNNx2']]) sparseness_example(batch, cellid, modelname_ln, rec=val, ax=ax[1], tag="pop LN", colors=['lightgray', DOT_COLORS['pop LN']]) ax[1].set_yticks([]) ax[0].set_box_aspect(1) ax[1].set_box_aspect(1) a1_sparseness_path = int_path / str(a1) / 'sparseness_data.csv' peg_sparseness_path = int_path / str(peg) / 'sparseness_data.csv' modelnames = [ ALL_FAMILY_MODELS[0], ALL_FAMILY_MODELS[2], ALL_FAMILY_MODELS[3] ] pop_reference_model = ALL_FAMILY_POP[2] batch = a1 sparseness_data_a1 = sparseness_by_batch( batch, modelnames=modelnames, pop_reference_model=pop_reference_model, save_path=a1_sparseness_path, force_regenerate=False, rec=None) batch = peg sparseness_data_peg = sparseness_by_batch( batch, modelnames=modelnames, pop_reference_model=pop_reference_model, save_path=peg_sparseness_path, force_regenerate=False, rec=None) sparseness_data_a1['area'] = 'A1' sparseness_data_peg['area'] = 'PEG' sparseness_data = pd.concat([sparseness_data_a1, sparseness_data_peg], ignore_index=True) # TODO: (maybe) filter out cellids that aren't in sig cell list, in addition to the r>0.2 check in sparseness_by_batch a1_cellids = get_significant_cells(322, SIG_TEST_MODELS, as_list=True) peg_cellids = get_significant_cells(323, SIG_TEST_MODELS, as_list=True) d = sparseness_data_a1.loc[sparseness_data_a1.model == 1].merge( sparseness_data_a1.loc[sparseness_data_a1.model == 2], how='inner', on='cellid', suffixes=('_dnn', '_ln')) ax[2].plot([0, 1], [0, 1], 'k--') d.plot.scatter( 'S_p_ln', 'S_p_dnn', s=1, c='black', ax=ax[2]) #, c='r_test_all_dnn', ax=ax[1, 0], vmin=0.1, vmax=0.9) #ax[2].set_title('Predicted sparseness') ax[2].set_aspect('equal') ax[2].set_xlabel('pop LN') ax[2].set_ylabel('1D CNNx2') sd_r = sparseness_data[['area', 'cellid', 'model', 'S_r', 'r_test_all']].copy() sd_r['model'] = "act" sd_r.columns = ['area', 'cellid', 'model', 'S', 'r_test'] sd_r = sd_r.drop_duplicates() sd_p = sparseness_data[['area', 'cellid', 'model', 'S_p', 'r_test_all']].copy() sd_p = sd_p.loc[sd_p['model'] > 0] sd_p['model'] = sd_p['model'].astype(str) sd_p.columns = ['area', 'cellid', 'model', 'S', 'r_test'] sd = pd.concat([sd_p, sd_r], ignore_index=True) sd.loc[sd['model'] == '1', 'model'] = '1D CNN' sd.loc[sd['model'] == '2', 'model'] = 'pop LN' sd['label'] = sd['area'] + " " + sd['model'] #tres=results.loc[(results[PLOT_STAT]<1) & results[PLOT_STAT]>-0.05] r_test_min = 0.0 print(f"r_test_min={r_test_min}") sd_thr = sd.loc[sd['r_test'] > r_test_min] #f,ax=plt.subplots() sns.stripplot( x='label', y='S', hue='label', data=sd_thr, zorder=0, palette=['gray', DOT_COLORS['1D CNNx2'], DOT_COLORS['pop LN']] * 2, hue_order=[ 'A1 act', 'A1 1D CNN', 'A1 pop LN', 'PEG act', 'PEG 1D CNN', 'PEG pop LN' ], order=[ 'A1 act', 'A1 1D CNN', 'A1 pop LN', 'PEG act', 'PEG 1D CNN', 'PEG pop LN' ], jitter=0.2, size=2, ax=ax[3]) #[1,1] sns.boxplot(x='label', y='S', data=sd_thr, boxprops={ 'facecolor': 'None', 'linewidth': 1 }, showcaps=False, showfliers=False, whiskerprops={'linewidth': 0}, order=[ 'A1 act', 'A1 1D CNN', 'A1 pop LN', 'PEG act', 'PEG 1D CNN', 'PEG pop LN' ], ax=ax[3]) #[1,1] plt.xticks(rotation=45, fontsize=6, ha='right') ax[3].legend_.remove() ax[3].set_xlabel('') ax[3].set_box_aspect(1) #ax[3].set_title(f"r_test_min={r_test_min}") ref_models = ['A1 act', 'A1 1D CNN', 'PEG act', 'PEG 1D CNN'] test_models = ['A1 1D CNN', 'A1 pop LN', 'PEG 1D CNN', 'PEG pop LN'] tests = [[ m1, m2, st.wilcoxon(sd.loc[sd['label'] == m1, 'S'], sd.loc[sd['label'] == m2, 'S'], alternative='two-sided') ] for m1, m2 in zip(ref_models, test_models)] print(pd.DataFrame(tests, columns=['ref', 'test', 'Wilcoxon u,p'])) print(sd.groupby('label').median()) f1.tight_layout() sd['r_approx'] = np.round(sd['r_test'], 1) ms = sd.groupby(['label', 'r_approx']).mean() mp = ms.reset_index().pivot(index='r_approx', columns='label', values='S') es = sd.groupby(['label', 'r_approx']).sem() ep = es.reset_index().pivot(index='r_approx', columns='label', values='S') f2, ax = plt.subplots(1, 2, figsize=column_and_half_short, sharex=True, sharey=True) palette = ['gray', DOT_COLORS['1D CNNx2'], DOT_COLORS['pop LN']] for c, p in zip(['A1 act', 'A1 pop LN', 'A1 1D CNN'], palette): ax[0].errorbar(mp.index, mp[c].values, ep[c].values, color=p, label=c) for c, p in zip(['PEG act', 'PEG pop LN', 'PEG 1D CNN'], palette): ax[1].errorbar(mp.index, mp[c].values, ep[c].values, color=p, label=c) ax[0].legend() ax[1].legend() ax[0].set_xlabel('r_test') ax[0].set_ylabel('Sparseness') return f1, f2, tests, sd
else: modelname_base = "psth.fs4.pup-ld-hrc-psthfr.z-pca.cc1.no.p-{0}-plgsm.p2-aev-rd"+\ "_stategain.2xR.x1,3-spred-lvnorm.4xR.so.x2-inoise.4xR.x3"+\ "_tfinit.xx0.n.lr1e4.cont.et4.i20000-lvnoise.r4-aev-ccnorm.md.t5.f0.ss3" states = [ 'st.pca0.pup+r1+s0,1', 'st.pca.pup+r1+s0,1', 'st.pca.pup+r1+s1', 'st.pca.pup+r1' ] resp_modelname = f"psth.fs4.pup-ld-hrc-psthfr.z-pca.cc1.no.p-{states[-1]}-plgsm.p2-aev-rd.resp"+\ "_stategain.2xR.x1,3-spred-lvnorm.4xR.so.x2-inoise.4xR.x3"+\ "_tfinit.xx0.n.lr1e4.cont.et4.i20-lvnoise.r4-aev-ccnorm.md.t1.f0.ss3" modelnames = [resp_modelname] + [modelname_base.format(s) for s in states] xf, ctx = load_model_xform(cellid=cellid, batch=batch, modelname=modelnames[-1]) val = ctx['val'].copy() resp = val['resp'].rasterize() epoch_regex = "^STIM_" epochs = ep.epoch_names_matching(resp.epochs, regex_str=epoch_regex) input_name = 'pred0' pred0 = val[input_name].extract_epochs(epochs, mask=val['mask']) pred = val['pred'].extract_epochs(epochs, mask=val['mask']) resp = val['resp'].extract_epochs(epochs, mask=val['mask']) pupil = val['pupil'].extract_epochs(epochs, mask=val['mask']) pmedian = np.nanmedian(val['pupil'].as_continuous()) epochs = list(resp.keys())
plt.figure() ax = plt.subplot(3, 1, 1) nplt.spectrogram_from_epoch(val['stim'], epoch, ax=ax, time_offset=2) ax = plt.subplot(3, 1, 2) nplt.timeseries_from_epoch([val['resp']], epoch, ax=ax) raster = rec['resp'].extract_epoch(epoch) ax = plt.subplot(3, 1, 3) plt.imshow(raster[:, 0, :]) plt.tight_layout() # see what a "traditional" NEMS model looks like nems_modelname = "ozgf.fs100.ch18-ld-sev_dlog-wc.18x2.g-do.2x15-lvl.1-dexp.1_init-basic" xfspec, ctx = load_model_xform(cellid, batch=batch, modelname=nems_modelname) nplt.quickplot(ctx) ex = gui.browse_xform_fit(ctx, xfspec) ## batch, cellid = 308, 'AMT018a-09-1' modelname = 'ozgf.fs100.ch18-ld-sev_dlog-wc.18x4.g-fir.2x15x2-relu.2-wc.2x1-lvl.1-dexp.1_init.tf.rb5-tf.n' xfspec, ctx = fit_model_xform(cellid, batch=batch, modelname=modelname) nplt.quickplot(ctx) ex = gui.browse_xform_fit(ctx, xfspec) ###Plot complexity of model versus how effective it was batch = 308 metric = 'r_test' metric2 = 'n_parms'
site = 'AMT020a' batch = 331 # LOAD RAW DATA / MODEL PREDICTIONS indep = "psth.fs4.pup-loadpred.cpnmvm-st.pup0.pvp-plgsm.e10.sp-lvnoise.r8-aev_lvnorm.2xR.d.so-inoise.3xR_ccnorm.t5.ss1" rlv = "psth.fs4.pup-loadpred.cpnmvm-st.pup0.pvp0-plgsm.e10.sp-lvnoise.r8-aev_lvnorm.SxR.d.so-inoise.2xR_ccnorm.t5.ss1" plv = "psth.fs4.pup-loadpred.cpnmvm-st.pup.pvp0-plgsm.e10.sp-lvnoise.r8-aev_lvnorm.SxR.d.so-inoise.2xR_ccnorm.t5.ss1" reverything = 'psth.fs4.pup-ld-st.pup0.pvp0-epcpn-mvm.25.2-hrc-psthfr-plgsm.e10.sp-lvnoise.r8-aev_sdexp2.SxR-lvnorm.SxR.d.so-inoise.2xR_ccnorm.r.t5.ss1' indep = 'psth.fs4.pup-ld-st.pup0.pvp-epcpn-mvm.25.2-hrc-psthfr-plgsm.e10.sp-lvnoise.r8-aev_sdexp2.SxR-lvnorm.2xR.d.so-inoise.3xR_ccnorm.r.t5.ss1' rlv = 'psth.fs4.pup-ld-st.pup0.pvp-epcpn-mvm.25.2-hrc-psthfr-plgsm.e10.sp-lvnoise.r8-aev_sdexp2.SxR-lvnorm.2xR.d.so-inoise.2xR_ccnorm.r.t5.ss1' plv = 'psth.fs4.pup-ld-st.pup.pvp0-epcpn-mvm.25.2-hrc-psthfr-plgsm.e10.sp-lvnoise.r8-aev_sdexp2.SxR-lvnorm.SxR.d.so-inoise.2xR_ccnorm.r.t5.ss1' try: cellid = site xf_indep, ctx_indep = load_model_xform(modelname=indep, batch=batch, cellid=cellid) xf_rlv, ctx_rlv = load_model_xform(modelname=rlv, batch=batch, cellid=cellid) xf_plv, ctx_plv = load_model_xform(modelname=plv, batch=batch, cellid=cellid) except: cellid = [c for c in nd.get_batch_cells(batch).cellid if site in c][0] xf_indep, ctx_indep = load_model_xform(modelname=indep, batch=batch, cellid=cellid) xf_rlv, ctx_rlv = load_model_xform(modelname=rlv, batch=batch, cellid=cellid) xf_plv, ctx_plv = load_model_xform(modelname=plv, batch=batch, cellid=cellid) # GET COV MATRICES stim = np.arange(10) ibg, ism = fhelp.get_cov_matrices(ctx_indep['val'].copy(), sig='pred', sub='psth_sp', stims=stim, ss=0) rbg, rsm = fhelp.get_cov_matrices(ctx_rlv['val'].copy(), sig='pred', sub='psth_sp', stims=stim, ss=0) pbg, psm = fhelp.get_cov_matrices(ctx_plv['val'].copy(), sig='pred', sub='psth_sp', stims=stim, ss=0) bg, sm = fhelp.get_cov_matrices(ctx_plv['val'].copy(), sig='resp', sub='psth_sp', stims=stim, ss=0) mm = np.abs(np.max(np.concatenate((ibg, ism, rbg, rsm, pbg, psm, bg, sm))))