def fit_to_simulation(fit_model, simulation_spec): ''' Parameters: ----------- fit_model : str Modelname to fit to the simulation. simulation_spec : NEMS ModelSpec Modelspec to base simulation on. Returns: -------- ctx : dict Xforms context. See nems.xforms. ''' rec = get_default_ctx()['rec'] ctk_idx = find_module('contrast_kernel', simulation_spec) if ctk_idx is not None: simulation_spec[ctk_idx]['fn_kwargs']['evaluate_contrast'] = True new_resp = simulation_spec.evaluate(rec)['pred'] rec['resp'] = new_resp # replace ozgf and ld with ldm modelname = '-'.join(fit_model.split('-')[2:]) xfspec = xhelp.generate_xforms_spec(modelname=modelname) ctx, _ = xforms.evaluate(xfspec, context={'rec': rec}) return ctx
def fit_model(recording_uri, modelstring, destination): ''' Fit a single model and save it to nems_db. ''' recordings = [recording_uri] xfspec = [ ['nems.xforms.load_recordings', {'recording_uri_list': recordings}], ['nems.xforms.add_average_sig', {'signal_to_average': 'resp', 'new_signalname': 'resp', 'epoch_regex': '^STIM_'}], ['nems.xforms.split_by_occurrence_counts', {'epoch_regex': '^STIM_'}], ['nems.xforms.init_from_keywords', {'keywordstring': modelstring}], #['nems.xforms.set_random_phi', {}], ['nems.xforms.fit_basic', {}], # ['nems.xforms.add_summary_statistics', {}], ['nems.xforms.plot_summary', {}] ] ctx, log = xforms.evaluate(xfspec) xforms.save_analysis(destination, recording=ctx['rec'], modelspecs=ctx['modelspecs'], xfspec=xfspec, figures=ctx['figures'], log=log)
def get_model_preds(cellid, batch, modelname): xf, ctx = xhelp.load_model_xform(cellid, batch, modelname, eval_model=False) ctx, l = xforms.evaluate(xf, ctx, stop=-1) #ctx, l = oxf.evaluate(xf, ctx, stop=-1) return xf, ctx
def load_model_baphy_xform( cellid="chn020f-b1", batch=271, modelname="ozgf100ch18_wc18x1_fir15x1_lvl1_dexp1_fit01", eval=True): logging.info('Loading modelspecs...') d = nd.get_results_file(batch, [modelname], [cellid]) savepath = d['modelpath'][0] xfspec = xforms.load_xform(savepath + 'xfspec.json') mspath = savepath + 'modelspec.0000.json' context = xforms.load_modelspecs([], uris=[mspath], IsReload=False) context['IsReload'] = True ctx, log_xf = xforms.evaluate(xfspec, context) return ctx
def reload_model(model_uri): ''' Reloads an xspec and modelspec that were saved in some directory somewhere. This recreates the context that occurred during the fit. Passes additional context {'IsReload': True}, which xforms should react to if they are not intended to be run on a reload. ''' xfspec_uri = model_uri + 'xfspec.json' # TODO: instead of just reading the first modelspec, read ALL of the modelspecs # I'm not sure how to know how many there are without a directory listing! modelspec_uri = model_uri + 'modelspec.0000.json' xfspec = load_resource(xfspec_uri) modelspec = load_resource(modelspec_uri) ctx, reloadlog = xforms.evaluate(xfspec, { 'IsReload': True, 'modelspecs': [modelspec] }) return ctx
def context(): modelspecname = 'fir.1x15-lvl.1' load_command = 'nems.demo.loaders.dummy_loader' meta = {'cellid': "DUMMY01", 'batch': 0, 'modelname': modelspecname, 'exptid': "DUMMY"} xfspec = [] xfspec.append(['nems.xforms.load_recording_wrapper', {'load_command': load_command, 'exptid': meta['exptid'], 'save_cache': False}]) xfspec.append(['nems.xforms.split_at_time', {'valfrac': 0.2}]) xfspec.append(['nems.xforms.init_from_keywords', {'keywordstring': modelspecname, 'meta': meta}]) xfspec.append(['nems.xforms.fit_basic', {}]) xfspec.append(['nems.xforms.predict', {}]) xfspec.append(['nems.xforms.add_summary_statistics', {}]) ctx = {} ctx, xf_log = xforms.evaluate(xfspec, ctx) return ctx
def fit_model_xforms_baphy(cellid, batch, modelname, autoPlot=True, saveInDB=False): """ DEPRECATED ? Now should work for xhelp.fit_model_xform() Fit a single NEMS model using data from baphy/celldb eg, 'ozgf100ch18_wc18x1_lvl1_fir15x1_dexp1_fit01' generates modelspec with 'wc18x1_lvl1_fir15x1_dexp1' based on this function in nems/scripts/fit_model.py def fit_model(recording_uri, modelstring, destination): xfspec = [ ['nems.xforms.load_recordings', {'recording_uri_list': recordings}], ['nems.xforms.add_average_sig', {'signal_to_average': 'resp', 'new_signalname': 'resp', 'epoch_regex': '^STIM_'}], ['nems.xforms.split_by_occurrence_counts', {'epoch_regex': '^STIM_'}], ['nems.xforms.init_from_keywords', {'keywordstring': modelspecname}], ['nems.xforms.set_random_phi', {}], ['nems.xforms.fit_basic', {}], # ['nems.xforms.add_summary_statistics', {}], ['nems.xforms.plot_summary', {}], # ['nems.xforms.save_recordings', {'recordings': ['est', 'val']}], ['nems.xforms.fill_in_default_metadata', {}], ] """ raise NotImplementedError("Replaced by xhelper function?") raise DeprecationWarning("Replaced by xhelp.fit_model_xforms") log.info('Initializing modelspec(s) for cell/batch %s/%d...', cellid, int(batch)) # Segment modelname for meta information kws = nems.utils.escaped_split(modelname, '_') old = False if (len(kws) > 3) or ((len(kws) == 3) and kws[1].startswith('stategain') and not kws[1].startswith('stategain.')): # Check if modelname uses old format. log.info("Using old modelname format ... ") old = True modelspecname = nems.utils.escaped_join(kws[1:-1], '_') else: modelspecname = nems.utils.escaped_join(kws[1:-1], '-') loadkey = kws[0] fitkey = kws[-1] meta = {'batch': batch, 'cellid': cellid, 'modelname': modelname, 'loader': loadkey, 'fitkey': fitkey, 'modelspecname': modelspecname, 'username': '******', 'labgroup': 'lbhb', 'public': 1, 'githash': os.environ.get('CODEHASH', ''), 'recording': loadkey} if old: recording_uri = ogru(cellid, batch, loadkey) xfspec = oxfh.generate_loader_xfspec(loadkey, recording_uri) xfspec.append(['nems_lbhb.old_xforms.xforms.init_from_keywords', {'keywordstring': modelspecname, 'meta': meta}]) xfspec.extend(oxfh.generate_fitter_xfspec(fitkey)) xfspec.append(['nems.analysis.api.standard_correlation', {}, ['est', 'val', 'modelspec', 'rec'], ['modelspec']]) if autoPlot: log.info('Generating summary plot ...') xfspec.append(['nems.xforms.plot_summary', {}]) else: # uri_key = nems.utils.escaped_split(loadkey, '-')[0] # recording_uri = generate_recording_uri(cellid, batch, uri_key) log.info("DONE? Moved handling of registry_args to xforms_init_context") recording_uri = None # registry_args = {'cellid': cellid, 'batch': int(batch)} registry_args = {} xforms_init_context = {'cellid': cellid, 'batch': int(batch)} xfspec = xhelp.generate_xforms_spec(recording_uri, modelname, meta, xforms_kwargs=registry_args, xforms_init_context=xforms_init_context) log.info(xfspec) # actually do the loading, preprocessing, fit ctx, log_xf = xforms.evaluate(xfspec) # save some extra metadata modelspec = ctx['modelspec'] # this code may not be necessary any more. destination = '/auto/data/nems_db/results/{0}/{1}/{2}/'.format( batch, cellid, ms.get_modelspec_longname(modelspec)) modelspec.meta['modelpath'] = destination modelspec.meta['figurefile'] = destination+'figure.0000.png' modelspec.meta.update(meta) # save results log.info('Saving modelspec(s) to {0} ...'.format(destination)) save_data = xforms.save_analysis(destination, recording=ctx['rec'], modelspec=modelspec, xfspec=xfspec, figures=ctx['figures'], log=log_xf) savepath = save_data['savepath'] # save in database as well if saveInDB: # TODO : db results finalized? nd.update_results_table(modelspec) return savepath
'epoch_regex': '^STIM_' }]) # xfspec.append(['nems.xforms.average_away_stim_occurrences', {}]) xfspec.append([ 'nems.xforms.init_from_keywords', { 'keywordstring': modelspec_name, 'meta': meta } ]) xfspec.append(['nems.xforms.fit_basic_init', {}]) xfspec.append(['nems.xforms.fit_basic', {}]) xfspec.append(['nems.xforms.predict', {}]) xfspec.append(['nems.xforms.add_summary_statistics', {}]) xfspec.append(['nems.xforms.plot_summary', {}]) ctx, log_xf = xforms.evaluate(xfspec) modelspecs = ctx['modelspecs'] destination = '/auto/data/nems_db/results/{0}/{1}/{2}/'.format( batch, cellid, ms.get_modelspec_longname(modelspecs[0])) modelspecs[0][0]['meta']['modelpath'] = destination modelspecs[0][0]['meta']['figurefile'] = destination + 'figure.0000.png' modelspecs[0][0]['meta'].update(meta) save_data = xforms.save_analysis(destination, recording=ctx['rec'], modelspecs=modelspecs, xfspec=xfspec, figures=ctx['figures'], log=log_xf) savepath = save_data['savepath']
def fit_model_xforms(recording_uri, modelname, fitter_kwargs=None, autoPlot=True): """ Fits a single NEMS model eg, 'ozgf100ch18_wc18x1_lvl1_fir15x1_dexp1_fit01' generates modelspec with 'wc18x1_lvl1_fir1x15_dexp1' based on fit_model function in nems/scripts/fit_model.py example xfspec: xfspec = [ ['nems.xforms.load_recordings', {'recording_uri_list': recordings}], ['nems.xforms.add_average_sig', {'signal_to_average': 'resp', 'new_signalname': 'resp', 'epoch_regex': '^STIM_'}], ['nems.xforms.split_by_occurrence_counts', {'epoch_regex': '^STIM_'}], ['nems.xforms.init_from_keywords', {'keywordstring': modelspecname}], ['nems.xforms.set_random_phi', {}], ['nems.xforms.fit_basic', {}], # ['nems.xforms.add_summary_statistics', {}], ['nems.xforms.plot_summary', {}], # ['nems.xforms.save_recordings', {'recordings': ['est', 'val']}], ['nems.xforms.fill_in_default_metadata', {}], ] """ log.info('Initializing modelspec(s) for recording/model {0}/{1}...'.format( recording_uri, modelname)) # parse modelname kws = modelname.split("_") loader = kws[0] modelspecname = "_".join(kws[1:-1]) fitter = kws[-1] meta = { 'modelname': modelname, 'loader': loader, 'fitter': fitter, 'modelspecname': modelspecname } # TODO: These should be added to meta by nems_db after ctx is returned. # 'username': '******', 'labgroup': 'lbhb', 'public': 1, # 'githash': os.environ.get('CODEHASH', ''), # 'recording': loader} # Generate the xfspec, which defines the sequence of events # to run through (like a packaged-up script) # 1) Load the data xfspec = generate_loader_xfspec(loader, recording_uri) # 2) generate a modelspec xfspec.append([ 'nems.xforms.init_from_keywords', { 'keywordstring': modelspecname, 'meta': meta } ]) # 3) fit the data xfspec += generate_fitter_xfspec(fitter, fitter_kwargs) # 4) add some performance statistics xfspec.append([ 'nems.analysis.api.standard_correlation', {}, ['est', 'val', 'modelspecs'], ['modelspecs'] ]) # 5) generate plots if autoPlot: log.info('Generating summary plot...') xfspec.append(['nems.xforms.plot_summary', {}]) # Now that the xfspec is assembled, run through it # in order to get the fitted modelspec, evaluated recording, etc. # (all packaged up in the ctx dictionary). ctx, log_xf = xforms.evaluate(xfspec) return ctx
def fit_model_xforms_baphy(cellid, batch, modelname, autoPlot=True, saveInDB=False): """ Fits a single NEMS model using data from baphy/celldb eg, 'ozgf100ch18_wc18x1_lvl1_fir15x1_dexp1_fit01' generates modelspec with 'wc18x1_lvl1_fir15x1_dexp1' based on this function in nems/scripts/fit_model.py def fit_model(recording_uri, modelstring, destination): xfspec = [ ['nems.xforms.load_recordings', {'recording_uri_list': recordings}], ['nems.xforms.add_average_sig', {'signal_to_average': 'resp', 'new_signalname': 'resp', 'epoch_regex': '^STIM_'}], ['nems.xforms.split_by_occurrence_counts', {'epoch_regex': '^STIM_'}], ['nems.xforms.init_from_keywords', {'keywordstring': modelspecname}], ['nems.xforms.set_random_phi', {}], ['nems.xforms.fit_basic', {}], # ['nems.xforms.add_summary_statistics', {}], ['nems.xforms.plot_summary', {}], # ['nems.xforms.save_recordings', {'recordings': ['est', 'val']}], ['nems.xforms.fill_in_default_metadata', {}], ] """ log.info('Initializing modelspec(s) for cell/batch {0}/{1}...'.format( cellid, batch)) # parse modelname kws = modelname.split("_") loader = kws[0] modelspecname = "_".join(kws[1:-1]) fitter = kws[-1] # generate xfspec, which defines sequence of events to load data, # generate modelspec, fit data, plot results and save xfspec = generate_loader_xfspec(cellid, batch, loader) xfspec.append( ['nems.xforms.init_from_keywords', { 'keywordstring': modelspecname }]) # parse the fit spec: Use gradient descent on whole data set(Fast) if fitter == "fit01": # prefit strf log.info("Prefitting STRF without other modules...") xfspec.append(['nems.xforms.fit_basic_init', {}]) xfspec.append(['nems.xforms.fit_basic', {}]) elif fitter == "fitjk01": # prefit strf log.info("Prefitting STRF without NL then JK...") xfspec.append(['nems.xforms.fit_basic_init', {}]) xfspec.append(['nems.xforms.split_for_jackknife', {'njacks': 10}]) xfspec.append(['nems.xforms.fit_basic', {}]) elif fitter == "fit02": # no pre-fit log.info("Performing full fit...") xfspec.append(['nems.xforms.fit_basic', {}]) else: raise ValueError('unknown fitter string') xfspec.append(['nems.xforms.add_summary_statistics', {}]) if autoPlot: # GENERATE PLOTS log.info('Generating summary plot...') xfspec.append(['nems.xforms.plot_summary', {}]) # actually do the fit ctx, log_xf = xforms.evaluate(xfspec) # save some extra metadata modelspecs = ctx['modelspecs'] if 'CODEHASH' in os.environ.keys(): githash = os.environ['CODEHASH'] else: githash = "" meta = { 'batch': batch, 'cellid': cellid, 'modelname': modelname, 'loader': loader, 'fitter': fitter, 'modelspecname': modelspecname, 'username': '******', 'labgroup': 'lbhb', 'public': 1, 'githash': githash, 'recording': loader } if not 'meta' in modelspecs[0][0].keys(): modelspecs[0][0]['meta'] = meta else: modelspecs[0][0]['meta'].update(meta) destination = '/auto/data/tmp/modelspecs/{0}/{1}/{2}/'.format( batch, cellid, ms.get_modelspec_longname(modelspecs[0])) modelspecs[0][0]['meta']['modelpath'] = destination modelspecs[0][0]['meta']['figurefile'] = destination + 'figure.0000.png' # save results xforms.save_analysis(destination, recording=ctx['rec'], modelspecs=modelspecs, xfspec=xfspec, figures=ctx['figures'], log=log_xf) log.info('Saved modelspec(s) to {0} ...'.format(destination)) # save in database as well if saveInDB: # TODO : db results nd.update_results_table(modelspecs[0]) return ctx
def _model_step_plot(cellid, batch, modelnames, factors, state_colors=None): """ state_colors : N x 2 list color spec for high/low lines in each of the N states """ global line_colors global fill_colors modelname_p0b0, modelname_p0b, modelname_pb0, modelname_pb = \ modelnames factor0, factor1, factor2 = factors xf_p0b0, ctx_p0b0 = xhelp.load_model_xform(cellid, batch, modelname_p0b0, eval_model=False) # ctx_p0b0, l = xforms.evaluate(xf_p0b0, ctx_p0b0, stop=-2) ctx_p0b0, l = xforms.evaluate(xf_p0b0, ctx_p0b0, start=0, stop=-2) xf_p0b, ctx_p0b = xhelp.load_model_xform(cellid, batch, modelname_p0b, eval_model=False) ctx_p0b, l = xforms.evaluate(xf_p0b, ctx_p0b, start=0, stop=-2) xf_pb0, ctx_pb0 = xhelp.load_model_xform(cellid, batch, modelname_pb0, eval_model=False) #ctx_pb0['rec'] = ctx_p0b0['rec'].copy() ctx_pb0, l = xforms.evaluate(xf_pb0, ctx_pb0, start=0, stop=-2) xf_pb, ctx_pb = xhelp.load_model_xform(cellid, batch, modelname_pb, eval_model=False) #ctx_pb['rec'] = ctx_p0b0['rec'].copy() ctx_pb, l = xforms.evaluate(xf_pb, ctx_pb, start=0, stop=-2) # organize predictions by different models val = ctx_pb['val'][0].copy() # val['pred_p0b0'] = ctx_p0b0['val'][0]['pred'].copy() val['pred_p0b'] = ctx_p0b['val'][0]['pred'].copy() val['pred_pb0'] = ctx_pb0['val'][0]['pred'].copy() state_var_list = val['state'].chans pred_mod = np.zeros([len(state_var_list), 2]) pred_mod_full = np.zeros([len(state_var_list), 2]) resp_mod_full = np.zeros([len(state_var_list), 1]) state_std = np.nanstd(val['state'].as_continuous(), axis=1, keepdims=True) for i, var in enumerate(state_var_list): if state_std[i]: # actual response modulation index for each state var resp_mod_full[i] = state_mod_index(val, epoch='REFERENCE', psth_name='resp', state_chan=var) mod2_p0b = state_mod_index(val, epoch='REFERENCE', psth_name='pred_p0b', state_chan=var) mod2_pb0 = state_mod_index(val, epoch='REFERENCE', psth_name='pred_pb0', state_chan=var) mod2_pb = state_mod_index(val, epoch='REFERENCE', psth_name='pred', state_chan=var) pred_mod[i] = np.array([mod2_pb - mod2_p0b, mod2_pb - mod2_pb0]) pred_mod_full[i] = np.array([mod2_pb0, mod2_p0b]) pred_mod_norm = pred_mod / (state_std + (state_std == 0).astype(float)) pred_mod_full_norm = pred_mod_full / (state_std + (state_std == 0).astype(float)) if 'each_passive' in factors: psth_names_ctl = ["pred_p0b"] factors.remove('each_passive') for v in state_var_list: if v.startswith('FILE_'): factors.append(v) psth_names_ctl.append("pred_pb0") else: psth_names_ctl = ["pred_p0b", "pred_pb0"] col_count = len(factors) - 1 if state_colors is None: state_colors = [[None, None]] * col_count fh = plt.figure(figsize=(8, 8)) ax = plt.subplot(4, 1, 1) nplt.state_vars_timeseries(val, ctx_pb['modelspecs'][0], state_colors=[s[1] for s in state_colors]) ax.set_title("{}/{} - {}".format(cellid, batch, modelname_pb)) ax.set_ylabel("{} r={:.3f}".format( factor0, ctx_p0b0['modelspecs'][0][0]['meta']['r_test'][0])) nplt.ax_remove_box(ax) for i, var in enumerate(factors[1:]): if var.startswith('FILE_'): varlbl = var[5:] else: varlbl = var ax = plt.subplot(4, col_count, col_count + i + 1) nplt.state_var_psth_from_epoch(val, epoch="REFERENCE", psth_name="resp", psth_name2=psth_names_ctl[i], state_chan=var, ax=ax, colors=state_colors[i]) if i == 0: ax.set_ylabel("Control model") ax.set_title("{} ctl r={:.3f}".format( varlbl.lower(), ctx_p0b['modelspecs'][0][0]['meta']['r_test'][0]), fontsize=6) else: ax.yaxis.label.set_visible(False) ax.set_title("{} ctl r={:.3f}".format( varlbl.lower(), ctx_pb0['modelspecs'][0][0]['meta']['r_test'][0]), fontsize=6) if ax.legend_: ax.legend_.remove() ax.xaxis.label.set_visible(False) nplt.ax_remove_box(ax) ax = plt.subplot(4, col_count, col_count * 2 + i + 1) nplt.state_var_psth_from_epoch(val, epoch="REFERENCE", psth_name="resp", psth_name2="pred", state_chan=var, ax=ax, colors=state_colors[i]) if i == 0: ax.set_ylabel("Full Model") else: ax.yaxis.label.set_visible(False) if ax.legend_: ax.legend_.remove() if psth_names_ctl[i] == "pred_p0b": j = 0 else: j = 1 ax.set_title("r={:.3f} rawmod={:.3f} umod={:.3f}".format( ctx_pb['modelspecs'][0][0]['meta']['r_test'][0], pred_mod_full_norm[i + 1][j], pred_mod_norm[i + 1][j]), fontsize=6) if var == 'active': ax.legend(('pas', 'act')) elif var == 'pupil': ax.legend(('small', 'large')) elif var == 'PRE_PASSIVE': ax.legend(('act+post', 'pre')) elif var.startswith('FILE_'): ax.legend(('this', 'others')) nplt.ax_remove_box(ax) # EXTRA PANELS # figure out some basic aspects of tuning/selectivity for target vs. # reference: r = ctx_pb['rec']['resp'] e = r.epochs fs = r.fs passive_epochs = r.get_epoch_indices("PASSIVE_EXPERIMENT") tar_names = ep.epoch_names_matching(e, "^TAR_") tar_resp = {} for tarname in tar_names: t = r.get_epoch_indices(tarname) t = ep.epoch_intersection(t, passive_epochs) tar_resp[tarname] = r.extract_epoch(t) * fs # only plot tar responses with max SNR or probe SNR keys = [] for k in list(tar_resp.keys()): if k.endswith('0') | k.endswith('2'): keys.append(k) keys.sort() # assume the reference with most reps is the one overlapping the target groups = ep.group_epochs_by_occurrence_counts(e, '^STIM_') l = np.array(list(groups.keys())) hi = np.max(l) ref_name = groups[hi][0] t = r.get_epoch_indices(ref_name) t = ep.epoch_intersection(t, passive_epochs) ref_resp = r.extract_epoch(t) * fs t = r.get_epoch_indices('REFERENCE') t = ep.epoch_intersection(t, passive_epochs) all_ref_resp = r.extract_epoch(t) * fs prestimsilence = r.get_epoch_indices('PreStimSilence') prebins = prestimsilence[0, 1] - prestimsilence[0, 0] poststimsilence = r.get_epoch_indices('PostStimSilence') postbins = poststimsilence[0, 1] - poststimsilence[0, 0] durbins = ref_resp.shape[-1] - prebins spont = np.nanmean(all_ref_resp[:, 0, :prebins]) ref_mean = np.nanmean(ref_resp[:, 0, prebins:durbins]) - spont all_ref_mean = np.nanmean(all_ref_resp[:, 0, prebins:durbins]) - spont #print(spont) #print(np.nanmean(ref_resp[:,0,prebins:-postbins])) ax1 = plt.subplot(4, 2, 7) ref_psth = [ np.nanmean(ref_resp[:, 0, :], axis=0), np.nanmean(all_ref_resp[:, 0, :], axis=0) ] ll = [ "{} {:.1f}".format(ref_name, ref_mean), "all refs {:.1f}".format(all_ref_mean) ] nplt.timeseries_from_vectors(ref_psth, fs=fs, legend=ll, ax=ax1, time_offset=prebins / fs) ax2 = plt.subplot(4, 2, 8) ll = [] tar_mean = np.zeros(np.max([2, len(keys)])) * np.nan tar_psth = [] for ii, k in enumerate(keys): tar_psth.append(np.nanmean(tar_resp[k][:, 0, :], axis=0)) tar_mean[ii] = np.nanmean(tar_resp[k][:, 0, prebins:durbins]) - spont ll.append("{} {:.1f}".format(k, tar_mean[ii])) nplt.timeseries_from_vectors(tar_psth, fs=fs, legend=ll, ax=ax2, time_offset=prebins / fs) # plt.legend(ll, fontsize=6) ymin = np.min([ax1.get_ylim()[0], ax2.get_ylim()[0]]) ymax = np.max([ax1.get_ylim()[1], ax2.get_ylim()[1]]) ax1.set_ylim([ymin, ymax]) ax2.set_ylim([ymin, ymax]) nplt.ax_remove_box(ax1) nplt.ax_remove_box(ax2) plt.tight_layout() stats = { 'cellid': cellid, 'batch': batch, 'modelnames': modelnames, 'state_vars': state_var_list, 'factors': factors, 'r_test': np.array([ ctx_p0b0['modelspecs'][0][0]['meta']['r_test'][0], ctx_p0b['modelspecs'][0][0]['meta']['r_test'][0], ctx_pb0['modelspecs'][0][0]['meta']['r_test'][0], ctx_pb['modelspecs'][0][0]['meta']['r_test'][0] ]), 'se_test': np.array([ ctx_p0b0['modelspecs'][0][0]['meta']['se_test'][0], ctx_p0b['modelspecs'][0][0]['meta']['se_test'][0], ctx_pb0['modelspecs'][0][0]['meta']['se_test'][0], ctx_pb['modelspecs'][0][0]['meta']['se_test'][0] ]), 'r_floor': np.array([ ctx_p0b0['modelspecs'][0][0]['meta']['r_floor'][0], ctx_p0b['modelspecs'][0][0]['meta']['r_floor'][0], ctx_pb0['modelspecs'][0][0]['meta']['r_floor'][0], ctx_pb['modelspecs'][0][0]['meta']['r_floor'][0] ]), 'pred_mod': pred_mod.T, 'pred_mod_full': pred_mod_full.T, 'pred_mod_norm': pred_mod_norm.T, 'pred_mod_full_norm': pred_mod_full_norm.T, 'g': np.array([ ctx_p0b0['modelspecs'][0][0]['phi']['g'], ctx_p0b['modelspecs'][0][0]['phi']['g'], ctx_pb0['modelspecs'][0][0]['phi']['g'], ctx_pb['modelspecs'][0][0]['phi']['g'] ]), 'b': np.array([ ctx_p0b0['modelspecs'][0][0]['phi']['d'], ctx_p0b['modelspecs'][0][0]['phi']['d'], ctx_pb0['modelspecs'][0][0]['phi']['d'], ctx_pb['modelspecs'][0][0]['phi']['d'] ]), 'ref_all_resp': all_ref_mean, 'ref_common_resp': ref_mean, 'tar_max_resp': tar_mean[0], 'tar_probe_resp': tar_mean[1] } return fh, stats
def equiv_vs_self(cellid, batch, modelname, LN_model, random_seed=1234): # evaluate old fit just to get est/val already split up xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname) # further divide est into two datasets # (how to do this? pick from epochs randomly?) est = ctx['est'] val = ctx['val'] epochs = est['stim'].epochs stims = np.array(ep.epoch_names_matching(epochs, 'STIM_')) indices = np.linspace(0, len(stims) - 1, len(stims), dtype=np.int) st0 = np.random.get_state() np.random.seed(random_seed) set1_idx = np.random.choice(indices, round(len(stims) / 2), replace=False) np.random.set_state(st0) mask = np.zeros_like(stims, np.bool) mask[set1_idx] = True set1_stims = stims[mask].tolist() set2_stims = stims[~mask].tolist() est1, est2 = est.split_by_epochs(set1_stims, set2_stims) # re-fit on the smaller est sets # (will have to re-fit LN model as well?) # also have to remove -sev- from modelname and add est-val in manually ctx1 = {'est': est1, 'val': val.copy()} ctx2 = {'est': est2, 'val': val.copy()} LN_ctx1 = copy.deepcopy(ctx1) LN_ctx2 = copy.deepcopy(ctx2) # modelname = modelname.replace('-sev', '') # LN_model = LN_model.replace('-sev', '') tm = 'none_' + '_'.join(modelname.split('_')[1:]) lm = 'none_' + '_'.join(LN_model.split('_')[1:]) # test model, est1 xfspec = xhelp.generate_xforms_spec(modelname=tm) ctx, _ = xforms.evaluate(xfspec, context=ctx1) test_pred1 = ctx['val']['pred'].as_continuous().flatten() # test model, est2 xfspec = xhelp.generate_xforms_spec(modelname=tm) ctx, _ = xforms.evaluate(xfspec, context=ctx2) test_pred2 = ctx['val']['pred'].as_continuous().flatten() # LN model, est1 xfspec = xhelp.generate_xforms_spec(modelname=lm) ctx, _ = xforms.evaluate(xfspec, context=ctx1) LN_pred1 = ctx['val']['pred'].as_continuous().flatten() # LN model, est2 xfspec = xhelp.generate_xforms_spec(modelname=lm) ctx, _ = xforms.evaluate(xfspec, context=ctx2) LN_pred2 = ctx['val']['pred'].as_continuous().flatten() # test equivalence on the new fits C1 = np.hstack((np.expand_dims(test_pred1, 0).transpose(), np.expand_dims(test_pred2, 0).transpose(), np.expand_dims(LN_pred1, 0).transpose())) p1 = partial_corr(C1)[0, 1] C2 = np.hstack((np.expand_dims(test_pred1, 0).transpose(), np.expand_dims(test_pred2, 0).transpose(), np.expand_dims(LN_pred2, 0).transpose())) p2 = partial_corr(C2)[0, 1] return 0.5 * (p1 + p2)
def fit_xfspec(xfspec): # Now that the xfspec is assembled, run through it # in order to get the fitted modelspec, evaluated recording, etc. # (all packaged up in the ctx dictionary). ctx, log_xf = xforms.evaluate(xfspec) return ctx
def fit_model_xform(cellid, batch, modelname, autoPlot=True, saveInDB=False, returnModel=False, recording_uri=None, initial_context=None): """ Fit a single NEMS model using data stored in database. First generates an xforms script based on modelname parameter and then evaluates it. :param cellid: cellid and batch specific dataset in database :param batch: :param modelname: string specifying model architecture, preprocessing and fit method :param autoPlot: generate summary plot when complete :param saveInDB: save results to Results table :param returnModel: boolean (default False). If False, return savepath if True return xfspec, ctx tuple :param recording_uri :return: savepath = path to saved results or (xfspec, ctx) tuple """ startime = time.time() log.info('Initializing modelspec(s) for cell/batch %s/%d...', cellid, int(batch)) # Segment modelname for meta information kws = escaped_split(modelname, '_') modelspecname = escaped_join(kws[1:-1], '-') loadkey = kws[0] fitkey = kws[-1] meta = { 'batch': batch, 'cellid': cellid, 'modelname': modelname, 'loader': loadkey, 'fitkey': fitkey, 'modelspecname': modelspecname, 'username': '******', 'labgroup': 'lbhb', 'public': 1, 'githash': os.environ.get('CODEHASH', ''), 'recording': loadkey } if type(cellid) is list: meta['siteid'] = cellid[0][:7] # registry_args = {'cellid': cellid, 'batch': int(batch)} registry_args = {} xforms_init_context = {'cellid': cellid, 'batch': int(batch)} if initial_context is not None: xforms_init_context.update(initial_context) log.info("TODO: simplify generate_xforms_spec parameters") xfspec = generate_xforms_spec(recording_uri=recording_uri, modelname=modelname, meta=meta, xforms_kwargs=registry_args, xforms_init_context=xforms_init_context, autoPlot=autoPlot) log.debug(xfspec) # actually do the loading, preprocessing, fit if initial_context is None: initial_context = {} ctx, log_xf = xforms.evaluate(xfspec) #, context=initial_context) # save some extra metadata modelspec = ctx['modelspec'] if type(cellid) is list: cell_name = cellid[0].split("-")[0] else: cell_name = cellid if 'modelpath' not in modelspec.meta: prefix = get_setting('NEMS_RESULTS_DIR') destination = os.path.join(prefix, str(batch), cell_name, modelspec.get_longname()) log.info(f'Setting modelpath to "{destination}"') modelspec.meta['modelpath'] = destination modelspec.meta['figurefile'] = os.path.join(destination, 'figure.0000.png') else: destination = modelspec.meta['modelpath'] # figure out URI for location to save results (either file or http, depending on USE_NEMS_BAPHY_API) if get_setting('USE_NEMS_BAPHY_API'): prefix = 'http://' + get_setting('NEMS_BAPHY_API_HOST') + ":" + str(get_setting('NEMS_BAPHY_API_PORT')) + \ '/results' save_loc = str( batch) + '/' + cell_name + '/' + modelspec.get_longname() save_destination = prefix + '/' + save_loc # set the modelspec meta save locations to be the filesystem and not baphy modelspec.meta['modelpath'] = get_setting( 'NEMS_RESULTS_DIR') + '/' + save_loc modelspec.meta['figurefile'] = modelspec.meta[ 'modelpath'] + '/' + 'figure.0000.png' else: save_destination = destination modelspec.meta['runtime'] = int(time.time() - startime) modelspec.meta.update(meta) if returnModel: # return fit, skip save! return xfspec, ctx # save results log.info('Saving modelspec(s) to {0} ...'.format(save_destination)) if 'figures' in ctx.keys(): figs = ctx['figures'] else: figs = [] save_data = xforms.save_analysis(save_destination, recording=ctx.get('rec'), modelspec=modelspec, xfspec=xfspec, figures=figs, log=log_xf, update_meta=False) # save in database as well if saveInDB: nd.update_results_table(modelspec) return save_data['savepath']
import nems.recording as recording import nems.epoch as ep import nems.xforms as xforms import nems.xform_helper as xhelp import nems_lbhb.xform_wrappers as nw cellid = "chn002h-a1" batch = 259 uri = nw.generate_recording_uri(cellid, batch, "env.fs100") modelname1 = "env100_dlog_fir2x15_lvl1_dexp1_basic" #modelname2="env100_dlog_stp2_fir2x15_lvl1_dexp1_basic" modelname2 = "env100_dlog_wcc2x2_stp2_fir2x15_lvl1_dexp1_basic" modelname2 = "env100_dlog_wcc2x3_stp3_fir3x15_lvl1_dexp1_basic-shr" xf1, ctx1 = xhelp.load_model_xform(cellid, batch, modelname1, eval_model=False) ctx1, l = xforms.evaluate(xf1, ctx1, stop=-1) xf2, ctx2 = xhelp.load_model_xform(cellid, batch, modelname2, eval_model=False) ctx2, l = xforms.evaluate(xf2, ctx2, stop=-1) rec = ctx1['rec'] val1 = ctx1['val'][0] val2 = ctx2['val'][0] resp = rec['resp'].rasterize() pred1 = val1['pred'] pred2 = val2['pred'] epoch_regex = "^STIM_" stim_epochs = ep.epoch_names_matching(resp.epochs, epoch_regex)
def fit_pop_model_xforms_baphy(cellid, batch, modelname, saveInDB=False): """ Fits a NEMS population model using baphy data DEPRECATED ? Now should work for xhelp.fit_model_xform() """ raise NotImplementedError("Replaced by xhelper function?") log.info("Preparing pop model: ({0},{1},{2})".format( cellid, batch, modelname)) # Segment modelname for meta information kws = modelname.split("_") modelspecname = "-".join(kws[1:-1]) loadkey = kws[0] fitkey = kws[-1] if type(cellid) is list: disp_cellid="_".join(cellid) else: disp_cellid=cellid meta = {'batch': batch, 'cellid': disp_cellid, 'modelname': modelname, 'loader': loadkey, 'fitkey': fitkey, 'modelspecname': modelspecname, 'username': '******', 'labgroup': 'lbhb', 'public': 1, 'githash': os.environ.get('CODEHASH', ''), 'recording': loadkey} uri_key = nems.utils.escaped_split(loadkey, '-')[0] recording_uri = generate_recording_uri(cellid, batch, uri_key) # pass cellid information to xforms so that loader knows which cells # to load from recording_uri xfspec = xhelp.generate_xforms_spec(recording_uri, modelname, meta, xforms_kwargs={'cellid': cellid}) # actually do the fit ctx, log_xf = xforms.evaluate(xfspec) # save some extra metadata modelspec = ctx['modelspec'] destination = '/auto/data/nems_db/results/{0}/{1}/{2}/'.format( batch, disp_cellid, ms.get_modelspec_longname(modelspec)) modelspec.meta['modelpath'] = destination modelspec.meta['figurefile'] = destination+'figure.0000.png' modelspec.meta.update(meta) # extra thing to save for pop model modelspec.meta['cellids'] = ctx['val']['resp'].chans # save results log.info('Saving modelspec(s) to {0} ...'.format(destination)) save_data = xforms.save_analysis(destination, recording=ctx['rec'], modelspec=modelspec, xfspec=xfspec, figures=ctx['figures'], log=log_xf) savepath = save_data['savepath'] if saveInDB: # save in database as well nd.update_results_table(modelspec) return savepath
#modelkeywords = 'dlog-wc.18x2.g-stp.2-fir.2x15-dexp.1' meta = {'cellids': ['TAR010c-18-1'], 'batch': 271, 'modelname': modelkeywords} xfspec = [[ 'load_recordings', { 'recording_uri_list': recordings, 'meta': meta } ], ['split_val_and_average_reps', { 'epoch_regex': '^STIM_' }], ['init_from_keywords', { 'keywordstring': modelkeywords }], ['fit_basic_init', {}], ['fit_basic', {}], ['predict', {}], ['add_summary_statistics', {}], ['plot_summary', {}]] ctx, log_xf = xforms.evaluate(xfspec) # evaluate the fit script #xforms.save_context(dest='/data/results', ctx=ctx, xfspec=xfspec, log=log_xf) """ Simplified: import nems.recording as recording import nems.xforms as xforms recording.get_demo_recordings("/data/recordings/") recordings = ["/data/recordings/TAR010c-18-1.tgz"] modelkeywords = 'dlog-wc.18x1.g-fir.1x15-lvl.1-dexp.1' #modelkeywords = 'dlog-wc.18x2.g-fir.2x15-lvl.1-dexp.1' meta = {'cellid': 'TAR010c-18-1', 'batch': 271, 'modelname': modelkeywords} xfspec = [] xfspec.append(['nems.xforms.load_recordings',
def model_per_time_wrapper(cellid, batch=307, loader="psth.fs20.pup-ld-", fitter="_jk.nf20-basic", basemodel="-ref-psthfr_stategain.S", state_list=None, plot_halves=True, colors=None): """ batch = 307 # A1 SUA and MUA batch = 309 # IC SUA and MUA alternatives: basemodels = ["-ref-psthfr.s_stategain.S", "-ref-psthfr.s_sdexp.S", "-ref.a-psthfr.s_sdexp.S"] state_list = ['st.pup0.hlf0','st.pup0.hlf','st.pup.hlf0','st.pup.hlf'] state_list = ['st.pup0.far0.hit0.hlf0','st.pup0.far0.hit0.hlf', 'st.pup.far.hit.hlf0','st.pup.far.hit.hlf'] state_list = ['st.pup0.fil0','st.pup0.fil','st.pup.fil0','st.pup.fil'] """ # pup vs. active/passive if state_list is None: state_list = [ 'st.pup0.hlf0', 'st.pup0.hlf', 'st.pup.hlf0', 'st.pup.hlf' ] #state_list = ['st.pup0.far0.hit0.hlf0','st.pup0.far0.hit0.hlf', # 'st.pup.far.hit.hlf0','st.pup.far.hit.hlf'] #state_list = ['st.pup0.fil0','st.pup0.fil','st.pup.fil0','st.pup.fil'] modelnames = [] contexts = [] for i, s in enumerate(state_list): modelnames.append(loader + s + basemodel + fitter) xf, ctx = xhelp.load_model_xform(cellid, batch, modelnames[i], eval_model=False) ctx, l = xforms.evaluate(xf, ctx, start=0, stop=-2) ctx['val'] = preproc.make_state_signal(ctx['val'], state_signals=['each_half'], new_signalname='state_f') contexts.append(ctx) #import pdb; #pdb.set_trace() plt.figure() #if ('hlf' in state_list[0]) or ('fil' in state_list[0]): if plot_halves: files_only = True else: files_only = False for i, ctx in enumerate(contexts): rec = ctx['val'].apply_mask() modelspec = ctx['modelspec'] epoch = "REFERENCE" rec = ms.evaluate(rec, modelspec) if i == len(contexts) - 1: ax = plt.subplot(len(contexts) + 1, 1, 1) nplt.state_vars_timeseries(rec, modelspec, ax=ax) ax.set_title('{} {}'.format(cellid, modelnames[-1])) ax = plt.subplot(len(contexts) + 1, 1, 2 + i) nplt.state_vars_psth_all(rec, epoch, psth_name='resp', psth_name2='pred', state_sig='state_f', colors=colors, channel=None, decimate_by=1, ax=ax, files_only=files_only, modelspec=modelspec) ax.set_ylabel(state_list[i]) ax.set_xticks([])