def fit_iteratively( data, modelspec, cost_function=basic_cost, fitter=coordinate_descent, evaluator=ms.evaluate, segmentor=nems.segmentors.use_all_data, mapper=nems.fitters.mappers.simple_vector, metric=lambda data: nems.metrics.api.nmse(data, 'pred', 'resp'), metaname='fit_basic', fit_kwargs={}, module_sets=None, invert=False, tolerances=None, tol_iter=100, fit_iter=20, ): ''' Required Arguments: data A recording object modelspec A modelspec object Optional Arguments: fitter A function of (sigma, costfn) that tests various points, in fitspace (i.e. sigmas) using the cost function costfn, and hopefully returns a better sigma after some time. mapper A class that has two methods, pack and unpack, which define the mapping between modelspecs and a fitter's fitspace. segmentor An function that selects a subset of the data during the fitting process. This is NOT the same as est/val data splits metric A function of a Recording that returns an error value that is to be minimized. module_sets A nested list specifying which model indices should be fit. Overall iteration will occurr len(module_sets) many times. ex: [[0], [1, 3], [0, 1, 2, 3]] invert Boolean. Causes module_sets to specify the model indices that should *not* be fit. Returns A list containing a single modelspec, which has the best parameters found by this fitter. ''' if module_sets is None: module_sets = [[i] for i in range(len(modelspec))] if tolerances is None: tolerances = [1e-6] start_time = time.time() ms.fit_mode_on(modelspec) # Ensure that phi exists for all modules; choose prior mean if not found for i, m in enumerate(modelspec): if not m.get('phi'): log.debug( 'Phi not found for module, using mean of prior: {}'.format(m)) m = nems.priors.set_mean_phi([m])[0] # Inits phi for 1 module modelspec[i] = m error = np.inf for tol in tolerances: log.info("Fitting all subsets with tolerance: %.2E", tol) fit_kwargs.update({'tolerance': tol, 'max_iter': fit_iter}) error_reduction = np.inf i = 0 while (error_reduction >= tol) and (i < tol_iter): for subset in module_sets: improved_modelspec = _module_set_loop(subset, data, modelspec, cost_function, fitter, mapper, segmentor, evaluator, metric, fit_kwargs) new_error = cost_function.error error_reduction = error - new_error error = new_error log.debug("Error reduction was: %.6E", error_reduction) i += 1 elapsed_time = (time.time() - start_time) # TODO: Should this maybe be moved to a higher level # so it applies to ALL the fittters? ms.fit_mode_off(improved_modelspec) ms.set_modelspec_metadata(improved_modelspec, 'fitter', metaname) ms.set_modelspec_metadata(improved_modelspec, 'fit_time', elapsed_time) results = [copy.deepcopy(improved_modelspec)] return results
def fit_module_sets( data, modelspec, cost_function=basic_cost, evaluator=ms.evaluate, segmentor=nems.segmentors.use_all_data, mapper=nems.fitters.mappers.simple_vector, metric=lambda data: nems.metrics.api.nmse(data, 'pred', 'resp'), fitter=coordinate_descent, fit_kwargs={}, metaname='fit_module_sets', module_sets=None, invert=False, tolerance=1e-4, max_iter=1000): ''' Required Arguments: data A recording object modelspec A modelspec object Optional Arguments: fitter A function of (sigma, costfn) that tests various points, in fitspace (i.e. sigmas) using the cost function costfn, and hopefully returns a better sigma after some time. module_sets A nested list specifying which model indices should be fit. Overall iteration will occurr len(module_sets) many times. ex: [[0], [1, 3], [0, 1, 2, 3]] invert Boolean. Causes module_sets to specify the model indices that should *not* be fit. Returns A list containing a single modelspec, which has the best parameters found by this fitter. ''' if module_sets is None: module_sets = [[i] for i in range(len(modelspec))] fit_kwargs.update({'tolerance': tolerance, 'max_iter': max_iter}) # Ensure that phi exists for all modules; choose prior mean if not found for i, m in enumerate(modelspec): if not m.get('phi'): log.debug( 'Phi not found for module, using mean of prior: {}'.format(m)) m = nems.priors.set_mean_phi([m])[0] # Inits phi for 1 module modelspec[i] = m if invert: module_sets = _invert_subsets(modelspec, module_sets) ms.fit_mode_on(modelspec) start_time = time.time() log.info("Fitting all subsets with tolerance: %.2E", tolerance) for subset in module_sets: improved_modelspec = _module_set_loop(subset, data, modelspec, cost_function, fitter, mapper, segmentor, evaluator, metric, fit_kwargs) elapsed_time = (time.time() - start_time) # TODO: Should this maybe be moved to a higher level # so it applies to ALL the fittters? ms.fit_mode_off(improved_modelspec) ms.set_modelspec_metadata(improved_modelspec, 'fitter', metaname) ms.set_modelspec_metadata(improved_modelspec, 'fit_time', elapsed_time) results = [copy.deepcopy(improved_modelspec)] return results
def init_logsig(rec, modelspec): ''' Initialization of priors for logistic_sigmoid, based on process described in methods of Rabinowitz et al. 2014. ''' # preserve input modelspec modelspec = copy.deepcopy(modelspec) target_i = find_module('logistic_sigmoid', modelspec) if target_i is None: log.warning("No logsig module was found, can't initialize.") return modelspec if target_i == len(modelspec): fit_portion = modelspec else: fit_portion = modelspec[:target_i] # generate prediction from module preceeding dexp ms.fit_mode_on(fit_portion) rec = ms.evaluate(rec, fit_portion) ms.fit_mode_off(fit_portion) pred = rec['pred'].as_continuous() resp = rec['resp'].as_continuous() mean_pred = np.nanmean(pred) min_pred = np.nanmean(pred) - np.nanstd(pred) * 3 max_pred = np.nanmean(pred) + np.nanstd(pred) * 3 if min_pred < 0: min_pred = 0 mean_pred = (min_pred + max_pred) / 2 pred_range = max_pred - min_pred min_resp = max(np.nanmean(resp) - np.nanstd(resp) * 3, 0) # must be >= 0 max_resp = np.nanmean(resp) + np.nanstd(resp) * 3 resp_range = max_resp - min_resp # Rather than setting a hard value for initial phi, # set the prior distributions and let the fitter/analysis # decide how to use it. base0 = min_resp + 0.05 * (resp_range) amplitude0 = resp_range shift0 = mean_pred kappa0 = pred_range log.info("Initial base,amplitude,shift,kappa=({}, {}, {}, {})".format( base0, amplitude0, shift0, kappa0)) base = ('Exponential', {'beta': base0}) amplitude = ('Exponential', {'beta': amplitude0}) shift = ('Normal', {'mean': shift0, 'sd': pred_range}) kappa = ('Exponential', {'beta': kappa0}) modelspec[target_i]['prior'].update({ 'base': base, 'amplitude': amplitude, 'shift': shift, 'kappa': kappa }) modelspec[target_i]['bounds'] = { 'base': (1e-15, None), 'amplitude': (1e-15, None), 'shift': (None, None), 'kappa': (1e-15, None) } return modelspec
def fit_basic(data, modelspec, fitter=scipy_minimize, cost_function=None, segmentor=nems.segmentors.use_all_data, mapper=nems.fitters.mappers.simple_vector, metric=lambda data: metrics.nmse(data, 'pred', 'resp'), metaname='fit_basic', fit_kwargs={}, require_phi=True): ''' Required Arguments: data A recording object modelspec A modelspec object Optional Arguments: fitter A function of (sigma, costfn) that tests various points, in fitspace (i.e. sigmas) using the cost function costfn, and hopefully returns a better sigma after some time. mapper A class that has two methods, pack and unpack, which define the mapping between modelspecs and a fitter's fitspace. segmentor An function that selects a subset of the data during the fitting process. This is NOT the same as est/val data splits metric A function of a Recording that returns an error value that is to be minimized. Returns A list containing a single modelspec, which has the best parameters found by this fitter. ''' start_time = time.time() if cost_function is None: # Use the cost function defined in this module by default cost_function = basic_cost if require_phi: # Ensure that phi exists for all modules; choose prior mean if not found for i, m in enumerate(modelspec): if not m.get('phi'): log.debug('Phi not found for module, using mean of prior: %s', m) m = nems.priors.set_mean_phi([m])[0] # Inits phi for 1 module modelspec[i] = m ms.fit_mode_on(modelspec) # Create the mapper object that translates to and from modelspecs. # It has two methods that, when defined as mathematical functions, are: # .pack(modelspec) -> fitspace_point # .unpack(fitspace_point) -> modelspec packer, unpacker = mapper(modelspec) # A function to evaluate the modelspec on the data evaluator = nems.modelspec.evaluate my_cost_function = cost_function my_cost_function.counter = 0 # Freeze everything but sigma, since that's all the fitter should be # updating. cost_fn = partial(my_cost_function, unpacker=unpacker, modelspec=modelspec, data=data, segmentor=segmentor, evaluator=evaluator, metric=metric) # get initial sigma value representing some point in the fit space sigma = packer(modelspec) # Results should be a list of modelspecs # (might only be one in list, but still should be packaged as a list) improved_sigma = fitter(sigma, cost_fn, **fit_kwargs) improved_modelspec = unpacker(improved_sigma) elapsed_time = (time.time() - start_time) # TODO: Should this maybe be moved to a higher level # so it applies to ALL the fittters? ms.fit_mode_off(improved_modelspec) ms.set_modelspec_metadata(improved_modelspec, 'fitter', metaname) ms.set_modelspec_metadata(improved_modelspec, 'fit_time', elapsed_time) results = [copy.deepcopy(improved_modelspec)] return results
def fit_basic(data, modelspec, fitter=scipy_minimize, cost_function=None, segmentor=nems.segmentors.use_all_data, mapper=nems.fitters.mappers.simple_vector, metric=None, metaname='fit_basic', fit_kwargs={}, require_phi=True): ''' Required Arguments: data A recording object modelspec A modelspec object Optional Arguments: fitter A function of (sigma, costfn) that tests various points, in fitspace (i.e. sigmas) using the cost function costfn, and hopefully returns a better sigma after some time. mapper A class that has two methods, pack and unpack, which define the mapping between modelspecs and a fitter's fitspace. segmentor An function that selects a subset of the data during the fitting process. This is NOT the same as est/val data splits metric A function of a Recording that returns an error value that is to be minimized. Returns A list containing a single modelspec, which has the best parameters found by this fitter. ''' start_time = time.time() modelspec = copy.deepcopy(modelspec) output_name = modelspec.meta.get('output_name', 'resp') if metric is None: metric = lambda data: metrics.nmse(data, 'pred', output_name) if cost_function is None: # Use the cost function defined in this module by default cost_function = basic_cost if require_phi: # Ensure that phi exists for all modules; # choose prior mean if not found for i, m in enumerate(modelspec.modules): if ('phi' not in m.keys()) and ('prior' in m.keys()): log.debug('Phi not found for module, using mean of prior: %s', m) m = nems.priors.set_mean_phi([m])[0] # Inits phi for 1 module modelspec[i] = m # apply mask to remove invalid portions of signals and allow fit to # only evaluate the model on the valid portion of the signals if 'mask' in data.signals.keys(): log.info("Data len pre-mask: %d", data['mask'].shape[1]) data = data.apply_mask() log.info("Data len post-mask: %d", data['mask'].shape[1]) # turn on "fit mode". currently this serves one purpose, for normalization # parameters to be re-fit for the output of each module that uses # normalization. does nothing if normalization is not being used. ms.fit_mode_on(modelspec, data) # Create the mapper functions that translates to and from modelspecs. # It has three functions that, when defined as mathematical functions, are: # .pack(modelspec) -> fitspace_point # .unpack(fitspace_point) -> modelspec # .bounds(modelspec) -> fitspace_bounds packer, unpacker, pack_bounds = mapper(modelspec) # A function to evaluate the modelspec on the data evaluator = nems.modelspec.evaluate my_cost_function = cost_function my_cost_function.counter = 0 # Freeze everything but sigma, since that's all the fitter should be # updating. cost_fn = partial(my_cost_function, unpacker=unpacker, modelspec=modelspec, data=data, segmentor=segmentor, evaluator=evaluator, metric=metric) # get initial sigma value representing some point in the fit space, # and corresponding bounds for each value sigma = packer(modelspec) bounds = pack_bounds(modelspec) # Results should be a list of modelspecs # (might only be one in list, but still should be packaged as a list) improved_sigma = fitter(sigma, cost_fn, bounds=bounds, **fit_kwargs) improved_modelspec = unpacker(improved_sigma) elapsed_time = (time.time() - start_time) start_err = cost_fn(sigma) final_err = cost_fn(improved_sigma) log.info("Delta error: %.06f - %.06f = %e", start_err, final_err, final_err-start_err) # TODO: Should this maybe be moved to a higher level # so it applies to ALL the fittters? ms.fit_mode_off(improved_modelspec) ms.set_modelspec_metadata(improved_modelspec, 'fitter', metaname) ms.set_modelspec_metadata(improved_modelspec, 'n_parms', len(improved_sigma)) if modelspec.fit_count == 1: improved_modelspec.meta['fit_time'] = elapsed_time improved_modelspec.meta['loss'] = final_err else: fit_index = modelspec.fit_index if fit_index == 0: improved_modelspec.meta['fit_time'] = np.zeros(improved_modelspec.fit_count) improved_modelspec.meta['loss'] = np.zeros(improved_modelspec.fit_count) improved_modelspec.meta['fit_time'][fit_index] = elapsed_time improved_modelspec.meta['loss'][fit_index] = final_err if type(improved_modelspec) is list: return [copy.deepcopy(improved_modelspec)] else: return improved_modelspec.copy()
def init_dexp(rec, modelspec): """ choose initial values for dexp applied after preceeding fir is initialized """ # preserve input modelspec modelspec = copy.deepcopy(modelspec) target_i = find_module('double_exponential', modelspec) if target_i is None: log.warning("No dexp module was found, can't initialize.") return modelspec if target_i == len(modelspec): fit_portion = modelspec else: fit_portion = modelspec[:target_i] # ensures all previous modules have their phi initialized # choose prior mean if not found for i, m in enumerate(fit_portion): if ('phi' not in m.keys()) and ('prior' in m.keys()): log.debug('Phi not found for module, using mean of prior: %s', m) m = priors.set_mean_phi([m])[0] # Inits phi for 1 module fit_portion[i] = m # generate prediction from module preceeding dexp ms.fit_mode_on(fit_portion) rec = ms.evaluate(rec, fit_portion) ms.fit_mode_off(fit_portion) in_signal = modelspec[target_i]['fn_kwargs']['i'] pchans = rec[in_signal].shape[0] amp = np.zeros([pchans, 1]) base = np.zeros([pchans, 1]) kappa = np.zeros([pchans, 1]) shift = np.zeros([pchans, 1]) for i in range(pchans): resp = rec['resp'].as_continuous() pred = rec[in_signal].as_continuous()[i:(i + 1), :] if resp.shape[0] == pchans: resp = resp[i:(i + 1), :] keepidx = np.isfinite(resp) * np.isfinite(pred) resp = resp[keepidx] pred = pred[keepidx] # choose phi s.t. dexp starts as almost a straight line # phi=[max_out min_out slope mean_in] # meanr = np.nanmean(resp) stdr = np.nanstd(resp) # base = np.max(np.array([meanr - stdr * 4, 0])) base[i, 0] = np.min(resp) # base = meanr - stdr * 3 # amp = np.max(resp) - np.min(resp) amp[i, 0] = stdr * 3 shift[i, 0] = np.mean(pred) # shift = (np.max(pred) + np.min(pred)) / 2 predrange = 2 / (np.max(pred) - np.min(pred) + 1) kappa[i, 0] = np.log(predrange) modelspec[target_i]['phi'] = { 'amplitude': amp, 'base': base, 'kappa': kappa, 'shift': shift } log.info("Init dexp: %s", modelspec[target_i]['phi']) return modelspec
def fit_population_iteratively( est, modelspec, cost_function=basic_cost, fitter=coordinate_descent, evaluator=ms.evaluate, segmentor=nems.segmentors.use_all_data, mapper=nems.fitters.mappers.simple_vector, metric=lambda data: nems.metrics.api.nmse(data, 'pred', 'resp'), metaname='fit_basic', fit_kwargs={}, module_sets=None, invert=False, tolerances=None, tol_iter=50, fit_iter=10, IsReload=False, **context ): ''' Required Arguments: est A recording object modelspec A modelspec object Optional Arguments: TODO: need to deal with the fact that you can't pass functions in an xforms-frieldly fucntion fitter (CURRENTLY NOT USED?) A function of (sigma, costfn) that tests various points, in fitspace (i.e. sigmas) using the cost function costfn, and hopefully returns a better sigma after some time. mapper (CURRENTLY NOT USED?) A class that has two methods, pack and unpack, which define the mapping between modelspecs and a fitter's fitspace. segmentor (CURRENTLY NOT USED?) An function that selects a subset of the data during the fitting process. This is NOT the same as est/val data splits metric A function of a Recording that returns an error value that is to be minimized. module_sets (CURRENTLY NOT USED?) A nested list specifying which model indices should be fit. Overall iteration will occurr len(module_sets) many times. ex: [[0], [1, 3], [0, 1, 2, 3]] invert (CURRENTLY NOT USED?) Boolean. Causes module_sets to specify the model indices that should *not* be fit. Returns A list containing a single modelspec, which has the best parameters found by this fitter. ''' if IsReload: return {} modelspec = copy.deepcopy(modelspec) data = est.copy() fit_set_all, fit_set_slice = _figure_out_mod_split(modelspec) if tolerances is None: tolerances = [1e-4, 1e-5] # apply mask to remove invalid portions of signals and allow fit to # only evaluate the model on the valid portion of the signals # then delete the mask signal so that it's not reapplied on each fit if 'mask' in data.signals.keys(): log.info("Data len pre-mask: %d", data['mask'].shape[1]) data = data.apply_mask() log.info("Data len post-mask: %d", data['mask'].shape[1]) del data.signals['mask'] start_time = time.time() ms.fit_mode_on(modelspec, data) # modelspec = init_pop_pca(data, modelspec) # print(modelspec) # Ensure that phi exists for all modules; choose prior mean if not found # for i, m in enumerate(modelspec): # if ('phi' not in m.keys()) and ('prior' in m.keys()): # m = nems.priors.set_mean_phi([m])[0] # Inits phi for 1 module # log.debug('Phi not found for module, using mean of prior: {}' # .format(m)) # modelspec[i] = m error = np.inf slice_count = data['resp'].shape[0] step_size = 0.1 if 'nonlinearity' in modelspec[-1]['fn']: skip_nl_first = True tolerances = [tolerances[0]] + tolerances else: skip_nl_first = False for toli, tol in enumerate(tolerances): log.info("Fitting subsets with tol: %.2E fit_iter %d tol_iter %d", tol, fit_iter, tol_iter) cd_kwargs = fit_kwargs.copy() cd_kwargs.update({'tolerance': tol, 'max_iter': fit_iter, 'step_size': step_size}) sp_kwargs = fit_kwargs.copy() sp_kwargs.update({'tolerance': tol, 'max_iter': fit_iter}) if (toli == 0) and skip_nl_first: log.info('skipping nl on first tolerance loop') saved_modelspec = copy.deepcopy(modelspec) saved_fit_set_slice = fit_set_slice.copy() # import pdb; # pdb.set_trace() modelspec.pop_module() fit_set_slice = fit_set_slice[:-1] inner_i = 0 error_reduction = np.inf # big_slice = 0 # big_n = data['resp'].ntimes # big_step = int(big_n/10) # big_slice_size = int(big_n/2) while (error_reduction >= tol) and (inner_i < tol_iter): log.info("(%d) Tol %.2e: Loop %d/%d (max)", toli, tol, inner_i, tol_iter) improved_modelspec = copy.deepcopy(modelspec) cc = 0 slist = list(range(slice_count)) # random.shuffle(slist) for i, m in enumerate(modelspec): if i in fit_set_all: log.info(m['fn'] + ": fitting") else: log.info(m['fn'] + ": frozen") # partially implemented: select temporal subset of data for fitting # on current loop. # data2 = data.copy() # big_slice += 1 # sl = np.zeros(big_n, dtype=bool) # sl[:big_slice_size]=True # sl = np.roll(sl, big_step*big_slice) # log.info('Sampling temporal subset %d (size=%d/%d)', big_step, big_slice_size, big_n) # for s in data2.signals.values(): # e = s._modified_copy(s._data[:,sl]) # data2[e.name] = e # improved_modelspec = init.prefit_mod_subset( # data, improved_modelspec, analysis.fit_basic, # metric=metric, # fit_set=fit_set_all, # fit_kwargs=sp_kwargs) improved_modelspec = fit_population_channel_fast2( data, improved_modelspec, fit_set_all, fit_set_slice, analysis_function=analysis.fit_basic, metric=metric, fitter=scipy_minimize, fit_kwargs=sp_kwargs) for s in slist: log.info('Slice %d set %s' % (s, fit_set_slice)) improved_modelspec = fit_population_slice( data, improved_modelspec, slice=s, fit_set=fit_set_slice, analysis_function=analysis.fit_basic, metric=metric, fitter=scipy_minimize, fit_kwargs=sp_kwargs) # fitter = coordinate_descent, # fit_kwargs = cd_kwargs) cc += 1 # if (cc % 8 == 0) or (cc == slice_count): data = ms.evaluate(data, improved_modelspec) new_error = metric(data) error_reduction = error - new_error error = new_error log.info("tol=%.2E, iter=%d/%d: deltaE=%.6E", tol, inner_i, tol_iter, error_reduction) inner_i += 1 if error_reduction > 0: modelspec = improved_modelspec log.info("Done with tol %.2E (i=%d, max_error_reduction %.7f)", tol, inner_i, error_reduction) if (toli == 0) and skip_nl_first: log.info('Restoring NL module after first tol loop') modelspec.append(saved_modelspec[-1]) fit_set_slice = saved_fit_set_slice if 'double_exponential' in saved_modelspec[-1]['fn']: modelspec = init.init_dexp(data, modelspec) elif 'logistic_sigmoid' in saved_modelspec[-1]['fn']: modelspec = init.init_logsig(data, modelspec) elif 'relu' in saved_modelspec[-1]['fn']: # just keep initialized to zero pass else: raise ValueError("Output NL %s not supported", saved_modelspec[-1]['fn']) # just fit the NL improved_modelspec = copy.deepcopy(modelspec) kwa = cd_kwargs.copy() kwa['max_iter'] *= 2 for s in range(slice_count): log.info('Slice %d set %s' % (s, [fit_set_slice[-1]])) improved_modelspec = fit_population_slice( data, improved_modelspec, slice=s, fit_set=fit_set_slice, analysis_function=analysis.fit_basic, metric=metric, fitter=scipy_minimize, fit_kwargs=sp_kwargs) # fitter = coordinate_descent, # fit_kwargs = cd_kwargs) data = ms.evaluate(data, modelspec) old_error = metric(data) data = ms.evaluate(data, improved_modelspec) new_error = metric(data) log.info('Init NL fit error change %.5f-%.5f = %.5f', old_error, new_error, old_error-new_error) modelspec = improved_modelspec else: step_size *= 0.25 elapsed_time = (time.time() - start_time) # TODO: Should this maybe be moved to a higher level # so it applies to ALL the fittters? ms.fit_mode_off(improved_modelspec) ms.set_modelspec_metadata(improved_modelspec, 'fitter', metaname) ms.set_modelspec_metadata(improved_modelspec, 'fit_time', elapsed_time) return {'modelspec': improved_modelspec.copy()}
def init_dexp(rec, modelspec): """ choose initial values for dexp applied after preceeding fir is initialized """ # preserve input modelspec modelspec = copy.deepcopy(modelspec) target_i = find_module('double_exponential', modelspec) if target_i is None: log.warning("No dexp module was found, can't initialize.") return modelspec if target_i == len(modelspec): fit_portion = modelspec else: fit_portion = modelspec[:target_i] # generate prediction from module preceeding dexp ms.fit_mode_on(fit_portion) rec = ms.evaluate(rec, fit_portion) ms.fit_mode_off(fit_portion) pchans = rec['pred'].shape[0] amp = np.zeros([pchans, 1]) base = np.zeros([pchans, 1]) kappa = np.zeros([pchans, 1]) shift = np.zeros([pchans, 1]) for i in range(pchans): resp = rec['resp'].as_continuous() pred = rec['pred'].as_continuous()[i:(i + 1), :] keepidx = np.isfinite(resp) * np.isfinite(pred) resp = resp[keepidx] pred = pred[keepidx] # choose phi s.t. dexp starts as almost a straight line # phi=[max_out min_out slope mean_in] # meanr = np.nanmean(resp) stdr = np.nanstd(resp) # base = np.max(np.array([meanr - stdr * 4, 0])) base[i, 0] = np.min(resp) # base = meanr - stdr * 3 # amp = np.max(resp) - np.min(resp) amp[i, 0] = stdr * 3 shift[i, 0] = np.mean(pred) # shift = (np.max(pred) + np.min(pred)) / 2 predrange = 2 / (np.max(pred) - np.min(pred) + 1) kappa[i, 0] = np.log(predrange) modelspec[target_i]['phi'] = { 'amplitude': amp, 'base': base, 'kappa': kappa, 'shift': shift } log.info("Init dexp: %s", modelspec[target_i]['phi']) return modelspec
def fit_iteratively( data, modelspec, cost_function=basic_cost, fitter=coordinate_descent, evaluator=ms.evaluate, segmentor=nems.segmentors.use_all_data, mapper=nems.fitters.mappers.simple_vector, metric=lambda data: nems.metrics.api.nmse(data, 'pred', 'resp'), metaname='fit_basic', fit_kwargs={}, module_sets=None, invert=False, tolerances=None, tol_iter=50, fit_iter=10, ): ''' Required Arguments: data A recording object modelspec A modelspec object Optional Arguments: fitter A function of (sigma, costfn) that tests various points, in fitspace (i.e. sigmas) using the cost function costfn, and hopefully returns a better sigma after some time. mapper A class that has two methods, pack and unpack, which define the mapping between modelspecs and a fitter's fitspace. segmentor An function that selects a subset of the data during the fitting process. This is NOT the same as est/val data splits metric A function of a Recording that returns an error value that is to be minimized. module_sets A nested list specifying which model indices should be fit. Overall iteration will occurr len(module_sets) many times. ex: [[0], [1, 3], [0, 1, 2, 3]] invert Boolean. Causes module_sets to specify the model indices that should *not* be fit. Returns A list containing a single modelspec, which has the best parameters found by this fitter. ''' if module_sets is None: module_sets = [] for i, m in enumerate(modelspec): if 'prior' in m.keys(): if 'levelshift' in m['fn'] and 'fir' in modelspec[i - 1]['fn']: # group levelshift with preceding fir filter by default module_sets[-1].append(i) else: # otherwise just fit each module separately module_sets.append([i]) log.info('Fit sets: %s', module_sets) if tolerances is None: tolerances = [1e-6] # apply mask to remove invalid portions of signals and allow fit to # only evaluate the model on the valid portion of the signals if 'mask' in data.signals.keys(): log.info("Data len pre-mask: %d", data['mask'].shape[1]) data = data.apply_mask() log.info("Data len post-mask: %d", data['mask'].shape[1]) start_time = time.time() ms.fit_mode_on(modelspec) # Ensure that phi exists for all modules; choose prior mean if not found for i, m in enumerate(modelspec): if ('phi' not in m.keys()) and ('prior' in m.keys()): m = nems.priors.set_mean_phi([m])[0] # Inits phi for 1 module log.debug( 'Phi not found for module, using mean of prior: {}'.format(m)) modelspec[i] = m error = np.inf for tol in tolerances: log.info("Fitting subsets with tol: %.2E fit_iter %d tol_iter %d", tol, fit_iter, tol_iter) fit_kwargs.update({'tolerance': tol, 'max_iter': fit_iter}) max_error_reduction = np.inf i = 0 while (max_error_reduction >= tol) and (i < tol_iter): max_error_reduction = 0 j = 0 for subset in module_sets: improved_modelspec = _module_set_loop(subset, data, modelspec, cost_function, fitter, mapper, segmentor, evaluator, metric, fit_kwargs) new_error = cost_function.error error_reduction = error - new_error error = new_error j += 1 if error_reduction > max_error_reduction: max_error_reduction = error_reduction log.info("tol=%.2E, iter=%d/%d: max deltaE=%.6E", tol, i, tol_iter, max_error_reduction) i += 1 log.info("Done with tol %.2E (i=%d, max_error_reduction %.7f)", tol, i, error_reduction) elapsed_time = (time.time() - start_time) # TODO: Should this maybe be moved to a higher level # so it applies to ALL the fittters? ms.fit_mode_off(improved_modelspec) ms.set_modelspec_metadata(improved_modelspec, 'fitter', metaname) ms.set_modelspec_metadata(improved_modelspec, 'fit_time', elapsed_time) results = [copy.deepcopy(improved_modelspec)] return results
def fit_pcnorm(modelspec, est: recording.Recording, metric=None, use_modelspec_init: bool = True, optimizer: str = 'adam', max_iter: int = 10000, early_stopping_steps: int = 5, tolerance: float = 5e-4, learning_rate: float = 1e-4, batch_size: typing.Union[None, int] = None, seed: int = 0, initializer: str = 'random_normal', freeze_layers: typing.Union[None, list] = None, epoch_name: str = "REFERENCE", n_pcs=2, **context): ''' Required Arguments: est A recording object modelspec A modelspec object Optional Arguments: <copied from fit_tf for now Returns dictionary: {'modelspec': updated_modelspec} ''' # Hard-coded cost_function = basic_cost fitter = scipy_minimize segmentor = nems.segmentors.use_all_data mapper = nems.fitters.mappers.simple_vector fit_kwargs = {'tolerance': tolerance, 'max_iter': max_iter} start_time = time.time() modelspec = copy.deepcopy(modelspec) # apply mask to remove invalid portions of signals and allow fit to # only evaluate the model on the valid portion of the signals if 'mask' in est.signals.keys(): log.info("Data len pre-mask: %d", est['mask'].shape[1]) est = est.apply_mask() log.info("Data len post-mask: %d", est['mask'].shape[1]) conditions = [ "_".join(k.split("_")[1:]) for k in est.signals.keys() if k.startswith("mask_") ] if (len(conditions) > 2) and any( [c.split("_")[-1] == 'lg' for c in conditions]): conditions.remove("small") conditions.remove("large") #conditions = conditions[0:2] #conditions = ['large','small'] group_idx = [est['mask_' + c].as_continuous()[0, :] for c in conditions] cg_filtered = [(c, g) for c, g in zip(conditions, group_idx) if g.sum() > 0] conditions, group_idx = zip(*cg_filtered) for c, g in zip(conditions, group_idx): log.info(f"Data subset for {c} len {g.sum()}") resp = est['resp'].as_continuous() pred0 = est['pred0'].as_continuous() residual = resp - pred0 pca = PCA(n_components=n_pcs) pca.fit(residual.T) pc_axes = pca.components_ pcproj = residual.T.dot(pc_axes.T).T group_pc = [pcproj[:, idx].std(axis=1) for idx in group_idx] resp_std = resp.std(axis=1) #import pdb; pdb.set_trace() if metric is None: metric = lambda d: pc_err(d, pred_name='pred', pred0_name='pred0', group_idx=group_idx, group_pc=group_pc, pc_axes=pc_axes, resp_std=resp_std) # turn on "fit mode". currently this serves one purpose, for normalization # parameters to be re-fit for the output of each module that uses # normalization. does nothing if normalization is not being used. ms.fit_mode_on(modelspec, est) # Create the mapper functions that translates to and from modelspecs. # It has three functions that, when defined as mathematical functions, are: # .pack(modelspec) -> fitspace_point # .unpack(fitspace_point) -> modelspec # .bounds(modelspec) -> fitspace_bounds packer, unpacker, pack_bounds = mapper(modelspec) # A function to evaluate the modelspec on the data evaluator = nems.modelspec.evaluate my_cost_function = cost_function my_cost_function.counter = 0 # Freeze everything but sigma, since that's all the fitter should be # updating. cost_fn = partial(my_cost_function, unpacker=unpacker, modelspec=modelspec, data=est, segmentor=segmentor, evaluator=evaluator, metric=metric, display_N=1000) # get initial sigma value representing some point in the fit space, # and corresponding bounds for each value sigma = packer(modelspec) bounds = pack_bounds(modelspec) # Results should be a list of modelspecs # (might only be one in list, but still should be packaged as a list) improved_sigma = fitter(sigma, cost_fn, bounds=bounds, **fit_kwargs) improved_modelspec = unpacker(improved_sigma) elapsed_time = (time.time() - start_time) start_err = cost_fn(sigma) final_err = cost_fn(improved_sigma) log.info("Delta error: %.06f - %.06f = %e", start_err, final_err, final_err - start_err) # TODO: Should this maybe be moved to a higher level # so it applies to ALL the fittters? ms.fit_mode_off(improved_modelspec) ms.set_modelspec_metadata(improved_modelspec, 'fitter', 'ccnorm') ms.set_modelspec_metadata(improved_modelspec, 'fit_time', elapsed_time) ms.set_modelspec_metadata(improved_modelspec, 'n_parms', len(improved_sigma)) return {'modelspec': improved_modelspec.copy(), 'save_context': True}
def fit_ccnorm(modelspec, est: recording.Recording, metric=None, use_modelspec_init: bool = True, optimizer: str = 'adam', max_iter: int = 10000, early_stopping_steps: int = 5, tolerance: float = 5e-4, learning_rate: float = 1e-4, batch_size: typing.Union[None, int] = None, seed: int = 0, initializer: str = 'random_normal', freeze_layers: typing.Union[None, list] = None, epoch_name: str = "REFERENCE", shrink_cc: float = 0, noise_pcs: int = 0, shared_pcs: int = 0, also_fit_resp: bool = False, force_psth: bool = False, use_metric: typing.Union[None, str] = None, alpha: float = 0.1, beta: float = 1, exclude_idx=None, exclude_after=None, freeze_idx=None, freeze_after=None, **context): ''' Required Arguments: est A recording object modelspec A modelspec object Optional Arguments: <copied from fit_tf for now Returns dictionary: {'modelspec': updated_modelspec} ''' # Hard-coded cost_function = basic_cost fitter = scipy_minimize segmentor = nems.segmentors.use_all_data mapper = nems.fitters.mappers.simple_vector fit_kwargs = {'tolerance': tolerance, 'max_iter': max_iter} start_time = time.time() fit_index = modelspec.fit_index if (exclude_idx is not None) | (freeze_idx is not None) | \ (exclude_after is not None) | (freeze_after is not None): modelspec0 = modelspec.copy() modelspec, include_set = modelspec_freeze_layers( modelspec, include_idx=None, exclude_idx=exclude_idx, exclude_after=exclude_after, freeze_idx=freeze_idx, freeze_after=freeze_after) modelspec0.set_fit(fit_index) modelspec.set_fit(fit_index) else: include_set = None # Computing PCs before masking out unwanted stimuli in order to # preserve match with epochs epoch_regex = "^STIM_" stims = (est.epochs['name'].value_counts() >= 8) stims = [ stims.index[i] for i, s in enumerate(stims) if bool(re.search(epoch_regex, stims.index[i])) and s == True ] Rall_u = est.apply_mask()['psth'].as_continuous().T # can't simply extract evoked for refs because can be longer/shorted if it came after target # and / or if it was the last stim. So, masking prestim / postim doesn't work. Do it manually #d = est['resp'].extract_epochs(stims, mask=est['mask']) #R = [v.mean(axis=0) for (k, v) in d.items()] #R = [np.reshape(np.transpose(v,[1,0,2]),[v.shape[1],-1]) for (k, v) in d.items()] #Rall_u = np.hstack(R).T pca = PCA(n_components=2) pca.fit(Rall_u) pc_axes = pca.components_ # apply mask to remove invalid portions of signals and allow fit to # only evaluate the model on the valid portion of the signals if 'mask_small' in est.signals.keys(): log.info('reseting mask with mask_small+mask_large subset') est['mask'] = est['mask']._modified_copy(data=est['mask_small']._data + est['mask_large']._data) if 'mask' in est.signals.keys(): log.info("Data len pre-mask: %d", est['mask'].shape[1]) est = est.apply_mask() log.info("Data len post-mask: %d", est['mask'].shape[1]) # if we want to fit to first-order cc error. #uncomment this and make sure sdexp is generating a pred0 signal est = modelspec.evaluate(est, stop=2) if ('pred0' in est.signals.keys()) & (not force_psth): input_name = 'pred0' log.info('Found pred0 for fitting CC') else: input_name = 'psth' log.info('No pred0, using psth for fitting CC') conditions = [ "_".join(k.split("_")[1:]) for k in est.signals.keys() if k.startswith("mask_") ] if (len(conditions) > 2) and any( [c.split("_")[-1] == 'lg' for c in conditions]): conditions.remove("small") conditions.remove("large") #conditions = conditions[0:2] #conditions = ['large','small'] group_idx = [est['mask_' + c].as_continuous()[0, :] for c in conditions] cg_filtered = [(c, g) for c, g in zip(conditions, group_idx) if g.sum() > 0] conditions, group_idx = zip(*cg_filtered) for c, g in zip(conditions, group_idx): log.info(f"cc data for {c} len {g.sum()}") resp = est['resp'].as_continuous() pred0 = est[input_name].as_continuous() #import pdb; pdb.set_trace() if shrink_cc > 0: log.info(f'cc approx: shrink_cc={shrink_cc}') group_cc = [ cc_shrink(resp[:, idx] - pred0[:, idx], sigrat=shrink_cc) for idx in group_idx ] elif shared_pcs > 0: log.info(f'cc approx: shared_pcs={shared_pcs}') cc = np.cov(resp - pred0) u, s, vh = np.linalg.svd(cc) U = u[:, :shared_pcs] @ u[:, :shared_pcs].T group_cc = [ cc_shared_space(resp[:, idx] - pred0[:, idx], U) for idx in group_idx ] elif noise_pcs > 0: log.info(f'cc approx: noise_pcs={noise_pcs}') group_cc = [ cc_lowrank(resp[:, idx] - pred0[:, idx], n_pcs=noise_pcs) for idx in group_idx ] else: group_cc = [np.cov(resp[:, idx] - pred0[:, idx]) for idx in group_idx] group_cc_raw = [np.cov(resp[:, idx] - pred0[:, idx]) for idx in group_idx] # variance of projection onto PCs (PCs computed above before masking) pcproj0 = (resp - pred0).T.dot(pc_axes.T).T pcproj_std = pcproj0.std(axis=1) if (use_metric == 'cc_err_w'): def metric(d, verbose=False): return metrics.cc_err_w(d, pred_name='pred', pred0_name=input_name, group_idx=group_idx, group_cc=group_cc, alpha=alpha, pcproj_std=None, pc_axes=None, verbose=verbose) log.info(f"fit_ccnorm metric: cc_err_w (alpha={alpha})") elif (metric is None) and also_fit_resp: log.info(f"resp_cc_err: pred0_name: {input_name} beta: {beta}") metric = lambda d: metrics.resp_cc_err(d, pred_name='pred', pred0_name=input_name, group_idx=group_idx, group_cc=group_cc, beta=beta, pcproj_std=None, pc_axes=None) elif (use_metric == 'cc_err_md'): def metric(d, verbose=False): return metrics.cc_err_md(d, pred_name='pred', pred0_name=input_name, group_idx=group_idx, group_cc=group_cc, pcproj_std=None, pc_axes=None) log.info(f"fit_ccnorm metric: cc_err_md") elif (metric is None): #def cc_err(result, pred_name='pred_lv', resp_name='resp', pred0_name='pred', # group_idx=None, group_cc=None, pcproj_std=None, pc_axes=None): # current implementation of cc_err metric = lambda d: metrics.cc_err(d, pred_name='pred', pred0_name=input_name, group_idx=group_idx, group_cc=group_cc, pcproj_std=None, pc_axes=None) log.info(f"fit_ccnorm metric: cc_err") # turn on "fit mode". currently this serves one purpose, for normalization # parameters to be re-fit for the output of each module that uses # normalization. does nothing if normalization is not being used. ms.fit_mode_on(modelspec, est) # Create the mapper functions that translates to and from modelspecs. # It has three functions that, when defined as mathematical functions, are: # .pack(modelspec) -> fitspace_point # .unpack(fitspace_point) -> modelspec # .bounds(modelspec) -> fitspace_bounds packer, unpacker, pack_bounds = mapper(modelspec) # A function to evaluate the modelspec on the data evaluator = nems.modelspec.evaluate my_cost_function = cost_function my_cost_function.counter = 0 # Freeze everything but sigma, since that's all the fitter should be # updating. cost_fn = partial(my_cost_function, unpacker=unpacker, modelspec=modelspec, data=est, segmentor=segmentor, evaluator=evaluator, metric=metric, display_N=1000) # get initial sigma value representing some point in the fit space, # and corresponding bounds for each value sigma = packer(modelspec) bounds = pack_bounds(modelspec) # Results should be a list of modelspecs # (might only be one in list, but still should be packaged as a list) improved_sigma = fitter(sigma, cost_fn, bounds=bounds, **fit_kwargs) improved_modelspec = unpacker(improved_sigma) elapsed_time = (time.time() - start_time) start_err = cost_fn(sigma) final_err = cost_fn(improved_sigma) log.info("Delta error: %.06f - %.06f = %e", start_err, final_err, final_err - start_err) # TODO: Should this maybe be moved to a higher level # so it applies to ALL the fittters? ms.fit_mode_off(improved_modelspec) if include_set is not None: # pull out updated phi values from improved_modelspec, include_set only improved_modelspec = \ modelspec_unfreeze_layers(improved_modelspec, modelspec0, include_set) improved_modelspec.set_fit(fit_index) log.info( f"Updating improved modelspec with fit_idx={improved_modelspec.fit_index}" ) improved_modelspec.meta['fitter'] = 'ccnorm' improved_modelspec.meta['n_parms'] = len(improved_sigma) if modelspec.fit_count == 1: improved_modelspec.meta['fit_time'] = elapsed_time improved_modelspec.meta['loss'] = final_err else: if modelspec.fit_index == 0: improved_modelspec.meta['fit_time'] = np.zeros( improved_modelspec.fit_count) improved_modelspec.meta['loss'] = np.zeros( improved_modelspec.fit_count) improved_modelspec.meta['fit_time'][fit_index] = elapsed_time improved_modelspec.meta['loss'][fit_index] = final_err return {'modelspec': improved_modelspec}
def _init_double_exponential(rec, modelspec, target_i): if target_i == len(modelspec): fit_portion = modelspec.modules else: fit_portion = modelspec.modules[:target_i] # generate prediction from modules preceeding dsig # HACK for i, m in enumerate(fit_portion): if not m.get('phi', None): old = m.get('prior', {}) m = priors.set_mean_phi([m])[0] m['prior'] = old fit_portion[i] = m ms.fit_mode_on(fit_portion) rec = ms.evaluate(rec, fit_portion) ms.fit_mode_off(fit_portion) in_signal = modelspec[target_i]['fn_kwargs']['i'] pchans = rec[in_signal].shape[0] amp = np.zeros([pchans, 1]) base = np.zeros([pchans, 1]) kappa = np.zeros([pchans, 1]) shift = np.zeros([pchans, 1]) for i in range(pchans): resp = rec['resp'].as_continuous() pred = rec[in_signal].as_continuous()[i:(i + 1), :] if resp.shape[0] == pchans: resp = resp[i:(i + 1), :] keepidx = np.isfinite(resp) * np.isfinite(pred) resp = resp[keepidx] pred = pred[keepidx] # choose phi s.t. dexp starts as almost a straight line # phi=[max_out min_out slope mean_in] # meanr = np.nanmean(resp) stdr = np.nanstd(resp) # base = np.max(np.array([meanr - stdr * 4, 0])) base[i, 0] = np.min(resp) # base = meanr - stdr * 3 # amp = np.max(resp) - np.min(resp) amp[i, 0] = stdr * 3 shift[i, 0] = np.mean(pred) # shift = (np.max(pred) + np.min(pred)) / 2 predrange = 2 / (np.max(pred) - np.min(pred) + 1) kappa[i, 0] = np.log(predrange) amp = ('Normal', {'mean': amp, 'sd': 1.0}) base = ('Normal', {'mean': base, 'sd': 1.0}) kappa = ('Normal', {'mean': kappa, 'sd': 1.0}) shift = ('Normal', {'mean': shift, 'sd': 1.0}) modelspec[target_i]['prior'].update({ 'base': base, 'amplitude': amp, 'shift': shift, 'kappa': kappa, }) return modelspec
def _init_logistic_sigmoid(rec, modelspec, dsig_idx): if dsig_idx == len(modelspec): fit_portion = modelspec.modules else: fit_portion = modelspec.modules[:dsig_idx] # generate prediction from module preceeding dexp # HACK to get phi for ctwc, ctfir, ctlvl which have not been prefit yet for i, m in enumerate(fit_portion): if not m.get('phi', None): if [k in m['id'] for k in ['ctwc', 'ctfir', 'ctlvl']]: old = m.get('prior', {}) m = priors.set_mean_phi([m])[0] m['prior'] = old fit_portion[i] = m else: log.warning( "unexpected module missing phi during init step\n:" "%s, #%d", m['id'], i) ms.fit_mode_on(fit_portion) rec = ms.evaluate(rec, fit_portion) ms.fit_mode_off(fit_portion) pred = rec['pred'].as_continuous() resp = rec['resp'].as_continuous() mean_pred = np.nanmean(pred) min_pred = np.nanmean(pred) - np.nanstd(pred) * 3 max_pred = np.nanmean(pred) + np.nanstd(pred) * 3 if min_pred < 0: min_pred = 0 mean_pred = (min_pred + max_pred) / 2 pred_range = max_pred - min_pred min_resp = max(np.nanmean(resp) - np.nanstd(resp) * 3, 0) # must be >= 0 max_resp = np.nanmean(resp) + np.nanstd(resp) * 3 resp_range = max_resp - min_resp # Rather than setting a hard value for initial phi, # set the prior distributions and let the fitter/analysis # decide how to use it. base0 = min_resp + 0.05 * (resp_range) amplitude0 = resp_range shift0 = mean_pred kappa0 = pred_range log.info("Initial base,amplitude,shift,kappa=({}, {}, {}, {})".format( base0, amplitude0, shift0, kappa0)) base = ('Exponential', {'beta': base0}) amplitude = ('Exponential', {'beta': amplitude0}) shift = ('Normal', {'mean': shift0, 'sd': pred_range**2}) kappa = ('Exponential', {'beta': kappa0}) modelspec[dsig_idx]['prior'].update({ 'base': base, 'amplitude': amplitude, 'shift': shift, 'kappa': kappa, 'base_mod': base, 'amplitude_mod': amplitude, 'shift_mod': shift, 'kappa_mod': kappa }) for kw in modelspec[dsig_idx]['fn_kwargs']: if kw in ['base_mod', 'amplitude_mod', 'shift_mod', 'kappa_mod']: modelspec[dsig_idx]['prior'].pop(kw) modelspec[dsig_idx]['bounds'] = { 'base': (1e-15, None), 'amplitude': (1e-15, None), 'shift': (None, None), 'kappa': (1e-15, None), } return modelspec