def fit_population_channel(rec, modelspec, fit_set_all, fit_set_slice, analysis_function=analysis.fit_basic, metric=metrics.nmse, fitter=scipy_minimize, fit_kwargs={}): # guess at number of subspace dimensions dim_count = modelspec[fit_set_slice[0]]['phi']['coefficients'].shape[0] # invert cell-specific modules trec = _invert_slice(rec, modelspec, fit_set_slice) tmodelspec = ms.ModelSpec() for i in fit_set_all: m = modelspec[i].copy() tmodelspec.append(m) tmodelspec = analysis_function(trec, tmodelspec, fitter=fitter, metric=metric, fit_kwargs=fit_kwargs) for i in fit_set_all: for k, v in tmodelspec[i]['phi'].items(): modelspec[i]['phi'][k] = v return modelspec
def _prefit_dsig_only(est, modelspec, analysis_function, fitter, metric=None, fit_kwargs={}): ''' Perform a rough fit that only allows dynamic_sigmoid parameters to vary. ''' dsig_idx = find_module('dynamic_sigmoid', modelspec) # freeze all non-static dynamic sigmoid parameters dynamic_phi = { 'amplitude_mod': False, 'base_mod': False, 'kappa_mod': False, 'shift_mod': False } for p in dynamic_phi: v = modelspec[dsig_idx]['prior'].pop(p, False) if v: modelspec[dsig_idx]['fn_kwargs'][p] = np.nan dynamic_phi[p] = v # Remove ctwc, ctfir, and ctlvl if they exist temp = [] for i, m in enumerate(modelspec.modules): if 'ct' in m['id']: pass else: temp.append(m) temp = ms.ModelSpec(raw=[temp]) temp = prefit_mod_subset(est, temp, analysis_function, fit_set=['dynamic_sigmoid'], fitter=fitter, metric=metric, fit_kwargs=fit_kwargs) # Put ctwc, ctfir, and ctlvl back in where applicable j = 0 for i, m in enumerate(modelspec.modules): if 'ct' in m['id']: pass else: modelspec[i] = temp[j] j += 1 # reset dynamic sigmoid parameters if they were frozen for p, v in dynamic_phi.items(): if v: prior = priors._tuples_to_distributions({p: v})[p] modelspec[dsig_idx]['fn_kwargs'].pop(p, None) modelspec[dsig_idx]['prior'][p] = v modelspec[dsig_idx]['phi'][p] = prior.mean() return modelspec
def fit_population_channel_fast_old(rec, modelspec, fit_set_all, fit_set_slice, analysis_function=analysis.fit_basic, metric=metrics.nmse, fitter=scipy_minimize, fit_kwargs={}): # guess at number of subspace dimensions dim_count = modelspec[fit_set_slice[0]]['phi']['coefficients'].shape[1] for d in range(dim_count): # fit each dim separately log.info('Updating dim %d/%d', d+1, dim_count) # invert cell-specific modules trec = _invert_slice(rec, modelspec, fit_set_slice, population_channel=d) tmodelspec = ms.ModelSpec() for i in fit_set_all: m = copy.deepcopy(modelspec[i]) for k, v in m['phi'].items(): x = v.shape[0] if x >= dim_count: x1 = int(x / dim_count) * d x2 = int(x / dim_count) * (d + 1) m['phi'][k] = v[x1:x2] if 'bank_count' in m['fn_kwargs'].keys(): m['fn_kwargs']['bank_count'] = 1 else: # single model-wide parameter, only fit for d == 0 if d == 0: m['phi'][k] = v else: m['fn_kwargs'][k] = v # keep fixed for d>0 tmodelspec.append(m) tmodelspec = analysis_function(trec, tmodelspec, fitter=fitter, metric=metric, fit_kwargs=fit_kwargs) for i in fit_set_all: for k, v in tmodelspec[i]['phi'].items(): x = modelspec[i]['phi'][k].shape[0] if x >= dim_count: x1 = int(x/dim_count) * d x2 = int(x/dim_count) * (d+1) modelspec[i]['phi'][k][x1:x2] = v else: modelspec[i]['phi'][k] = v # print([modelspec.phi[f] for f in fit_set_all]) return modelspec
def _get_modelspecs(cellids, batch, modelname, multi='mean'): filepaths = load_batch_modelpaths(batch, modelname, cellids, eval_model=False) speclists = [] for path in filepaths: mspaths = [] path = path.replace('http://hyrax.ohsu.edu:3003/', '/auto/data/nems_db/') if get_setting('NEMS_RESULTS_DIR').startswith("/Volumes"): path = path.replace('/auto/', '/Volumes/') for file in os.listdir(path): if file.startswith("modelspec"): mspaths.append(os.path.join(path, file)) speclists.append([load_resource(p) for p in mspaths]) modelspecs = [] for m in speclists: if len(m) > 1: if multi == 'first': this_mspec = m[0] elif multi == 'all': this_mspec = m elif multi == 'mean': stats = ms.summary_stats(m) temp_spec = copy.deepcopy(m[0]) phis = [m['phi'] for m in temp_spec] for p in phis: for k in p: for s in stats: if s.endswith('--' + k): p[k] = stats[s]['mean'] for m, p in zip(temp_spec, phis): m['phi'] = p this_mspec = temp_spec else: log.warning( "Couldn't interpret <multi> parameter. Got: %s,\n" "Expected one of: 'mean, first, random, all'.\n" "Using first modelspec instead.", multi) this_mspec = m[0] else: this_mspec = m[0] modelspecs.append(ms.ModelSpec([this_mspec])) return modelspecs
def _extract_pop_channel(modelspec, d, fit_set_all, freeze_idx=[]): """ extract mini model from modelspec, just for channel d over the modules indexed by fit_set_all :param modelspec: :param fit_set_all: :param d: :param freeze_idx: list of module indices to freeze (move phi to fn_kwargs) :return: tmodelspec - subspace model """ # create modelspec with single population subspace filter dim_count = modelspec[fit_set_all[-1]+1]['phi']['coefficients'].shape[1] tmodelspec = ms.ModelSpec() for i in fit_set_all: m = copy.deepcopy(modelspec[i]) print(m['fn']) for k, v in m['phi'].items(): x = v.shape[0] if x >= dim_count: x1 = int(x/dim_count) * d x2 = int(x/dim_count) * (d+1) m['phi'][k] = v[x1:x2] if 'bounds' in m.keys(): m['bounds'][k] = (m['bounds'][k][0][x1:x2], m['bounds'][k][1][x1:x2]) if 'bank_count' in m['fn_kwargs'].keys(): m['fn_kwargs']['bank_count'] = 1 else: # single model-wide parameter, only fit for d==0 if d == 0: m['phi'][k] = v else: m['fn_kwargs'][k] = v # keep fixed for d>0 del m['phi'] del m['prior'] tmodelspec.append(m) return tmodelspec
def fit_population_channel(rec, modelspec, fit_set_all, fit_set_slice, analysis_function=analysis.fit_basic, metric=metrics.nmse, fitter=scipy_minimize, fit_kwargs={}): """ DEPRECATED? Fit all the population channels, but only trying to predict the responses inverted through the weight_channels of layer 2 :param rec: :param modelspec: :param fit_set_all: :param fit_set_slice: :param analysis_function: :param metric: :param fitter: :param fit_kwargs: :return: """ # guess at number of subspace dimensions # dim_count = modelspec[fit_set_slice[0]]['phi']['coefficients'].shape[1] # invert cell-specific modules trec = _invert_slice(rec, modelspec, fit_set_slice) tmodelspec = ms.ModelSpec() for i in fit_set_all: m = modelspec[i].copy() tmodelspec.append(m) tmodelspec = analysis_function(trec, tmodelspec, fitter=fitter, metric=metric, fit_kwargs=fit_kwargs) for i in fit_set_all: for k, v in tmodelspec[i]['phi'].items(): modelspec[i]['phi'][k] = v return modelspec
def init_relsat(rec, modelspec): modelspec = copy.deepcopy(modelspec) target_i = find_module('saturated_rectifier', modelspec) if target_i is None: log.warning("No relsat module was found, can't initialize.") return modelspec if target_i == len(modelspec): fit_portion = modelspec.modules else: fit_portion = modelspec.modules[:target_i] # generate prediction from module preceeding dexp #rec = ms.evaluate(rec, ms.ModelSpec(fit_portion)).apply_mask() rec = ms.ModelSpec(fit_portion).evaluate(rec).apply_mask() pred = rec['pred'].as_continuous().flatten() resp = rec['resp'].as_continuous().flatten() stdr = np.nanstd(resp) base = np.min(resp) amplitude = min(np.mean(resp) + stdr * 3, np.max(resp)) shift = np.mean(pred) - 1.5 * np.nanstd(pred) kappa = 1 base_prior = ('Exponential', {'beta': base}) amplitude_prior = ('Exponential', {'beta': amplitude}) shift_prior = ('Normal', {'mean': shift, 'sd': shift}) kappa_prior = ('Exponential', {'beta': kappa}) modelspec['prior'] = { 'base': base_prior, 'amplitude': amplitude_prior, 'shift': shift_prior, 'kappa': kappa_prior } return modelspec
def fit_population_channel_fast(rec, modelspec, fit_set_all, fit_set_slice, analysis_function=analysis.fit_basic, metric=metrics.nmse, fitter=scipy_minimize, fit_kwargs={}): # guess at number of subspace dimensions dim_count = modelspec[fit_set_slice[0]]['phi']['coefficients'].shape[1] for d in range(dim_count): # fit each dim separately log.info('Updating dim %d/%d', d+1, dim_count) # create modelspec with single population subspace filter tmodelspec = ms.ModelSpec() for i in fit_set_all: m = copy.deepcopy(modelspec[i]) for k, v in m['phi'].items(): x = v.shape[0] if x >= dim_count: x1 = int(x/dim_count) * d x2 = int(x/dim_count) * (d+1) m['phi'][k] = v[x1:x2] if 'bank_count' in m['fn_kwargs'].keys(): m['fn_kwargs']['bank_count'] = 1 else: # single model-wide parameter, only fit for d==0 if d == 0: m['phi'][k] = v else: m['fn_kwargs'][k] = v # keep fixed for d>0 del m['phi'] del m['prior'] tmodelspec.append(m) # append full-population layer as non-free parameters for i in fit_set_slice: m = copy.deepcopy(modelspec[i]) for k, v in m['phi'].items(): # just applies to wc module? if v.shape[1] >= dim_count: m['fn_kwargs'][k] = v[:, [d]] else: m['fn_kwargs'][k] = v # print(k) # print(m['fn_kwargs'][k]) del m['phi'] del m['prior'] tmodelspec.append(m) # print(tmodelspec[-1]['fn_kwargs']) # compute residual from prediction by the rest of the pop model trec = rec.copy() trec = ms.evaluate(trec, modelspec) r = trec['resp'].as_continuous() p = trec['pred'].as_continuous() trec = ms.evaluate(trec, tmodelspec) p2 = trec['pred'].as_continuous() trec['resp'] = trec['resp']._modified_copy(data=r-p+p2) # import pdb; # pdb.set_trace() tmodelspec = analysis_function(trec, tmodelspec, fitter=fitter, metric=metric, fit_kwargs=fit_kwargs) for i in fit_set_all: for k, v in tmodelspec[i]['phi'].items(): x = modelspec[i]['phi'][k].shape[0] if x >= dim_count: x1 = int(x/dim_count) * d x2 = int(x/dim_count) * (d+1) modelspec[i]['phi'][k][x1:x2] = v else: modelspec[i]['phi'][k] = v # for i in fit_set_slice: # for k, v in tmodelspec[i]['phi'].items(): # if modelspec[i]['phi'][k].shape[0] >= dim_count: # modelspec[i]['phi'][k][d,:] = v # print([modelspec.phi[f] for f in fit_set_all]) return modelspec
def init_logsig(rec, modelspec): ''' Initialization of priors for logistic_sigmoid, based on process described in methods of Rabinowitz et al. 2014. ''' # preserve input modelspec modelspec = copy.deepcopy(modelspec) target_i = find_module('logistic_sigmoid', modelspec) if target_i is None: log.warning("No logsig module was found, can't initialize.") return modelspec if target_i == len(modelspec): fit_portion = modelspec.modules else: fit_portion = modelspec.modules[:target_i] # generate prediction from module preceeding dexp #rec = ms.evaluate(rec, ms.ModelSpec(fit_portion)) rec = ms.ModelSpec(fit_portion).evaluate(rec) pred = rec['pred'].as_continuous() resp = rec['resp'].as_continuous() mean_pred = np.nanmean(pred) min_pred = np.nanmean(pred) - np.nanstd(pred) * 3 max_pred = np.nanmean(pred) + np.nanstd(pred) * 3 if min_pred < 0: min_pred = 0 mean_pred = (min_pred + max_pred) / 2 pred_range = max_pred - min_pred min_resp = max(np.nanmean(resp) - np.nanstd(resp) * 3, 0) # must be >= 0 max_resp = np.nanmean(resp) + np.nanstd(resp) * 3 resp_range = max_resp - min_resp # Rather than setting a hard value for initial phi, # set the prior distributions and let the fitter/analysis # decide how to use it. base0 = min_resp + 0.05 * (resp_range) amplitude0 = resp_range shift0 = mean_pred kappa0 = pred_range log.info("Initial base,amplitude,shift,kappa=({}, {}, {}, {})".format( base0, amplitude0, shift0, kappa0)) base = ('Exponential', {'beta': base0}) amplitude = ('Exponential', {'beta': amplitude0}) shift = ('Normal', {'mean': shift0, 'sd': pred_range}) kappa = ('Exponential', {'beta': kappa0}) if 'phi' in modelspec[target_i]: modelspec[target_i]['phi'].update({ 'base': base0, 'amplitude': amplitude0, 'shift': shift0, 'kappa': kappa0 }) modelspec[target_i]['prior'].update({ 'base': base, 'amplitude': amplitude, 'shift': shift, 'kappa': kappa }) modelspec[target_i]['bounds'] = { 'base': (1e-15, None), 'amplitude': (1e-15, None), 'shift': (None, None), 'kappa': (1e-15, None) } return modelspec
def init_dexp(rec, modelspec, nl_mode=2, override_target_i=None): """ choose initial values for dexp applied after preceeding fir is initialized nl_mode must be in {1,2} (default is 2), pre 11/29/18 models were fit with v1 1: amp = np.nanstd(resp) * 3 kappa = np.log(2 / (np.max(pred) - np.min(pred) + 1)) 2: amp = resp[pred>np.percentile(pred,90)].mean() kappa = np.log(2 / (np.std(pred)*3)) override_target_i should be an integer index into the modelspec. This replaces the normal behavior of the function which would look up the index of the 'double_exponential' module. Use this if you want to use dexp's initialization procedure for a similar nonlinearity module. """ # preserve input modelspec modelspec = copy.deepcopy(modelspec) if override_target_i is None: target_i = find_module('double_exponential', modelspec) if target_i is None: log.warning("No dexp module was found, can't initialize.") return modelspec else: target_i = override_target_i if target_i == len(modelspec): fit_portion = modelspec.modules else: fit_portion = modelspec.modules[:target_i] # ensures all previous modules have their phi initialized # choose prior mean if not found for i, m in enumerate(fit_portion): if ('phi' not in m.keys()) and ('prior' in m.keys()): log.debug('Phi not found for module, using mean of prior: %s', m) m = priors.set_mean_phi([m])[0] # Inits phi for 1 module fit_portion[i] = m # generate prediction from module preceeding dexp #rec = ms.evaluate(rec, ms.ModelSpec(fit_portion)) rec = ms.ModelSpec(fit_portion).evaluate(rec) in_signal = modelspec[target_i]['fn_kwargs']['i'] pchans = rec[in_signal].shape[0] amp = np.zeros([pchans, 1]) base = np.zeros([pchans, 1]) kappa = np.zeros([pchans, 1]) shift = np.zeros([pchans, 1]) out_signal = modelspec.meta.get('output_name', 'resp') for i in range(pchans): resp = rec[out_signal].as_continuous() pred = rec[in_signal].as_continuous()[i:(i + 1), :] if resp.shape[0] == pchans: resp = resp[i:(i + 1), :] keepidx = np.isfinite(resp) * np.isfinite(pred) resp = resp[keepidx] pred = pred[keepidx] # choose phi s.t. dexp starts as almost a straight line # phi=[max_out min_out slope mean_in] # meanr = np.nanmean(resp) stdr = np.nanstd(resp) # base = np.max(np.array([meanr - stdr * 4, 0])) base[i, 0] = np.min(resp) # amp = np.max(resp) - np.min(resp) if nl_mode == 1: amp[i, 0] = stdr * 3 predrange = 2 / (np.max(pred) - np.min(pred) + 1) shift[i, 0] = np.mean(pred) kappa[i, 0] = np.log(predrange) elif (nl_mode == 2) & (np.std(pred) == 0): log.warning( 'Init dexp: channel %d prediction has zero std, reverting to nl_mode==1', i) amp[i, 0] = stdr * 3 predrange = 2 / (np.max(pred) - np.min(pred) + 1) shift[i, 0] = np.mean(pred) kappa[i, 0] = np.log(predrange) elif nl_mode == 2: mask = np.zeros_like(pred, dtype=bool) pct = 91 while (sum(mask) < .01 * pred.shape[0]) and (pct > 1): pct -= 1 mask = pred > np.percentile(pred, pct) if np.sum(mask) == 0: mask = np.ones_like(pred, dtype=bool) if pct != 90: log.warning( 'Init dexp: Default for init mode 2 is to find mean ' 'of responses for times where pred>pctile(pred,90). ' '\nNo times were found so this was lowered to ' 'pred>pctile(pred,%d).', pct) amp[i, 0] = resp[mask].mean() predrange = 2 / (np.std(pred) * 3) if not np.isfinite(predrange): predrange = 1 shift[i, 0] = np.mean(pred) kappa[i, 0] = np.log(predrange) elif nl_mode == 3: base[i, 0] = np.min(resp) - stdr amp[i, 0] = stdr * 4 predrange = 1 / (np.std(pred) * 3) shift[i, 0] = np.mean(pred) kappa[i, 0] = np.log(predrange) elif nl_mode == 4: base[i, 0] = np.mean(resp) - stdr * 1 amp[i, 0] = stdr * 4 predrange = 2 / (np.std(pred) * 3) if not np.isfinite(predrange): predrange = 1 kappa[i, 0] = 0 shift[i, 0] = 0 else: raise ValueError('nl mode = {} not valid'.format(nl_mode)) modelspec[target_i]['phi'] = { 'amplitude': amp, 'base': base, 'kappa': kappa, 'shift': shift } log.info("Init dexp: %s", modelspec[target_i]['phi']) return modelspec
def from_keywords(keyword_string, registry=None, rec=None, meta={}, init_phi_to_mean_prior=True, input_name='stim', output_name='resp'): ''' Returns a modelspec created by splitting keyword_string on underscores and replacing each keyword with what is found in the nems.keywords.defaults registry. You may provide your own keyword registry using the registry={...} argument. ''' if registry is None: from nems.xforms import keyword_lib registry = keyword_lib keywords = keyword_string.split('-') # Lookup the modelspec fragments in the registry modelspec = ms.ModelSpec() for kw in keywords: if (kw.startswith("fir.Nx") or kw.startswith("wc.Nx")) and \ (rec is not None): N = rec[input_name].nchans kw_old = kw kw = kw.replace(".N", ".{}".format(N)) log.info("kw: dynamically subbing %s with %s", kw_old, kw) elif kw.startswith("stategain.N") and (rec is not None): N = rec['state'].nchans kw_old = kw kw = kw.replace("stategain.N", "stategain.{}".format(N)) log.info("kw: dynamically subbing %s with %s", kw_old, kw) elif (kw.endswith(".N")) and (rec is not None): N = rec[input_name].nchans kw_old = kw kw = kw.replace(".N", ".{}".format(N)) log.info("kw: dynamically subbing %s with %s", kw_old, kw) elif (kw.endswith(".cN")) and (rec is not None): N = rec[input_name].nchans kw_old = kw kw = kw.replace(".cN", ".c{}".format(N)) log.info("kw: dynamically subbing %s with %s", kw_old, kw) elif (kw.endswith("xN")) and (rec is not None): N = rec[input_name].nchans kw_old = kw kw = kw.replace("xN", "x{}".format(N)) log.info("kw: dynamically subbing %s with %s", kw_old, kw) elif ("xN" in kw) and (rec is not None): N = rec[input_name].nchans kw_old = kw kw = kw.replace("xN", "x{}".format(N)) log.info("kw: dynamically subbing %s with %s", kw_old, kw) if (".S" in kw or ".Sx" in kw) and (rec is not None): S = rec['state'].nchans kw_old = kw kw = kw.replace(".S", ".{}".format(S)) log.info("kw: dynamically subbing %s with %s", kw_old, kw) if ("xS" in kw) and (rec is not None): S = rec['state'].nchans kw_old = kw kw = kw.replace("xS", "x{}".format(S)) log.info("kw: dynamically subbing %s with %s", kw_old, kw) if (".R" in kw) and (rec is not None): R = rec[output_name].nchans kw_old = kw kw = kw.replace(".R", ".{}".format(R)) log.info("kw: dynamically subbing %s with %s", kw_old, kw) elif ("xR" in kw) and (rec is not None): R = rec[output_name].nchans kw_old = kw kw = kw.replace("xR", "x{}".format(R)) log.info("kw: dynamically subbing %s with %s", kw_old, kw) else: log.info('kw: %s', kw) if registry.kw_head(kw) not in registry: raise ValueError("unknown keyword: {}".format(kw)) templates = copy.deepcopy(registry[kw]) if not isinstance(templates, list): templates = [templates] for d in templates: d['id'] = kw if init_phi_to_mean_prior: d = priors.set_mean_phi([d])[0] # Inits phi for 1 module modelspec.append(d) # first module that takes input='pred' should take ctx['input_name'] # instead. can't hard-code in keywords, since we don't know which # keyword will be first. and can't assume that it will be module[0] # because those might be state manipulations first_input_found = False i = 0 while (not first_input_found) and (i < len(modelspec)): if ('i' in modelspec[i]['fn_kwargs'].keys()) and ( modelspec[i]['fn_kwargs']['i'] == 'pred'): log.info("Setting modelspec[%d] input to %s", i, input_name) modelspec[i]['fn_kwargs']['i'] = input_name """ OLD if input_name != 'stim': modelspec[i]['fn_kwargs']['i'] = input_name elif 'state' in modelspec[i]['fn']: modelspec[i]['fn_kwargs']['i'] = 'psth' else: modelspec[i]['fn_kwargs']['i'] = input_name """ # 'i' key found first_input_found = True i += 1 # insert metadata, if provided if rec is not None: if 'cellids' in meta.keys(): # cellids list already exists. keep it. pass elif 'cellid' in meta.keys(): meta['cellids'] = [meta['cellid']] elif ((rec['resp'].shape[0] > 1) and (type(rec.meta['cellid']) is list)): # guess cellids list from rec.meta meta['cellids'] = rec.meta['cellid'] meta['input_name'] = input_name meta['output_name'] = output_name # for modelspec object, we know that meta must exist, so just update modelspec.meta.update(meta) if modelspec.meta.get('modelpath') is None: destination = get_default_savepath(modelspec) modelspec.meta['modelpath'] = destination modelspec.meta['figurefile'] = os.path.join(destination, 'figure.0000.png') return modelspec
def pick_best_phi(modelspec=None, est=None, val=None, est_list=None, val_list=None, criterion='mse_fit', metric_fn='nems.metrics.mse.nmse', jackknifed_fit=False, keep_n=1, IsReload=False, **context): """ For models with multiple fits (eg, based on multiple initial conditions), find the best prediction for the recording provided (presumably est data, though possibly something held out) For jackknifed fits, pick the best fit for each jackknife set. so a F x J modelspec is reduced in size to 1 x J. Models are tested with est data for that jackknife. This has only been tested with est recording, which is likely should be used. :param modelspec: should have fit_count>0 :param est: view_count should match fit_count, ie, after generate_prediction is called :param context: extra context stuff for xforms compatibility. :return: modelspec with fit_count==1 """ if IsReload: return {} if modelspec.fit_count <= 1: return {} if est_list is None: est_list = [est] val_list = [val] #for cellidx,est,val in zip(range(len(est_list)),est_list,val_list): # modelspec.set_cell(cellidx) # est, val = nems.analysis.api.generate_prediction(est, val, modelspec, jackknifed_fit=jackknifed_fit) # modelspec.recording = val # est_list[cellidx] = est # val_list[cellidx] = val #import pdb; pdb.set_trace() jack_count = modelspec.jack_count fit_count = modelspec.fit_count best_idx = np.zeros(jack_count, dtype=int) modelspec.set_cell(0) # for each jackknife set, figure out best fit. save it to new_raw new_raw = np.zeros((modelspec.cell_count, keep_n, jack_count), dtype='O') for j in range(jack_count): view_range = [i * jack_count + j for i in range(fit_count)] x = None n = 0 # support for multi-cell, len(est_list)>1 fits #import pdb; pdb.set_trace() for cell_idx, this_est in enumerate(est_list): #this_modelspec = modelspec.copy(jack_index=j) #this_modelspec.cell_index = cell_idx modelspec.cell_index = cell_idx modelspec.jack_index = j if 'loss' in modelspec.meta.keys(): # quick ranking: use a vector of loss values, one per fitidx stored in meta if x is None: x = modelspec.meta['loss'] else: x = x + modelspec.meta['loss'] n = n + 1 if j > 1: log.info( 'Not supported yet! jackknife + multifit using tf loss to select' ) import pdb pdb.set_trace() elif (metric_fn == 'nems.metrics.mse.nmse') & (criterion == 'mse_fit'): # for backwards compatibility, run the below code to compute metric specified # by criterion. # unclear if this works x = np.zeros(modelspec.fit_count) for fitidx in range(modelspec.fit_count): traw = modelspec.raw[cell_idx, fitidx, j] tmodelspec = ms.ModelSpec(traw) this_est = tmodelspec.evaluate(this_est) this_val = tmodelspec.evaluate(this_val) tmodelspec = standard_correlation(est=this_est, val=this_val, modelspec=tmodelspec) # average performance across output channels (if more than one output) x[fitidx] += tmodelspec.meta[criterion].sum(axis=0) n += new_modelspec.meta[criterion].shape[0] else: # unclear if this works fn = nems.utils.lookup_fn_at(metric_fn) tx = [] for fitidx in range(modelspec.fit_count): traw = modelspec.raw[cell_idx, fitidx, j] tmodelspec = ms.ModelSpec(traw) this_est = tmodelspec.evaluate(this_est) tx.append(fn(this_est, **context)) n = n + tx[0].shape[0] tx = np.concatenate(tx, axis=1).mean(axis=0) if x is None: x = tx else: x = x + tx x = x / n if np.any(x == 0): raise RuntimeError('Error: At least one elements of modelspec.meta[\'loss\'] is 0. '\ 'This will cause this fn to pick that phi, but it\'s probably not. Bug in the code?') tx = x.copy() for n in range(keep_n): best_idx[j] = int(np.nanargmin(tx)) new_raw[:, n, j] = modelspec.raw[:, best_idx[j], j] log.info( 'jack %d: %d/%d best phi (fit_idx=%d) has fit_metric=%.5f', j, n + 1, keep_n, best_idx[j], tx[best_idx[j]]) tx[best_idx[j]] = np.nanmax(tx) for cell_index in range(new_raw.shape[0]): new_raw[cell_index, 0, 0][0]['meta'] = modelspec.raw[cell_index, 0, 0][0]['meta'].copy() new_modelspec = ms.ModelSpec(new_raw) new_modelspec.set_cell(0) new_modelspec.set_jack(0) new_modelspec.meta['rand_' + criterion] = x return {'modelspec': new_modelspec, 'best_random_idx': best_idx}
def predict_and_summarize_for_all_modelspec( modelspec=None, est=None, val=None, est_list=None, val_list=None, criterion=['r_test', 'se_test', 'mse_test', 'll_test'], metric_fn='nems.metrics.mse.nmse', jackknifed_fit=False, keep_n=1, IsReload=False, **context): """ For models with multiple fits (eg, based on multiple initial conditions), find the best prediction for the recording provided (presumably est data, though possibly something held out) For jackknifed fits, pick the best fit for each jackknife set. so a F x J modelspec is reduced in size to 1 x J. Models are tested with est data for that jackknife. This has only been tested with est recording, which is likely should be used. :param modelspec: should have fit_count>0 :param est: view_count should match fit_count, ie, after generate_prediction is called :param context: extra context stuff for xforms compatibility. :return: modelspec with fit_count==1 """ if IsReload: return {} if modelspec.fit_count <= 1: return {} if est_list is None: est_list = [est] val_list = [val] if modelspec.jack_count > 1: log.info( 'Not supported yet! jackknife + multifit using tf loss to select') import pdb pdb.set_trace() if modelspec.cell_count > 1: log.info( 'Not tested for mega models. I think it will work but confirm. LAS' ) import pdb pdb.set_trace() fit_count = modelspec.fit_count modelspec.set_cell(0) #cell is a misnomer here, means site Ncells = val_list[0]['resp'].shape[0] means = [np.zeros(modelspec.fit_count) for i in range(len(criterion))] if modelspec.cell_count == 1: indiv_cells = [ np.zeros((Ncells, modelspec.fit_count)) for i in range(len(criterion)) ] n = 0 for cell_idx, (this_est, this_val) in enumerate(zip(est_list, val_list)): #loop through each "cell_idx" (site index for megamodels) for fitidx in range(modelspec.fit_count): #loop through each fitidx traw = modelspec.raw[cell_idx, fitidx, 0] tmodelspec = ms.ModelSpec(traw) this_est = tmodelspec.evaluate(this_est) this_val = tmodelspec.evaluate(this_val) tmodelspec = standard_correlation(est=this_est, val=this_val, modelspec=tmodelspec) # average performance across output channels (if more than one output) for i, criterion_ in enumerate(criterion): means[i][fitidx] += tmodelspec.meta[criterion_].sum(axis=0) if modelspec.cell_count == 1: indiv_cells[i][:, fitidx] = tmodelspec.meta[criterion_][:, 0] n += tmodelspec.meta[criterion[0]].shape[0] for i, criterion_ in enumerate(criterion): means[i] = means[i] / n modelspec.meta['rand_' + criterion_] = means[i] if modelspec.cell_count == 1: modelspec.meta['rand_' + criterion_ + '_all'] = indiv_cells[i] return {'modelspec': modelspec}
def fit_tf_init(modelspec, est: recording.Recording, nl_init: str = 'tf', IsReload: bool = False, **kwargs) -> dict: """Inits a model using tf. Makes a new model up to the last relu or first levelshift, in the process setting levelshift to the mean of the resp. Excludes in the new model stp, rdt_gain, state_dc_gain, state_gain. Fits this. Then runs init_static_nl , which looks at the last 2 layers of the original model, and if any of dexp, relu, log_sig, sat_rect are in those last two, only fits the first it encounters (freezes all other layers). """ if IsReload: return {} def first_substring_index(strings, substring): try: return next(i for i, string in enumerate(strings) if substring in string) except StopIteration: return None # find the first 'lvl' or last 'relu' ms_modules = [ms['fn'] for ms in modelspec] #up_to_idx = first_substring_index(ms_modules, 'levelshift') relu_idx = first_substring_index(reversed(ms_modules), 'relu') lvl_idx = first_substring_index(reversed(ms_modules), 'levelshift') _idxs = [ i for i in [relu_idx, lvl_idx, len(modelspec) - 1] if i is not None ] up_to_idx = len(modelspec) - 1 - np.min(_idxs) #last_idx = np.min([relu_idx, lvl_idx]) #up_to_idx = len(modelspec) - 1 - up_to_idx #if up_to_idx is None: # up_to_idx = first_substring_index(reversed(ms_modules), 'levelshift') # # because reversed, need to mirror the idx # if up_to_idx is not None: # up_to_idx = len(modelspec) - 1 - up_to_idx # else: # up_to_idx = len(modelspec) - 1 log.info('up_to_idx=%d (%s)', up_to_idx, modelspec[up_to_idx]['fn']) # do the +1 here to avoid adding to None up_to_idx += 1 # exclude the following from the init exclude = ['rdt_gain', 'state_dc_gain', 'state_gain'] freeze = ['stp'] # more complex version of first_substring_index: checks for not membership in init_static_nl_layers init_idxes = [ idx for idx, ms in enumerate(ms_modules[:up_to_idx]) if not any(sub in ms for sub in exclude) ] freeze_idxes = [] # make a temp modelspec temp_ms = mslib.ModelSpec() log.info('Creating temporary model for init with:') for idx in init_idxes: # TODO: handle 'merge_channels' ms = copy.deepcopy(modelspec[idx]) log.info(f'{ms["fn"]}') # fix levelshift if present (will always be the last module) if idx == init_idxes[-1] and 'levelshift' in ms['fn']: output_name = modelspec.meta.get('output_name', 'resp') try: mean_resp = np.nanmean(est[output_name].as_continuous(), axis=1, keepdims=True) except NotImplementedError: # as_continuous only available for RasterizedSignal mean_resp = np.nanmean( est[output_name].rasterize().as_continuous(), axis=1, keepdims=True) if len(ms['phi']['level'][:]) == len(mean_resp): log.info( f'Fixing "{ms["fn"]}" to: {mean_resp.flatten()[0]:.3f}') ms['phi']['level'][:] = mean_resp temp_ms.append(ms) if any(fr in ms['fn'] for fr in freeze): freeze_idxes.append(len(temp_ms) - 1) log.info( 'Running first init fit: model up to first lvl/relu without stp/gain.') log.debug('freeze_idxes: %s', freeze_idxes) filepath = Path(modelspec.meta['modelpath']) / 'init_part1' temp_ms = fit_tf(temp_ms, est, freeze_layers=freeze_idxes, filepath=filepath, **kwargs)['modelspec'] # put back into original modelspec for ms_idx, temp_ms_module in zip(init_idxes, temp_ms): modelspec[ms_idx] = temp_ms_module if nl_init == 'skip': return {'modelspec': modelspec} elif nl_init == 'scipy': # pre-fit static NL if it exists _d = init_static_nl(est=est, modelspec=modelspec) modelspec = _d['modelspec'] # TODO : Initialize relu in some intelligent way? log.info('finished fit_tf_init, fit_idx=%d/%d', modelspec.fit_index + 1, modelspec.fit_count) return {'modelspec': modelspec} else: # init the static nl init_static_nl_mapping = { 'double_exponential': initializers.init_dexp, 'relu': None, 'logistic_sigmoid': initializers.init_logsig, 'saturated_rectifier': initializers.init_relsat, } # first find the first occurrence of a static nl in last two layers # if present, remove it from the idxes of the modules to freeze, init the nl and fit, and return the modelspec for idx, ms in enumerate(modelspec[-2:], len(modelspec) - 2): for init_static_layer, init_fn in init_static_nl_mapping.items(): if init_static_layer in ms['fn']: log.info( f'Initializing static nl "{ms["fn"]}" at layer #{idx}') # relu has a custom init if init_static_layer == 'relu': ms['phi']['offset'][:] = -0.1 else: modelspec = init_fn(est, modelspec, nl_mode=4) static_nl_idx_not = list( set(range(len(modelspec))) - set([idx])) log.info( 'Running second init fit: all frozen but static nl.') # don't overwrite the phis in the modelspec kwargs['use_modelspec_init'] = True filepath = Path(modelspec.meta['modelpath']) / 'init_part2' return fit_tf(modelspec, est, freeze_layers=static_nl_idx_not, filepath=filepath, **kwargs) # no static nl to init return {'modelspec': modelspec}
def fit_tf_init(modelspec, est: recording.Recording, est_list: typing.Union[None, list] = None, nl_init: str = 'tf', IsReload: bool = False, isolate_NL: bool = False, skip_init: bool = False, up_to_idx=None, **kwargs) -> dict: """Inits a model using tf. Makes a new model up to the last relu or first levelshift, in the process setting levelshift to the mean of the resp. Excludes in the new model stp, rdt_gain, state_dc_gain, state_gain. Fits this. Then runs init_static_nl , which looks at the last 2 layers of the original model, and if any of dexp, relu, log_sig, sat_rect are in those last two, only fits the first it encounters (freezes all other layers). """ if IsReload or skip_init: return {} def first_substring_index(strings, substring): try: return next(i for i, string in enumerate(strings) if substring in string) except StopIteration: return None modelspec.cell_index = 0 if est_list is None: est_list = [est] # find the first 'lvl' or last 'relu' ms_modules = [ms['fn'] for ms in modelspec] if up_to_idx is None: #up_to_idx = first_substring_index(ms_modules, 'levelshift') relu_idx = first_substring_index(reversed(ms_modules), 'relu') #fit to last relu lvl_idx = first_substring_index(reversed(ms_modules), 'levelshift') #fit to last levelshift _idxs = [ i for i in [relu_idx, lvl_idx, len(modelspec) - 1] if i is not None ] up_to_idx = len(modelspec) - 1 - np.min(_idxs) #last_idx = np.min([relu_idx, lvl_idx]) #up_to_idx = len(modelspec) - 1 - up_to_idx log.info('up_to_idx=%d (%s)', up_to_idx, modelspec[up_to_idx]['fn']) # do the +1 here to avoid adding to None up_to_idx += 1 # exclude the following from the init exclude = ['rdt_gain'] # , 'state_dc_gain', 'state_gain', 'sdexp'] freeze = ['stp'] # more complex version of first_substring_index: checks for not membership in init_static_nl_layers init_idxes = [ idx for idx, ms in enumerate(ms_modules[:up_to_idx]) if not any(sub in ms for sub in exclude) ] freeze_idxes = [] # make a temp modelspec log.info('Creating temporary model for init') temp_ms = mslib.ModelSpec(cell_count=modelspec.cell_count) temp_ms.shared_count = modelspec.shared_count for cell_idx in range(modelspec.cell_count): modelspec.cell_index = cell_idx temp_ms.cell_index = cell_idx est = est_list[cell_idx] output_name = modelspec.meta.get('output_name', 'resp') try: mean_resp = np.nanmean(est[output_name].as_continuous(), axis=1, keepdims=True) except NotImplementedError: # as_continuous only available for RasterizedSignal mean_resp = np.nanmean( est[output_name].rasterize().as_continuous(), axis=1, keepdims=True) mean_added = False for idx in init_idxes: # TODO: handle 'merge_channels' ms = copy.deepcopy(modelspec[idx]) log.info(f'{ms["fn"]}') # fix dc to mean_resp if present if (not mean_added) and (idx == init_idxes[-1]) and \ ('levelshift' in ms['fn']) and (len(ms['phi']['level']) == len(mean_resp)): log.info( f'Fixing "{ms["fn"]}" to: {mean_resp.flatten()[0]:.3f}') ms['phi']['level'][:] = mean_resp mean_added = True elif (not mean_added) and ('state_dc' in ms['fn']) and \ (ms['phi']['d'].shape[0] == len(mean_resp)): log.info(f'Fixing "{ms["fn"]}[d][:,0]" to mean_resp') ms['phi']['d'][:, [0]] = mean_resp mean_added = True temp_ms.append(ms) if any(fr in ms['fn'] for fr in freeze): freeze_idxes.append(len(temp_ms) - 1) # important to reset cell_index to zero??? est = est_list[0] temp_ms.cell_index = 0 log.info( 'Running first init fit: model up to last lvl/relu without stp/gain.') log.debug('freeze_idxes: %s', freeze_idxes) filepath = Path(modelspec.meta['modelpath']) / 'init_part1' if 'freeze_layers' in kwargs.keys(): force_freeze = kwargs.pop( 'freeze_layers') # can't pass freeze_layers twice, else: force_freeze = None if force_freeze is not None: freeze_idxes = list( set(force_freeze + freeze_idxes)) # but also need to take union with freeze_idxes if modelspec.cell_count > 1: temp_ms = fit_tf_iterate(temp_ms, est, est_list=est_list, freeze_layers=freeze_idxes, filepath=filepath, **kwargs)['modelspec'] else: temp_ms = fit_tf(temp_ms, est, est_list=est_list, freeze_layers=freeze_idxes, filepath=filepath, **kwargs)['modelspec'] for cell_idx in range(temp_ms.cell_count): log.info( f"***********************************************************************************" ) log.info( f"**** fit_tf_init, fit_index={modelspec.fit_index} cell_index={cell_idx} fitting output NL ****" ) modelspec.cell_index = cell_idx temp_ms.cell_index = cell_idx est = est_list[cell_idx] # put back into original modelspec meta_save = modelspec.meta.copy() for ms_idx, temp_ms_module in zip(init_idxes, temp_ms): modelspec[ms_idx] = temp_ms_module if modelspec.fit_count > 1: if modelspec.fit_index == 0: meta_save['n_epochs'] = np.zeros(modelspec.fit_count) meta_save['loss'] = np.zeros(modelspec.fit_count) meta_save['n_epochs'][ modelspec.fit_index] = temp_ms.meta['n_epochs'] meta_save['loss'][modelspec.fit_index] = temp_ms.meta['loss'] modelspec.meta.update(meta_save) if nl_init == 'skip': pass # to support multiple cell_indexes, don't return here, # wait til done looping instead #return {'modelspec': modelspec} elif nl_init == 'scipy': # pre-fit static NL if it exists _d = init_static_nl(est=est, modelspec=modelspec) modelspec = _d['modelspec'] # TODO : Initialize relu in some intelligent way? log.info('finished fit_tf_init, fit_idx=%d/%d', modelspec.fit_index + 1, modelspec.fit_count) #return {'modelspec': modelspec} else: # init the static nl init_static_nl_mapping = { 'double_exponential': initializers.init_dexp, 'relu': None, 'logistic_sigmoid': initializers.init_logsig, 'saturated_rectifier': initializers.init_relsat, } # first find the first occurrence of a static nl in last two layers # if present, remove it from the idxes of the modules to freeze, init the nl and fit, and return the modelspec stage2_pending = True for idx, ms in enumerate(modelspec[-2:], len(modelspec) - 2): for init_static_layer, init_fn in init_static_nl_mapping.items( ): if stage2_pending and (init_static_layer in ms['fn']): log.info( f'Initializing static nl "{ms["fn"]}" at layer #{idx}' ) # relu has a custom init if init_static_layer == 'relu': ms['phi']['offset'][:] = -0.1 else: modelspec = init_fn(est, modelspec, nl_mode=4) if force_freeze is not None: log.info( f'Running second init fit: force_freeze: {force_freeze}.' ) static_nl_idx_not = force_freeze elif isolate_NL: log.info( 'Running second init fit: all frozen but static nl.' ) static_nl_idx_not = list( set(range(len(modelspec))) - set([idx])) elif modelspec.cell_count > 1: log.info( f'Titan model: freezing first {modelspec.shared_count} layers.' ) static_nl_idx_not = list( range(modelspec.shared_count)) else: log.info( 'Running second init fit: not frozen but coarser tolerance.' ) static_nl_idx_not = [] # don't overwrite the phis in the modelspec kwargs['use_modelspec_init'] = True filepath = Path( modelspec.meta['modelpath']) / 'init_part2' modelspec = fit_tf(modelspec, est, freeze_layers=static_nl_idx_not, filepath=filepath, **kwargs)['modelspec'] stage2_pending = False modelspec.cell_index = 0 return {'modelspec': modelspec}
def _prefit_to_target(rec, modelspec, analysis_function, target_module, extra_exclude=[], fitter=scipy_minimize, metric=None, fit_kwargs={}): """Removes all modules from the modelspec that come after the first occurrence of the target module, then performs a rough fit on the shortened modelspec, then adds the latter modules back on and returns the full modelspec. """ # preserve input modelspec modelspec = copy.deepcopy(modelspec) # figure out last modelspec module to fit target_i = None if type(target_module) is not list: target_module = [target_module] for i, m in enumerate(modelspec.modules): tlist = [True for t in target_module if t in m['fn']] if len(tlist): target_i = i + 1 # don't break. use last occurrence of target module if not target_i: log.info( "target_module: {} not found in modelspec.".format(target_module)) return modelspec else: log.info("target_module: {0} found at modelspec[{1}].".format( target_module, target_i - 1)) # identify any excluded modules and take them out of temp modelspec # that will be fit here exclude_idx = [] tmodelspec = ms.ModelSpec() for i in range(len(modelspec)): m = copy.deepcopy(modelspec[i]) for fn in extra_exclude: if (fn in m['fn']): if (m.get('phi') is None): m = priors.set_mean_phi([m])[0] # Inits phi log.info('Mod %d (%s) fixing phi to prior mean', i, fn) else: log.info('Mod %d (%s) fixing phi', i, fn) m['fn_kwargs'].update(m['phi']) del m['phi'] del m['prior'] exclude_idx.append(i) if ('relu' in m['fn']): log.info('found relu') elif ('levelshift' in m['fn']): #m = priors.set_mean_phi([m])[0] output_name = modelspec.meta.get('output_name', 'resp') try: mean_resp = np.nanmean(rec[output_name].as_continuous(), axis=1, keepdims=True) except NotImplementedError: # as_continuous only available for RasterizedSignal mean_resp = np.nanmean( rec[output_name].rasterize().as_continuous(), axis=1, keepdims=True) log.info('Mod %d (%s) initializing level to %s mean %.3f', i, m['fn'], output_name, mean_resp[0]) log.info('resp has %d channels', len(mean_resp)) m['phi']['level'][:] = mean_resp if (i < target_i) or ('merge_channels' in m['fn']): tmodelspec.append(m) # fit the subset of modules if metric is None: tmodelspec = analysis_function(rec, tmodelspec, fitter=fitter, fit_kwargs=fit_kwargs) else: tmodelspec = analysis_function(rec, tmodelspec, fitter=fitter, metric=metric, fit_kwargs=fit_kwargs) if type(tmodelspec) is list: # backward compatibility tmodelspec = tmodelspec[0] # reassemble the full modelspec with updated phi values from tmodelspec #print(modelspec[0]) #print(modelspec.phi[2]) for i in np.setdiff1d(np.arange(target_i), np.array(exclude_idx)).tolist(): modelspec[int(i)] = tmodelspec[int(i)] return modelspec
def fit_population_slice(rec, modelspec, slice=0, fit_set=None, analysis_function=analysis.fit_basic, metric=metrics.nmse, fitter=scipy_minimize, fit_kwargs={}): """ fits a slice of a population model. modified from prefit_mod_subset slice: int response channel to fit fit_set: list list of mod names to fit """ # preserve input modelspec modelspec = copy.deepcopy(modelspec) slice_count = rec['resp'].shape[0] if slice > slice_count: raise ValueError("Slice %d > slice_count %d", slice, slice_count) if fit_set is None: raise ValueError("fit_set list of module indices must be specified") if type(fit_set[0]) is int: fit_idx = fit_set else: fit_idx = [] for i, m in enumerate(modelspec): for fn in fit_set: if fn in m['fn']: fit_idx.append(i) # identify any excluded modules and take them out of temp modelspec # that will be fit here tmodelspec = ms.ModelSpec() sliceinfo = [] for i, m in enumerate(modelspec): m = copy.deepcopy(m) # need to have phi in place if not m.get('phi'): log.info('Initializing phi for module %d (%s)', i, m['fn']) m = priors.set_mean_phi([m])[0] # Inits phi if i in fit_idx: s = {} for key, value in m['phi'].items(): log.debug('Slicing %d (%s) key %s chan %d for fit', i, m['fn'], key, slice) # keep only sliced channel(s) if 'bank_count' in m['fn_kwargs'].keys(): bank_count = m['fn_kwargs']['bank_count'] filters_per_bank = int(value.shape[0] / bank_count) slices = np.arange(slice*filters_per_bank, (slice+1)*filters_per_bank) m['phi'][key] = value[slices, ...] s[key] = slices m['fn_kwargs']['bank_count'] = 1 elif value.shape[0] == slice_count: m['phi'][key] = value[[slice], ...] s[key] = [slice] else: raise ValueError("Not sure how to slice %s %s", m['fn'], key) # record info about how sliced this module parameter sliceinfo.append(s) tmodelspec.append(m) if len(fit_idx) == 0: log.info('No modules matching fit_set for slice fit') return modelspec exclude_idx = np.setdiff1d(np.arange(0, len(modelspec)), np.array(fit_idx)).tolist() for i in exclude_idx: m = tmodelspec[i] log.debug('Freezing phi for module %d (%s)', i, m['fn']) m['fn_kwargs'].update(m['phi']) m['phi'] = {} tmodelspec[i] = m # generate temp recording with only resposnes of interest temp_rec = rec.copy() slice_chans = [temp_rec['resp'].chans[slice]] temp_rec['resp'] = temp_rec['resp'].extract_channels(slice_chans) # remove initial modules first_idx = fit_idx[0] if first_idx > 0: # print('firstidx {}'.format(first_idx)) temp_rec = ms.evaluate(temp_rec, tmodelspec, stop=first_idx) # temp_rec['stim'] = temp_rec['pred'].copy() # tmodelspec = tmodelspec.copy(lb=first_idx) # tmodelspec[0]['fn_kwargs']['i'] = 'stim' tmodelspec = tmodelspec.copy() tmodelspec.fast_eval_on(rec=temp_rec, subset=fit_idx) first_idx = 0 # print(tmodelspec) # print(temp_rec.signals.keys()) # IS this mask necessary? Does it work? # if 'mask' in temp_rec.signals.keys(): # print("Data len pre-mask: %d" % (temp_rec['mask'].shape[1])) # temp_rec = temp_rec.apply_mask() # print("Data len post-mask: %d" % (temp_rec['mask'].shape[1])) # fit the subset of modules temp_rec = ms.evaluate(temp_rec, tmodelspec) error_before = metric(temp_rec) tmodelspec = analysis_function(temp_rec, tmodelspec, fitter=fitter, metric=metric, fit_kwargs=fit_kwargs) tmodelspec.fast_eval_off() temp_rec = ms.evaluate(temp_rec, tmodelspec) error_after = metric(temp_rec) dError = error_before - error_after if dError < 0: log.info("dError (%.6f - %.6f) = %.6f worse. not updating modelspec" % (error_before, error_after, dError)) else: log.info("dError (%.6f - %.6f) = %.6f better. updating modelspec" % (error_before, error_after, dError)) # reassemble the full modelspec with updated phi values from tmodelspec for i, mod_idx in enumerate(fit_idx): m = copy.deepcopy(modelspec[mod_idx]) # need to have phi in place if not m.get('phi'): log.info('Intializing phi for module %d (%s)', mod_idx, m['fn']) m = priors.set_mean_phi([m])[0] # Inits phi for key, value in tmodelspec[mod_idx - first_idx]['phi'].items(): # print(key) # print(m['phi'][key].shape) # print(sliceinfo[i][key]) # print(value) m['phi'][key][sliceinfo[i][key], :] = value modelspec[mod_idx] = m return modelspec
def pick_best_phi(modelspec=None, est=None, val=None, criterion='mse_fit', metric_fn='nems.metrics.mse.nmse', jackknifed_fit=False, keep_n=1, IsReload=False, **context): """ For models with multiple fits (eg, based on multiple initial conditions), find the best prediction for the recording provided (presumably est data, though possibly something held out) For jackknifed fits, pick the best fit for each jackknife set. so a F x J modelspec is reduced in size to 1 x J. Models are tested with est data for that jackknife. This has only been tested with est recording, which is likely should be used. :param modelspec: should have fit_count>0 :param est: view_count should match fit_count, ie, after generate_prediction is called :param context: extra context stuff for xforms compatibility. :return: modelspec with fit_count==1 """ if IsReload: return {} # generate prediction for each jack and fit new_est, new_val = generate_prediction(est, val, modelspec, jackknifed_fit=jackknifed_fit) jack_count = modelspec.jack_count fit_count = modelspec.fit_count best_idx = np.zeros(jack_count, dtype=int) new_raw = np.zeros((1, keep_n, jack_count), dtype='O') #import pdb; pdb.set_trace() # for each jackknife set, figure out best fit for j in range(jack_count): view_range = [i * jack_count + j for i in range(fit_count)] this_est = new_est.view_subset(view_range) this_modelspec = modelspec.copy(jack_index=j) if (metric_fn == 'nems.metrics.mse.nmse') & (criterion == 'mse_fit'): # for backwards compatibility, run the below code to compute metric specified # by criterion. new_modelspec = standard_correlation(est=this_est, val=new_val, modelspec=this_modelspec) # average performance across output channels (if more than one output) x = np.mean(new_modelspec.meta[criterion], axis=0) else: fn = nems.utils.lookup_fn_at(metric_fn) x = [] for e in this_est.views(): x.append(fn(e, **context)) tx = x.copy() for n in range(keep_n): best_idx[j] = int(np.argmin(tx)) new_raw[0, n, j] = modelspec.raw[0, best_idx[j], j] log.info( 'jack %d: %d/%d best phi (fit_idx=%d) has fit_metric=%.5f', j, n + 1, keep_n, best_idx[j], tx[best_idx[j]]) tx[best_idx[j]] = tx.max() new_raw[0, 0, 0][0]['meta'] = modelspec.meta.copy() new_modelspec = ms.ModelSpec(new_raw) new_modelspec.meta['rand_' + criterion] = x return {'modelspec': new_modelspec, 'best_random_idx': best_idx}