def factor_strf_fit(site='TAR010c16', factorN=0, batch=271, modelname="fchan100_wc02_fir15_fit01"): #site='zee015h05' doval = 1 cellid = "{0}-F{1}".format(site, factorN) # load the stimulus stack = ns.nems_stack(cellid=cellid, batch=batch, modelname=modelname) stack.meta['resp_channels'] = [factorN] stack.meta['site'] = site stack.keyfuns = 0 stack.valmode = False print(modelname) # evaluate the stack of keywords if 'nested' in stack.keywords[-1]: # special case for nested keywords. Stick with this design? print('Using nested cross-validation, fitting will take longer!') k = stack.keywords[-1] keyword_registry[k](stack) else: print('Using standard est/val conditions') for k in stack.keywords: print(k) keyword_registry[k](stack) if doval: # validation stuff stack.valmode = True stack.evaluate(1) stack.append(nm.metrics.correlation) #print("mse_est={0}, mse_val={1}, r_est={2}, r_val={3}".format(stack.meta['mse_est'], # stack.meta['mse_val'],stack.meta['r_est'],stack.meta['r_val'])) valdata = [i for i, d in enumerate(stack.data[-1]) if not d['est']] if valdata: stack.plot_dataidx = valdata[0] else: stack.plot_dataidx = 0 stack.quick_plot() savefile = nu.io.get_file_name(cellid, batch, modelname) nu.io.save_model(stack, savefile) return stack
def fit_single_cell(cellid, batch, modelname): stack = ns.nems_stack() stack.meta['batch'] = batch stack.meta['cellid'] = cellid stack.meta['modelname'] = modelname stack.valmode = False stack.keywords = modelname.split("_") print('Evaluating stack') for k in stack.keywords: nk.keyfuns[k](stack) stack.append(nmet.correlation) return stack
def build_stack_from_signals_and_keywords(signals, modelname): """ This is a hacked version of fit_single_model """ stack = ns.nems_stack() stack.append(nm.loaders.load_signals, signals=signals) stack.append(nm.est_val.crossval, valfrac=0.2) # I don't think a model should really know or care about its batch / cellid stack.meta['batch'] = '9999' # Batch numbers are not strictly needed stack.meta['cellid'] = 'dummy-cellid' # Cellids are also unnecessary stack.meta['modelname'] = modelname stack.valmode = False stack.keywords = modelname.split("_") # evaluate the stack of keywords if 'nested' in stack.keywords[-1]: # special case for nested keywords. Stick with this design? print('Using nested cross-validation, fitting will take longer!') k = stack.keywords[-1] keyword_registry[k](stack) else: print('Using standard est/val conditions') for k in stack.keywords: print(k) keyword_registry[k](stack) # measure performance on both estimation and validation data stack.valmode = True stack.evaluate(1) stack.append(nm.metrics.correlation) print("mse_est={0}, mse_val={1}, r_est={2}, r_val={3}" .format(stack.meta['mse_est'], stack.meta['mse_val'], stack.meta['r_est'], stack.meta['r_val'])) valdata = [i for i, d in enumerate(stack.data[-1]) if not d['est']] if valdata: stack.plot_dataidx = valdata[0] else: stack.plot_dataidx = 0 return stack
def newfunc(): batch = 296 cellid = 'gus030d-b1' # first good example modelname = "env100e_stp1pc_fir20_fit01_ssa" stack = ns.nems_stack() stack.meta['batch'] = batch stack.meta['cellid'] = cellid stack.meta['modelname'] = modelname file = ut.baphy.get_celldb_file(stack.meta['batch'], stack.meta['cellid'], fs=100, stimfmt='envelope') stack.append(nm.loaders.load_mat, est_files=[file], fs=100, avg_resp=True) stack.append(nm.metrics.ssa_index) return stack
imp.reload(nk) imp.reload(nu) imp.reload(ns) batch = 296 # cellid='gus018d-d1' #cellid = 'gus023e-c2' cellid = 'gus016c-c2' #modelname = "env100e_fir20_fit01_ssa" modelname="env100e_stp1pc_fir20_fit01_ssa" #modelname="env50e_stp1pc_fir10_fit01_ssa" #if 0: # stack = main.fit_single_model(cellid, batch, modelname, autoplot=False) #else: stack = ns.nems_stack() stack.meta['batch'] = batch stack.meta['cellid'] = cellid stack.meta['modelname'] = modelname stack.valmode = False # extract keywords from modelname, look up relevant functions in nk and save # so they don't have to be found again. stack.keywords = modelname.split("_") # evaluate the stack of keywords if 'nested' in stack.keywords[-1]: # special case if last keyword contains "nested". TODO: better imp! print('Evaluating stack using nested cross validation. May be slow!') k = stack.keywords[-1]
def fit_single_model(cellid, batch, modelname, autoplot=True, saveInDB=True, **xvals): """ Fits a single NEMS model. With the exception of the autoplot feature, all the details of modelfitting are taken care of by the model keywords. fit_single_model functions by iterating through each of the keywords in the modelname, and perfroming the actions specified by each keyword, usually appending a nems module. Nested crossval is implemented as a special keyword, which is placed last in a modelname. fit_single_model returns the evaluated stack, which contains both the estimation and validation datasets. In the caste of nested crossvalidation, the validation dataset contains all the data, while the estimation dataset is just the estimation data that was fitted last (i.e. on the last nest) """ log.info("Creating empty stack object") stack = ns.nems_stack() stack.meta['batch'] = batch stack.meta['cellid'] = cellid stack.meta['modelname'] = modelname log.info("Stack.meta information added: {0}".format(stack.meta)) stack.valmode = False stack.keywords = modelname.split("_") # evaluate the stack of keywords if 'nested' in stack.keywords[-1]: # special case for nested keywords. Stick with this design? log.info('Using nested cross-validation, fitting will take longer!') k = stack.keywords[-1] keyword_registry[k](stack) else: log.info('Using standard est/val conditions') for k in stack.keywords: log.info("Adding keyword: {0}".format(k)) keyword_registry[k](stack) # measure performance on both estimation and validation data stack.valmode = True stack.evaluate(1) stack.append(nm.metrics.correlation) log.info("mse_est={0}, mse_val={1}, r_est={2}, r_val={3}".format(stack.meta['mse_est'], stack.meta['mse_val'], stack.meta['r_est'], stack.meta['r_val'])) valdata = [i for i, d in enumerate(stack.data[-1]) if not d['est']] if valdata: stack.plot_dataidx = valdata[0] else: stack.plot_dataidx = 0 #phi = stack.fitter.fit_to_phi() #stack.meta['n_parms'] = len(phi) # edit: added autoplot kwarg for option to disable auto plotting # -jacob, 6/20/17 if autoplot: stack.quick_plot() # save in data base if saveInDB == True: filename = nems.utilities.io.get_file_name(cellid, batch, modelname) nems.utilities.io.save_model(stack, filename) return(stack)
imp.reload(main) imp.reload(nf) imp.reload(nk) imp.reload(nu) imp.reload(ns) site = 'TAR010c16' factorN = 3 #site='zee015h05' doval = 1 cellid = "{0}-F{1}".format(site, factorN) # load the stimulus batch = 271 #A1 modelname = "fchan100_wc02_fir15_fit01" stack = ns.nems_stack(cellid=cellid, batch=batch, modelname=modelname) stack.meta['resp_channels'] = [factorN] stack.meta['site'] = site stack.keyfuns = 0 stack.valmode = False # evaluate the stack of keywords if 'nested' in stack.keywords[-1]: # special case for nested keywords. Stick with this design? print('Using nested cross-validation, fitting will take longer!') k = stack.keywords[-1] keyword_registry[k](stack) else: print('Using standard est/val conditions') for k in stack.keywords:
def model_name_to_stack(self, model_name): stack = nems_stack() for key in model_name.split('_'): self[key](stack) return stack
def pop_factor_strf_init(site='TAR010c16', factorCount=4, batch=271, fmodelname="fchan100_wc02_fir15_fit01", modelname=None): # find all cells in site that meet iso criterion d = ndb.get_batch_cells(batch=batch, cellid=site[:-2]) d = d.loc[d['min_isolation'] >= 75] d = d.loc[d['cellid'] != 'TAR010c-21-2'] d.reset_index(inplace=True) cellcount = len(d['cellid']) # modelname should be compatible with fmodelname if modelname is None: modelname = fmodelname.replace("fchan100", "ssfb18ch100") modelname = modelname.replace("_fit01", "") stack = ns.nems_stack(cellid=site, batch=batch, modelname=modelname) stack.meta['site'] = site stack.meta['d'] = d stack.meta['factorCount'] = factorCount stack.meta['mini_fit'] = False stack.keyfuns = 0 stack.valmode = False # evaluate the stack of keywords if 'nested' in stack.keywords[-1]: # special case for nested keywords. Stick with this design? print('Using nested cross-validation, fitting will take longer!') k = stack.keywords[-1] keyword_registry[k](stack) else: print('Using standard est/val conditions') for k in stack.keywords: print(k) keyword_registry[k](stack) wc0 = nu.utils.find_modules(stack, 'filters.weight_channels')[0] try: stp0 = nu.utils.find_modules(stack, 'filters.stp')[0] except: stp0 = 0 fir0 = nu.utils.find_modules(stack, 'filters.fir')[0] try: gn0 = nu.utils.find_modules(stack, 'nonlin.gain')[-1] except: gn0 = 0 stack.modules[fir0].baseline[0, 0] = 0 stack.modules[fir0].fit_fields = ['coefs'] # load strfs fit to factors factorN = 0 cellid = "{0}-F{1}".format(site, factorN) savefile = nu.io.get_file_name(cellid, batch, fmodelname) tstack = nu.io.load_model(savefile) try: tstp0 = nu.utils.find_modules(tstack, 'filters.stp')[0] except: tstp0 = 0 twc0 = nu.utils.find_modules(tstack, 'filters.weight_channels')[0] tfir0 = nu.utils.find_modules(tstack, 'filters.fir')[0] try: tgn0 = nu.utils.find_modules(tstack, 'nonlin.gain')[-1] except: tgn0 = 0 stack.modules[wc0].phi = tstack.modules[twc0].phi stack.modules[wc0].coefs = tstack.modules[twc0].coefs stack.modules[wc0].baseline = tstack.modules[twc0].baseline stack.modules[wc0].num_chans = stack.modules[twc0].coefs.shape[0] chans_per_subspace = stack.modules[wc0].num_chans if tstp0: stack.modules[stp0].u = tstack.modules[tstp0].u stack.modules[stp0].tau = tstack.modules[stp0].tau stack.modules[fir0].coefs = tstack.modules[tfir0].coefs stack.modules[fir0].num_dims = stack.modules[fir0].coefs.shape[0] stack.modules[fir0].bank_count = 1 if tgn0: stack.modules[gn0].nltype = tstack.modules[tgn0].nltype stack.modules[gn0].phi = tstack.modules[tgn0].phi for factorN in range(1, factorCount): cellid = "{0}-F{1}".format(site, factorN) savefile = nu.io.get_file_name(cellid, batch, fmodelname) tstack = nu.io.load_model(savefile) stack.modules[wc0].phi = np.concatenate( (stack.modules[wc0].phi, tstack.modules[twc0].phi), axis=0) stack.modules[wc0].coefs = np.concatenate( (stack.modules[wc0].coefs, tstack.modules[twc0].coefs), axis=0) stack.modules[wc0].baseline = np.concatenate( (stack.modules[wc0].baseline, tstack.modules[twc0].baseline), axis=0) stack.modules[wc0].num_chans = stack.modules[wc0].coefs.shape[0] if tstp0: stack.modules[stp0].u = np.concatenate( (stack.modules[stp0].u, tstack.modules[tstp0].u), axis=0) stack.modules[stp0].tau = np.concatenate( (stack.modules[stp0].tau, tstack.modules[stp0].tau), axis=0) elif stp0: #stack.modules[stp0].num_channels=stack.modules[wc0].num_chans stack.modules[stp0].u = np.concatenate( (stack.modules[stp0].u, stack.modules[stp0].u[0:chans_per_subspace, :]), axis=0) stack.modules[stp0].tau = np.concatenate( (stack.modules[stp0].tau, stack.modules[stp0].tau[0:chans_per_subspace, :]), axis=0) stack.modules[fir0].coefs = np.concatenate( (stack.modules[fir0].coefs, tstack.modules[tfir0].coefs), axis=0) stack.modules[fir0].num_dims = stack.modules[fir0].coefs.shape[0] stack.modules[fir0].bank_count += 1 if tgn0: stack.modules[gn0].phi = np.concatenate( (stack.modules[gn0].phi, tstack.modules[tgn0].phi), axis=0) stack.evaluate(1) stack.append(nm.filters.WeightChannels, num_chans=cellcount) stack.modules[-1].fit_fields = ['coefs', 'baseline'] stack.modules[-1].coefs[:, :] = 0 stack.append(nm.metrics.mean_square_error) stack.error = stack.modules[-1].error stack.evaluate(2) return stack