def fit_to_simulation(fit_model, simulation_spec): ''' Parameters: ----------- fit_model : str Modelname to fit to the simulation. simulation_spec : NEMS ModelSpec Modelspec to base simulation on. Returns: -------- ctx : dict Xforms context. See nems.xforms. ''' rec = get_default_ctx()['rec'] ctk_idx = find_module('contrast_kernel', simulation_spec) if ctk_idx is not None: simulation_spec[ctk_idx]['fn_kwargs']['evaluate_contrast'] = True new_resp = simulation_spec.evaluate(rec)['pred'] rec['resp'] = new_resp # replace ozgf and ld with ldm modelname = '-'.join(fit_model.split('-')[2:]) xfspec = xhelp.generate_xforms_spec(modelname=modelname) ctx, _ = xforms.evaluate(xfspec, context={'rec': rec}) return ctx
def fit_pop_model_xforms_baphy(cellid, batch, modelname, saveInDB=False): """ Fits a NEMS population model using baphy data DEPRECATED ? Now should work for xhelp.fit_model_xform() """ raise NotImplementedError("Replaced by xhelper function?") log.info("Preparing pop model: ({0},{1},{2})".format( cellid, batch, modelname)) # Segment modelname for meta information kws = modelname.split("_") modelspecname = "-".join(kws[1:-1]) loadkey = kws[0] fitkey = kws[-1] if type(cellid) is list: disp_cellid="_".join(cellid) else: disp_cellid=cellid meta = {'batch': batch, 'cellid': disp_cellid, 'modelname': modelname, 'loader': loadkey, 'fitkey': fitkey, 'modelspecname': modelspecname, 'username': '******', 'labgroup': 'lbhb', 'public': 1, 'githash': os.environ.get('CODEHASH', ''), 'recording': loadkey} uri_key = nems.utils.escaped_split(loadkey, '-')[0] recording_uri = generate_recording_uri(cellid, batch, uri_key) # pass cellid information to xforms so that loader knows which cells # to load from recording_uri xfspec = xhelp.generate_xforms_spec(recording_uri, modelname, meta, xforms_kwargs={'cellid': cellid}) # actually do the fit ctx, log_xf = xforms.evaluate(xfspec) # save some extra metadata modelspec = ctx['modelspec'] destination = '/auto/data/nems_db/results/{0}/{1}/{2}/'.format( batch, disp_cellid, ms.get_modelspec_longname(modelspec)) modelspec.meta['modelpath'] = destination modelspec.meta['figurefile'] = destination+'figure.0000.png' modelspec.meta.update(meta) # extra thing to save for pop model modelspec.meta['cellids'] = ctx['val']['resp'].chans # save results log.info('Saving modelspec(s) to {0} ...'.format(destination)) save_data = xforms.save_analysis(destination, recording=ctx['rec'], modelspec=modelspec, xfspec=xfspec, figures=ctx['figures'], log=log_xf) savepath = save_data['savepath'] if saveInDB: # save in database as well nd.update_results_table(modelspec) return savepath
def fit_model_xforms_baphy(cellid, batch, modelname, autoPlot=True, saveInDB=False): """ DEPRECATED ? Now should work for xhelp.fit_model_xform() Fit a single NEMS model using data from baphy/celldb eg, 'ozgf100ch18_wc18x1_lvl1_fir15x1_dexp1_fit01' generates modelspec with 'wc18x1_lvl1_fir15x1_dexp1' based on this function in nems/scripts/fit_model.py def fit_model(recording_uri, modelstring, destination): xfspec = [ ['nems.xforms.load_recordings', {'recording_uri_list': recordings}], ['nems.xforms.add_average_sig', {'signal_to_average': 'resp', 'new_signalname': 'resp', 'epoch_regex': '^STIM_'}], ['nems.xforms.split_by_occurrence_counts', {'epoch_regex': '^STIM_'}], ['nems.xforms.init_from_keywords', {'keywordstring': modelspecname}], ['nems.xforms.set_random_phi', {}], ['nems.xforms.fit_basic', {}], # ['nems.xforms.add_summary_statistics', {}], ['nems.xforms.plot_summary', {}], # ['nems.xforms.save_recordings', {'recordings': ['est', 'val']}], ['nems.xforms.fill_in_default_metadata', {}], ] """ raise NotImplementedError("Replaced by xhelper function?") raise DeprecationWarning("Replaced by xhelp.fit_model_xforms") log.info('Initializing modelspec(s) for cell/batch %s/%d...', cellid, int(batch)) # Segment modelname for meta information kws = nems.utils.escaped_split(modelname, '_') old = False if (len(kws) > 3) or ((len(kws) == 3) and kws[1].startswith('stategain') and not kws[1].startswith('stategain.')): # Check if modelname uses old format. log.info("Using old modelname format ... ") old = True modelspecname = nems.utils.escaped_join(kws[1:-1], '_') else: modelspecname = nems.utils.escaped_join(kws[1:-1], '-') loadkey = kws[0] fitkey = kws[-1] meta = {'batch': batch, 'cellid': cellid, 'modelname': modelname, 'loader': loadkey, 'fitkey': fitkey, 'modelspecname': modelspecname, 'username': '******', 'labgroup': 'lbhb', 'public': 1, 'githash': os.environ.get('CODEHASH', ''), 'recording': loadkey} if old: recording_uri = ogru(cellid, batch, loadkey) xfspec = oxfh.generate_loader_xfspec(loadkey, recording_uri) xfspec.append(['nems_lbhb.old_xforms.xforms.init_from_keywords', {'keywordstring': modelspecname, 'meta': meta}]) xfspec.extend(oxfh.generate_fitter_xfspec(fitkey)) xfspec.append(['nems.analysis.api.standard_correlation', {}, ['est', 'val', 'modelspec', 'rec'], ['modelspec']]) if autoPlot: log.info('Generating summary plot ...') xfspec.append(['nems.xforms.plot_summary', {}]) else: # uri_key = nems.utils.escaped_split(loadkey, '-')[0] # recording_uri = generate_recording_uri(cellid, batch, uri_key) log.info("DONE? Moved handling of registry_args to xforms_init_context") recording_uri = None # registry_args = {'cellid': cellid, 'batch': int(batch)} registry_args = {} xforms_init_context = {'cellid': cellid, 'batch': int(batch)} xfspec = xhelp.generate_xforms_spec(recording_uri, modelname, meta, xforms_kwargs=registry_args, xforms_init_context=xforms_init_context) log.info(xfspec) # actually do the loading, preprocessing, fit ctx, log_xf = xforms.evaluate(xfspec) # save some extra metadata modelspec = ctx['modelspec'] # this code may not be necessary any more. destination = '/auto/data/nems_db/results/{0}/{1}/{2}/'.format( batch, cellid, ms.get_modelspec_longname(modelspec)) modelspec.meta['modelpath'] = destination modelspec.meta['figurefile'] = destination+'figure.0000.png' modelspec.meta.update(meta) # save results log.info('Saving modelspec(s) to {0} ...'.format(destination)) save_data = xforms.save_analysis(destination, recording=ctx['rec'], modelspec=modelspec, xfspec=xfspec, figures=ctx['figures'], log=log_xf) savepath = save_data['savepath'] # save in database as well if saveInDB: # TODO : db results finalized? nd.update_results_table(modelspec) return savepath
'modelname': modelname, 'loader': loadkey, 'fitkey': fitkey, 'modelspecname': modelspecname, 'username': '******', 'labgroup': 'lbhb', 'public': 1, 'githash': os.environ.get('CODEHASH', ''), 'recording': loadkey } uri_key = escaped_split(loadkey, '-')[0] recording_uri = generate_recording_uri(cellid, batch, uri_key) registry_args = {'cellid': cellid, 'batch': int(batch)} xfspec = xhelp.generate_xforms_spec(modelname=modelname, meta=meta, xforms_kwargs=registry_args) # actually do the fit ctx = {} for i, xfa in enumerate(xfspec): ctx = xforms.evaluate_step(xfa, ctx) m = ctx['modelspec'] e = ctx['est'] v = ctx['val'] r = ctx['rec'] p = m.phi() # Plot spikes vs sim to check model behavior
"dlog.f" # Spectral filter (nems.modules.weight_channels) "-wc.18x1.g" # Temporal filter (nems.modules.fir) "-fir.1x15" # Scale, currently init to 1. "-scl.1" # Level shift, usually init to mean response (nems.modules.levelshift) "-lvl.1" # Nonlinearity (nems.modules.nonlinearity -> double_exponential) #"-dexp.1" "_" # modules -> fitters # Set initial values and do a rough "pre-fit" # Initialize fir coeffs to L2-norm of random values "-init.lnp"#.L2f" # Do the full fits "-lnp.t5" #"-nestspec" ) result_dict = {} for name, stim in stim_dict.items(): ctx = {'rec': stim} xfspec = xhelp.generate_xforms_spec(None, modelname, autoPlot=False) for i, xf in enumerate(xfspec): ctx = xforms.evaluate_step(xf, ctx) # Store tuple of ctx, error for each stim stim_length = ctx['val'][0]['stim'].shape[1] result_dict[name] = (ctx, _lnp_metric(ctx['val'][0])/stim_length)
if len(sys.argv)<3: print('Two parameters required.') print('Syntax: fit_single <modelname> <recording_uri>') exit(-1) modelname=sys.argv[1] recording_uri=sys.argv[2] log.info("Running fit_single(%s, %s)", modelname,recording_uri) meta = {'cellid': recording_uri, 'modelname': modelname, 'githash': os.environ.get('CODEHASH', ''), 'recording_uri': recording_uri} # set up sequence of events for fitting xfspec = xform_helper.generate_xforms_spec(recording_uri, modelname, meta=meta) # actually do the fit ctx, log_xf = xforms.evaluate(xfspec) # save results destination = os.path.dirname(recording_uri) log.info('Saving modelspec(s) to %s ...', destination) save_data = xforms.save_analysis(destination, recording=ctx['rec'], modelspecs=ctx['modelspecs'], xfspec=xfspec, figures=ctx['figures'], log=log_xf) savepath = save_data['savepath']
def equiv_vs_self(cellid, batch, modelname, LN_model, random_seed=1234): # evaluate old fit just to get est/val already split up xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname) # further divide est into two datasets # (how to do this? pick from epochs randomly?) est = ctx['est'] val = ctx['val'] epochs = est['stim'].epochs stims = np.array(ep.epoch_names_matching(epochs, 'STIM_')) indices = np.linspace(0, len(stims) - 1, len(stims), dtype=np.int) st0 = np.random.get_state() np.random.seed(random_seed) set1_idx = np.random.choice(indices, round(len(stims) / 2), replace=False) np.random.set_state(st0) mask = np.zeros_like(stims, np.bool) mask[set1_idx] = True set1_stims = stims[mask].tolist() set2_stims = stims[~mask].tolist() est1, est2 = est.split_by_epochs(set1_stims, set2_stims) # re-fit on the smaller est sets # (will have to re-fit LN model as well?) # also have to remove -sev- from modelname and add est-val in manually ctx1 = {'est': est1, 'val': val.copy()} ctx2 = {'est': est2, 'val': val.copy()} LN_ctx1 = copy.deepcopy(ctx1) LN_ctx2 = copy.deepcopy(ctx2) # modelname = modelname.replace('-sev', '') # LN_model = LN_model.replace('-sev', '') tm = 'none_' + '_'.join(modelname.split('_')[1:]) lm = 'none_' + '_'.join(LN_model.split('_')[1:]) # test model, est1 xfspec = xhelp.generate_xforms_spec(modelname=tm) ctx, _ = xforms.evaluate(xfspec, context=ctx1) test_pred1 = ctx['val']['pred'].as_continuous().flatten() # test model, est2 xfspec = xhelp.generate_xforms_spec(modelname=tm) ctx, _ = xforms.evaluate(xfspec, context=ctx2) test_pred2 = ctx['val']['pred'].as_continuous().flatten() # LN model, est1 xfspec = xhelp.generate_xforms_spec(modelname=lm) ctx, _ = xforms.evaluate(xfspec, context=ctx1) LN_pred1 = ctx['val']['pred'].as_continuous().flatten() # LN model, est2 xfspec = xhelp.generate_xforms_spec(modelname=lm) ctx, _ = xforms.evaluate(xfspec, context=ctx2) LN_pred2 = ctx['val']['pred'].as_continuous().flatten() # test equivalence on the new fits C1 = np.hstack((np.expand_dims(test_pred1, 0).transpose(), np.expand_dims(test_pred2, 0).transpose(), np.expand_dims(LN_pred1, 0).transpose())) p1 = partial_corr(C1)[0, 1] C2 = np.hstack((np.expand_dims(test_pred1, 0).transpose(), np.expand_dims(test_pred2, 0).transpose(), np.expand_dims(LN_pred2, 0).transpose())) p2 = partial_corr(C2)[0, 1] return 0.5 * (p1 + p2)
'nems_lbhb.old_xforms.xforms.init_from_keywords', { 'keywordstring': modelspecname, 'meta': meta } ]) xfspec.extend(oxfh.generate_fitter_xfspec(fitkey)) xfspec.append([ 'nems.analysis.api.standard_correlation', {}, ['est', 'val', 'modelspecs', 'rec'], ['modelspecs'] ]) if autoPlot: log.info('Generating summary plot ...') xfspec.append(['nems.xforms.plot_summary', {}]) else: recording_uri = nw.generate_recording_uri(cellid, batch, loadkey) xfspec = xhelp.generate_xforms_spec(recording_uri, modelname, meta) # Create a log stream set to the debug level; add it as a root log handler log_stream = io.StringIO() ch = logging.StreamHandler(log_stream) ch.setLevel(logging.DEBUG) fmt = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' formatter = logging.Formatter(fmt) ch.setFormatter(formatter) rootlogger = logging.getLogger() rootlogger.addHandler(ch) ctx = {} for xfa in xfspec: ctx = xforms.evaluate_step(xfa, ctx)
meta = {'batch': batch, 'cellid': cellid, 'modelname': modelname, 'loader': loadkey, 'fitkey': fitkey, 'modelspecname': modelspecname, 'username': '******', 'labgroup': 'lbhb', 'public': 1, 'githash': os.environ.get('CODEHASH', ''), 'recording': loadkey} if type(cellid) is list: meta['siteid'] = cellid[0][:7] # registry_args = {'cellid': cellid, 'batch': int(batch)} registry_args = {} xforms_init_context = {'cellid': cellid, 'batch': int(batch)} log.info("TODO: simplify generate_xforms_spec parameters") xfspec = generate_xforms_spec(recording_uri=None, modelname=modelname, meta=meta, xforms_kwargs=registry_args, xforms_init_context=xforms_init_context, autoPlot=autoPlot) log.info(xfspec) # actually do the loading, preprocessing, fit ctx, log_xf = xforms.evaluate(xfspec) # save some extra metadata modelspec = ctx['modelspec'] # this code may not be necessary any more. #destination = '{0}/{1}/{2}/{3}'.format( # get_setting('NEMS_RESULTS_DIR'), batch, cellid, modelspec.get_longname()) if type(cellid) is list: destination = os.path.join( get_setting('NEMS_RESULTS_DIR'), str(batch),