def reconsitute_rec(batch, cellid_list, modelname): ''' Takes a group of single cell recordings (from cells of a population recording) including their model predictions, and builds a recording withe signals containing the responses and predictions of all the cells in the population This is to make the recordigs compatible with downstream dispersion analisis or any analysis working with signals of neuronal populations :param batch: int batch number :param cellid_list: [str, str ...] list of cell IDs :param modelname: str. modelaname :return: NEMS Recording object ''' result_paths = _get_result_paths(batch, cellid_list, modelname) cell_resp_dict = dict() cell_pred_dict = col.defaultdict() for ff, filepath in enumerate(result_paths): # use modelsepcs to predict the response of resp xfspec, ctx = xforms.load_analysis(filepath=filepath, eval_model=False, only=slice(0, 2, 1)) modelspecs = ctx['modelspecs'][0] cellid = modelspecs[0]['meta']['cellid'] real_modelname = modelspecs[0]['meta']['modelname'] rec = ctx['rec'].copy() rec = ms.evaluate( rec, modelspecs) # recording containing signal for resp and pred # holds and organizes the raw data, keeping track of the cell for later concatenations. cell_resp_dict.update( rec['resp']._data ) # in PointProcess signals _data is already a dict, thus the use of update cell_pred_dict[cellid] = rec[ 'pred']._data # in Rasterized signals _data is a matrix, thus the requirement to asign key. # create a new population recording. pull stim from last single cell, create signal from meta form last resp signal and # stacked data for all cells. modify signal metadata to be consistent with new data and cells contained pop_resp = rec['resp']._modified_copy(data=cell_resp_dict, chans=list(cell_resp_dict.keys()), nchans=len( list(cell_resp_dict.keys()))) stack_data = np.concatenate(list(cell_pred_dict.values()), axis=0) pop_pred = rec['pred']._modified_copy(data=stack_data, chans=list(cell_pred_dict.keys()), nchans=len( list(cell_pred_dict.keys()))) reconstituted_recording = rec.copy() reconstituted_recording['resp'] = pop_resp reconstituted_recording['pred'] = pop_pred del reconstituted_recording.signals['state'] del reconstituted_recording.signals['state_raw'] return reconstituted_recording
def load_model_xform(cellid, batch=271, modelname="ozgf100ch18_wcg18x2_fir15x2_lvl1_dexp1_fit01", eval_model=True, only=None): ''' Load a model that was previously fit via fit_model_xforms Parameters ---------- cellid : str cellid in celldb database batch : int batch number in celldb database modelname : str modelname in celldb database eval_model : boolean If true, the entire xfspec will be re-evaluated after loading. only : int Index of single xfspec step to evaluate if eval_model is False. For example, only=0 will typically just load the recording. Returns ------- xfspec, ctx : nested list, dictionary ''' kws = escaped_split(modelname, '_') old = False if (len(kws) > 3) or ((len(kws) == 3) and kws[1].startswith('stategain') and not kws[1].startswith('stategain.')): # Check if modelname uses old format. log.info("Using old modelname format ... ") old = True d = nd.get_results_file(batch, [modelname], [cellid]) filepath = d['modelpath'][0] # TODO add BAPHY_API support . Not implemented on nems_baphy yet? #if get_setting('USE_NEMS_BAPHY_API'): # prefix = '/auto/data/nems_db' # get_setting('NEMS_RESULTS_DIR') # uri = filepath.replace(prefix, # 'http://' + get_setting('NEMS_BAPHY_API_HOST') + ":" + str(get_setting('NEMS_BAPHY_API_PORT'))) #else: # uri = filepath.replace('/auto/data/nems_db/results', get_setting('NEMS_RESULTS_DIR')) # hack: hard-coded assumption that server will use this data root uri = filepath.replace('/auto/data/nems_db/results', get_setting('NEMS_RESULTS_DIR')) if old: raise NotImplementedError("need to use oxf library.") xfspec, ctx = oxf.load_analysis(uri, eval_model=eval_model) else: xfspec, ctx = xforms.load_analysis(uri, eval_model=eval_model, only=only) return xfspec, ctx
def load_model_baphy_xform( cellid, batch=271, modelname="ozgf100ch18_wcg18x2_fir15x2_lvl1_dexp1_fit01", eval_model=True, only=None): ''' DEPRECATED. Migrated to xhelp.load_model_xform() Load a model that was previously fit via fit_model_xforms_baphy. Parameters ---------- cellid : str cellid in celldb database batch : int batch number in celldb database modelname : str modelname in celldb database eval_model : boolean If true, the entire xfspec will be re-evaluated after loading. only : int Index of single xfspec step to evaluate if eval_model is False. For example, only=0 will typically just load the recording. Returns ------- xfspec, ctx : nested list, dictionary ''' kws = nems.utils.escaped_split(modelname, '_') old = False if (len(kws) > 3) or ((len(kws) == 3) and kws[1].startswith('stategain') and not kws[1].startswith('stategain.')): # Check if modelname uses old format. log.info("Using old modelname format ... ") old = True d = nd.get_results_file(batch, [modelname], [cellid]) filepath = d['modelpath'][0] if old: xfspec, ctx = oxf.load_analysis(filepath, eval_model=eval_model) else: xfspec, ctx = xforms.load_analysis(filepath, eval_model=eval_model, only=only) return xfspec, ctx
def dstrf_sample(ctx=None, cellid='TAR010c-18-2', savepath=None): if ctx is None: xf, ctx = load_analysis(savepath) rec = ctx['val'] modelspec = ctx['modelspec'] cellids = rec['resp'].chans match = [c == cellid for c in cellids] c = np.where(match)[0][0] maxbins = 1000 stepbins = 3 memory = 15 # analyze all output channels out_channel = [c] channel_count = len(out_channel) stim_mag = rec['stim'].as_continuous()[:, :maxbins].sum(axis=0) index_range = np.arange(0, len(stim_mag), 1) log.info( 'Calculating dstrf for %d channels, %d timepoints (%d steps), memory=%d', channel_count, len(index_range), stepbins, memory) dstrf = compute_dstrf(modelspec, rec.copy(), out_channel=out_channel, memory=memory, index_range=index_range) rr = (150, 550) dindex = [52, 54, 56, 65, 133, 165, 215, 220, 233, 240, 250] f = dstrf_details(modelspec, rec, cellid, rr, dindex, dstrf=dstrf, dpcs=None, maxbins=maxbins) f.suptitle(f"{cellid} r_test={modelspec.meta['r_test'][c][0]:.3f}") return f, dstrf
def kamiak_to_database(cellids, batch, modelnames, source_path, executable_path=None, script_path=None): user = '******' linux_user = '******' allowqueuemaster = 1 waitid = 0 parmstring = '' rundataid = 0 priority = 1 reserve_gb = 0 codeHash = 'kamiak' if executable_path in [None, 'None', 'NONE', '']: executable_path = get_setting('DEFAULT_EXEC_PATH') if script_path in [None, 'None', 'NONE', '']: script_path = get_setting('DEFAULT_SCRIPT_PATH') combined = [(c, b, m) for c, b, m in itertools.product(cellids, [batch], modelnames)] notes = ['%s/%s/%s' % (c, b, m) for c, b, m in combined] commandPrompts = [ "%s %s %s %s %s" % (executable_path, script_path, c, b, m) for c, b, m in combined ] engine = nd.Engine() for (c, b, m), note, commandPrompt in zip(combined, notes, commandPrompts): path = os.path.join(source_path, batch, c, m) if not os.path.exists(path): log.warning("missing fit for: \n%s\n%s\n%s\n" "using path: %s\n", batch, c, m, path) continue else: xfspec, ctx = xforms.load_analysis(path, eval_model=False) preview = ctx['modelspec'].meta.get('figurefile', None) if 'log' not in ctx: ctx['log'] = 'missing log' figures_to_load = ctx['figures_to_load'] figures = [xforms.load_resource(f) for f in figures_to_load] ctx['figures'] = figures xforms.save_analysis(None, None, ctx['modelspec'], xfspec, ctx['figures'], ctx['log']) nd.update_results_table(ctx['modelspec'], preview=preview) conn = engine.connect() sql = 'SELECT * FROM tQueue WHERE note="' + note + '"' r = conn.execute(sql) if r.rowcount > 0: # existing job, figure out what to do with it x = r.fetchone() queueid = x['id'] complete = x['complete'] if complete == 1: # Do nothing - the queue already shows a complete job pass elif complete == 2: # Change dead to complete sql = "UPDATE tQueue SET complete=1, killnow=0 WHERE id={}".format( queueid) r = conn.execute(sql) else: # complete in [-1, 0] -- already running or queued # Do nothing pass else: # New job sql = "INSERT INTO tQueue (rundataid,progname,priority," +\ "reserve_gb,parmstring,allowqueuemaster,user," +\ "linux_user,note,waitid,codehash,queuedate,complete) VALUES"+\ " ({},'{}',{}," +\ "{},'{}',{},'{}'," +\ "'{}','{}',{},'{}',NOW(),1)" sql = sql.format(rundataid, commandPrompt, priority, reserve_gb, parmstring, allowqueuemaster, user, linux_user, note, waitid, codeHash) r = conn.execute(sql) conn.close()
all_cells = nd.get_batch_cells(batch=310).cellid.tolist() goodcell = 'BRT037b-39-1' best_model = 'wc.2x2.c-stp.2-fir.2x15-lvl.1-stategain.18-dexp.1' test_path = '/auto/data/nems_db/results/310/BRT037b-39-1/BRT037b-39-1.wc.2x2.c_stp.2_fir.2x15_lvl.1_stategain.18_dexp.1.fit_basic.2018-11-14T093820/' rerun = False # compare goodness of fit between models # iteratively go trough file if rerun == True: population_metas = list() for filepath in result_paths: _, ctx = xforms.load_analysis(filepath=filepath, eval_model=True, only=None) meta = ctx['modelspecs'][0][0]['meta'] # extract important values into a dictionary subset_keys = ['cellid', 'modelname', 'r_test', 'r_fit'] d = {k: v for k, v in meta.items() if k in subset_keys} d['r_test'] = d['r_test'][0] d['r_fit'] = d['r_fit'][0] population_metas.append(d) meta_DF = pd.DataFrame(population_metas) jl.dump(meta_DF, '/home/mateo/code/context_probe_analysis/pickles/summary_metrics') else: meta_DF = jl.load( '/home/mateo/code/context_probe_analysis/pickles/summary_metrics')
def load_xfrom_from_folder(self, directory, eval_model=True): """Loads an xform/context from a directory.""" xfspec, ctx = xforms.load_analysis(str(directory), eval_model=eval_model) return xfspec, ctx
def __init__(self, batch, cellids, modelnames, parent=None): ''' contexts should be a nested dictionary with the format: contexts = { cellid1: {'model1': ctx_a, 'model2': ctx_b}, cellid2: {'model1': ctx_c, 'model2': ctx_d}, ... } ''' super(qw.QWidget, self).__init__() d = nd.get_results_file(batch=batch, cellids=cellids, modelnames=modelnames) contexts = {} for c in cellids: cell_contexts = {} for m in modelnames: try: filepath = d[d.cellid == c][d.modelname == m]['modelpath'].values[0] + '/' xfspec, ctx = xforms.load_analysis(filepath, eval_model=True) cell_contexts[m] = ctx except IndexError: print("Coudln't find modelpath for cell: %s model: %s" % (c, m)) pass contexts[c] = cell_contexts self.contexts = contexts self.batch = batch self.cellids = cellids self.modelnames = modelnames self.time_scroller = TimeScroller(self) self.layout = qw.QVBoxLayout() self.tabs = qw.QTabWidget() self.comparison_tabs = [] for k, v in self.contexts.items(): names = list(v.keys()) names.insert(0, 'Response') signals = [] for i, m in enumerate(v): if i == 0: resp = resp = v[list(v.keys())[0]]['val']['resp'] times = np.linspace(0, resp.shape[-1] / resp.fs, resp.shape[-1]) signals.append(resp.as_continuous().T) signals.append(v[m]['val']['pred'].as_continuous().T) if signals: tab = ComparisonFrame(signals, names, times, self) self.comparison_tabs.append(tab) self.tabs.addTab(tab, k) else: pass self.time_scroller._update_max_time() self.layout.addWidget(self.tabs) self.layout.addWidget(self.time_scroller) self.setLayout(self.layout)
#batch, cellid = 269, 'chn019a-a1' #batch, cellid = 269, 'oys042c-d1' #batch, cellid = 273, 'chn041d-b1' #batch, cellid = 273, 'zee027b-c1' batch, cellid = 269, 'btn144a-a1' #modelspec = 'RDTwcg18x2-RDTfir2x15_RDTstreamgain_lvl1_dexp1' #keywordstring = 'dlog-wc.18x1.g-fir.1x15-lvl.1' #keywordstring = 'rdtwc.18x1.g-rdtfir.1x15-rdtgain.relative.NTARGETS-lvl.1' keywordstring = 'rdtgain.gen.NTARGETS-rdtmerge.stim-wc.18x1.g-fir.1x15-lvl.1' modelname = 'rdtld-rdtshf.rep-rdtsev-rdtfmt_' + keywordstring + '_init-basic' savefile = nw.fit_model_xforms_baphy(cellid, batch, modelname, saveInDB=False) #xf,ctx = nw.load_model_baphy_xform(cellid,batch,modelname) xf, ctx = xforms.load_analysis(savefile) # browse_context(ctx, 'val', signals=['stim', 'resp', 'fg_sf', 'bg_sf', 'state']) """ # database-free version recording_uri = '/Users/svd/python/nems/recordings/chn019a_e3a6a2e25b582125a7a6ee98d8f8461557ae0cf7.tgz' #recording_uri = '/Users/svd/python/nems/recordings/chn019a_16e888cad7fef05b2f51c58874bd07040ae80903.tgz' shuff_streams=False shuff_rep=False xfspec = [ ('nems.xforms.init_context', {'batch': batch, 'cellid': cellid, 'keywordstring': keywordstring, 'recording_uri': recording_uri}), ('nems_lbhb.rdt.io.load_recording', {}), ('nems_lbhb.rdt.preprocessing.rdt_shuffle', {'shuff_streams': shuff_streams, 'shuff_rep': shuff_rep}), ('nems_lbhb.rdt.preprocessing.split_est_val', {}), ('nems_lbhb.rdt.xforms.format_keywordstring', {}), ('nems.xforms.init_from_keywords', {}),
self.plot_container[id(layer_area)] = layer_area self.add_collapsible_dock(layer_area, window_title=f'{recording}:{signal}') # layer_area.plotWidget.update_plot(y_data=signal_data, y_data_name=signal) layer_area.parent().parent().set_toggle(True) # add the plot types to the combo box layer_area.comboBox.blockSignals(True) layer_area.comboBox.addItems(PG_PLOTS.keys()) layer_area.comboBox.blockSignals(False) # only link if shapes match ## TODO: make this a try in the plot widget # if self.ctx['val']['stim']._data.shape[-1] == signal_data.shape[-1]: # self.link_together(layer_area.plotWidget) layer_area.update_plot() if __name__ == '__main__': from nems import xforms demo_model = '/auto/data/nems_db/results/322/ARM030a-28-2/ozgf.fs100.ch18-ld-sev.dlog-wc.18x3.g-fir.3x15-lvl.1-dexp.1.tfinit.n.lr1e3.rb5.es20-newtf.n.lr1e4.es20.2021-06-15T212246' xfspec, ctx = xforms.load_analysis(demo_model) #xfspec, ctx = xforms.load_context(r'C:\Users\Alex\PycharmProjects\NEMS\results\temp_xform') app = QApplication(sys.argv) window = MainWindow(ctx=ctx, xfspec=xfspec) app.exec_()
def generate_state_corrected_psth(batch=None, modelname=None, cellids=None, siteid=None, movement_mask=False, gain_only=False, dc_only=False, cache_path=None, recache=False): """ Modifies the exisiting recording so that psth signal is the prediction specified by the modelname. Designed with stategain models in mind. CRH. If the model doesn't exist already in /auto/users/hellerc/results/, this will go ahead and fit the model and save it in /auto/users/hellerc/results. If the fit dir (from xforms) exists, simply reload the result and call this psth. """ if siteid is None: raise ValueError("must specify siteid!") if cache_path is not None: fn = cache_path + siteid + '_{}.tgz'.format(modelname.split('.')[1]) if gain_only: fn = fn.replace('.tgz', '_gonly.tgz') if 'mvm' in modelname: fn = fn.replace('.tgz', '_mvm.tgz') if (os.path.isfile(fn)) & (recache == False): rec = Recording.load(fn) return rec else: # do the rest of the code pass if batch is None or modelname is None: raise ValueError('Must specify batch and modelname!') results_table = nd.get_results_file(batch, modelnames=[modelname]) preds = [] ms = [] for cell in cellids: log.info(cell) try: p = results_table[results_table['cellid']==cell]['modelpath'].values[0] if os.path.isdir(p): xfspec, ctx = xforms.load_analysis(p) preds.append(ctx['val']) ms.append(ctx['modelspec']) else: sys.exit('Fit for {0} does not exist'.format(cell)) except: log.info("WARNING: fit doesn't exist for cell {0}".format(cell)) # Need to add a check to make sure that the preds are the same length (if # multiple cellids). This could be violated if one cell for example existed # in a prepassive run but the other didn't and so they were fit differently file_epochs = [] for pr in preds: file_epochs += [ep for ep in pr.epochs.name if ep.startswith('FILE')] unique_files = np.unique(file_epochs) shared_files = [] for f in unique_files: if np.sum([1 for file in file_epochs if file == f]) == len(preds): shared_files.append(str(f)) else: # this rawid didn't span all cells at the requested site pass # mask all file epochs for all preds with the shared file epochs # and adjust epochs if (int(batch) == 307) | (int(batch) == 294): for i, p in enumerate(preds): preds[i] = p.and_mask(shared_files) preds[i] = preds[i].apply_mask(reset_epochs=True) sigs = {} for i, p in enumerate(preds): if gain_only: # update phi mspec = ms[i] not_gain_keys = [k for k in mspec[0]['phi'].keys() if '_g' not in k] for k in not_gain_keys: mspec[0]['phi'][k] = np.append(mspec[0]['phi'][k][0, 0], np.zeros(mspec[0]['phi'][k].shape[-1]-1))[np.newaxis, :] pred = mspec.evaluate(p)['pred'] elif dc_only: mspec = ms[i] not_dc_keys = [k for k in mspec[0]['phi'].keys() if '_d' not in k] for k in not_dc_keys: mspec[0]['phi'][k] = np.append(mspec[0]['phi'][k][0, 0], np.zeros(mspec[0]['phi'][k].shape[-1]-1))[np.newaxis, :] pred = mspec.evaluate(p)['pred'] else: pred = p['pred'] if i == 0: new_psth_sp = p['psth_sp'] new_psth = pred new_resp = p['resp'].rasterize() else: try: new_psth_sp = new_psth_sp.concatenate_channels([new_psth_sp, p['psth_sp']]) new_psth = new_psth.concatenate_channels([new_psth, pred]) new_resp = new_resp.concatenate_channels([new_resp, p['resp'].rasterize()]) except ValueError: import pdb; pdb.set_trace() new_pup = preds[0]['pupil'] sigs['pupil'] = new_pup if 'pupil_raw' in preds[0].signals.keys(): sigs['pupil_raw'] = preds[0]['pupil_raw'] if 'mask' in preds[0].signals: new_mask = preds[0]['mask'] sigs['mask'] = new_mask else: mask_rec = preds[0].create_mask(True) new_mask = mask_rec['mask'] sigs['mask'] = new_mask if 'rem' in preds[0].signals.keys(): rem = preds[0]['rem'] sigs['rem'] = rem if 'pupil_eyespeed' in preds[0].signals.keys(): new_eyespeed = preds[0]['pupil_eyespeed'] sigs['pupil_eyespeed'] = new_eyespeed new_psth_sp.name = 'psth_sp' new_psth.name = 'psth' new_resp.name = 'resp' sigs['psth_sp'] = new_psth_sp sigs['psth'] = new_psth sigs['resp'] = new_resp new_rec = Recording(sigs, meta=preds[0].meta) # make sure mask is cast to bool new_rec['mask'] = new_rec['mask']._modified_copy(new_rec['mask']._data.astype(bool)) if cache_path is not None: log.info('caching {}'.format(fn)) new_rec.save_targz(fn) return new_rec
#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Tue Jan 23 10:11:40 2018 @author: svd """ import nems.plots.api as nplt import nems.db as nd import nems.xforms as xforms from nems.gui.recording_browser import browse_recording, browse_context cellid = 'TAR010c-18-1' batch = 271 #modelname = 'wc.18x1.g-fir.1x15-lvl.1' modelname = 'dlog-wc.18x1.g-fir.1x15-lvl.1' #modelname = 'dlog-wc.18x1.g-stp.1-fir.1x15-lvl.1-dexp.1' d = nd.get_results_file(batch=batch, cellids=[cellid], modelnames=[modelname]) filepath = d['modelpath'][0] + '/' xfspec, ctx = xforms.load_analysis(filepath, eval_model=False) ctx, log_xf = xforms.evaluate(xfspec, ctx) #nplt.quickplot(ctx) ctx['modelspec'].quickplot(ctx['val']) aw = browse_context(ctx, signals=['stim', 'pred', 'resp'])
ax.set_xlabel(model, fontsize=6) ax.set_ylabel(lp_model, fontsize=6) ax.plot([-1, 1], [-1, 1], 'k--', zorder=-1) ax.axhline(0, linestyle='--', color='k') ax.axvline(0, linestyle='--', color='k') ax.axis('equal') ax.set_title("gain weights for LV") f.tight_layout() plt.show() import nems.xforms as xforms from nems.xform_helper import load_model_xform # load an old model modelpath = '/auto/data/nems_db/results/331/AMT020a/psth.fs4.pup-loadpred.cpnmvm-st.pup.pvp0-plgsm.e10.sp-lvnoise.r8-aev.lvnorm.SxR.d.so-inoise.2xR.ccnorm.t5.ss1.2021-06-21T174425' xf, ctx = xforms.load_analysis(modelpath) # load a new model site = 'AMT020a' batch = 331 modelname = "psth.fs4.pup-ld-st.pup.pvp0-epcpn.old-mvm.t25.w1-hrc-psthfr-plgsm.e10.sp-aev_sdexp2.2xR-lvnorm.SxR.d.so-inoise.2xR_init.xx1.it50000-lvnoise.r8-aev-ccnorm.f0.ss1" xfn, ctxn = load_model_xform( cellid=[c for c in nd.get_batch_cells(batch).cellid if site in c][0], batch=batch, modelname=modelname) # old/new loadpred models modelname = "psth.fs4.pup-loadpred.cpnOldmvm,t25,w1-st.pup.pvp0-plgsm.e10.sp-lvnoise.r8-aev_lvnorm.SxR.d.so-inoise.2xR_ccnorm.t5.ss1" xfnn, ctxnn = load_model_xform(cellid=site, batch=batch, modelname=modelname)
# fast LV with pupil (gain model with sigmoid nonlinearity) modelname = 'ns.fs4.pup-ld-st.pup-hrc-apm-pbal-psthfr-ev-addmeta_slogsig.SxR-lv.1xR.f.pred-lvlogsig.2xR_jk.nf5.p-pupLVbasic.constrLVonly.af0:2.sc' # single module for LV (testing LV modeling architectures) #modelname = 'ns.fs4.pup-ld-st.pup-hrc-apm-pbal-psthfr-ev-addmeta_puplvmodel.pred.step.dc.R_jk.nf5.p-pupLVbasic.constrLVonly.af0:3.sc.rb10' #modelname0 = 'ns.fs4.pup-ld-st.pup-hrc-apm-pbal-psthfr-ev-addmeta_puplvmodel.dc.pupOnly.R_jk.nf5.p-pupLVbasic.constrLVonly.af0:0.sc' # without jackknifing (or cross validation) modelname = 'ns.fs4.pup-ld-st.pup-epsig-hrc-apm-pbal-psthfr-ev-addmeta-aev_puplvmodel.pred.step.g.dc.R_pupLVbasic.constrNC.af0:1.sc.rb2' modelname0 = 'ns.fs4.pup-ld-st.pup-hrc-apm-pbal-psthfr-ev-addmeta-aev_puplvmodel.g.dc.pupOnly.R_pupLVbasic.constrLVonly.af0:0.sc.rb2' xforms_model = 'ns.fs4.pup-ld-st.pup-hrc-apm-pbal-psthfr-addmeta-aev_puplvmodel.pred.step.pfix.dc.R_pupLVbasic.constrLVonly.af0:0.sc.rb10' if load: c = [c for c in nd.get_batch_cells(batch).cellid if cellid in c][0] mp = nd.get_results_file(batch, [modelname], [c]).modelpath[0] _, ctx = xforms.load_analysis(mp) mp = nd.get_results_file(batch, [modelname0], [c]).modelpath[0] _, ctx2 = xforms.load_analysis(mp) else: ctx = xfit.fit_xforms_model(batch, cellid, modelname, save_analysis=False) ctx2 = xfit.fit_xforms_model(batch, cellid, modelname0, save_analysis=False) # plot lv, pupil, PC1 timecourses if '.g.' in modelname: key1 = 'pg' key2 = 'lvg' if '.dc.' in modelname: key1 = 'pd'
# xfspec.append(['nems.xforms.average_away_stim_occurrences', {}]) xfspec.append([ 'nems.xforms.init_from_keywords', { 'keywordstring': modelspec_name, 'meta': meta } ]) xfspec.append(['nems.xforms.fit_basic_init', {}]) xfspec.append(['nems.xforms.fit_basic', {}]) xfspec.append(['nems.xforms.predict', {}]) xfspec.append(['nems.xforms.add_summary_statistics', {}]) xfspec.append(['nems.xforms.plot_summary', {}]) ctx, log_xf = xforms.evaluate(xfspec) modelspecs = ctx['modelspecs'] destination = '/auto/data/nems_db/results/{0}/{1}/{2}/'.format( batch, cellid, ms.get_modelspec_longname(modelspecs[0])) modelspecs[0][0]['meta']['modelpath'] = destination modelspecs[0][0]['meta']['figurefile'] = destination + 'figure.0000.png' modelspecs[0][0]['meta'].update(meta) save_data = xforms.save_analysis(destination, recording=ctx['rec'], modelspecs=modelspecs, xfspec=xfspec, figures=ctx['figures'], log=log_xf) savepath = save_data['savepath'] loaded = xforms.load_analysis(filepath=savepath, eval_model=True, only=None)
datapath = '/auto/users/svd/projects/reward_training/nems_export/torcs/' # find all models, knowing that the folder names should contain "2020" d = glob.glob(datapath + '*2020*') # get unique list of cells and modelnames cellids = list(set([f.split("_")[0] for f in d])) modelnames = list(set([f.split("_")[1].split('.2020')[0] for f in d])) # load an example model cellid = 'NMK020c-29-1' modelinfo = 'st.pup.fil-' i, = np.where([(cellid in f) and (modelinfo in f) for f in d]) i = i[0] # is an array, just take first entry (should just be 1) xf, ctx = load_analysis(d[i], eval_model=False) modelspec = ctx['modelspec'] print('Loaded model {}'.format(os.path.basename(d[i]))) print('Cellid: {}\nModel name: {}'.format(modelspec.meta['cellid'], modelspec.meta['modelname'])) print('r_test: {:.3f}'.format(modelspec.meta['r_test'][0][0])) state_channels = modelspec.meta['state_chans'] for i, s in enumerate(state_channels): # find name of current file print("{}: offset={:.3f} gain={:.3f} MI={:.3f}".format( s, modelspec.phi[0]['d'][0, i], modelspec.phi[0]['g'][0, i], modelspec.meta['state_mod'][i]))
best_alpha = pd.read_csv( '/auto/users/hellerc/code/projects/nat_pupil_ms_final/dprime/best_alpha.csv', index_col=0) alpha = best_alpha.loc[site][0] alpha = (float(alpha.split(',')[0].replace('(', '')), float(alpha.split(',')[1].replace(')', ''))) a = 'af{0}.as{1}.sc.rb10'.format( str(alpha[0]).replace('.', ':'), str(alpha[1]).replace('.', ':')) modelname = 'ns.fs4.pup-ld-hrc-apm-pbal-psthfr-ev-residual-addmeta_lv.2xR.f.s-lvlogsig.3xR.ipsth_jk.nf5.p-pupLVbasic.constrLVonly.{}'.format( a) cellid = [c for c in nd.get_batch_cells(batch).cellid if site in c][0] mp = nd.get_results_file(batch, [modelname], [cellid]).modelpath[0] xfspec, ctx = xforms.load_analysis(mp) r = ctx['val'].apply_mask() fs = r['resp'].fs fast = r['lv'].extract_channels(['lv_fast'])._data.squeeze() slow = r['lv'].extract_channels(['lv_slow'])._data.squeeze() pupil = r['pupil']._data.squeeze() o = ss.periodogram(fast, fs=fs) F.append(o[1].squeeze()) Fm.append(o[0][np.argmax(o[1].squeeze())]) o = ss.periodogram(slow, fs=fs) S.append(o[1].squeeze()) Sm.append(o[0][np.argmax(o[1].squeeze())])
# NC constraint modelname = [ 'ns.fs4.pup-ld-st.pup-hrc-apm-psthfr-ev_slogsig.SxR-lv.1xR-lvlogsig.2xR_jk.nf5.p-pupLVbasic.constrNC.a0:35' ] modelname = [ 'ns.fs4.pup-ld-st.pup-hrc-apm-psthfr-ev-residual_slogsig.SxR-lv.1xR-lvlogsig.2xR_jk.nf2.p-pupLVbasic.constrNC.a0:05' ] modelname2 = ['ns.fs4.pup-ld-st.pup-hrc-psthfr-ev_slogsig.SxR_jk.nf5.p-basic'] batch = 289 cellids = nd.get_batch_cells(batch).cellid cellid = [[c for c in cellids if site in c][0]] mp = nd.get_results_file(batch, modelname, cellid).modelpath[0] mp2 = nd.get_results_file(batch, modelname2, cellid).modelpath[0] xfspec, ctx = xforms.load_analysis(mp) xfspec2, ctx2 = xforms.load_analysis(mp2) # plot summary of fit ctx2['modelspec'].quickplot() ctx['modelspec'].quickplot() rec = ctx['val'].apply_mask(reset_epochs=True).copy() # raw recording rec['lv'] = rec['lv']._modified_copy(rec['lv']._data[1, :][np.newaxis, :]) rec1 = ctx2['val'].apply_mask( reset_epochs=True).copy() # first order regression rec12 = rec.copy() # first / second order regression #rec1 = preproc.regress_state(rec1, state_sigs=['pupil'], regress=['pupil']) #rec12 = preproc.regress_state(rec12, state_sigs=['pupil', 'lv'], regress=['pupil', 'lv']) psth = preproc.generate_psth(rec12)