def fit_basic_cd(modelspecs, est, maxiter=1000, ftol=1e-8, IsReload=False, **context): ''' A basic fit that optimizes every input modelspec. Use coordinate descent for fitting and nmse_shrink for cost function ''' if not IsReload: fit_kwargs = {'options': {'ftol': ftol, 'maxiter': maxiter}} if type(est) is list: # jackknife! modelspecs_out = [] njacks = len(modelspecs) i = 0 for m, d in zip(modelspecs, est): i += 1 log.info("Fitting JK {}/{}".format(i, njacks)) metric = lambda d: metrics.nmse_shrink(d, 'pred', 'resp') modelspecs_out += nems.analysis.api.fit_basic( d, m, fit_kwargs=fit_kwargs, metric=metric, fitter=coordinate_descent) modelspecs = modelspecs_out else: # standard single shot # print('Fitting fit_basic') # print(fit_kwargs) metric = lambda est: metrics.nmse_shrink(est, 'pred', 'resp') modelspecs = [ nems.analysis.api.fit_basic(est, modelspec, fit_kwargs=fit_kwargs, metric=metric, fitter=coordinate_descent)[0] for modelspec in modelspecs ] return {'modelspecs': modelspecs}
def fit_cd_nfold_shrinkage(modelspecs, est, IsReload=False, **context): ''' fitting n fold, one from each entry in est, use mse_shrink for cost function''' if not IsReload: metric = lambda d: metrics.nmse_shrink(d, 'pred', 'resp') modelspecs = nems.analysis.api.fit_nfold(est, modelspecs, metric=metric, fitter=coordinate_descent) return {'modelspecs': modelspecs}
def fit_basic(modelspecs, est, max_iter=1000, tolerance=1e-7, shrinkage=0, IsReload=False, **context): ''' A basic fit that optimizes every input modelspec. ''' if not IsReload: if shrinkage: metric = lambda d: metrics.nmse_shrink(d, 'pred', 'resp') else: metric = lambda d: metrics.nmse(d, 'pred', 'resp') fit_kwargs = {'tolerance': tolerance, 'max_iter': max_iter} if type(est) is list: # jackknife! modelspecs_out = [] njacks = len(modelspecs) i = 0 for m, d in zip(modelspecs, est): i += 1 log.info("Fitting JK {}/{}".format(i, njacks)) modelspecs_out += nems.analysis.api.fit_basic( d, m, fit_kwargs=fit_kwargs, metric=metric, fitter=scipy_minimize) modelspecs = modelspecs_out else: # standard single shot modelspecs = [ nems.analysis.api.fit_basic(est, modelspec, fit_kwargs=fit_kwargs, metric=metric, fitter=scipy_minimize)[0] for modelspec in modelspecs ] return {'modelspecs': modelspecs}
def fit_nfold_shrinkage(modelspecs, est, ftol=1e-7, maxiter=1000, IsReload=False, **context): ''' fitting n fold, one from each entry in est, use mse_shrink for cost function''' if not IsReload: metric = lambda d: metrics.nmse_shrink(d, 'pred', 'resp') fit_kwargs = { 'options': { 'tolerance': tolerance, 'max_iter': max_iter } } modelspecs = nems.analysis.api.fit_nfold(est, modelspecs, metric=metric, fitter=scipy_minimize, fit_kwargs=fit_kwargs) return {'modelspecs': modelspecs}
def fit_basic_shr_init(modelspecs, est, IsReload=False, **context): ''' Initialize modelspecs in a way that avoids getting stuck in local minima. written/optimized to work for (dlog)-wc-(stp)-fir-(dexp) architectures optional modules in (parens) ''' # only run if fitting if not IsReload: metric = lambda data: metrics.nmse_shrink(data, 'pred', 'resp') modelspecs = [ init.prefit_LN(est, modelspecs[0], analysis_function=nems.analysis.api.fit_basic, fitter=scipy_minimize, metric=metric, tolerance=10**-5, max_iter=700) ] return {'modelspecs': modelspecs}