def load_model_xform(cellid, batch=271, modelname="ozgf100ch18_wcg18x2_fir15x2_lvl1_dexp1_fit01", eval_model=True, only=None): ''' Load a model that was previously fit via fit_model_xforms Parameters ---------- cellid : str cellid in celldb database batch : int batch number in celldb database modelname : str modelname in celldb database eval_model : boolean If true, the entire xfspec will be re-evaluated after loading. only : int Index of single xfspec step to evaluate if eval_model is False. For example, only=0 will typically just load the recording. Returns ------- xfspec, ctx : nested list, dictionary ''' kws = escaped_split(modelname, '_') old = False if (len(kws) > 3) or ((len(kws) == 3) and kws[1].startswith('stategain') and not kws[1].startswith('stategain.')): # Check if modelname uses old format. log.info("Using old modelname format ... ") old = True d = nd.get_results_file(batch, [modelname], [cellid]) filepath = d['modelpath'][0] # TODO add BAPHY_API support . Not implemented on nems_baphy yet? #if get_setting('USE_NEMS_BAPHY_API'): # prefix = '/auto/data/nems_db' # get_setting('NEMS_RESULTS_DIR') # uri = filepath.replace(prefix, # 'http://' + get_setting('NEMS_BAPHY_API_HOST') + ":" + str(get_setting('NEMS_BAPHY_API_PORT'))) #else: # uri = filepath.replace('/auto/data/nems_db/results', get_setting('NEMS_RESULTS_DIR')) # hack: hard-coded assumption that server will use this data root uri = filepath.replace('/auto/data/nems_db/results', get_setting('NEMS_RESULTS_DIR')) if old: raise NotImplementedError("need to use oxf library.") xfspec, ctx = oxf.load_analysis(uri, eval_model=eval_model) else: xfspec, ctx = xforms.load_analysis(uri, eval_model=eval_model, only=only) return xfspec, ctx
def load_model_baphy_xform( cellid, batch=271, modelname="ozgf100ch18_wcg18x2_fir15x2_lvl1_dexp1_fit01", eval_model=True, only=None): ''' DEPRECATED. Migrated to xhelp.load_model_xform() Load a model that was previously fit via fit_model_xforms_baphy. Parameters ---------- cellid : str cellid in celldb database batch : int batch number in celldb database modelname : str modelname in celldb database eval_model : boolean If true, the entire xfspec will be re-evaluated after loading. only : int Index of single xfspec step to evaluate if eval_model is False. For example, only=0 will typically just load the recording. Returns ------- xfspec, ctx : nested list, dictionary ''' kws = nems.utils.escaped_split(modelname, '_') old = False if (len(kws) > 3) or ((len(kws) == 3) and kws[1].startswith('stategain') and not kws[1].startswith('stategain.')): # Check if modelname uses old format. log.info("Using old modelname format ... ") old = True d = nd.get_results_file(batch, [modelname], [cellid]) filepath = d['modelpath'][0] if old: xfspec, ctx = oxf.load_analysis(filepath, eval_model=eval_model) else: xfspec, ctx = xforms.load_analysis(filepath, eval_model=eval_model, only=only) return xfspec, ctx
def _get_result_paths(batch, cellid_list, modelname): results_file = nd.get_results_file(batch) ff_cellid = results_file.cellid.isin(cellid_list) ff_modelname = results_file.modelname == modelname result_paths = results_file.loc[ff_cellid & ff_modelname, 'modelpath'].tolist() if not result_paths: raise ValueError( 'no cells in cellid_list fit with this model'.format(modelname)) elif len(result_paths) != len(cellid_list): raise ValueError( 'inconsitent cells in cell_idlist, and in loading paths\n' 'cells') return result_paths
def load_batch_modelpaths(batch, modelnames, cellids=None, eval_model=True): d = nd.get_results_file(batch, [modelnames], cellids=cellids) return d['modelpath'].tolist()
a = 'af0:4.as0:4.sc.rb10' best_alpha = pd.read_csv( '/auto/users/hellerc/code/projects/nat_pupil_ms_final/dprime/best_alpha.csv', index_col=0) alpha = best_alpha.loc[site][0] alpha = (float(alpha.split(',')[0].replace('(', '')), float(alpha.split(',')[1].replace(')', ''))) a = 'af{0}.as{1}.sc.rb10'.format( str(alpha[0]).replace('.', ':'), str(alpha[1]).replace('.', ':')) modelname = 'ns.fs4.pup-ld-hrc-apm-pbal-psthfr-ev-residual-addmeta_lv.2xR.f.s-lvlogsig.3xR.ipsth_jk.nf5.p-pupLVbasic.constrLVonly.{}'.format( a) cellid = [c for c in nd.get_batch_cells(batch).cellid if site in c][0] mp = nd.get_results_file(batch, [modelname], [cellid]).modelpath[0] xfspec, ctx = xforms.load_analysis(mp) r = ctx['val'].apply_mask() fs = r['resp'].fs fast = r['lv'].extract_channels(['lv_fast'])._data.squeeze() slow = r['lv'].extract_channels(['lv_slow'])._data.squeeze() pupil = r['pupil']._data.squeeze() o = ss.periodogram(fast, fs=fs) F.append(o[1].squeeze()) Fm.append(o[0][np.argmax(o[1].squeeze())]) o = ss.periodogram(slow, fs=fs) S.append(o[1].squeeze())
import joblib as jl import nems.db as nd import numpy as np import pandas as pd import nems.xforms as xforms import matplotlib.pyplot as plt import itertools as itt batch = 310 results_file = nd.get_results_file(batch) all_models = results_file.modelname.unique().tolist() result_paths = results_file.modelpath.tolist() mod_modelnames = [ss.replace('-', '_') for ss in all_models] models_shortname = { 'wc.2x2.c-fir.2x15-lvl.1-dexp.1': 'LN', 'wc.2x2.c-stp.2-fir.2x15-lvl.1-dexp.1': 'STP', 'wc.2x2.c-fir.2x15-lvl.1-stategain.18-dexp.1': 'pop', 'wc.2x2.c-stp.2-fir.2x15-lvl.1-stategain.18-dexp.1': 'STP_pop' } all_cells = nd.get_batch_cells(batch=310).cellid.tolist() goodcell = 'BRT037b-39-1' best_model = 'wc.2x2.c-stp.2-fir.2x15-lvl.1-stategain.18-dexp.1' test_path = '/auto/data/nems_db/results/310/BRT037b-39-1/BRT037b-39-1.wc.2x2.c_stp.2_fir.2x15_lvl.1_stategain.18_dexp.1.fit_basic.2018-11-14T093820/' rerun = False
# name of script that you'd like to run script_path = '/auto/users/mateo/context_probe_analysis/cluster_script.py' # parameters that will be passed to script. force_rerun = True batch = 310 # define modelspec_name modelnames = ['wc.2x2.c-fir.2x15-lvl.1-dexp.1', 'wc.2x2.c-stp.2-fir.2x15-lvl.1-dexp.1', 'wc.2x2.c-fir.2x15-lvl.1-stategain.S-dexp.1', 'wc.2x2.c-stp.2-fir.2x15-lvl.1-stategain.S-dexp.1'] # define cellids batch_cells = set(nd.get_batch_cells(batch=batch).cellid) full_analysis = nd.get_results_file(batch) already_analyzed = full_analysis.cellid.unique().tolist() # batch_cells = ['BRT037b-39-1'] # best cell # iterates over every mode, checks what cells have not been fitted with it and runs the fit command. for model in modelnames: ff_model = full_analysis.modelname == model already_fitted_cells = set(full_analysis.loc[ff_model, 'cellid']) cells_to_fit = list(batch_cells.difference(already_fitted_cells)) print('model {}, cells to fit:\n{}'.format(model, cells_to_fit)) out = nd.enqueue_models(celllist=cells_to_fit, batch=batch, modellist=[model], user='******', force_rerun=force_rerun, executable_path=executable_path, script_path=script_path)
def __init__(self, batch, cellids, modelnames, parent=None): ''' contexts should be a nested dictionary with the format: contexts = { cellid1: {'model1': ctx_a, 'model2': ctx_b}, cellid2: {'model1': ctx_c, 'model2': ctx_d}, ... } ''' super(qw.QWidget, self).__init__() d = nd.get_results_file(batch=batch, cellids=cellids, modelnames=modelnames) contexts = {} for c in cellids: cell_contexts = {} for m in modelnames: try: filepath = d[d.cellid == c][d.modelname == m]['modelpath'].values[0] + '/' xfspec, ctx = xforms.load_analysis(filepath, eval_model=True) cell_contexts[m] = ctx except IndexError: print("Coudln't find modelpath for cell: %s model: %s" % (c, m)) pass contexts[c] = cell_contexts self.contexts = contexts self.batch = batch self.cellids = cellids self.modelnames = modelnames self.time_scroller = TimeScroller(self) self.layout = qw.QVBoxLayout() self.tabs = qw.QTabWidget() self.comparison_tabs = [] for k, v in self.contexts.items(): names = list(v.keys()) names.insert(0, 'Response') signals = [] for i, m in enumerate(v): if i == 0: resp = resp = v[list(v.keys())[0]]['val']['resp'] times = np.linspace(0, resp.shape[-1] / resp.fs, resp.shape[-1]) signals.append(resp.as_continuous().T) signals.append(v[m]['val']['pred'].as_continuous().T) if signals: tab = ComparisonFrame(signals, names, times, self) self.comparison_tabs.append(tab) self.tabs.addTab(tab, k) else: pass self.time_scroller._update_max_time() self.layout.addWidget(self.tabs) self.layout.addWidget(self.time_scroller) self.setLayout(self.layout)
else: # No regression. Load any model. xforms_models = [xforms_models[0]] for xforms_modelname in xforms_models: log.info("Load recording from xforms model {}".format(xforms_modelname)) if chan_nums is not None: cellid = [[ c for c in nd.get_batch_cells(batch).cellid if (site in c) & (c.split('-')[1] in chan_nums) ][0]] else: cellid = [[c for c in nd.get_batch_cells(batch).cellid if site in c][0]] mp = nd.get_results_file(batch, [xforms_modelname], cellid).modelpath[0] xfspec, ctx = xforms.load_analysis(mp) # apply hrc (and balanced, if exists) and evoked masks from xforms fit rec = ctx['val'].apply_mask(reset_epochs=True) # only necessary if using correction method 1 if lv_regress: if rec['lv']._data.shape[0] > 1: rec['lv'] = rec['pupil']._modified_copy(rec['lv']._data[1:, :]) rec = rec.create_mask(True) # filtering / pupil regression must always go first! if pupil_regressed & lv_regress: if regression_method1:
def ddr_pred_site_sim(site, batch=None, modelname_base=None, just_return_mean_dp=False, save_fig=False, skip_plot=False): if batch is None: batch = 331 cellid = [c for c in db.get_batch_cells(batch).cellid if site in c][0] if len(cellid) == 0: batch = 322 cellid = [c for c in db.get_batch_cells(batch).cellid if site in c][0] if len(cellid) == 0: raise ValueError(f"No match for site {site} batch {batch}") if modelname_base is None: if batch == 331: modelname_base = "psth.fs4.pup-ld-epcpn-hrc-psthfr.z-pca.cc1.no.p-{0}-plgsm.p2-aev-rd" + \ "_stategain.2xR.x1,3-spred-lvnorm.4xR.so.x2-inoise.4xR.x3" + \ "_tfinit.xx0.n.lr1e4.cont.et4.i50000-lvnoise.r8-aev-ccnorm.t4.f0.ss3" else: modelname_base = "psth.fs4.pup-ld-hrc-psthfr.z-pca.cc1.no.p-{0}-plgsm.p2-aev-rd" + \ "_stategain.2xR.x1,3-spred-lvnorm.4xR.so.x2-inoise.4xR.x3" + \ "_tfinit.xx0.n.lr1e4.cont.et4.i50000-lvnoise.r8-aev-ccnorm.md.t5.f0.ss3" log.info(f"site {site} modelname_base: {modelname_base}") modelnames, states = parse_modelname_base(modelname_base) labels = ['actual'] + states mse = np.zeros((len(modelnames) - 1, 3)) cc = np.zeros((len(modelnames) - 1, 3)) if just_return_mean_dp: dp = np.zeros((len(modelnames), 2)) for i, m in enumerate(modelnames): modelpath = db.get_results_file(batch=batch, modelnames=[m], cellids=[cellid ]).iloc[0]["modelpath"] loader = decoding.DecodingResults() raw_res = loader.load_results( os.path.join(modelpath, "decoding_TDR.pickle")) raw_df = raw_res.numeric_results raw_df = raw_df.loc[pd.IndexSlice[raw_res.evoked_stimulus_pairs, 2], :].copy() if 'mask_bins' in raw_res.meta.keys(): mask_bins = raw_res.meta['mask_bins'] fit_combos = [k for k, v in raw_res.mapping.items() if (('_'.join(v[0].split('_')[:-1]), int(v[0].split('_')[-1])) in mask_bins) & \ (('_'.join(v[1].split('_')[:-1]), int(v[1].split('_')[-1])) in mask_bins)] all_combos = raw_res.evoked_stimulus_pairs val_combos = [c for c in all_combos if c not in fit_combos] s = raw_df["sp_dp"] / (raw_df["sp_dp"] + raw_df["bp_dp"]) l = raw_df["bp_dp"] / (raw_df["sp_dp"] + raw_df["bp_dp"]) dp[i, :] = [s.mean(), l.mean] return labels, dp f, ax = plt.subplots(4, len(modelnames), figsize=(8, 6), sharex='row', sharey='row') for i, m in enumerate(modelnames): modelpath = db.get_results_file(batch=batch, modelnames=[m], cellids=[cellid]).iloc[0]["modelpath"] loader = decoding.DecodingResults() raw_res = loader.load_results( os.path.join(modelpath, "decoding_TDR.pickle")) raw_df = raw_res.numeric_results raw_df = raw_df.loc[pd.IndexSlice[raw_res.evoked_stimulus_pairs, 2], :].copy() if 'mask_bins' in raw_res.meta.keys(): mask_bins = raw_res.meta['mask_bins'] fit_combos = [k for k, v in raw_res.mapping.items() if (('_'.join(v[0].split('_')[:-1]), int(v[0].split('_')[-1])) in mask_bins) & \ (('_'.join(v[1].split('_')[:-1]), int(v[1].split('_')[-1])) in mask_bins)] all_combos = raw_res.evoked_stimulus_pairs val_combos = [c for c in all_combos if c not in fit_combos] import pbd pdb.set_trace() if i == 0: mmraw0 = raw_df[["sp_dp", "bp_dp"]].values.max() ax[0, i].plot([0, mmraw0], [0, mmraw0], 'k--', lw=0.5) ax[0, i].scatter(raw_df["sp_dp"], raw_df["bp_dp"], s=3) #a, b = 'delta_pred', 'delta_act' a, b = 'delta_pred_raw', 'delta_act_raw' if i == 0: #raw_df = raw_df.loc[(raw_df["bp_dp"]>10) & (raw_df["sp_dp"]>10)] #raw_df = raw_df.loc[(raw_df["bp_dp"]<60) & (raw_df["sp_dp"]<60)] raw_df.loc[:, 'delta_act_raw'] = (raw_df["bp_dp"] - raw_df["sp_dp"]) raw_df.loc[:, 'delta_act'] = (raw_df["bp_dp"] - raw_df["sp_dp"]) / ( raw_df["bp_dp"] + raw_df["sp_dp"]) raw_df.loc[:, 'bp_dp_act'] = raw_df["bp_dp"] raw_df.loc[:, 'sp_dp_act'] = raw_df["sp_dp"] resp_df = raw_df[[ 'sp_dp_act', 'bp_dp_act', 'delta_act_raw', 'delta_act' ]] ax[1, 0].set_axis_off() ax[2, 0].set_axis_off() ax[3, 0].set_axis_off() mmraw = np.max(np.abs(raw_df[['delta_act_raw']].values)) mmnorm = np.max(np.abs(raw_df[['delta_act']].values)) ax[0, i].set_title(f"{labels[i]} n={len(raw_df)}") else: raw_df.loc[:, 'delta_pred_raw'] = (raw_df["bp_dp"] - raw_df["sp_dp"]) raw_df.loc[:, 'delta_pred'] = (raw_df["bp_dp"] - raw_df["sp_dp"]) / ( raw_df["bp_dp"] + raw_df["sp_dp"]) raw_df = raw_df.merge(resp_df, how='inner', left_index=True, right_index=True) x = np.concatenate((raw_df['bp_dp'], raw_df['sp_dp'])) y = np.concatenate((raw_df['bp_dp_act'], raw_df['sp_dp_act'])) cc[i - 1, 0] = np.corrcoef(x, y)[0, 1] mse[i - 1, 0] = np.std(x - y) ax[1, i].scatter(raw_df['bp_dp'], raw_df['bp_dp_act'], s=3, alpha=0.4) ax[1, i].set_title(f"{cc[i-1, 0]:.3f}") # normed dp a, b = 'delta_pred', 'delta_act' ax[2, i].plot([-mmnorm, mmnorm], [-mmnorm, mmnorm], 'k--', lw=0.5) ax[2, i].scatter(raw_df[a], raw_df[b], s=3, alpha=0.4) cc[i - 1, 1] = np.corrcoef(raw_df[a], raw_df[b])[0, 1] mse[i - 1, 1] = np.std(raw_df[a] - raw_df[b]) ax[2, i].set_title(f"e={mse[i-1,1]:.1f} cc={cc[i-1,1]:.3f}") # raw dp a, b = 'delta_pred_raw', 'delta_act_raw' ax[3, i].plot([-mmraw, mmraw], [-mmraw, mmraw], 'k--', lw=0.5) ax[3, i].scatter(raw_df[a], raw_df[b], s=3, alpha=0.4) cc[i - 1, 2] = np.corrcoef(raw_df[a], raw_df[b])[0, 1] mse[i - 1, 2] = np.std(raw_df[a] - raw_df[b]) ax[3, i].set_title(f"e={mse[i-1,2]:.3f} cc={cc[i-1,2]:.3f}") ax[0, i].set_title(f"{labels[i]}") ax[0, 0].set_ylabel('big pupil dp') ax[0, 0].set_xlabel('small pupil dp') ax[1, 1].set_ylabel('actual big dp') ax[1, 1].set_xlabel('pred big dp') ax[2, 1].set_xlabel('pred delta raw') ax[2, 1].set_ylabel('act delta raw') ax[3, 1].set_xlabel('pred delta norm') ax[3, 1].set_ylabel('act delta norm') pupil_range = raw_res.pupil_range['range'].mean() f.suptitle(f"{site} - {batch} - puprange {pupil_range:.3f}") plt.tight_layout() if save_fig: f.savefig( f'/auto/users/svd/projects/pop_state/ddr_pred_{site}_{batch}.jpg') if skip_plot: plt.close(f) return labels[1:], cc, mse, pupil_range
def generate_state_corrected_psth(batch=None, modelname=None, cellids=None, siteid=None, movement_mask=False, gain_only=False, dc_only=False, cache_path=None, recache=False): """ Modifies the exisiting recording so that psth signal is the prediction specified by the modelname. Designed with stategain models in mind. CRH. If the model doesn't exist already in /auto/users/hellerc/results/, this will go ahead and fit the model and save it in /auto/users/hellerc/results. If the fit dir (from xforms) exists, simply reload the result and call this psth. """ if siteid is None: raise ValueError("must specify siteid!") if cache_path is not None: fn = cache_path + siteid + '_{}.tgz'.format(modelname.split('.')[1]) if gain_only: fn = fn.replace('.tgz', '_gonly.tgz') if 'mvm' in modelname: fn = fn.replace('.tgz', '_mvm.tgz') if (os.path.isfile(fn)) & (recache == False): rec = Recording.load(fn) return rec else: # do the rest of the code pass if batch is None or modelname is None: raise ValueError('Must specify batch and modelname!') results_table = nd.get_results_file(batch, modelnames=[modelname]) preds = [] ms = [] for cell in cellids: log.info(cell) try: p = results_table[results_table['cellid']==cell]['modelpath'].values[0] if os.path.isdir(p): xfspec, ctx = xforms.load_analysis(p) preds.append(ctx['val']) ms.append(ctx['modelspec']) else: sys.exit('Fit for {0} does not exist'.format(cell)) except: log.info("WARNING: fit doesn't exist for cell {0}".format(cell)) # Need to add a check to make sure that the preds are the same length (if # multiple cellids). This could be violated if one cell for example existed # in a prepassive run but the other didn't and so they were fit differently file_epochs = [] for pr in preds: file_epochs += [ep for ep in pr.epochs.name if ep.startswith('FILE')] unique_files = np.unique(file_epochs) shared_files = [] for f in unique_files: if np.sum([1 for file in file_epochs if file == f]) == len(preds): shared_files.append(str(f)) else: # this rawid didn't span all cells at the requested site pass # mask all file epochs for all preds with the shared file epochs # and adjust epochs if (int(batch) == 307) | (int(batch) == 294): for i, p in enumerate(preds): preds[i] = p.and_mask(shared_files) preds[i] = preds[i].apply_mask(reset_epochs=True) sigs = {} for i, p in enumerate(preds): if gain_only: # update phi mspec = ms[i] not_gain_keys = [k for k in mspec[0]['phi'].keys() if '_g' not in k] for k in not_gain_keys: mspec[0]['phi'][k] = np.append(mspec[0]['phi'][k][0, 0], np.zeros(mspec[0]['phi'][k].shape[-1]-1))[np.newaxis, :] pred = mspec.evaluate(p)['pred'] elif dc_only: mspec = ms[i] not_dc_keys = [k for k in mspec[0]['phi'].keys() if '_d' not in k] for k in not_dc_keys: mspec[0]['phi'][k] = np.append(mspec[0]['phi'][k][0, 0], np.zeros(mspec[0]['phi'][k].shape[-1]-1))[np.newaxis, :] pred = mspec.evaluate(p)['pred'] else: pred = p['pred'] if i == 0: new_psth_sp = p['psth_sp'] new_psth = pred new_resp = p['resp'].rasterize() else: try: new_psth_sp = new_psth_sp.concatenate_channels([new_psth_sp, p['psth_sp']]) new_psth = new_psth.concatenate_channels([new_psth, pred]) new_resp = new_resp.concatenate_channels([new_resp, p['resp'].rasterize()]) except ValueError: import pdb; pdb.set_trace() new_pup = preds[0]['pupil'] sigs['pupil'] = new_pup if 'pupil_raw' in preds[0].signals.keys(): sigs['pupil_raw'] = preds[0]['pupil_raw'] if 'mask' in preds[0].signals: new_mask = preds[0]['mask'] sigs['mask'] = new_mask else: mask_rec = preds[0].create_mask(True) new_mask = mask_rec['mask'] sigs['mask'] = new_mask if 'rem' in preds[0].signals.keys(): rem = preds[0]['rem'] sigs['rem'] = rem if 'pupil_eyespeed' in preds[0].signals.keys(): new_eyespeed = preds[0]['pupil_eyespeed'] sigs['pupil_eyespeed'] = new_eyespeed new_psth_sp.name = 'psth_sp' new_psth.name = 'psth' new_resp.name = 'resp' sigs['psth_sp'] = new_psth_sp sigs['psth'] = new_psth sigs['resp'] = new_resp new_rec = Recording(sigs, meta=preds[0].meta) # make sure mask is cast to bool new_rec['mask'] = new_rec['mask']._modified_copy(new_rec['mask']._data.astype(bool)) if cache_path is not None: log.info('caching {}'.format(fn)) new_rec.save_targz(fn) return new_rec
#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Tue Jan 23 10:11:40 2018 @author: svd """ import nems.plots.api as nplt import nems.db as nd import nems.xforms as xforms from nems.gui.recording_browser import browse_recording, browse_context cellid = 'TAR010c-18-1' batch = 271 #modelname = 'wc.18x1.g-fir.1x15-lvl.1' modelname = 'dlog-wc.18x1.g-fir.1x15-lvl.1' #modelname = 'dlog-wc.18x1.g-stp.1-fir.1x15-lvl.1-dexp.1' d = nd.get_results_file(batch=batch, cellids=[cellid], modelnames=[modelname]) filepath = d['modelpath'][0] + '/' xfspec, ctx = xforms.load_analysis(filepath, eval_model=False) ctx, log_xf = xforms.evaluate(xfspec, ctx) #nplt.quickplot(ctx) ctx['modelspec'].quickplot(ctx['val']) aw = browse_context(ctx, signals=['stim', 'pred', 'resp'])
load_dprime = False # fast LV with pupil (gain model with sigmoid nonlinearity) modelname = 'ns.fs4.pup-ld-st.pup-hrc-apm-pbal-psthfr-ev-addmeta_slogsig.SxR-lv.1xR.f.pred-lvlogsig.2xR_jk.nf5.p-pupLVbasic.constrLVonly.af0:2.sc' # single module for LV (testing LV modeling architectures) #modelname = 'ns.fs4.pup-ld-st.pup-hrc-apm-pbal-psthfr-ev-addmeta_puplvmodel.pred.step.dc.R_jk.nf5.p-pupLVbasic.constrLVonly.af0:3.sc.rb10' #modelname0 = 'ns.fs4.pup-ld-st.pup-hrc-apm-pbal-psthfr-ev-addmeta_puplvmodel.dc.pupOnly.R_jk.nf5.p-pupLVbasic.constrLVonly.af0:0.sc' # without jackknifing (or cross validation) modelname = 'ns.fs4.pup-ld-st.pup-epsig-hrc-apm-pbal-psthfr-ev-addmeta-aev_puplvmodel.pred.step.g.dc.R_pupLVbasic.constrNC.af0:1.sc.rb2' modelname0 = 'ns.fs4.pup-ld-st.pup-hrc-apm-pbal-psthfr-ev-addmeta-aev_puplvmodel.g.dc.pupOnly.R_pupLVbasic.constrLVonly.af0:0.sc.rb2' xforms_model = 'ns.fs4.pup-ld-st.pup-hrc-apm-pbal-psthfr-addmeta-aev_puplvmodel.pred.step.pfix.dc.R_pupLVbasic.constrLVonly.af0:0.sc.rb10' if load: c = [c for c in nd.get_batch_cells(batch).cellid if cellid in c][0] mp = nd.get_results_file(batch, [modelname], [c]).modelpath[0] _, ctx = xforms.load_analysis(mp) mp = nd.get_results_file(batch, [modelname0], [c]).modelpath[0] _, ctx2 = xforms.load_analysis(mp) else: ctx = xfit.fit_xforms_model(batch, cellid, modelname, save_analysis=False) ctx2 = xfit.fit_xforms_model(batch, cellid, modelname0, save_analysis=False) # plot lv, pupil, PC1 timecourses if '.g.' in modelname: key1 = 'pg' key2 = 'lvg' if '.dc.' in modelname:
# mns = [ # 'env.fs100-SPOld-stSPO.nb-SPOsev_dlog-stategain.2x2.g.o1.b0d001:5-fir.2x15.z-lvl.1_SDB-init.rb5-basic.t7-SPOpf.Exa', # 'env.fs100-SPOld-stSPO.nb-SPOsev_dlog-stategain.2x2.g.o1.b0d001:5-fir.2x15.z-lvl.1_SDB-init.rb10-basic.t7-SPOpf.Exa' # ] # mns=['env.fs200-SPOld-stSPO.nb-SPOsev-shuf.st_dlog-stategain.2x2.g.o1.b0d001:5-fir.2x30.z-lvl.1-dexp.1_SDB-init.t5.rb5-basic.t6-SPOpf.Exa', # 'env.fs200-SPOld-stSPO.nb-SPOsev_dlog-stategain.2x2.g.o1.b0d001:5-fir.2x30.z-lvl.1-dexp.1_SDB-init.t5.rb5-basic.t6-SPOpf.Exa', # 'env.fs200-SPOld-stSPO.nb-SPOsev-shuf.st_dlog-stategain.2x2.g.o1.b0d001:5-wc.2x2.c-stp.2-fir.2x30.z-lvl.1-dexp.1_SDB-init.t5.rb5-basic.t6-SPOpf.Exa', # 'env.fs200-SPOld-stSPO.nb-SPOsev_dlog-stategain.2x2.g.o1.b0d001:5-wc.2x2.c-stp.2-fir.2x30.z-lvl.1-dexp.1_SDB-init.t5.rb5-basic.t6-SPOpf.Exa' # ] # mns = [mn.replace('fs200','fs100').replace('2x30','2x15') for mn in mns] # comparisons=((0,1),(2,3),(0,2)) ## Get df of modefits #cells = sp.get_significant_cells(batch,mns,as_list=True) #Sig cells across all models #cells = sp.get_significant_cells(batch,mns[:1],as_list=True) #Sig cells in the first model dfc = nd.get_results_file(batch,mns[:1]); cells = list(dfc['cellid'].values) # All cells fit in the first model #cells = [cell for cell in cells if 'fre' not in cell]; print('Keeping only fred cells') cells = list(np.load('/auto/users/luke/Projects/SPS/NEMS/fre_oldfit_newfit_common_subset.npy', allow_pickle=True)) print(f'{len(cells)} cells') df = nd.get_results_file(batch,mns,cells) cells_fit_in_all = set(df.loc[df['modelname']==mns[0],'cellid']) for mn_ in mns[1:]: cells_fit_in_all = cells_fit_in_all.intersection(set(df.loc[df['modelname']==mn_,'cellid'])) df = df[df['cellid'].isin(cells_fit_in_all)] print(f'Dropped not fit, down to {len(df)/len(mns)} cells per model') ## Define fnargs, arguments tha will be passed to a function called when you click on a point fnargs = [{'ax': imageax, 'ft': 5, 'data_series_dict': 'dsx'}, {'ax': imageax2, 'ft': 5, 'data_series_dict': 'dsy'}]
modelname = [ 'ns.fs4.pup-ld-st.pup-hrc-pbal-psthfr-ev_slogsig.SxR-lv.1xR-lvlogsig.2xR_jk.nf5.p-pupLVbasic.a0:009' ] # NC constraint modelname = [ 'ns.fs4.pup-ld-st.pup-hrc-apm-psthfr-ev_slogsig.SxR-lv.1xR-lvlogsig.2xR_jk.nf5.p-pupLVbasic.constrNC.a0:35' ] modelname = [ 'ns.fs4.pup-ld-st.pup-hrc-apm-psthfr-ev-residual_slogsig.SxR-lv.1xR-lvlogsig.2xR_jk.nf2.p-pupLVbasic.constrNC.a0:05' ] modelname2 = ['ns.fs4.pup-ld-st.pup-hrc-psthfr-ev_slogsig.SxR_jk.nf5.p-basic'] batch = 289 cellids = nd.get_batch_cells(batch).cellid cellid = [[c for c in cellids if site in c][0]] mp = nd.get_results_file(batch, modelname, cellid).modelpath[0] mp2 = nd.get_results_file(batch, modelname2, cellid).modelpath[0] xfspec, ctx = xforms.load_analysis(mp) xfspec2, ctx2 = xforms.load_analysis(mp2) # plot summary of fit ctx2['modelspec'].quickplot() ctx['modelspec'].quickplot() rec = ctx['val'].apply_mask(reset_epochs=True).copy() # raw recording rec['lv'] = rec['lv']._modified_copy(rec['lv']._data[1, :][np.newaxis, :]) rec1 = ctx2['val'].apply_mask( reset_epochs=True).copy() # first order regression rec12 = rec.copy() # first / second order regression