示例#1
0
def fit_basic_cd(modelspecs,
                 est,
                 maxiter=1000,
                 ftol=1e-8,
                 IsReload=False,
                 **context):
    '''
    A basic fit that optimizes every input modelspec. Use coordinate
    descent for fitting and nmse_shrink for cost function
    '''

    if not IsReload:
        fit_kwargs = {'options': {'ftol': ftol, 'maxiter': maxiter}}
        if type(est) is list:
            # jackknife!
            modelspecs_out = []
            njacks = len(modelspecs)
            i = 0
            for m, d in zip(modelspecs, est):
                i += 1
                log.info("Fitting JK {}/{}".format(i, njacks))
                metric = lambda d: metrics.nmse_shrink(d, 'pred', 'resp')
                modelspecs_out += nems.analysis.api.fit_basic(
                    d,
                    m,
                    fit_kwargs=fit_kwargs,
                    metric=metric,
                    fitter=coordinate_descent)
            modelspecs = modelspecs_out
        else:
            # standard single shot
            # print('Fitting fit_basic')
            # print(fit_kwargs)

            metric = lambda est: metrics.nmse_shrink(est, 'pred', 'resp')
            modelspecs = [
                nems.analysis.api.fit_basic(est,
                                            modelspec,
                                            fit_kwargs=fit_kwargs,
                                            metric=metric,
                                            fitter=coordinate_descent)[0]
                for modelspec in modelspecs
            ]
    return {'modelspecs': modelspecs}
示例#2
0
def fit_cd_nfold_shrinkage(modelspecs, est, IsReload=False, **context):
    ''' fitting n fold, one from each entry in est, use mse_shrink for
    cost function'''
    if not IsReload:
        metric = lambda d: metrics.nmse_shrink(d, 'pred', 'resp')
        modelspecs = nems.analysis.api.fit_nfold(est,
                                                 modelspecs,
                                                 metric=metric,
                                                 fitter=coordinate_descent)
    return {'modelspecs': modelspecs}
示例#3
0
文件: xforms.py 项目: nadoss/nems_db
def fit_basic(modelspecs,
              est,
              max_iter=1000,
              tolerance=1e-7,
              shrinkage=0,
              IsReload=False,
              **context):
    ''' A basic fit that optimizes every input modelspec. '''
    if not IsReload:
        if shrinkage:
            metric = lambda d: metrics.nmse_shrink(d, 'pred', 'resp')
        else:
            metric = lambda d: metrics.nmse(d, 'pred', 'resp')

        fit_kwargs = {'tolerance': tolerance, 'max_iter': max_iter}
        if type(est) is list:
            # jackknife!
            modelspecs_out = []
            njacks = len(modelspecs)
            i = 0
            for m, d in zip(modelspecs, est):
                i += 1
                log.info("Fitting JK {}/{}".format(i, njacks))
                modelspecs_out += nems.analysis.api.fit_basic(
                    d,
                    m,
                    fit_kwargs=fit_kwargs,
                    metric=metric,
                    fitter=scipy_minimize)
            modelspecs = modelspecs_out
        else:
            # standard single shot
            modelspecs = [
                nems.analysis.api.fit_basic(est,
                                            modelspec,
                                            fit_kwargs=fit_kwargs,
                                            metric=metric,
                                            fitter=scipy_minimize)[0]
                for modelspec in modelspecs
            ]
    return {'modelspecs': modelspecs}
示例#4
0
文件: xforms.py 项目: nadoss/nems_db
def fit_nfold_shrinkage(modelspecs,
                        est,
                        ftol=1e-7,
                        maxiter=1000,
                        IsReload=False,
                        **context):
    ''' fitting n fold, one from each entry in est, use mse_shrink for
    cost function'''
    if not IsReload:
        metric = lambda d: metrics.nmse_shrink(d, 'pred', 'resp')
        fit_kwargs = {
            'options': {
                'tolerance': tolerance,
                'max_iter': max_iter
            }
        }
        modelspecs = nems.analysis.api.fit_nfold(est,
                                                 modelspecs,
                                                 metric=metric,
                                                 fitter=scipy_minimize,
                                                 fit_kwargs=fit_kwargs)
    return {'modelspecs': modelspecs}
示例#5
0
文件: xforms.py 项目: nadoss/nems_db
def fit_basic_shr_init(modelspecs, est, IsReload=False, **context):
    '''
    Initialize modelspecs in a way that avoids getting stuck in
    local minima.

    written/optimized to work for (dlog)-wc-(stp)-fir-(dexp) architectures
    optional modules in (parens)

    '''
    # only run if fitting
    if not IsReload:
        metric = lambda data: metrics.nmse_shrink(data, 'pred', 'resp')
        modelspecs = [
            init.prefit_LN(est,
                           modelspecs[0],
                           analysis_function=nems.analysis.api.fit_basic,
                           fitter=scipy_minimize,
                           metric=metric,
                           tolerance=10**-5,
                           max_iter=700)
        ]

    return {'modelspecs': modelspecs}