Example #1
0
def factor_strf_fit(site='TAR010c16',
                    factorN=0,
                    batch=271,
                    modelname="fchan100_wc02_fir15_fit01"):
    #site='zee015h05'
    doval = 1
    cellid = "{0}-F{1}".format(site, factorN)

    # load the stimulus
    stack = ns.nems_stack(cellid=cellid, batch=batch, modelname=modelname)
    stack.meta['resp_channels'] = [factorN]
    stack.meta['site'] = site
    stack.keyfuns = 0

    stack.valmode = False
    print(modelname)
    # evaluate the stack of keywords
    if 'nested' in stack.keywords[-1]:
        # special case for nested keywords. Stick with this design?
        print('Using nested cross-validation, fitting will take longer!')
        k = stack.keywords[-1]
        keyword_registry[k](stack)
    else:
        print('Using standard est/val conditions')
        for k in stack.keywords:
            print(k)
            keyword_registry[k](stack)

    if doval:
        # validation stuff
        stack.valmode = True
        stack.evaluate(1)

        stack.append(nm.metrics.correlation)

        #print("mse_est={0}, mse_val={1}, r_est={2}, r_val={3}".format(stack.meta['mse_est'],
        #             stack.meta['mse_val'],stack.meta['r_est'],stack.meta['r_val']))
        valdata = [i for i, d in enumerate(stack.data[-1]) if not d['est']]
        if valdata:
            stack.plot_dataidx = valdata[0]
        else:
            stack.plot_dataidx = 0

    stack.quick_plot()

    savefile = nu.io.get_file_name(cellid, batch, modelname)
    nu.io.save_model(stack, savefile)

    return stack
Example #2
0
def fit_single_cell(cellid, batch, modelname):
    stack = ns.nems_stack()
    stack.meta['batch'] = batch
    stack.meta['cellid'] = cellid
    stack.meta['modelname'] = modelname
    stack.valmode = False

    stack.keywords = modelname.split("_")

    print('Evaluating stack')
    for k in stack.keywords:
        nk.keyfuns[k](stack)

    stack.append(nmet.correlation)

    return stack
Example #3
0
def build_stack_from_signals_and_keywords(signals, modelname):
    """ This is a hacked version of fit_single_model """
    stack = ns.nems_stack()
    stack.append(nm.loaders.load_signals, signals=signals)
    stack.append(nm.est_val.crossval, valfrac=0.2)

    # I don't think a model should really know or care about its batch / cellid
    stack.meta['batch'] = '9999'   # Batch numbers are not strictly needed
    stack.meta['cellid'] = 'dummy-cellid'  # Cellids are also unnecessary
    stack.meta['modelname'] = modelname

    stack.valmode = False
    stack.keywords = modelname.split("_")

    # evaluate the stack of keywords
    if 'nested' in stack.keywords[-1]:
        # special case for nested keywords. Stick with this design?
        print('Using nested cross-validation, fitting will take longer!')
        k = stack.keywords[-1]
        keyword_registry[k](stack)
    else:
        print('Using standard est/val conditions')
        for k in stack.keywords:
            print(k)
            keyword_registry[k](stack)

    # measure performance on both estimation and validation data
    stack.valmode = True
    stack.evaluate(1)

    stack.append(nm.metrics.correlation)

    print("mse_est={0}, mse_val={1}, r_est={2}, r_val={3}"
          .format(stack.meta['mse_est'],
                  stack.meta['mse_val'],
                  stack.meta['r_est'],
                  stack.meta['r_val']))

    valdata = [i for i, d in enumerate(stack.data[-1]) if not d['est']]
    if valdata:
        stack.plot_dataidx = valdata[0]
    else:
        stack.plot_dataidx = 0

    return stack
Example #4
0
def newfunc():
    batch = 296
    cellid = 'gus030d-b1'  # first good example
    modelname = "env100e_stp1pc_fir20_fit01_ssa"

    stack = ns.nems_stack()
    stack.meta['batch'] = batch
    stack.meta['cellid'] = cellid
    stack.meta['modelname'] = modelname

    file = ut.baphy.get_celldb_file(stack.meta['batch'],
                                    stack.meta['cellid'],
                                    fs=100,
                                    stimfmt='envelope')

    stack.append(nm.loaders.load_mat, est_files=[file], fs=100, avg_resp=True)
    stack.append(nm.metrics.ssa_index)

    return stack
imp.reload(nk)
imp.reload(nu)
imp.reload(ns)

batch = 296
# cellid='gus018d-d1'
#cellid = 'gus023e-c2'
cellid = 'gus016c-c2'
#modelname = "env100e_fir20_fit01_ssa"
modelname="env100e_stp1pc_fir20_fit01_ssa"
#modelname="env50e_stp1pc_fir10_fit01_ssa"

#if 0:
#    stack = main.fit_single_model(cellid, batch, modelname, autoplot=False)
#else:
stack = ns.nems_stack()

stack.meta['batch'] = batch
stack.meta['cellid'] = cellid
stack.meta['modelname'] = modelname
stack.valmode = False

# extract keywords from modelname, look up relevant functions in nk and save
# so they don't have to be found again.
stack.keywords = modelname.split("_")

# evaluate the stack of keywords
if 'nested' in stack.keywords[-1]:
    # special case if last keyword contains "nested". TODO: better imp!
    print('Evaluating stack using nested cross validation. May be slow!')
    k = stack.keywords[-1]
Example #6
0
def fit_single_model(cellid, batch, modelname,
                     autoplot=True, saveInDB=True, **xvals):
    """
    Fits a single NEMS model. With the exception of the autoplot feature,
    all the details of modelfitting are taken care of by the model keywords.

    fit_single_model functions by iterating through each of the keywords in the
    modelname, and perfroming the actions specified by each keyword, usually
    appending a nems module. Nested crossval is implemented as a special keyword,
    which is placed last in a modelname.

    fit_single_model returns the evaluated stack, which contains both the estimation
    and validation datasets. In the caste of nested crossvalidation, the validation
    dataset contains all the data, while the estimation dataset is just the estimation
    data that was fitted last (i.e. on the last nest)
    """

    log.info("Creating empty stack object")
    stack = ns.nems_stack()

    stack.meta['batch'] = batch
    stack.meta['cellid'] = cellid
    stack.meta['modelname'] = modelname
    log.info("Stack.meta information added: {0}".format(stack.meta))
    stack.valmode = False
    stack.keywords = modelname.split("_")

    # evaluate the stack of keywords
    if 'nested' in stack.keywords[-1]:
        # special case for nested keywords. Stick with this design?
        log.info('Using nested cross-validation, fitting will take longer!')
        k = stack.keywords[-1]
        keyword_registry[k](stack)
    else:
        log.info('Using standard est/val conditions')
        for k in stack.keywords:
            log.info("Adding keyword: {0}".format(k))
            keyword_registry[k](stack)

    # measure performance on both estimation and validation data
    stack.valmode = True
    stack.evaluate(1)

    stack.append(nm.metrics.correlation)

    log.info("mse_est={0}, mse_val={1}, r_est={2}, r_val={3}".format(stack.meta['mse_est'],
                                                                  stack.meta['mse_val'], stack.meta['r_est'], stack.meta['r_val']))
    valdata = [i for i, d in enumerate(stack.data[-1]) if not d['est']]
    if valdata:
        stack.plot_dataidx = valdata[0]
    else:
        stack.plot_dataidx = 0
    #phi = stack.fitter.fit_to_phi()
    #stack.meta['n_parms'] = len(phi)

    # edit: added autoplot kwarg for option to disable auto plotting
    #       -jacob, 6/20/17
    if autoplot:
        stack.quick_plot()

    # save in data base
    if saveInDB == True:
        filename = nems.utilities.io.get_file_name(cellid, batch, modelname)
        nems.utilities.io.save_model(stack, filename)

    return(stack)
Example #7
0
imp.reload(main)
imp.reload(nf)
imp.reload(nk)
imp.reload(nu)
imp.reload(ns)

site = 'TAR010c16'
factorN = 3
#site='zee015h05'
doval = 1
cellid = "{0}-F{1}".format(site, factorN)

# load the stimulus
batch = 271  #A1
modelname = "fchan100_wc02_fir15_fit01"
stack = ns.nems_stack(cellid=cellid, batch=batch, modelname=modelname)
stack.meta['resp_channels'] = [factorN]
stack.meta['site'] = site
stack.keyfuns = 0

stack.valmode = False

# evaluate the stack of keywords
if 'nested' in stack.keywords[-1]:
    # special case for nested keywords. Stick with this design?
    print('Using nested cross-validation, fitting will take longer!')
    k = stack.keywords[-1]
    keyword_registry[k](stack)
else:
    print('Using standard est/val conditions')
    for k in stack.keywords:
Example #8
0
 def model_name_to_stack(self, model_name):
     stack = nems_stack()
     for key in model_name.split('_'):
         self[key](stack)
     return stack
Example #9
0
def pop_factor_strf_init(site='TAR010c16',
                         factorCount=4,
                         batch=271,
                         fmodelname="fchan100_wc02_fir15_fit01",
                         modelname=None):

    # find all cells in site that meet iso criterion
    d = ndb.get_batch_cells(batch=batch, cellid=site[:-2])
    d = d.loc[d['min_isolation'] >= 75]
    d = d.loc[d['cellid'] != 'TAR010c-21-2']
    d.reset_index(inplace=True)

    cellcount = len(d['cellid'])

    # modelname should be compatible with fmodelname
    if modelname is None:
        modelname = fmodelname.replace("fchan100", "ssfb18ch100")
        modelname = modelname.replace("_fit01", "")

    stack = ns.nems_stack(cellid=site, batch=batch, modelname=modelname)
    stack.meta['site'] = site
    stack.meta['d'] = d
    stack.meta['factorCount'] = factorCount
    stack.meta['mini_fit'] = False
    stack.keyfuns = 0
    stack.valmode = False

    # evaluate the stack of keywords
    if 'nested' in stack.keywords[-1]:
        # special case for nested keywords. Stick with this design?
        print('Using nested cross-validation, fitting will take longer!')
        k = stack.keywords[-1]
        keyword_registry[k](stack)
    else:
        print('Using standard est/val conditions')
        for k in stack.keywords:
            print(k)
            keyword_registry[k](stack)

    wc0 = nu.utils.find_modules(stack, 'filters.weight_channels')[0]
    try:
        stp0 = nu.utils.find_modules(stack, 'filters.stp')[0]
    except:
        stp0 = 0
    fir0 = nu.utils.find_modules(stack, 'filters.fir')[0]
    try:
        gn0 = nu.utils.find_modules(stack, 'nonlin.gain')[-1]
    except:
        gn0 = 0

    stack.modules[fir0].baseline[0, 0] = 0
    stack.modules[fir0].fit_fields = ['coefs']

    # load strfs fit to factors
    factorN = 0
    cellid = "{0}-F{1}".format(site, factorN)
    savefile = nu.io.get_file_name(cellid, batch, fmodelname)
    tstack = nu.io.load_model(savefile)
    try:
        tstp0 = nu.utils.find_modules(tstack, 'filters.stp')[0]
    except:
        tstp0 = 0
    twc0 = nu.utils.find_modules(tstack, 'filters.weight_channels')[0]
    tfir0 = nu.utils.find_modules(tstack, 'filters.fir')[0]
    try:
        tgn0 = nu.utils.find_modules(tstack, 'nonlin.gain')[-1]
    except:
        tgn0 = 0

    stack.modules[wc0].phi = tstack.modules[twc0].phi
    stack.modules[wc0].coefs = tstack.modules[twc0].coefs
    stack.modules[wc0].baseline = tstack.modules[twc0].baseline
    stack.modules[wc0].num_chans = stack.modules[twc0].coefs.shape[0]
    chans_per_subspace = stack.modules[wc0].num_chans

    if tstp0:
        stack.modules[stp0].u = tstack.modules[tstp0].u
        stack.modules[stp0].tau = tstack.modules[stp0].tau

    stack.modules[fir0].coefs = tstack.modules[tfir0].coefs
    stack.modules[fir0].num_dims = stack.modules[fir0].coefs.shape[0]
    stack.modules[fir0].bank_count = 1

    if tgn0:
        stack.modules[gn0].nltype = tstack.modules[tgn0].nltype
        stack.modules[gn0].phi = tstack.modules[tgn0].phi

    for factorN in range(1, factorCount):
        cellid = "{0}-F{1}".format(site, factorN)
        savefile = nu.io.get_file_name(cellid, batch, fmodelname)
        tstack = nu.io.load_model(savefile)

        stack.modules[wc0].phi = np.concatenate(
            (stack.modules[wc0].phi, tstack.modules[twc0].phi), axis=0)
        stack.modules[wc0].coefs = np.concatenate(
            (stack.modules[wc0].coefs, tstack.modules[twc0].coefs), axis=0)
        stack.modules[wc0].baseline = np.concatenate(
            (stack.modules[wc0].baseline, tstack.modules[twc0].baseline),
            axis=0)
        stack.modules[wc0].num_chans = stack.modules[wc0].coefs.shape[0]

        if tstp0:
            stack.modules[stp0].u = np.concatenate(
                (stack.modules[stp0].u, tstack.modules[tstp0].u), axis=0)
            stack.modules[stp0].tau = np.concatenate(
                (stack.modules[stp0].tau, tstack.modules[stp0].tau), axis=0)

        elif stp0:
            #stack.modules[stp0].num_channels=stack.modules[wc0].num_chans
            stack.modules[stp0].u = np.concatenate(
                (stack.modules[stp0].u,
                 stack.modules[stp0].u[0:chans_per_subspace, :]),
                axis=0)
            stack.modules[stp0].tau = np.concatenate(
                (stack.modules[stp0].tau,
                 stack.modules[stp0].tau[0:chans_per_subspace, :]),
                axis=0)

        stack.modules[fir0].coefs = np.concatenate(
            (stack.modules[fir0].coefs, tstack.modules[tfir0].coefs), axis=0)
        stack.modules[fir0].num_dims = stack.modules[fir0].coefs.shape[0]
        stack.modules[fir0].bank_count += 1

        if tgn0:
            stack.modules[gn0].phi = np.concatenate(
                (stack.modules[gn0].phi, tstack.modules[tgn0].phi), axis=0)

    stack.evaluate(1)
    stack.append(nm.filters.WeightChannels, num_chans=cellcount)
    stack.modules[-1].fit_fields = ['coefs', 'baseline']
    stack.modules[-1].coefs[:, :] = 0

    stack.append(nm.metrics.mean_square_error)
    stack.error = stack.modules[-1].error

    stack.evaluate(2)

    return stack