예제 #1
0
def load_high_res_stim():

    global high_res_ctx

    if high_res_ctx is None:
        # load hi-res spectrogram
        batch = 322
        cellid = "DRX006b-128-2"

        b = baphy_experiment.BAPHYExperiment(batch=batch, cellid=cellid)
        tctx = {'rec': b.get_recording(loadkey="ozgf.fs100.ch64")}
        tctx = xforms.evaluate_step([
            'nems.xforms.split_by_occurrence_counts', {
                'epoch_regex': '^STIM',
                'keepfrac': 1.0
            }
        ], tctx)
        tctx = xforms.evaluate_step([
            'nems.xforms.average_away_stim_occurrences', {
                'epoch_regex': '^STIM'
            }
        ], tctx)
        high_res_ctx = tctx

    return high_res_ctx
예제 #2
0
        log.info('Generating summary plot...')
        xfspec.append(['nems.xforms.plot_summary', {}])

# Create a log stream set to the debug level; add it as a root log handler
log_stream = io.StringIO()
ch = logging.StreamHandler(log_stream)
ch.setLevel(logging.DEBUG)
fmt = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
formatter = logging.Formatter(fmt)
ch.setFormatter(formatter)
rootlogger = logging.getLogger()
rootlogger.addHandler(ch)

ctx = {}
for xfa in xfspec:
    ctx = xforms.evaluate_step(xfa, ctx)


if False:
    Ipred = np.where([xf[0]=='nems.xforms.predict' for xf in xfspec])[0][0]
    Ifit_init = np.where([xf[0]=='nems.xforms.fit_basic_init' for xf in xfspec])[0][0]
    Ifit = np.where([xf[0]=='nems.xforms.fit_basic' for xf in xfspec])[0][0]
    Isbc = np.where([xf[0]=='nems_lbhb.postprocessing.add_summary_statistics_by_condition' for xf in xfspec])[0][0]
    Ipav = np.where([xf[0]=='nems_lbhb.SPO_helpers.plot_all_vals_' for xf in xfspec])[0][0]
    ctx = {}
    for xfa in xfspec[:Ifit_init+1]:
        ctx = xforms.evaluate_step(xfa, ctx)
    ctxI=ctx.copy()
    
    ctx=ctxI.copy()
    ctx['modelspecs'][0] = init_dexp(ctx['est'], ctx['modelspecs'][0])
예제 #3
0
def fit_xforms_model(batch, cellid, modelname, save_analysis=False):

    # parse modelname into loaders, modelspecs, and fit keys
    load_keywords, model_keywords, fit_keywords = modelname.split("_")

    # construct the meta data dict
    meta = {
        'batch': batch,
        'cellid': cellid,
        'modelname': modelname,
        'loader': load_keywords,
        'fitkey': fit_keywords,
        'modelspecname': model_keywords,
        'username': '******',
        'labgroup': 'lbhb',
        'public': 1,
        'githash': os.environ.get('CODEHASH', ''),
        'recording': load_keywords
    }

    xforms_kwargs = {}
    xforms_init_context = {'cellid': cellid, 'batch': int(batch)}
    recording_uri = None
    kw_kwargs = {}

    xforms_lib = KeywordRegistry(**xforms_kwargs)

    xforms_lib.register_modules(
        [default_loaders, default_fitters, default_initializers])
    xforms_lib.register_plugins(get_setting('XFORMS_PLUGINS'))

    keyword_lib = KeywordRegistry()
    keyword_lib.register_module(default_keywords)
    keyword_lib.register_plugins(get_setting('KEYWORD_PLUGINS'))

    # Generate the xfspec, which defines the sequence of events
    # to run through (like a packaged-up script)
    xfspec = []

    # 0) set up initial context
    if xforms_init_context is None:
        xforms_init_context = {}
    if kw_kwargs is not None:
        xforms_init_context['kw_kwargs'] = kw_kwargs
    xforms_init_context['keywordstring'] = model_keywords
    xforms_init_context['meta'] = meta
    xfspec.append(['nems.xforms.init_context', xforms_init_context])

    # 1) Load the data
    xfspec.extend(xhelp._parse_kw_string(load_keywords, xforms_lib))

    # 2) generate a modelspec
    xfspec.append(
        ['nems.xforms.init_from_keywords', {
            'registry': keyword_lib
        }])

    # 3) fit the data
    xfspec.extend(xhelp._parse_kw_string(fit_keywords, xforms_lib))

    # Generate a prediction
    xfspec.append(['nems.xforms.predict', {}])

    # 4) add some performance statistics
    xfspec.append(['nems.xforms.add_summary_statistics', {}])

    # 5) plot
    #xfspec.append(['nems_lbhb.lv_helpers.add_summary_statistics', {}])

    # Create a log stream set to the debug level; add it as a root log handler
    log_stream = io.StringIO()
    ch = logging.StreamHandler(log_stream)
    ch.setLevel(logging.DEBUG)
    fmt = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
    formatter = logging.Formatter(fmt)
    ch.setFormatter(formatter)
    rootlogger = logging.getLogger()
    rootlogger.addHandler(ch)

    ctx = {}
    for xfa in xfspec:
        ctx = xforms.evaluate_step(xfa, ctx)

    # Close the log, remove the handler, and add the 'log' string to context
    log.info('Done (re-)evaluating xforms.')
    ch.close()
    rootlogger.removeFilter(ch)

    log_xf = log_stream.getvalue()

    modelspec = ctx['modelspec']
    if save_analysis:
        # save results
        if get_setting('USE_NEMS_BAPHY_API'):
            prefix = 'http://' + get_setting(
                'NEMS_BAPHY_API_HOST') + ":" + str(
                    get_setting('NEMS_BAPHY_API_PORT')) + '/results/'
        else:
            prefix = get_setting('NEMS_RESULTS_DIR')

        if type(cellid) is list:
            cell_name = cellid[0].split("-")[0]
        else:
            cell_name = cellid

        destination = os.path.join(prefix, str(batch), cell_name,
                                   modelspec.get_longname())

        modelspec.meta['modelpath'] = destination
        modelspec.meta.update(meta)

        log.info('Saving modelspec(s) to {0} ...'.format(destination))

        xforms.save_analysis(destination,
                             recording=ctx['rec'],
                             modelspec=modelspec,
                             xfspec=xfspec,
                             figures=[],
                             log=log_xf)

        # save performance and some other metadata in database Results table
        nd.update_results_table(modelspec)

    return xfspec, ctx
예제 #4
0
        "dlog.f"
        # Spectral filter (nems.modules.weight_channels)
        "-wc.18x1.g"
        # Temporal filter (nems.modules.fir)
        "-fir.1x15"
        # Scale, currently init to 1.
        "-scl.1"
        # Level shift, usually init to mean response (nems.modules.levelshift)
        "-lvl.1"
        # Nonlinearity (nems.modules.nonlinearity -> double_exponential)
        #"-dexp.1"
        "_"  # modules -> fitters
        # Set initial values and do a rough "pre-fit"
        # Initialize fir coeffs to L2-norm of random values
        "-init.lnp"#.L2f"
        # Do the full fits
        "-lnp.t5"
        #"-nestspec"
        )

result_dict = {}
for name, stim in stim_dict.items():
    ctx = {'rec': stim}
    xfspec = xhelp.generate_xforms_spec(None, modelname, autoPlot=False)
    for i, xf in enumerate(xfspec):
        ctx = xforms.evaluate_step(xf, ctx)

    # Store tuple of ctx, error for each stim
    stim_length = ctx['val'][0]['stim'].shape[1]
    result_dict[name] = (ctx, _lnp_metric(ctx['val'][0])/stim_length)
예제 #5
0
                               xforms_kwargs=registry_args,
                               xforms_init_context=xforms_init_context,
                               autoPlot=False)

# perform "normal" fit, with siteid held out
# context2 = {}
# for xfa in xfspec2[:-4]:
#     context2 = evaluate_step(xfa, context2)
#
# context3 = copy.deepcopy(context2)
# for xfa in xfspec2[-4:]:
#     context3 = evaluate_step(xfa, context3)

context2 = {}
for xfa in xfspec2[:-5]:
    context2 = evaluate_step(xfa, context2)

context3 = copy.deepcopy(context2)
context2.update(nems.xforms.predict(**context2))
context2.update(nems.xforms.add_summary_statistics(**context2))
rt2 = context2['modelspec'].meta['r_test']

for xfa in xfspec2[-5:]:
    context3 = evaluate_step(xfa, context3)
rt3 = context3['modelspec'].meta['r_test']

std_model = 'ozgf.fs100.ch18.pop-ld-norm.l1-popev_conv2d.4.8x3.rep3-wcn.40-relu.40-wc.40xR-lvl.R-dexp.R_tfinit.n.lr1e3.et3.rb5.es20-newtf.n.lr1e4.es20'
rt4 = nd.batch_comp(batch, [std_model], stat='r_test')
rt5 = rt4.reindex(context2['rec']['resp'].chans)

# fig = plt.figure()