示例#1
0
def fit_to_simulation(fit_model, simulation_spec):
    '''
    Parameters:
    -----------
    fit_model : str
        Modelname to fit to the simulation.
    simulation_spec : NEMS ModelSpec
        Modelspec to base simulation on.

    Returns:
    --------
    ctx : dict
        Xforms context. See nems.xforms.

    '''
    rec = get_default_ctx()['rec']
    ctk_idx = find_module('contrast_kernel', simulation_spec)
    if ctk_idx is not None:
        simulation_spec[ctk_idx]['fn_kwargs']['evaluate_contrast'] = True
    new_resp = simulation_spec.evaluate(rec)['pred']
    rec['resp'] = new_resp

    # replace ozgf and ld with ldm
    modelname = '-'.join(fit_model.split('-')[2:])
    xfspec = xhelp.generate_xforms_spec(modelname=modelname)
    ctx, _ = xforms.evaluate(xfspec, context={'rec': rec})

    return ctx
示例#2
0
文件: fit_models.py 项目: LBHB/NEMS
def fit_model(recording_uri, modelstring, destination):
    '''
    Fit a single model and save it to nems_db.
    '''
    recordings = [recording_uri]

    xfspec = [
        ['nems.xforms.load_recordings', {'recording_uri_list': recordings}],
        ['nems.xforms.add_average_sig', {'signal_to_average': 'resp',
                                         'new_signalname': 'resp',
                                         'epoch_regex': '^STIM_'}],
        ['nems.xforms.split_by_occurrence_counts', {'epoch_regex': '^STIM_'}],
        ['nems.xforms.init_from_keywords', {'keywordstring': modelstring}],
        #['nems.xforms.set_random_phi',  {}],
        ['nems.xforms.fit_basic',  {}],
        # ['nems.xforms.add_summary_statistics',    {}],
        ['nems.xforms.plot_summary',    {}]
    ]

    ctx, log = xforms.evaluate(xfspec)

    xforms.save_analysis(destination,
                         recording=ctx['rec'],
                         modelspecs=ctx['modelspecs'],
                         xfspec=xfspec,
                         figures=ctx['figures'],
                         log=log)
示例#3
0
def get_model_preds(cellid, batch, modelname):
    xf, ctx = xhelp.load_model_xform(cellid,
                                     batch,
                                     modelname,
                                     eval_model=False)
    ctx, l = xforms.evaluate(xf, ctx, stop=-1)
    #ctx, l = oxf.evaluate(xf, ctx, stop=-1)

    return xf, ctx
示例#4
0
def load_model_baphy_xform(
        cellid="chn020f-b1",
        batch=271,
        modelname="ozgf100ch18_wc18x1_fir15x1_lvl1_dexp1_fit01",
        eval=True):

    logging.info('Loading modelspecs...')
    d = nd.get_results_file(batch, [modelname], [cellid])
    savepath = d['modelpath'][0]

    xfspec = xforms.load_xform(savepath + 'xfspec.json')
    mspath = savepath + 'modelspec.0000.json'
    context = xforms.load_modelspecs([], uris=[mspath], IsReload=False)

    context['IsReload'] = True
    ctx, log_xf = xforms.evaluate(xfspec, context)

    return ctx
示例#5
0
def reload_model(model_uri):
    '''
    Reloads an xspec and modelspec that were saved in some directory somewhere.
    This recreates the context that occurred during the fit.
    Passes additional context {'IsReload': True}, which xforms should react to
    if they are not intended to be run on a reload.
    '''
    xfspec_uri = model_uri + 'xfspec.json'

    # TODO: instead of just reading the first modelspec, read ALL of the modelspecs
    # I'm not sure how to know how many there are without a directory listing!
    modelspec_uri = model_uri + 'modelspec.0000.json'

    xfspec = load_resource(xfspec_uri)
    modelspec = load_resource(modelspec_uri)

    ctx, reloadlog = xforms.evaluate(xfspec, {
        'IsReload': True,
        'modelspecs': [modelspec]
    })

    return ctx
示例#6
0
文件: test_xforms.py 项目: LBHB/NEMS
def context():
    modelspecname = 'fir.1x15-lvl.1'
    load_command = 'nems.demo.loaders.dummy_loader'
    meta = {'cellid': "DUMMY01", 'batch': 0, 'modelname': modelspecname,
            'exptid': "DUMMY"}

    xfspec = []
    xfspec.append(['nems.xforms.load_recording_wrapper',
                   {'load_command': load_command, 'exptid': meta['exptid'],
                    'save_cache': False}])
    xfspec.append(['nems.xforms.split_at_time',
                   {'valfrac': 0.2}])
    xfspec.append(['nems.xforms.init_from_keywords',
                   {'keywordstring': modelspecname, 'meta': meta}])
    xfspec.append(['nems.xforms.fit_basic', {}])
    xfspec.append(['nems.xforms.predict', {}])
    xfspec.append(['nems.xforms.add_summary_statistics', {}])

    ctx = {}
    ctx, xf_log = xforms.evaluate(xfspec, ctx)

    return ctx
示例#7
0
def fit_model_xforms_baphy(cellid, batch, modelname,
                           autoPlot=True, saveInDB=False):
    """
    DEPRECATED ? Now should work for xhelp.fit_model_xform()

    Fit a single NEMS model using data from baphy/celldb
    eg, 'ozgf100ch18_wc18x1_lvl1_fir15x1_dexp1_fit01'
    generates modelspec with 'wc18x1_lvl1_fir15x1_dexp1'

    based on this function in nems/scripts/fit_model.py
       def fit_model(recording_uri, modelstring, destination):

     xfspec = [
        ['nems.xforms.load_recordings', {'recording_uri_list': recordings}],
        ['nems.xforms.add_average_sig', {'signal_to_average': 'resp',
                                         'new_signalname': 'resp',
                                         'epoch_regex': '^STIM_'}],
        ['nems.xforms.split_by_occurrence_counts', {'epoch_regex': '^STIM_'}],
        ['nems.xforms.init_from_keywords', {'keywordstring': modelspecname}],
        ['nems.xforms.set_random_phi',  {}],
        ['nems.xforms.fit_basic',       {}],
        # ['nems.xforms.add_summary_statistics',    {}],
        ['nems.xforms.plot_summary',    {}],
        # ['nems.xforms.save_recordings', {'recordings': ['est', 'val']}],
        ['nems.xforms.fill_in_default_metadata',    {}],
    ]

    """
    raise NotImplementedError("Replaced by xhelper function?")
    raise DeprecationWarning("Replaced by xhelp.fit_model_xforms")
    log.info('Initializing modelspec(s) for cell/batch %s/%d...',
             cellid, int(batch))

    # Segment modelname for meta information
    kws = nems.utils.escaped_split(modelname, '_')

    old = False
    if (len(kws) > 3) or ((len(kws) == 3) and kws[1].startswith('stategain')
                          and not kws[1].startswith('stategain.')):
        # Check if modelname uses old format.
        log.info("Using old modelname format ... ")
        old = True
        modelspecname = nems.utils.escaped_join(kws[1:-1], '_')
    else:
        modelspecname = nems.utils.escaped_join(kws[1:-1], '-')
    loadkey = kws[0]
    fitkey = kws[-1]

    meta = {'batch': batch, 'cellid': cellid, 'modelname': modelname,
            'loader': loadkey, 'fitkey': fitkey, 'modelspecname': modelspecname,
            'username': '******', 'labgroup': 'lbhb', 'public': 1,
            'githash': os.environ.get('CODEHASH', ''),
            'recording': loadkey}

    if old:
        recording_uri = ogru(cellid, batch, loadkey)
        xfspec = oxfh.generate_loader_xfspec(loadkey, recording_uri)
        xfspec.append(['nems_lbhb.old_xforms.xforms.init_from_keywords',
                       {'keywordstring': modelspecname, 'meta': meta}])
        xfspec.extend(oxfh.generate_fitter_xfspec(fitkey))
        xfspec.append(['nems.analysis.api.standard_correlation', {},
                       ['est', 'val', 'modelspec', 'rec'], ['modelspec']])
        if autoPlot:
            log.info('Generating summary plot ...')
            xfspec.append(['nems.xforms.plot_summary', {}])
    else:
#        uri_key = nems.utils.escaped_split(loadkey, '-')[0]
#        recording_uri = generate_recording_uri(cellid, batch, uri_key)
        log.info("DONE? Moved handling of registry_args to xforms_init_context")
        recording_uri = None

        # registry_args = {'cellid': cellid, 'batch': int(batch)}
        registry_args = {}
        xforms_init_context = {'cellid': cellid, 'batch': int(batch)}

        xfspec = xhelp.generate_xforms_spec(recording_uri, modelname, meta,
                                            xforms_kwargs=registry_args,
                                            xforms_init_context=xforms_init_context)
        log.info(xfspec)

    # actually do the loading, preprocessing, fit
    ctx, log_xf = xforms.evaluate(xfspec)

    # save some extra metadata
    modelspec = ctx['modelspec']

    # this code may not be necessary any more.
    destination = '/auto/data/nems_db/results/{0}/{1}/{2}/'.format(
            batch, cellid, ms.get_modelspec_longname(modelspec))
    modelspec.meta['modelpath'] = destination
    modelspec.meta['figurefile'] = destination+'figure.0000.png'
    modelspec.meta.update(meta)

    # save results
    log.info('Saving modelspec(s) to {0} ...'.format(destination))
    save_data = xforms.save_analysis(destination,
                                     recording=ctx['rec'],
                                     modelspec=modelspec,
                                     xfspec=xfspec,
                                     figures=ctx['figures'],
                                     log=log_xf)
    savepath = save_data['savepath']

    # save in database as well
    if saveInDB:
        # TODO : db results finalized?
        nd.update_results_table(modelspec)

    return savepath
示例#8
0
        'epoch_regex': '^STIM_'
    }])
# xfspec.append(['nems.xforms.average_away_stim_occurrences', {}])
xfspec.append([
    'nems.xforms.init_from_keywords', {
        'keywordstring': modelspec_name,
        'meta': meta
    }
])
xfspec.append(['nems.xforms.fit_basic_init', {}])
xfspec.append(['nems.xforms.fit_basic', {}])
xfspec.append(['nems.xforms.predict', {}])
xfspec.append(['nems.xforms.add_summary_statistics', {}])
xfspec.append(['nems.xforms.plot_summary', {}])

ctx, log_xf = xforms.evaluate(xfspec)
modelspecs = ctx['modelspecs']
destination = '/auto/data/nems_db/results/{0}/{1}/{2}/'.format(
    batch, cellid, ms.get_modelspec_longname(modelspecs[0]))
modelspecs[0][0]['meta']['modelpath'] = destination
modelspecs[0][0]['meta']['figurefile'] = destination + 'figure.0000.png'
modelspecs[0][0]['meta'].update(meta)

save_data = xforms.save_analysis(destination,
                                 recording=ctx['rec'],
                                 modelspecs=modelspecs,
                                 xfspec=xfspec,
                                 figures=ctx['figures'],
                                 log=log_xf)

savepath = save_data['savepath']
示例#9
0
def fit_model_xforms(recording_uri,
                     modelname,
                     fitter_kwargs=None,
                     autoPlot=True):
    """
    Fits a single NEMS model
    eg, 'ozgf100ch18_wc18x1_lvl1_fir15x1_dexp1_fit01'
    generates modelspec with 'wc18x1_lvl1_fir1x15_dexp1'

    based on fit_model function in nems/scripts/fit_model.py

    example xfspec:
     xfspec = [
        ['nems.xforms.load_recordings', {'recording_uri_list': recordings}],
        ['nems.xforms.add_average_sig', {'signal_to_average': 'resp',
                                         'new_signalname': 'resp',
                                         'epoch_regex': '^STIM_'}],
        ['nems.xforms.split_by_occurrence_counts', {'epoch_regex': '^STIM_'}],
        ['nems.xforms.init_from_keywords', {'keywordstring': modelspecname}],
        ['nems.xforms.set_random_phi',  {}],
        ['nems.xforms.fit_basic',       {}],
        # ['nems.xforms.add_summary_statistics',    {}],
        ['nems.xforms.plot_summary',    {}],
        # ['nems.xforms.save_recordings', {'recordings': ['est', 'val']}],
        ['nems.xforms.fill_in_default_metadata',    {}],
     ]
    """

    log.info('Initializing modelspec(s) for recording/model {0}/{1}...'.format(
        recording_uri, modelname))

    # parse modelname
    kws = modelname.split("_")
    loader = kws[0]
    modelspecname = "_".join(kws[1:-1])
    fitter = kws[-1]

    meta = {
        'modelname': modelname,
        'loader': loader,
        'fitter': fitter,
        'modelspecname': modelspecname
    }

    # TODO: These should be added to meta by nems_db after ctx is returned.
    #       'username': '******', 'labgroup': 'lbhb', 'public': 1,
    #       'githash': os.environ.get('CODEHASH', ''),
    #       'recording': loader}

    # Generate the xfspec, which defines the sequence of events
    # to run through (like a packaged-up script)

    # 1) Load the data
    xfspec = generate_loader_xfspec(loader, recording_uri)

    # 2) generate a modelspec
    xfspec.append([
        'nems.xforms.init_from_keywords', {
            'keywordstring': modelspecname,
            'meta': meta
        }
    ])

    # 3) fit the data
    xfspec += generate_fitter_xfspec(fitter, fitter_kwargs)

    # 4) add some performance statistics
    xfspec.append([
        'nems.analysis.api.standard_correlation', {},
        ['est', 'val', 'modelspecs'], ['modelspecs']
    ])

    # 5) generate plots
    if autoPlot:
        log.info('Generating summary plot...')
        xfspec.append(['nems.xforms.plot_summary', {}])

    # Now that the xfspec is assembled, run through it
    # in order to get the fitted modelspec, evaluated recording, etc.
    # (all packaged up in the ctx dictionary).
    ctx, log_xf = xforms.evaluate(xfspec)

    return ctx
示例#10
0
def fit_model_xforms_baphy(cellid,
                           batch,
                           modelname,
                           autoPlot=True,
                           saveInDB=False):
    """
    Fits a single NEMS model using data from baphy/celldb
    eg, 'ozgf100ch18_wc18x1_lvl1_fir15x1_dexp1_fit01'
    generates modelspec with 'wc18x1_lvl1_fir15x1_dexp1'

    based on this function in nems/scripts/fit_model.py
       def fit_model(recording_uri, modelstring, destination):

     xfspec = [
        ['nems.xforms.load_recordings', {'recording_uri_list': recordings}],
        ['nems.xforms.add_average_sig', {'signal_to_average': 'resp',
                                         'new_signalname': 'resp',
                                         'epoch_regex': '^STIM_'}],
        ['nems.xforms.split_by_occurrence_counts', {'epoch_regex': '^STIM_'}],
        ['nems.xforms.init_from_keywords', {'keywordstring': modelspecname}],
        ['nems.xforms.set_random_phi',  {}],
        ['nems.xforms.fit_basic',       {}],
        # ['nems.xforms.add_summary_statistics',    {}],
        ['nems.xforms.plot_summary',    {}],
        # ['nems.xforms.save_recordings', {'recordings': ['est', 'val']}],
        ['nems.xforms.fill_in_default_metadata',    {}],
    ]
"""

    log.info('Initializing modelspec(s) for cell/batch {0}/{1}...'.format(
        cellid, batch))

    # parse modelname
    kws = modelname.split("_")
    loader = kws[0]
    modelspecname = "_".join(kws[1:-1])
    fitter = kws[-1]

    # generate xfspec, which defines sequence of events to load data,
    # generate modelspec, fit data, plot results and save
    xfspec = generate_loader_xfspec(cellid, batch, loader)

    xfspec.append(
        ['nems.xforms.init_from_keywords', {
            'keywordstring': modelspecname
        }])

    # parse the fit spec: Use gradient descent on whole data set(Fast)
    if fitter == "fit01":
        # prefit strf
        log.info("Prefitting STRF without other modules...")
        xfspec.append(['nems.xforms.fit_basic_init', {}])
        xfspec.append(['nems.xforms.fit_basic', {}])
    elif fitter == "fitjk01":
        # prefit strf
        log.info("Prefitting STRF without NL then JK...")
        xfspec.append(['nems.xforms.fit_basic_init', {}])
        xfspec.append(['nems.xforms.split_for_jackknife', {'njacks': 10}])
        xfspec.append(['nems.xforms.fit_basic', {}])
    elif fitter == "fit02":
        # no pre-fit
        log.info("Performing full fit...")
        xfspec.append(['nems.xforms.fit_basic', {}])
    else:
        raise ValueError('unknown fitter string')

    xfspec.append(['nems.xforms.add_summary_statistics', {}])

    if autoPlot:
        # GENERATE PLOTS
        log.info('Generating summary plot...')
        xfspec.append(['nems.xforms.plot_summary', {}])

    # actually do the fit
    ctx, log_xf = xforms.evaluate(xfspec)

    # save some extra metadata
    modelspecs = ctx['modelspecs']

    if 'CODEHASH' in os.environ.keys():
        githash = os.environ['CODEHASH']
    else:
        githash = ""
    meta = {
        'batch': batch,
        'cellid': cellid,
        'modelname': modelname,
        'loader': loader,
        'fitter': fitter,
        'modelspecname': modelspecname,
        'username': '******',
        'labgroup': 'lbhb',
        'public': 1,
        'githash': githash,
        'recording': loader
    }
    if not 'meta' in modelspecs[0][0].keys():
        modelspecs[0][0]['meta'] = meta
    else:
        modelspecs[0][0]['meta'].update(meta)
    destination = '/auto/data/tmp/modelspecs/{0}/{1}/{2}/'.format(
        batch, cellid, ms.get_modelspec_longname(modelspecs[0]))
    modelspecs[0][0]['meta']['modelpath'] = destination
    modelspecs[0][0]['meta']['figurefile'] = destination + 'figure.0000.png'

    # save results

    xforms.save_analysis(destination,
                         recording=ctx['rec'],
                         modelspecs=modelspecs,
                         xfspec=xfspec,
                         figures=ctx['figures'],
                         log=log_xf)
    log.info('Saved modelspec(s) to {0} ...'.format(destination))

    # save in database as well
    if saveInDB:
        # TODO : db results
        nd.update_results_table(modelspecs[0])

    return ctx
示例#11
0
def _model_step_plot(cellid, batch, modelnames, factors, state_colors=None):
    """
    state_colors : N x 2 list
       color spec for high/low lines in each of the N states
    """
    global line_colors
    global fill_colors

    modelname_p0b0, modelname_p0b, modelname_pb0, modelname_pb = \
       modelnames
    factor0, factor1, factor2 = factors

    xf_p0b0, ctx_p0b0 = xhelp.load_model_xform(cellid,
                                               batch,
                                               modelname_p0b0,
                                               eval_model=False)
    # ctx_p0b0, l = xforms.evaluate(xf_p0b0, ctx_p0b0, stop=-2)

    ctx_p0b0, l = xforms.evaluate(xf_p0b0, ctx_p0b0, start=0, stop=-2)

    xf_p0b, ctx_p0b = xhelp.load_model_xform(cellid,
                                             batch,
                                             modelname_p0b,
                                             eval_model=False)
    ctx_p0b, l = xforms.evaluate(xf_p0b, ctx_p0b, start=0, stop=-2)

    xf_pb0, ctx_pb0 = xhelp.load_model_xform(cellid,
                                             batch,
                                             modelname_pb0,
                                             eval_model=False)
    #ctx_pb0['rec'] = ctx_p0b0['rec'].copy()
    ctx_pb0, l = xforms.evaluate(xf_pb0, ctx_pb0, start=0, stop=-2)

    xf_pb, ctx_pb = xhelp.load_model_xform(cellid,
                                           batch,
                                           modelname_pb,
                                           eval_model=False)
    #ctx_pb['rec'] = ctx_p0b0['rec'].copy()
    ctx_pb, l = xforms.evaluate(xf_pb, ctx_pb, start=0, stop=-2)

    # organize predictions by different models
    val = ctx_pb['val'][0].copy()

    # val['pred_p0b0'] = ctx_p0b0['val'][0]['pred'].copy()
    val['pred_p0b'] = ctx_p0b['val'][0]['pred'].copy()
    val['pred_pb0'] = ctx_pb0['val'][0]['pred'].copy()

    state_var_list = val['state'].chans

    pred_mod = np.zeros([len(state_var_list), 2])
    pred_mod_full = np.zeros([len(state_var_list), 2])
    resp_mod_full = np.zeros([len(state_var_list), 1])

    state_std = np.nanstd(val['state'].as_continuous(), axis=1, keepdims=True)
    for i, var in enumerate(state_var_list):
        if state_std[i]:
            # actual response modulation index for each state var
            resp_mod_full[i] = state_mod_index(val,
                                               epoch='REFERENCE',
                                               psth_name='resp',
                                               state_chan=var)

            mod2_p0b = state_mod_index(val,
                                       epoch='REFERENCE',
                                       psth_name='pred_p0b',
                                       state_chan=var)
            mod2_pb0 = state_mod_index(val,
                                       epoch='REFERENCE',
                                       psth_name='pred_pb0',
                                       state_chan=var)
            mod2_pb = state_mod_index(val,
                                      epoch='REFERENCE',
                                      psth_name='pred',
                                      state_chan=var)

            pred_mod[i] = np.array([mod2_pb - mod2_p0b, mod2_pb - mod2_pb0])
            pred_mod_full[i] = np.array([mod2_pb0, mod2_p0b])

    pred_mod_norm = pred_mod / (state_std + (state_std == 0).astype(float))
    pred_mod_full_norm = pred_mod_full / (state_std +
                                          (state_std == 0).astype(float))

    if 'each_passive' in factors:
        psth_names_ctl = ["pred_p0b"]
        factors.remove('each_passive')
        for v in state_var_list:
            if v.startswith('FILE_'):
                factors.append(v)
                psth_names_ctl.append("pred_pb0")
    else:
        psth_names_ctl = ["pred_p0b", "pred_pb0"]

    col_count = len(factors) - 1
    if state_colors is None:
        state_colors = [[None, None]] * col_count

    fh = plt.figure(figsize=(8, 8))
    ax = plt.subplot(4, 1, 1)
    nplt.state_vars_timeseries(val,
                               ctx_pb['modelspecs'][0],
                               state_colors=[s[1] for s in state_colors])
    ax.set_title("{}/{} - {}".format(cellid, batch, modelname_pb))
    ax.set_ylabel("{} r={:.3f}".format(
        factor0, ctx_p0b0['modelspecs'][0][0]['meta']['r_test'][0]))
    nplt.ax_remove_box(ax)

    for i, var in enumerate(factors[1:]):
        if var.startswith('FILE_'):
            varlbl = var[5:]
        else:
            varlbl = var
        ax = plt.subplot(4, col_count, col_count + i + 1)

        nplt.state_var_psth_from_epoch(val,
                                       epoch="REFERENCE",
                                       psth_name="resp",
                                       psth_name2=psth_names_ctl[i],
                                       state_chan=var,
                                       ax=ax,
                                       colors=state_colors[i])
        if i == 0:
            ax.set_ylabel("Control model")
            ax.set_title("{} ctl r={:.3f}".format(
                varlbl.lower(),
                ctx_p0b['modelspecs'][0][0]['meta']['r_test'][0]),
                         fontsize=6)
        else:
            ax.yaxis.label.set_visible(False)
            ax.set_title("{} ctl r={:.3f}".format(
                varlbl.lower(),
                ctx_pb0['modelspecs'][0][0]['meta']['r_test'][0]),
                         fontsize=6)
        if ax.legend_:
            ax.legend_.remove()
        ax.xaxis.label.set_visible(False)
        nplt.ax_remove_box(ax)

        ax = plt.subplot(4, col_count, col_count * 2 + i + 1)
        nplt.state_var_psth_from_epoch(val,
                                       epoch="REFERENCE",
                                       psth_name="resp",
                                       psth_name2="pred",
                                       state_chan=var,
                                       ax=ax,
                                       colors=state_colors[i])
        if i == 0:
            ax.set_ylabel("Full Model")
        else:
            ax.yaxis.label.set_visible(False)
        if ax.legend_:
            ax.legend_.remove()

        if psth_names_ctl[i] == "pred_p0b":
            j = 0
        else:
            j = 1

        ax.set_title("r={:.3f} rawmod={:.3f} umod={:.3f}".format(
            ctx_pb['modelspecs'][0][0]['meta']['r_test'][0],
            pred_mod_full_norm[i + 1][j], pred_mod_norm[i + 1][j]),
                     fontsize=6)

        if var == 'active':
            ax.legend(('pas', 'act'))
        elif var == 'pupil':
            ax.legend(('small', 'large'))
        elif var == 'PRE_PASSIVE':
            ax.legend(('act+post', 'pre'))
        elif var.startswith('FILE_'):
            ax.legend(('this', 'others'))
        nplt.ax_remove_box(ax)

    # EXTRA PANELS
    # figure out some basic aspects of tuning/selectivity for target vs.
    # reference:
    r = ctx_pb['rec']['resp']
    e = r.epochs
    fs = r.fs

    passive_epochs = r.get_epoch_indices("PASSIVE_EXPERIMENT")
    tar_names = ep.epoch_names_matching(e, "^TAR_")
    tar_resp = {}
    for tarname in tar_names:
        t = r.get_epoch_indices(tarname)
        t = ep.epoch_intersection(t, passive_epochs)
        tar_resp[tarname] = r.extract_epoch(t) * fs

    # only plot tar responses with max SNR or probe SNR
    keys = []
    for k in list(tar_resp.keys()):
        if k.endswith('0') | k.endswith('2'):
            keys.append(k)
    keys.sort()

    # assume the reference with most reps is the one overlapping the target
    groups = ep.group_epochs_by_occurrence_counts(e, '^STIM_')
    l = np.array(list(groups.keys()))
    hi = np.max(l)
    ref_name = groups[hi][0]
    t = r.get_epoch_indices(ref_name)
    t = ep.epoch_intersection(t, passive_epochs)
    ref_resp = r.extract_epoch(t) * fs

    t = r.get_epoch_indices('REFERENCE')
    t = ep.epoch_intersection(t, passive_epochs)
    all_ref_resp = r.extract_epoch(t) * fs

    prestimsilence = r.get_epoch_indices('PreStimSilence')
    prebins = prestimsilence[0, 1] - prestimsilence[0, 0]
    poststimsilence = r.get_epoch_indices('PostStimSilence')
    postbins = poststimsilence[0, 1] - poststimsilence[0, 0]
    durbins = ref_resp.shape[-1] - prebins

    spont = np.nanmean(all_ref_resp[:, 0, :prebins])
    ref_mean = np.nanmean(ref_resp[:, 0, prebins:durbins]) - spont
    all_ref_mean = np.nanmean(all_ref_resp[:, 0, prebins:durbins]) - spont
    #print(spont)
    #print(np.nanmean(ref_resp[:,0,prebins:-postbins]))
    ax1 = plt.subplot(4, 2, 7)
    ref_psth = [
        np.nanmean(ref_resp[:, 0, :], axis=0),
        np.nanmean(all_ref_resp[:, 0, :], axis=0)
    ]
    ll = [
        "{} {:.1f}".format(ref_name, ref_mean),
        "all refs {:.1f}".format(all_ref_mean)
    ]
    nplt.timeseries_from_vectors(ref_psth,
                                 fs=fs,
                                 legend=ll,
                                 ax=ax1,
                                 time_offset=prebins / fs)

    ax2 = plt.subplot(4, 2, 8)
    ll = []
    tar_mean = np.zeros(np.max([2, len(keys)])) * np.nan
    tar_psth = []
    for ii, k in enumerate(keys):
        tar_psth.append(np.nanmean(tar_resp[k][:, 0, :], axis=0))
        tar_mean[ii] = np.nanmean(tar_resp[k][:, 0, prebins:durbins]) - spont
        ll.append("{} {:.1f}".format(k, tar_mean[ii]))
    nplt.timeseries_from_vectors(tar_psth,
                                 fs=fs,
                                 legend=ll,
                                 ax=ax2,
                                 time_offset=prebins / fs)
    # plt.legend(ll, fontsize=6)

    ymin = np.min([ax1.get_ylim()[0], ax2.get_ylim()[0]])
    ymax = np.max([ax1.get_ylim()[1], ax2.get_ylim()[1]])
    ax1.set_ylim([ymin, ymax])
    ax2.set_ylim([ymin, ymax])
    nplt.ax_remove_box(ax1)
    nplt.ax_remove_box(ax2)

    plt.tight_layout()

    stats = {
        'cellid':
        cellid,
        'batch':
        batch,
        'modelnames':
        modelnames,
        'state_vars':
        state_var_list,
        'factors':
        factors,
        'r_test':
        np.array([
            ctx_p0b0['modelspecs'][0][0]['meta']['r_test'][0],
            ctx_p0b['modelspecs'][0][0]['meta']['r_test'][0],
            ctx_pb0['modelspecs'][0][0]['meta']['r_test'][0],
            ctx_pb['modelspecs'][0][0]['meta']['r_test'][0]
        ]),
        'se_test':
        np.array([
            ctx_p0b0['modelspecs'][0][0]['meta']['se_test'][0],
            ctx_p0b['modelspecs'][0][0]['meta']['se_test'][0],
            ctx_pb0['modelspecs'][0][0]['meta']['se_test'][0],
            ctx_pb['modelspecs'][0][0]['meta']['se_test'][0]
        ]),
        'r_floor':
        np.array([
            ctx_p0b0['modelspecs'][0][0]['meta']['r_floor'][0],
            ctx_p0b['modelspecs'][0][0]['meta']['r_floor'][0],
            ctx_pb0['modelspecs'][0][0]['meta']['r_floor'][0],
            ctx_pb['modelspecs'][0][0]['meta']['r_floor'][0]
        ]),
        'pred_mod':
        pred_mod.T,
        'pred_mod_full':
        pred_mod_full.T,
        'pred_mod_norm':
        pred_mod_norm.T,
        'pred_mod_full_norm':
        pred_mod_full_norm.T,
        'g':
        np.array([
            ctx_p0b0['modelspecs'][0][0]['phi']['g'],
            ctx_p0b['modelspecs'][0][0]['phi']['g'],
            ctx_pb0['modelspecs'][0][0]['phi']['g'],
            ctx_pb['modelspecs'][0][0]['phi']['g']
        ]),
        'b':
        np.array([
            ctx_p0b0['modelspecs'][0][0]['phi']['d'],
            ctx_p0b['modelspecs'][0][0]['phi']['d'],
            ctx_pb0['modelspecs'][0][0]['phi']['d'],
            ctx_pb['modelspecs'][0][0]['phi']['d']
        ]),
        'ref_all_resp':
        all_ref_mean,
        'ref_common_resp':
        ref_mean,
        'tar_max_resp':
        tar_mean[0],
        'tar_probe_resp':
        tar_mean[1]
    }

    return fh, stats
示例#12
0
def equiv_vs_self(cellid, batch, modelname, LN_model, random_seed=1234):
    # evaluate old fit just to get est/val already split up
    xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname)

    # further divide est into two datasets
    # (how to do this?  pick from epochs randomly?)
    est = ctx['est']
    val = ctx['val']
    epochs = est['stim'].epochs
    stims = np.array(ep.epoch_names_matching(epochs, 'STIM_'))
    indices = np.linspace(0, len(stims) - 1, len(stims), dtype=np.int)

    st0 = np.random.get_state()
    np.random.seed(random_seed)
    set1_idx = np.random.choice(indices, round(len(stims) / 2), replace=False)
    np.random.set_state(st0)

    mask = np.zeros_like(stims, np.bool)
    mask[set1_idx] = True
    set1_stims = stims[mask].tolist()
    set2_stims = stims[~mask].tolist()

    est1, est2 = est.split_by_epochs(set1_stims, set2_stims)

    # re-fit on the smaller est sets
    # (will have to re-fit LN model as well?)
    # also have to remove -sev- from modelname and add est-val in manually
    ctx1 = {'est': est1, 'val': val.copy()}
    ctx2 = {'est': est2, 'val': val.copy()}
    LN_ctx1 = copy.deepcopy(ctx1)
    LN_ctx2 = copy.deepcopy(ctx2)
    #    modelname = modelname.replace('-sev', '')
    #    LN_model = LN_model.replace('-sev', '')
    tm = 'none_' + '_'.join(modelname.split('_')[1:])
    lm = 'none_' + '_'.join(LN_model.split('_')[1:])

    # test model, est1
    xfspec = xhelp.generate_xforms_spec(modelname=tm)
    ctx, _ = xforms.evaluate(xfspec, context=ctx1)
    test_pred1 = ctx['val']['pred'].as_continuous().flatten()

    # test model, est2
    xfspec = xhelp.generate_xforms_spec(modelname=tm)
    ctx, _ = xforms.evaluate(xfspec, context=ctx2)
    test_pred2 = ctx['val']['pred'].as_continuous().flatten()

    # LN model, est1
    xfspec = xhelp.generate_xforms_spec(modelname=lm)
    ctx, _ = xforms.evaluate(xfspec, context=ctx1)
    LN_pred1 = ctx['val']['pred'].as_continuous().flatten()

    # LN model, est2
    xfspec = xhelp.generate_xforms_spec(modelname=lm)
    ctx, _ = xforms.evaluate(xfspec, context=ctx2)
    LN_pred2 = ctx['val']['pred'].as_continuous().flatten()

    # test equivalence on the new fits
    C1 = np.hstack((np.expand_dims(test_pred1, 0).transpose(),
                    np.expand_dims(test_pred2, 0).transpose(),
                    np.expand_dims(LN_pred1, 0).transpose()))
    p1 = partial_corr(C1)[0, 1]

    C2 = np.hstack((np.expand_dims(test_pred1, 0).transpose(),
                    np.expand_dims(test_pred2, 0).transpose(),
                    np.expand_dims(LN_pred2, 0).transpose()))
    p2 = partial_corr(C2)[0, 1]

    return 0.5 * (p1 + p2)
示例#13
0
def fit_xfspec(xfspec):
    # Now that the xfspec is assembled, run through it
    # in order to get the fitted modelspec, evaluated recording, etc.
    # (all packaged up in the ctx dictionary).
    ctx, log_xf = xforms.evaluate(xfspec)
    return ctx
示例#14
0
def fit_model_xform(cellid,
                    batch,
                    modelname,
                    autoPlot=True,
                    saveInDB=False,
                    returnModel=False,
                    recording_uri=None,
                    initial_context=None):
    """
    Fit a single NEMS model using data stored in database. First generates an xforms
    script based on modelname parameter and then evaluates it.
    :param cellid: cellid and batch specific dataset in database
    :param batch:
    :param modelname: string specifying model architecture, preprocessing
    and fit method
    :param autoPlot: generate summary plot when complete
    :param saveInDB: save results to Results table
    :param returnModel: boolean (default False). If False, return savepath
       if True return xfspec, ctx tuple
    :param recording_uri
    :return: savepath = path to saved results or (xfspec, ctx) tuple
    """
    startime = time.time()
    log.info('Initializing modelspec(s) for cell/batch %s/%d...', cellid,
             int(batch))

    # Segment modelname for meta information
    kws = escaped_split(modelname, '_')

    modelspecname = escaped_join(kws[1:-1], '-')
    loadkey = kws[0]
    fitkey = kws[-1]

    meta = {
        'batch': batch,
        'cellid': cellid,
        'modelname': modelname,
        'loader': loadkey,
        'fitkey': fitkey,
        'modelspecname': modelspecname,
        'username': '******',
        'labgroup': 'lbhb',
        'public': 1,
        'githash': os.environ.get('CODEHASH', ''),
        'recording': loadkey
    }
    if type(cellid) is list:
        meta['siteid'] = cellid[0][:7]

    # registry_args = {'cellid': cellid, 'batch': int(batch)}
    registry_args = {}
    xforms_init_context = {'cellid': cellid, 'batch': int(batch)}
    if initial_context is not None:
        xforms_init_context.update(initial_context)

    log.info("TODO: simplify generate_xforms_spec parameters")
    xfspec = generate_xforms_spec(recording_uri=recording_uri,
                                  modelname=modelname,
                                  meta=meta,
                                  xforms_kwargs=registry_args,
                                  xforms_init_context=xforms_init_context,
                                  autoPlot=autoPlot)
    log.debug(xfspec)

    # actually do the loading, preprocessing, fit
    if initial_context is None:
        initial_context = {}
    ctx, log_xf = xforms.evaluate(xfspec)  #, context=initial_context)

    # save some extra metadata
    modelspec = ctx['modelspec']

    if type(cellid) is list:
        cell_name = cellid[0].split("-")[0]
    else:
        cell_name = cellid

    if 'modelpath' not in modelspec.meta:
        prefix = get_setting('NEMS_RESULTS_DIR')
        destination = os.path.join(prefix, str(batch), cell_name,
                                   modelspec.get_longname())

        log.info(f'Setting modelpath to "{destination}"')
        modelspec.meta['modelpath'] = destination
        modelspec.meta['figurefile'] = os.path.join(destination,
                                                    'figure.0000.png')
    else:
        destination = modelspec.meta['modelpath']

    # figure out URI for location to save results (either file or http, depending on USE_NEMS_BAPHY_API)
    if get_setting('USE_NEMS_BAPHY_API'):
        prefix = 'http://' + get_setting('NEMS_BAPHY_API_HOST') + ":" + str(get_setting('NEMS_BAPHY_API_PORT')) + \
                 '/results'
        save_loc = str(
            batch) + '/' + cell_name + '/' + modelspec.get_longname()
        save_destination = prefix + '/' + save_loc
        # set the modelspec meta save locations to be the filesystem and not baphy
        modelspec.meta['modelpath'] = get_setting(
            'NEMS_RESULTS_DIR') + '/' + save_loc
        modelspec.meta['figurefile'] = modelspec.meta[
            'modelpath'] + '/' + 'figure.0000.png'
    else:
        save_destination = destination

    modelspec.meta['runtime'] = int(time.time() - startime)
    modelspec.meta.update(meta)

    if returnModel:
        # return fit, skip save!
        return xfspec, ctx

    # save results
    log.info('Saving modelspec(s) to {0} ...'.format(save_destination))
    if 'figures' in ctx.keys():
        figs = ctx['figures']
    else:
        figs = []
    save_data = xforms.save_analysis(save_destination,
                                     recording=ctx.get('rec'),
                                     modelspec=modelspec,
                                     xfspec=xfspec,
                                     figures=figs,
                                     log=log_xf,
                                     update_meta=False)

    # save in database as well
    if saveInDB:
        nd.update_results_table(modelspec)

    return save_data['savepath']
示例#15
0
import nems.recording as recording
import nems.epoch as ep
import nems.xforms as xforms
import nems.xform_helper as xhelp
import nems_lbhb.xform_wrappers as nw

cellid = "chn002h-a1"
batch = 259
uri = nw.generate_recording_uri(cellid, batch, "env.fs100")

modelname1 = "env100_dlog_fir2x15_lvl1_dexp1_basic"
#modelname2="env100_dlog_stp2_fir2x15_lvl1_dexp1_basic"
modelname2 = "env100_dlog_wcc2x2_stp2_fir2x15_lvl1_dexp1_basic"
modelname2 = "env100_dlog_wcc2x3_stp3_fir3x15_lvl1_dexp1_basic-shr"
xf1, ctx1 = xhelp.load_model_xform(cellid, batch, modelname1, eval_model=False)
ctx1, l = xforms.evaluate(xf1, ctx1, stop=-1)
xf2, ctx2 = xhelp.load_model_xform(cellid, batch, modelname2, eval_model=False)
ctx2, l = xforms.evaluate(xf2, ctx2, stop=-1)

rec = ctx1['rec']
val1 = ctx1['val'][0]
val2 = ctx2['val'][0]

resp = rec['resp'].rasterize()
pred1 = val1['pred']
pred2 = val2['pred']

epoch_regex = "^STIM_"

stim_epochs = ep.epoch_names_matching(resp.epochs, epoch_regex)
示例#16
0
def fit_pop_model_xforms_baphy(cellid, batch, modelname, saveInDB=False):
    """
    Fits a NEMS population model using baphy data

    DEPRECATED ? Now should work for xhelp.fit_model_xform()

    """

    raise NotImplementedError("Replaced by xhelper function?")
    log.info("Preparing pop model: ({0},{1},{2})".format(
            cellid, batch, modelname))

    # Segment modelname for meta information
    kws = modelname.split("_")
    modelspecname = "-".join(kws[1:-1])

    loadkey = kws[0]
    fitkey = kws[-1]
    if type(cellid) is list:
        disp_cellid="_".join(cellid)
    else:
        disp_cellid=cellid

    meta = {'batch': batch, 'cellid': disp_cellid, 'modelname': modelname,
            'loader': loadkey, 'fitkey': fitkey,
            'modelspecname': modelspecname,
            'username': '******', 'labgroup': 'lbhb', 'public': 1,
            'githash': os.environ.get('CODEHASH', ''),
            'recording': loadkey}

    uri_key = nems.utils.escaped_split(loadkey, '-')[0]
    recording_uri = generate_recording_uri(cellid, batch, uri_key)

    # pass cellid information to xforms so that loader knows which cells
    # to load from recording_uri
    xfspec = xhelp.generate_xforms_spec(recording_uri, modelname, meta,
                                        xforms_kwargs={'cellid': cellid})

    # actually do the fit
    ctx, log_xf = xforms.evaluate(xfspec)

    # save some extra metadata
    modelspec = ctx['modelspec']

    destination = '/auto/data/nems_db/results/{0}/{1}/{2}/'.format(
            batch, disp_cellid, ms.get_modelspec_longname(modelspec))
    modelspec.meta['modelpath'] = destination
    modelspec.meta['figurefile'] = destination+'figure.0000.png'
    modelspec.meta.update(meta)

    # extra thing to save for pop model
    modelspec.meta['cellids'] = ctx['val']['resp'].chans

    # save results
    log.info('Saving modelspec(s) to {0} ...'.format(destination))
    save_data = xforms.save_analysis(destination,
                                     recording=ctx['rec'],
                                     modelspec=modelspec,
                                     xfspec=xfspec,
                                     figures=ctx['figures'],
                                     log=log_xf)
    savepath = save_data['savepath']

    if saveInDB:
        # save in database as well
        nd.update_results_table(modelspec)

    return savepath
示例#17
0
#modelkeywords = 'dlog-wc.18x2.g-stp.2-fir.2x15-dexp.1'
meta = {'cellids': ['TAR010c-18-1'], 'batch': 271, 'modelname': modelkeywords}

xfspec = [[
    'load_recordings', {
        'recording_uri_list': recordings,
        'meta': meta
    }
], ['split_val_and_average_reps', {
    'epoch_regex': '^STIM_'
}], ['init_from_keywords', {
    'keywordstring': modelkeywords
}], ['fit_basic_init', {}], ['fit_basic', {}], ['predict', {}],
          ['add_summary_statistics', {}], ['plot_summary', {}]]

ctx, log_xf = xforms.evaluate(xfspec)  # evaluate the fit script

#xforms.save_context(dest='/data/results', ctx=ctx, xfspec=xfspec, log=log_xf)
"""
Simplified:
import nems.recording as recording
import nems.xforms as xforms
    
recording.get_demo_recordings("/data/recordings/")
recordings = ["/data/recordings/TAR010c-18-1.tgz"]
modelkeywords = 'dlog-wc.18x1.g-fir.1x15-lvl.1-dexp.1'
#modelkeywords = 'dlog-wc.18x2.g-fir.2x15-lvl.1-dexp.1'
meta = {'cellid': 'TAR010c-18-1', 'batch': 271, 'modelname': modelkeywords}

xfspec = []
xfspec.append(['nems.xforms.load_recordings',
示例#18
0
def model_per_time_wrapper(cellid,
                           batch=307,
                           loader="psth.fs20.pup-ld-",
                           fitter="_jk.nf20-basic",
                           basemodel="-ref-psthfr_stategain.S",
                           state_list=None,
                           plot_halves=True,
                           colors=None):
    """
    batch = 307  # A1 SUA and MUA
    batch = 309  # IC SUA and MUA

    alternatives:
        basemodels = ["-ref-psthfr.s_stategain.S",
                      "-ref-psthfr.s_sdexp.S",
                      "-ref.a-psthfr.s_sdexp.S"]
        state_list = ['st.pup0.hlf0','st.pup0.hlf','st.pup.hlf0','st.pup.hlf']
        state_list = ['st.pup0.far0.hit0.hlf0','st.pup0.far0.hit0.hlf',
                      'st.pup.far.hit.hlf0','st.pup.far.hit.hlf']
        state_list = ['st.pup0.fil0','st.pup0.fil','st.pup.fil0','st.pup.fil']
        
    """

    # pup vs. active/passive
    if state_list is None:
        state_list = [
            'st.pup0.hlf0', 'st.pup0.hlf', 'st.pup.hlf0', 'st.pup.hlf'
        ]
        #state_list = ['st.pup0.far0.hit0.hlf0','st.pup0.far0.hit0.hlf',
        #              'st.pup.far.hit.hlf0','st.pup.far.hit.hlf']
        #state_list = ['st.pup0.fil0','st.pup0.fil','st.pup.fil0','st.pup.fil']

    modelnames = []
    contexts = []
    for i, s in enumerate(state_list):
        modelnames.append(loader + s + basemodel + fitter)

        xf, ctx = xhelp.load_model_xform(cellid,
                                         batch,
                                         modelnames[i],
                                         eval_model=False)
        ctx, l = xforms.evaluate(xf, ctx, start=0, stop=-2)

        ctx['val'] = preproc.make_state_signal(ctx['val'],
                                               state_signals=['each_half'],
                                               new_signalname='state_f')

        contexts.append(ctx)
        #import pdb;
        #pdb.set_trace()

    plt.figure()
    #if ('hlf' in state_list[0]) or ('fil' in state_list[0]):
    if plot_halves:
        files_only = True
    else:
        files_only = False

    for i, ctx in enumerate(contexts):

        rec = ctx['val'].apply_mask()
        modelspec = ctx['modelspec']
        epoch = "REFERENCE"
        rec = ms.evaluate(rec, modelspec)
        if i == len(contexts) - 1:
            ax = plt.subplot(len(contexts) + 1, 1, 1)
            nplt.state_vars_timeseries(rec, modelspec, ax=ax)
            ax.set_title('{} {}'.format(cellid, modelnames[-1]))

        ax = plt.subplot(len(contexts) + 1, 1, 2 + i)
        nplt.state_vars_psth_all(rec,
                                 epoch,
                                 psth_name='resp',
                                 psth_name2='pred',
                                 state_sig='state_f',
                                 colors=colors,
                                 channel=None,
                                 decimate_by=1,
                                 ax=ax,
                                 files_only=files_only,
                                 modelspec=modelspec)
        ax.set_ylabel(state_list[i])
        ax.set_xticks([])