Example #1
0
def fit_to_simulation(fit_model, simulation_spec):
    '''
    Parameters:
    -----------
    fit_model : str
        Modelname to fit to the simulation.
    simulation_spec : NEMS ModelSpec
        Modelspec to base simulation on.

    Returns:
    --------
    ctx : dict
        Xforms context. See nems.xforms.

    '''
    rec = get_default_ctx()['rec']
    ctk_idx = find_module('contrast_kernel', simulation_spec)
    if ctk_idx is not None:
        simulation_spec[ctk_idx]['fn_kwargs']['evaluate_contrast'] = True
    new_resp = simulation_spec.evaluate(rec)['pred']
    rec['resp'] = new_resp

    # replace ozgf and ld with ldm
    modelname = '-'.join(fit_model.split('-')[2:])
    xfspec = xhelp.generate_xforms_spec(modelname=modelname)
    ctx, _ = xforms.evaluate(xfspec, context={'rec': rec})

    return ctx
Example #2
0
def fit_pop_model_xforms_baphy(cellid, batch, modelname, saveInDB=False):
    """
    Fits a NEMS population model using baphy data

    DEPRECATED ? Now should work for xhelp.fit_model_xform()

    """

    raise NotImplementedError("Replaced by xhelper function?")
    log.info("Preparing pop model: ({0},{1},{2})".format(
            cellid, batch, modelname))

    # Segment modelname for meta information
    kws = modelname.split("_")
    modelspecname = "-".join(kws[1:-1])

    loadkey = kws[0]
    fitkey = kws[-1]
    if type(cellid) is list:
        disp_cellid="_".join(cellid)
    else:
        disp_cellid=cellid

    meta = {'batch': batch, 'cellid': disp_cellid, 'modelname': modelname,
            'loader': loadkey, 'fitkey': fitkey,
            'modelspecname': modelspecname,
            'username': '******', 'labgroup': 'lbhb', 'public': 1,
            'githash': os.environ.get('CODEHASH', ''),
            'recording': loadkey}

    uri_key = nems.utils.escaped_split(loadkey, '-')[0]
    recording_uri = generate_recording_uri(cellid, batch, uri_key)

    # pass cellid information to xforms so that loader knows which cells
    # to load from recording_uri
    xfspec = xhelp.generate_xforms_spec(recording_uri, modelname, meta,
                                        xforms_kwargs={'cellid': cellid})

    # actually do the fit
    ctx, log_xf = xforms.evaluate(xfspec)

    # save some extra metadata
    modelspec = ctx['modelspec']

    destination = '/auto/data/nems_db/results/{0}/{1}/{2}/'.format(
            batch, disp_cellid, ms.get_modelspec_longname(modelspec))
    modelspec.meta['modelpath'] = destination
    modelspec.meta['figurefile'] = destination+'figure.0000.png'
    modelspec.meta.update(meta)

    # extra thing to save for pop model
    modelspec.meta['cellids'] = ctx['val']['resp'].chans

    # save results
    log.info('Saving modelspec(s) to {0} ...'.format(destination))
    save_data = xforms.save_analysis(destination,
                                     recording=ctx['rec'],
                                     modelspec=modelspec,
                                     xfspec=xfspec,
                                     figures=ctx['figures'],
                                     log=log_xf)
    savepath = save_data['savepath']

    if saveInDB:
        # save in database as well
        nd.update_results_table(modelspec)

    return savepath
Example #3
0
def fit_model_xforms_baphy(cellid, batch, modelname,
                           autoPlot=True, saveInDB=False):
    """
    DEPRECATED ? Now should work for xhelp.fit_model_xform()

    Fit a single NEMS model using data from baphy/celldb
    eg, 'ozgf100ch18_wc18x1_lvl1_fir15x1_dexp1_fit01'
    generates modelspec with 'wc18x1_lvl1_fir15x1_dexp1'

    based on this function in nems/scripts/fit_model.py
       def fit_model(recording_uri, modelstring, destination):

     xfspec = [
        ['nems.xforms.load_recordings', {'recording_uri_list': recordings}],
        ['nems.xforms.add_average_sig', {'signal_to_average': 'resp',
                                         'new_signalname': 'resp',
                                         'epoch_regex': '^STIM_'}],
        ['nems.xforms.split_by_occurrence_counts', {'epoch_regex': '^STIM_'}],
        ['nems.xforms.init_from_keywords', {'keywordstring': modelspecname}],
        ['nems.xforms.set_random_phi',  {}],
        ['nems.xforms.fit_basic',       {}],
        # ['nems.xforms.add_summary_statistics',    {}],
        ['nems.xforms.plot_summary',    {}],
        # ['nems.xforms.save_recordings', {'recordings': ['est', 'val']}],
        ['nems.xforms.fill_in_default_metadata',    {}],
    ]

    """
    raise NotImplementedError("Replaced by xhelper function?")
    raise DeprecationWarning("Replaced by xhelp.fit_model_xforms")
    log.info('Initializing modelspec(s) for cell/batch %s/%d...',
             cellid, int(batch))

    # Segment modelname for meta information
    kws = nems.utils.escaped_split(modelname, '_')

    old = False
    if (len(kws) > 3) or ((len(kws) == 3) and kws[1].startswith('stategain')
                          and not kws[1].startswith('stategain.')):
        # Check if modelname uses old format.
        log.info("Using old modelname format ... ")
        old = True
        modelspecname = nems.utils.escaped_join(kws[1:-1], '_')
    else:
        modelspecname = nems.utils.escaped_join(kws[1:-1], '-')
    loadkey = kws[0]
    fitkey = kws[-1]

    meta = {'batch': batch, 'cellid': cellid, 'modelname': modelname,
            'loader': loadkey, 'fitkey': fitkey, 'modelspecname': modelspecname,
            'username': '******', 'labgroup': 'lbhb', 'public': 1,
            'githash': os.environ.get('CODEHASH', ''),
            'recording': loadkey}

    if old:
        recording_uri = ogru(cellid, batch, loadkey)
        xfspec = oxfh.generate_loader_xfspec(loadkey, recording_uri)
        xfspec.append(['nems_lbhb.old_xforms.xforms.init_from_keywords',
                       {'keywordstring': modelspecname, 'meta': meta}])
        xfspec.extend(oxfh.generate_fitter_xfspec(fitkey))
        xfspec.append(['nems.analysis.api.standard_correlation', {},
                       ['est', 'val', 'modelspec', 'rec'], ['modelspec']])
        if autoPlot:
            log.info('Generating summary plot ...')
            xfspec.append(['nems.xforms.plot_summary', {}])
    else:
#        uri_key = nems.utils.escaped_split(loadkey, '-')[0]
#        recording_uri = generate_recording_uri(cellid, batch, uri_key)
        log.info("DONE? Moved handling of registry_args to xforms_init_context")
        recording_uri = None

        # registry_args = {'cellid': cellid, 'batch': int(batch)}
        registry_args = {}
        xforms_init_context = {'cellid': cellid, 'batch': int(batch)}

        xfspec = xhelp.generate_xforms_spec(recording_uri, modelname, meta,
                                            xforms_kwargs=registry_args,
                                            xforms_init_context=xforms_init_context)
        log.info(xfspec)

    # actually do the loading, preprocessing, fit
    ctx, log_xf = xforms.evaluate(xfspec)

    # save some extra metadata
    modelspec = ctx['modelspec']

    # this code may not be necessary any more.
    destination = '/auto/data/nems_db/results/{0}/{1}/{2}/'.format(
            batch, cellid, ms.get_modelspec_longname(modelspec))
    modelspec.meta['modelpath'] = destination
    modelspec.meta['figurefile'] = destination+'figure.0000.png'
    modelspec.meta.update(meta)

    # save results
    log.info('Saving modelspec(s) to {0} ...'.format(destination))
    save_data = xforms.save_analysis(destination,
                                     recording=ctx['rec'],
                                     modelspec=modelspec,
                                     xfspec=xfspec,
                                     figures=ctx['figures'],
                                     log=log_xf)
    savepath = save_data['savepath']

    # save in database as well
    if saveInDB:
        # TODO : db results finalized?
        nd.update_results_table(modelspec)

    return savepath
Example #4
0
    'modelname': modelname,
    'loader': loadkey,
    'fitkey': fitkey,
    'modelspecname': modelspecname,
    'username': '******',
    'labgroup': 'lbhb',
    'public': 1,
    'githash': os.environ.get('CODEHASH', ''),
    'recording': loadkey
}

uri_key = escaped_split(loadkey, '-')[0]
recording_uri = generate_recording_uri(cellid, batch, uri_key)
registry_args = {'cellid': cellid, 'batch': int(batch)}
xfspec = xhelp.generate_xforms_spec(modelname=modelname,
                                    meta=meta,
                                    xforms_kwargs=registry_args)

# actually do the fit
ctx = {}
for i, xfa in enumerate(xfspec):
    ctx = xforms.evaluate_step(xfa, ctx)

m = ctx['modelspec']
e = ctx['est']
v = ctx['val']
r = ctx['rec']

p = m.phi()

# Plot spikes vs sim to check model behavior
Example #5
0
        "dlog.f"
        # Spectral filter (nems.modules.weight_channels)
        "-wc.18x1.g"
        # Temporal filter (nems.modules.fir)
        "-fir.1x15"
        # Scale, currently init to 1.
        "-scl.1"
        # Level shift, usually init to mean response (nems.modules.levelshift)
        "-lvl.1"
        # Nonlinearity (nems.modules.nonlinearity -> double_exponential)
        #"-dexp.1"
        "_"  # modules -> fitters
        # Set initial values and do a rough "pre-fit"
        # Initialize fir coeffs to L2-norm of random values
        "-init.lnp"#.L2f"
        # Do the full fits
        "-lnp.t5"
        #"-nestspec"
        )

result_dict = {}
for name, stim in stim_dict.items():
    ctx = {'rec': stim}
    xfspec = xhelp.generate_xforms_spec(None, modelname, autoPlot=False)
    for i, xf in enumerate(xfspec):
        ctx = xforms.evaluate_step(xf, ctx)

    # Store tuple of ctx, error for each stim
    stim_length = ctx['val'][0]['stim'].shape[1]
    result_dict[name] = (ctx, _lnp_metric(ctx['val'][0])/stim_length)
Example #6
0
    if len(sys.argv)<3:
        print('Two parameters required.')
        print('Syntax: fit_single <modelname> <recording_uri>')
        exit(-1)

    modelname=sys.argv[1]
    recording_uri=sys.argv[2]
    
    log.info("Running fit_single(%s, %s)", modelname,recording_uri)
    
    meta = {'cellid': recording_uri, 'modelname': modelname,
        'githash': os.environ.get('CODEHASH', ''),
        'recording_uri': recording_uri}

    # set up sequence of events for fitting
    xfspec = xform_helper.generate_xforms_spec(recording_uri, modelname, meta=meta)
    
    # actually do the fit
    ctx, log_xf = xforms.evaluate(xfspec)

    # save results
    destination = os.path.dirname(recording_uri)
    log.info('Saving modelspec(s) to %s ...', destination)
    save_data = xforms.save_analysis(destination,
                                     recording=ctx['rec'],
                                     modelspecs=ctx['modelspecs'],
                                     xfspec=xfspec,
                                     figures=ctx['figures'],
                                     log=log_xf)
    savepath = save_data['savepath']
Example #7
0
def equiv_vs_self(cellid, batch, modelname, LN_model, random_seed=1234):
    # evaluate old fit just to get est/val already split up
    xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname)

    # further divide est into two datasets
    # (how to do this?  pick from epochs randomly?)
    est = ctx['est']
    val = ctx['val']
    epochs = est['stim'].epochs
    stims = np.array(ep.epoch_names_matching(epochs, 'STIM_'))
    indices = np.linspace(0, len(stims) - 1, len(stims), dtype=np.int)

    st0 = np.random.get_state()
    np.random.seed(random_seed)
    set1_idx = np.random.choice(indices, round(len(stims) / 2), replace=False)
    np.random.set_state(st0)

    mask = np.zeros_like(stims, np.bool)
    mask[set1_idx] = True
    set1_stims = stims[mask].tolist()
    set2_stims = stims[~mask].tolist()

    est1, est2 = est.split_by_epochs(set1_stims, set2_stims)

    # re-fit on the smaller est sets
    # (will have to re-fit LN model as well?)
    # also have to remove -sev- from modelname and add est-val in manually
    ctx1 = {'est': est1, 'val': val.copy()}
    ctx2 = {'est': est2, 'val': val.copy()}
    LN_ctx1 = copy.deepcopy(ctx1)
    LN_ctx2 = copy.deepcopy(ctx2)
    #    modelname = modelname.replace('-sev', '')
    #    LN_model = LN_model.replace('-sev', '')
    tm = 'none_' + '_'.join(modelname.split('_')[1:])
    lm = 'none_' + '_'.join(LN_model.split('_')[1:])

    # test model, est1
    xfspec = xhelp.generate_xforms_spec(modelname=tm)
    ctx, _ = xforms.evaluate(xfspec, context=ctx1)
    test_pred1 = ctx['val']['pred'].as_continuous().flatten()

    # test model, est2
    xfspec = xhelp.generate_xforms_spec(modelname=tm)
    ctx, _ = xforms.evaluate(xfspec, context=ctx2)
    test_pred2 = ctx['val']['pred'].as_continuous().flatten()

    # LN model, est1
    xfspec = xhelp.generate_xforms_spec(modelname=lm)
    ctx, _ = xforms.evaluate(xfspec, context=ctx1)
    LN_pred1 = ctx['val']['pred'].as_continuous().flatten()

    # LN model, est2
    xfspec = xhelp.generate_xforms_spec(modelname=lm)
    ctx, _ = xforms.evaluate(xfspec, context=ctx2)
    LN_pred2 = ctx['val']['pred'].as_continuous().flatten()

    # test equivalence on the new fits
    C1 = np.hstack((np.expand_dims(test_pred1, 0).transpose(),
                    np.expand_dims(test_pred2, 0).transpose(),
                    np.expand_dims(LN_pred1, 0).transpose()))
    p1 = partial_corr(C1)[0, 1]

    C2 = np.hstack((np.expand_dims(test_pred1, 0).transpose(),
                    np.expand_dims(test_pred2, 0).transpose(),
                    np.expand_dims(LN_pred2, 0).transpose()))
    p2 = partial_corr(C2)[0, 1]

    return 0.5 * (p1 + p2)
Example #8
0
        'nems_lbhb.old_xforms.xforms.init_from_keywords', {
            'keywordstring': modelspecname,
            'meta': meta
        }
    ])
    xfspec.extend(oxfh.generate_fitter_xfspec(fitkey))
    xfspec.append([
        'nems.analysis.api.standard_correlation', {},
        ['est', 'val', 'modelspecs', 'rec'], ['modelspecs']
    ])
    if autoPlot:
        log.info('Generating summary plot ...')
        xfspec.append(['nems.xforms.plot_summary', {}])
else:
    recording_uri = nw.generate_recording_uri(cellid, batch, loadkey)
    xfspec = xhelp.generate_xforms_spec(recording_uri, modelname, meta)

# Create a log stream set to the debug level; add it as a root log handler
log_stream = io.StringIO()
ch = logging.StreamHandler(log_stream)
ch.setLevel(logging.DEBUG)
fmt = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
formatter = logging.Formatter(fmt)
ch.setFormatter(formatter)
rootlogger = logging.getLogger()
rootlogger.addHandler(ch)

ctx = {}
for xfa in xfspec:
    ctx = xforms.evaluate_step(xfa, ctx)
Example #9
0
meta = {'batch': batch, 'cellid': cellid, 'modelname': modelname,
        'loader': loadkey, 'fitkey': fitkey, 'modelspecname': modelspecname,
        'username': '******', 'labgroup': 'lbhb', 'public': 1,
        'githash': os.environ.get('CODEHASH', ''),
        'recording': loadkey}

if type(cellid) is list:
    meta['siteid'] = cellid[0][:7]

# registry_args = {'cellid': cellid, 'batch': int(batch)}
registry_args = {}
xforms_init_context = {'cellid': cellid, 'batch': int(batch)}

log.info("TODO: simplify generate_xforms_spec parameters")
xfspec = generate_xforms_spec(recording_uri=None, modelname=modelname,
                              meta=meta,  xforms_kwargs=registry_args,
                              xforms_init_context=xforms_init_context,
                              autoPlot=autoPlot)
log.info(xfspec)

# actually do the loading, preprocessing, fit
ctx, log_xf = xforms.evaluate(xfspec)

# save some extra metadata
modelspec = ctx['modelspec']

# this code may not be necessary any more.
#destination = '{0}/{1}/{2}/{3}'.format(
#    get_setting('NEMS_RESULTS_DIR'), batch, cellid, modelspec.get_longname())
if type(cellid) is list:
    destination = os.path.join(
        get_setting('NEMS_RESULTS_DIR'), str(batch),