def generate_prediction(est, val, modelspecs): list_val = False if type(val) is list: # ie, if jackknifing list_val = True else: # Evaluate estimation and validation data # Since ms.evaluate only does a shallow copy of rec, successive # evaluations of one rec on many modelspecs just results in a list of # different pointers to the same recording. So need to force copies of # est/val before evaluating. if len(modelspecs) == 1: # no copies needed for 1 modelspec est = [est] val = [val] else: est = [est.copy() for i, _ in enumerate(modelspecs)] val = [val.copy() for i, _ in enumerate(modelspecs)] new_est = [ms.evaluate(d, m) for m, d in zip(modelspecs, est)] new_val = [ms.evaluate(d, m) for m, d in zip(modelspecs, val)] if list_val: new_val = [recording.jackknife_inverse_merge(new_val)] return new_est, new_val
def before_and_after_scatter(rec, modelspec, idx, sig_name='pred', compare='resp', smoothing_bins=False, mod_name='Unknown', xlabel1=None, xlabel2=None, ylabel1=None, ylabel2=None): # HACK: shouldn't hardcode 'stim', might be named something else # or not present at all. Need to figure out a better solution # for special case of idx = 0 if idx == 0: # Can't have anything before index 0, so use input stimulus before = rec before_sig = rec['stim'] before.name = '**stim' else: before = ms.evaluate(rec, modelspec, start=None, stop=idx) before_sig = before[sig_name] # compute correlation for pre-module before it's over-written corr1 = nm.corrcoef(before, pred_name=sig_name, resp_name=compare) # now evaluate next module step after = ms.evaluate(before.copy(), modelspec, start=idx, stop=idx + 1) after_sig = after[sig_name] corr2 = nm.corrcoef(after, pred_name=sig_name, resp_name=compare) compare_to = rec[compare] title1 = '{} vs {} before {}'.format(sig_name, compare, mod_name) title2 = '{} vs {} after {}'.format(sig_name, compare, mod_name) # TODO: These are coming out the same, but that seems unlikely text1 = "r = {0:.5f}".format(corr1) text2 = "r = {0:.5f}".format(corr2) fn1 = partial(plot_scatter, before_sig, compare_to, title=title1, smoothing_bins=smoothing_bins, xlabel=xlabel1, ylabel=ylabel1, text=text1) fn2 = partial(plot_scatter, after_sig, compare_to, title=title2, smoothing_bins=smoothing_bins, xlabel=xlabel2, ylabel=ylabel2, text=text2) return fn1, fn2
def generate_prediction_sets(est, val, modelspecs): if type(val) is list: # ie, if jackknifing new_est = [ms.evaluate(d, m) for m, d in zip(modelspecs, est)] new_val = [ms.evaluate(d, m) for m, d in zip(modelspecs, val)] else: print('val and est must be lists') return new_est, new_val
def before_and_after(rec, modelspec, sig_name, ax=None, title=None, idx=0, channels=0, xlabel='Time', ylabel='Value', **options): ''' Plots a timeseries of specified signal just before and just after the transformation performed at some step in the modelspec. Arguments: ---------- rec : recording object The dataset to use. See nems/recording.py. modelspec : list of dicts The transformations to perform. See nems/modelspec.py. sig_name : str Specifies the signal in 'rec' to be examined. idx : int An index into the modelspec. rec[sig_name] will be plotted as it exists after step idx-1 and after step idx. Returns: -------- None ''' # HACK: shouldn't hardcode 'stim', might be named something else # or not present at all. Need to figure out a better solution # for special case of idx = 0 if idx == 0: input_name = modelspec[0]['fn_kwargs']['i'] before = rec[input_name].copy() before.name += ' before**' else: before = ms.evaluate(rec.copy(), modelspec, start=None, stop=idx)[sig_name] before.name += ' before' after = ms.evaluate(rec, modelspec, start=idx, stop=idx + 1)[sig_name].copy() after.name += ' after' timeseries_from_signals([before, after], channels=channels, xlabel=xlabel, ylabel=ylabel, ax=ax, title=title, **options)
def generate_prediction(est, val, modelspecs): if type(val) is list: # ie, if jackknifing new_est = [ms.evaluate(d, m) for m, d in zip(modelspecs, est)] new_val = [ms.evaluate(d, m) for m, d in zip(modelspecs, val)] new_val = [recording.jackknife_inverse_merge(new_val)] else: # Evaluate estimation and validation data new_est = [ms.evaluate(est, m) for m in modelspecs] new_val = [ms.evaluate(val, m) for m in modelspecs] return new_est, new_val
def _reduced_param_pred(mspec, rec, idx, param): gc_ms_no_param = mspec.copy() gc_ms_no_param[idx]['phi']['%s_mod' % param] = \ gc_ms_no_param[idx]['phi']['%s' % param].copy() pred_no_param = ms.evaluate(rec, gc_ms_no_param)['pred'].as_continuous().T return pred_no_param
def nl_scatter(rec, modelspec, idx, sig_name='pred', compare='resp', smoothing_bins=False, cursor_time=None, xlabel1=None, ylabel1=None, **options): # HACK: shouldn't hardcode 'stim', might be named something else # or not present at all. Need to figure out a better solution # for special case of idx = 0 if 'mask' in rec.signals.keys(): before = rec.apply_mask() else: before = rec.copy() if idx == 0: # Can't have anything before index 0, so use input stimulus sig_name = 'stim' before_sig = before['stim'] before.name = '**stim' else: before = ms.evaluate(before, modelspec, start=None, stop=idx) before_sig = before[sig_name] # compute correlation for pre-module before it's over-written if before[sig_name].shape[0] == 1: corr1 = nm.corrcoef(before, pred_name=sig_name, resp_name=compare) else: corr1 = 0 log.warning('corr coef expects single-dim predictions') compare_to = before[compare] module = modelspec[idx] mod_name = module['fn'].replace('nems.modules.', '').replace('.', ' ').replace('_', ' ').title() title1 = mod_name text1 = "r = {0:.5f}".format(np.mean(corr1)) ax = plot_scatter(before_sig, compare_to, title=title1, smoothing_bins=smoothing_bins, xlabel=xlabel1, ylabel=ylabel1, text=text1, module=module, **options) if cursor_time is not None: tbin = int(cursor_time * rec[sig_name].fs) x = before_sig.as_continuous()[0, tbin] ylim = ax.get_ylim() ax.plot([x, x], ylim, 'r-')
def model_per_time(ctx): """ state_colors : N x 2 list color spec for high/low lines in each of the N states """ rec = ctx['val'][0].apply_mask() modelspec = ctx['modelspecs'][0] epoch = "REFERENCE" rec = ms.evaluate(rec, modelspec) plt.figure() ax = plt.subplot(2, 1, 1) state_vars_timeseries(rec, modelspec, ax=ax) ax = plt.subplot(2, 1, 2) state_vars_psth_all(rec, epoch, psth_name='resp', psth_name2='pred', state_sig='state_raw', colors=None, channel=None, decimate_by=1, ax=ax, files_only=True, modelspec=modelspec)
def stp_sigmoid_pred_matched(cellid, batch, modelname, LN, include_phi=True): xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname) ln_spec, ln_ctx = xhelp.load_model_xform(cellid, batch, LN) modelspec = ctx['modelspec'] modelspec.recording = ctx['val'] val = ctx['val'].apply_mask() ln_modelspec = ln_ctx['modelspec'] ln_modelspec.recording = ln_ctx['val'] ln_val = ln_ctx['val'].apply_mask() pred_after_NL = val['pred'].as_continuous().flatten() # with stp val_before_NL = ms.evaluate(ln_val, ln_modelspec, stop=-1) pred_before_NL = val_before_NL['pred'].as_continuous().flatten() # no stp stp_idx = find_module('stp', modelspec) val_before_stp = ms.evaluate(val, modelspec, stop=stp_idx) val_after_stp = ms.evaluate(val, modelspec, stop=stp_idx + 1) pred_before_stp = val_before_stp['pred'].as_continuous().mean( axis=0).flatten() pred_after_stp = val_after_stp['pred'].as_continuous().mean( axis=0).flatten() stp_effect = (pred_after_stp - pred_before_stp) / (pred_after_stp + pred_before_stp) fig = plt.figure() plasma = plt.get_cmap('plasma') plt.scatter(pred_before_NL, pred_after_NL, c=stp_effect, s=2, alpha=0.75, cmap=plasma) plt.title(cellid) plt.xlabel('pred in (no stp)') plt.ylabel('pred out (with stp)') if include_phi: stp_phi = modelspec.phi[stp_idx] phi_string = '\n'.join( ['%s: %.4E' % (k, v) for k, v in stp_phi.items()]) fig.text(0.775, 0.9, phi_string, va='top', ha='left') plt.subplots_adjust(right=0.775, left=0.075) plt.colorbar() return fig
def evaluate_context(ctx, rec_key='val', rec_idx=0, mspec_idx=0, start=None, stop=None): rec = ctx[rec_key][0] mspec = ctx['modelspecs'][mspec_idx] return ms.evaluate(rec, mspec, start=start, stop=stop)
def reconsitute_rec(batch, cellid_list, modelname): ''' Takes a group of single cell recordings (from cells of a population recording) including their model predictions, and builds a recording withe signals containing the responses and predictions of all the cells in the population This is to make the recordigs compatible with downstream dispersion analisis or any analysis working with signals of neuronal populations :param batch: int batch number :param cellid_list: [str, str ...] list of cell IDs :param modelname: str. modelaname :return: NEMS Recording object ''' result_paths = _get_result_paths(batch, cellid_list, modelname) cell_resp_dict = dict() cell_pred_dict = col.defaultdict() for ff, filepath in enumerate(result_paths): # use modelsepcs to predict the response of resp xfspec, ctx = xforms.load_analysis(filepath=filepath, eval_model=False, only=slice(0, 2, 1)) modelspecs = ctx['modelspecs'][0] cellid = modelspecs[0]['meta']['cellid'] real_modelname = modelspecs[0]['meta']['modelname'] rec = ctx['rec'].copy() rec = ms.evaluate( rec, modelspecs) # recording containing signal for resp and pred # holds and organizes the raw data, keeping track of the cell for later concatenations. cell_resp_dict.update( rec['resp']._data ) # in PointProcess signals _data is already a dict, thus the use of update cell_pred_dict[cellid] = rec[ 'pred']._data # in Rasterized signals _data is a matrix, thus the requirement to asign key. # create a new population recording. pull stim from last single cell, create signal from meta form last resp signal and # stacked data for all cells. modify signal metadata to be consistent with new data and cells contained pop_resp = rec['resp']._modified_copy(data=cell_resp_dict, chans=list(cell_resp_dict.keys()), nchans=len( list(cell_resp_dict.keys()))) stack_data = np.concatenate(list(cell_pred_dict.values()), axis=0) pop_pred = rec['pred']._modified_copy(data=stack_data, chans=list(cell_pred_dict.keys()), nchans=len( list(cell_pred_dict.keys()))) reconstituted_recording = rec.copy() reconstituted_recording['resp'] = pop_resp reconstituted_recording['pred'] = pop_pred del reconstituted_recording.signals['state'] del reconstituted_recording.signals['state_raw'] return reconstituted_recording
def dynamic_sigmoid_range(cellid, batch, modelname, plot=True, show_min_max=True): xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname) modelspec = ctx['modelspec'] modelspec.recording = ctx['val'] val = ctx['val'].apply_mask() ctpred = val['ctpred'].as_continuous().flatten() val_before_dsig = ms.evaluate(val, modelspec, stop=-1) pred_before_dsig = val_before_dsig['pred'].as_continuous().flatten() lows = {k: v for k, v in modelspec[-1]['phi'].items() if '_mod' not in k} highs = {k[:-4]: v for k, v in modelspec[-1]['phi'].items() if '_mod' in k} for k in lows: if k not in highs: highs[k] = lows[k] # re-sort keys to make sure they're in the same order lows = {k: lows[k] for k in sorted(lows)} highs = {k: highs[k] for k in sorted(highs)} ctmax_val = np.max(ctpred) ctmin_val = np.min(ctpred) ctmax_idx = np.argmax(ctpred) ctmin_idx = np.argmin(ctpred) thetas = list(lows.values()) theta_mods = list(highs.values()) for t, t_m, k in zip(thetas, theta_mods, list(lows.keys())): lows[k] = t + (t_m - t) * ctmin_val highs[k] = t + (t_m - t) * ctmax_val low_out = _double_exponential(pred_before_dsig, **lows).flatten() high_out = _double_exponential(pred_before_dsig, **highs).flatten() if plot: fig = plt.figure() plt.scatter(pred_before_dsig, low_out, color='blue', s=0.7, alpha=0.6) plt.scatter(pred_before_dsig, high_out, color='red', s=0.7, alpha=0.6) if show_min_max: max_pred = pred_before_dsig[ctmax_idx] min_pred = pred_before_dsig[ctmin_idx] plt.scatter(min_pred, low_out[ctmin_idx], facecolors=None, edgecolors='blue', s=60) plt.scatter(max_pred, high_out[ctmax_idx], facecolors=None, edgecolors='red', s=60) return pred_before_dsig, low_out, high_out
def dynamic_sigmoid_differences(batch, modelname, hist_bins=60, test_limit=None, save_path=None, load_path=None, use_quartiles=False, avg_bin_count=20): if load_path is None: cellids = nd.get_batch_cells(batch, as_list=True) ratios = [] for cellid in cellids[:test_limit]: xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname) val = ctx['val'].apply_mask() ctpred = val['ctpred'].as_continuous().flatten() pred_after = val['pred'].as_continuous().flatten() val_before = ms.evaluate(val, ctx['modelspec'], stop=-1) pred_before = val_before['pred'].as_continuous().flatten() median_ct = np.nanmedian(ctpred) if use_quartiles: low = np.percentile(ctpred, 25) high = np.percentile(ctpred, 75) low_mask = (ctpred >= low) & (ctpred < median_ct) high_mask = ctpred >= high else: low_mask = ctpred < median_ct high_mask = ctpred >= median_ct # TODO: do some kind of binning here since the two vectors # don't actually overlap in x axis mean_before, bin_masks = _binned_xvar(pred_before, avg_bin_count) low = _binned_yavg(pred_after, low_mask, bin_masks) high = _binned_yavg(pred_after, high_mask, bin_masks) ratio = np.nanmean((low - high) / (np.abs(low) + np.abs(high))) ratios.append(ratio) ratios = np.array(ratios) if save_path is not None: np.save(save_path, ratios) else: ratios = np.load(load_path) plt.figure() plt.hist(ratios, bins=hist_bins, color=[wsu_gray_light], edgecolor='black', linewidth=1) #plt.rc('text', usetex=True) #plt.xlabel(r'\texit{\frac{low-high}{\left|high\right|+\left|low\right|}}') plt.xlabel('(low - high)/(|low| + |high|)') plt.ylabel('cell count') plt.title("difference of low-contrast output and high-contrast output\n" "positive means low-contrast has higher firing rate on average")
def before_and_after_signal(rec, modelspec, idx, sig_name='pred'): # HACK: shouldn't hardcode 'stim', might be named something else # or not present at all. Need to figure out a better solution # for special case of idx = 0 if idx == 0: # Can't have anything before index 0, so use input stimulus before = rec before_sig = copy.deepcopy(rec['stim']) else: before = ms.evaluate(rec, modelspec, start=None, stop=idx) before_sig = copy.deepcopy(before[sig_name]) before_sig.name = 'before' after = ms.evaluate(before.copy(), modelspec, start=idx, stop=idx+1) after_sig = copy.deepcopy(after[sig_name]) after_sig.name = 'after' return before_sig, after_sig
def generate_prediction(est, val, modelspecs): list_val = False if type(val) is list: # ie, if jackknifing list_val = True else: # Evaluate estimation and validation data # Since ms.evaluate only does a shallow copy of rec, successive # evaluations of one rec on many modelspecs just results in a list of # different pointers to the same recording. So need to force copies of # est/val before evaluating. if len(modelspecs) == 1: # no copies needed for 1 modelspec est = [est] val = [val] else: est = [est.copy() for i, _ in enumerate(modelspecs)] val = [val.copy() for i, _ in enumerate(modelspecs)] new_est = [] new_val = [] for m, e, v in zip(modelspecs, est, val): # nan-out periods outside of mask e = ms.evaluate(e, m) v = ms.evaluate(v, m) if 'mask' in v.signals.keys(): m = v['mask'].as_continuous() x = v['pred'].as_continuous().copy() x[..., m[0,:] == 0] = np.nan v['pred'] = v['pred']._modified_copy(x) new_est.append(e) new_val.append(v) #new_est = [ms.evaluate(d, m) for m, d in zip(modelspecs, est)] #new_val = [ms.evaluate(d, m) for m, d in zip(modelspecs, val)] if list_val: new_val = [recording.jackknife_inverse_merge(new_val)] return new_est, new_val
def dynamic_sigmoid_distribution(cellid, batch, modelname, sample_every=10, alpha=0.1): xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname) modelspec = ctx['modelspec'] val = ctx['val'].apply_mask() modelspec.recording = val val_before_dsig = ms.evaluate(val, modelspec, stop=-1) pred_before_dsig = val_before_dsig['pred'].as_continuous().T ctpred = val_before_dsig['ctpred'].as_continuous().T lows = {k: v for k, v in modelspec[-1]['phi'].items() if '_mod' not in k} highs = {k[:-4]: v for k, v in modelspec[-1]['phi'].items() if '_mod' in k} for k in lows: if k not in highs: highs[k] = lows[k] # re-sort keys to make sure they're in the same order lows = {k: lows[k] for k in sorted(lows)} highs = {k: highs[k] for k in sorted(highs)} thetas = list(lows.values()) theta_mods = list(highs.values()) fig = plt.figure() for i in range(int(len(pred_before_dsig) / sample_every)): try: ts = {} for t, t_m, k in zip(thetas, theta_mods, list(lows.keys())): ts[k] = t + (t_m - t) * ctpred[i * sample_every] y = _double_exponential(pred_before_dsig, **ts) plt.scatter(pred_before_dsig, y, color='black', alpha=alpha, s=0.01) except IndexError: # Will happen on last attempt if array wasn't evenly divisible # by sample_every pass t_max = {} t_min = {} for t, t_m, k in zip(thetas, theta_mods, list(lows.keys())): t_max[k] = t + (t_m - t) * np.nanmax(ctpred) t_min[k] = t + (t_m - t) * np.nanmin(ctpred) max_out = _double_exponential(pred_before_dsig, **t_max) min_out = _double_exponential(pred_before_dsig, **t_min) plt.scatter(pred_before_dsig, max_out, color='red', s=0.1) plt.scatter(pred_before_dsig, min_out, color='blue', s=0.1)
def mod_output(rec, modelspec, sig_name='pred', ax=None, title=None, idx=0, channels=0, xlabel='Time', ylabel='Value', **options): ''' Plots a time series of specified signal output by step in the modelspec. Arguments: ---------- rec : recording object The dataset to use. See nems/recording.py. modelspec : list of dicts The transformations to perform. See nems/modelspec.py. sig_name : str or list of strings Specifies the signal in 'rec' to be examined. idx : int An index into the modelspec. rec[sig_name] will be plotted as it exists after step idx-1 and after step idx. Returns: -------- ax : axis containing plot ''' if type(sig_name) is str: sig_name = [sig_name] trec = ms.evaluate(rec, modelspec, stop=idx + 1) if 'mask' in trec.signals.keys(): trec = trec.apply_mask() sigs = [trec[s] for s in sig_name] ax = timeseries_from_signals(sigs, channels=channels, xlabel=xlabel, ylabel=ylabel, ax=ax, title=title, **options) return ax
def init_dexp(rec, modelspec): """ choose initial values for dexp applied after preceeding fir is initialized """ target_i = None target_module = 'double_exponential' for i, m in enumerate(modelspec): if target_module in m['fn']: target_i = i break if not target_i: log.info( "target_module: {} not found in modelspec.".format(target_module)) return modelspec else: log.info("target_module: {0} found at modelspec[{1}].".format( target_module, target_i - 1)) if target_i == len(modelspec): fit_portion = modelspec else: fit_portion = modelspec[:target_i] # generate prediction from module preceeding dexp rec = ms.evaluate(rec, fit_portion) resp = rec['resp'].as_continuous() pred = rec['pred'].as_continuous() keepidx = np.isfinite(resp) * np.isfinite(pred) resp = resp[keepidx] pred = pred[keepidx] # choose phi s.t. dexp starts as almost a straight line # phi=[max_out min_out slope mean_in] meanr = np.nanmean(resp) stdr = np.nanstd(resp) modelspec[target_i]['phi'] = {} modelspec[target_i]['phi']['amplitude'] = stdr * 8 modelspec[target_i]['phi']['base'] = meanr - stdr * 4 modelspec[target_i]['phi']['kappa'] = np.log(np.std(pred) / 10) modelspec[target_i]['phi']['shift'] = np.mean(pred) log.info(modelspec[target_i]) return modelspec
def dynamic_sigmoid_pred_matched(cellid, batch, modelname, include_phi=True): xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname) modelspec = ctx['modelspec'] modelspec.recording = ctx['val'] val = ctx['val'].apply_mask() ctpred = val['ctpred'].as_continuous().flatten() # HACK # this really shouldn't happen.. but for some reason some of the # batch 263 cells are getting nans, so temporary fix. ctpred[np.isnan(ctpred)] = 0 pred_after_dsig = val['pred'].as_continuous().flatten() val_before_dsig = ms.evaluate(val, modelspec, stop=-1) pred_before_dsig = val_before_dsig['pred'].as_continuous().flatten() fig = plt.figure(figsize=(12, 7)) plasma = plt.get_cmap('plasma') plt.scatter(pred_before_dsig, pred_after_dsig, c=ctpred, s=2, alpha=0.75, cmap=plasma) plt.title(cellid) plt.xlabel('pred in') plt.ylabel('pred out') if include_phi: dsig_phi = modelspec.phi[-1] phi_string = '\n'.join( ['%s: %.4E' % (k, v) for k, v in dsig_phi.items()]) thetas = list(dsig_phi.keys())[0:-1:2] mods = list(dsig_phi.keys())[1::2] weights = { k: (dsig_phi[mods[i]] - dsig_phi[thetas[i]]) for i, k in enumerate(thetas) } weights_string = 'weights:\n' + '\n'.join( ['%s: %.4E' % (k, v) for k, v in weights.items()]) fig.text(0.775, 0.9, phi_string, va='top', ha='left') fig.text(0.775, 0.1, weights_string, va='bottom', ha='left') plt.subplots_adjust(right=0.775, left=0.075) plt.colorbar() return fig
def contrast_kernel_output(rec, modelspec, ax=None, title=None, idx=0, channels=0, xlabel='Time', ylabel='Value', **options): output = ms.evaluate(rec, modelspec, stop=idx + 1)['ctpred'] timeseries_from_signals([output], channels=channels, xlabel=xlabel, ylabel=ylabel, ax=ax, title=title) return ax
def single_state_mod_index(rec, modelspec, epoch='REFERENCE', psth_name='pred', state_sig='state', state_chan='pupil'): if type(state_chan) is list: if len(state_chan) == 0: state_chan = rec[state_sig].chans mod_list = [ single_state_mod_index(rec, modelspec, epoch=epoch, psth_name=psth_name, state_sig=state_sig, state_chan=s) for s in state_chan ] return mod_list sidx = find_module('state', modelspec) if sidx is None: raise ValueError("no state signal found") modelspec = copy.deepcopy(modelspec) state_chan_idx = rec[state_sig].chans.index(state_chan) k = np.ones(rec[state_sig].shape[0], dtype=bool) k[0] = False k[state_chan_idx] = False modelspec[sidx]['phi']['d'][:, k] = 0 modelspec[sidx]['phi']['g'][:, k] = 0 newrec = ms.evaluate(rec, modelspec) return np.array( state_mod_index(newrec, epoch=epoch, psth_name=psth_name, state_sig=state_sig, state_chan=state_chan))
# 4. Now add the signal to the recording rec.add_signal(respavg) # Now split into est and val data sets est, val = rec.split_using_epoch_occurrence_counts(epoch_regex='^STIM_') # est, val = rec.split_at_time(0.8) # Load some modelspecs and create their predictions modelspecs = ms.load_modelspecs(modelspecs_dir, 'TAR010c-18-1')#, regex=('^TAR010c-18-1\.{\d+}\.json')) # Testing summary statistics: means, stds = ms.summary_stats(modelspecs) print("means: {}".format(means)) print("stds: {}".format(stds)) pred = [ms.evaluate(val, m)['pred'] for m in modelspecs] # Shorthands for unchanging signals stim = val['stim'] resp = val['resp'] respavg = val['respavg'] def plot_layout(plot_fn_struct): ''' Accepts a list of lists of functions of 1 argument (ax). Basically a fancy subplot that lets you lay out functions without worrying about details. See example below ''' # Count how many plot functions we want nrows = len(plot_fn_struct)
def _get_pred_resp(rec, modelspec, pred_name, resp_name): """Evaluates rec using a fitted modelspec, then returns pred and resp signals.""" new_rec = ms.evaluate(rec, modelspec) return [new_rec[pred_name], new_rec[resp_name]]
def init_logsig(rec, modelspec): ''' Initialization of priors for logistic_sigmoid, based on process described in methods of Rabinowitz et al. 2014. ''' # preserve input modelspec modelspec = copy.deepcopy(modelspec) target_i = find_module('logistic_sigmoid', modelspec) if target_i is None: log.warning("No logsig module was found, can't initialize.") return modelspec if target_i == len(modelspec): fit_portion = modelspec else: fit_portion = modelspec[:target_i] # generate prediction from module preceeding dexp ms.fit_mode_on(fit_portion) rec = ms.evaluate(rec, fit_portion) ms.fit_mode_off(fit_portion) pred = rec['pred'].as_continuous() resp = rec['resp'].as_continuous() mean_pred = np.nanmean(pred) min_pred = np.nanmean(pred) - np.nanstd(pred) * 3 max_pred = np.nanmean(pred) + np.nanstd(pred) * 3 if min_pred < 0: min_pred = 0 mean_pred = (min_pred + max_pred) / 2 pred_range = max_pred - min_pred min_resp = max(np.nanmean(resp) - np.nanstd(resp) * 3, 0) # must be >= 0 max_resp = np.nanmean(resp) + np.nanstd(resp) * 3 resp_range = max_resp - min_resp # Rather than setting a hard value for initial phi, # set the prior distributions and let the fitter/analysis # decide how to use it. base0 = min_resp + 0.05 * (resp_range) amplitude0 = resp_range shift0 = mean_pred kappa0 = pred_range log.info("Initial base,amplitude,shift,kappa=({}, {}, {}, {})".format( base0, amplitude0, shift0, kappa0)) base = ('Exponential', {'beta': base0}) amplitude = ('Exponential', {'beta': amplitude0}) shift = ('Normal', {'mean': shift0, 'sd': pred_range}) kappa = ('Exponential', {'beta': kappa0}) modelspec[target_i]['prior'].update({ 'base': base, 'amplitude': amplitude, 'shift': shift, 'kappa': kappa }) modelspec[target_i]['bounds'] = { 'base': (1e-15, None), 'amplitude': (1e-15, None), 'shift': (None, None), 'kappa': (1e-15, None) } return modelspec
def init_dexp(rec, modelspec): """ choose initial values for dexp applied after preceeding fir is initialized """ # preserve input modelspec modelspec = copy.deepcopy(modelspec) target_i = find_module('double_exponential', modelspec) if target_i is None: log.warning("No dexp module was found, can't initialize.") return modelspec if target_i == len(modelspec): fit_portion = modelspec else: fit_portion = modelspec[:target_i] # ensures all previous modules have their phi initialized # choose prior mean if not found for i, m in enumerate(fit_portion): if ('phi' not in m.keys()) and ('prior' in m.keys()): log.debug('Phi not found for module, using mean of prior: %s', m) m = priors.set_mean_phi([m])[0] # Inits phi for 1 module fit_portion[i] = m # generate prediction from module preceeding dexp ms.fit_mode_on(fit_portion) rec = ms.evaluate(rec, fit_portion) ms.fit_mode_off(fit_portion) in_signal = modelspec[target_i]['fn_kwargs']['i'] pchans = rec[in_signal].shape[0] amp = np.zeros([pchans, 1]) base = np.zeros([pchans, 1]) kappa = np.zeros([pchans, 1]) shift = np.zeros([pchans, 1]) for i in range(pchans): resp = rec['resp'].as_continuous() pred = rec[in_signal].as_continuous()[i:(i + 1), :] if resp.shape[0] == pchans: resp = resp[i:(i + 1), :] keepidx = np.isfinite(resp) * np.isfinite(pred) resp = resp[keepidx] pred = pred[keepidx] # choose phi s.t. dexp starts as almost a straight line # phi=[max_out min_out slope mean_in] # meanr = np.nanmean(resp) stdr = np.nanstd(resp) # base = np.max(np.array([meanr - stdr * 4, 0])) base[i, 0] = np.min(resp) # base = meanr - stdr * 3 # amp = np.max(resp) - np.min(resp) amp[i, 0] = stdr * 3 shift[i, 0] = np.mean(pred) # shift = (np.max(pred) + np.min(pred)) / 2 predrange = 2 / (np.max(pred) - np.min(pred) + 1) kappa[i, 0] = np.log(predrange) modelspec[target_i]['phi'] = { 'amplitude': amp, 'base': base, 'kappa': kappa, 'shift': shift } log.info("Init dexp: %s", modelspec[target_i]['phi']) return modelspec
def fit_population_channel_fast2(rec, modelspec, fit_set_all, fit_set_slice, analysis_function=analysis.fit_basic, metric=metrics.nmse, fitter=scipy_minimize, fit_kwargs={}): # guess at number of subspace dimensions dim_count = modelspec[fit_set_slice[0]]['phi']['coefficients'].shape[1] wi = [i for i in fit_set_slice if 'weight_channels' in modelspec[i]['fn']] wi = wi[0] li = [i for i in fit_set_slice if 'levelshift' in modelspec[i]['fn']] li = li[0] for d in range(dim_count): # fit each dim separately log.info('Updating dim %d/%d', d+1, dim_count) # create modelspec with single population subspace filter tmodelspec = _extract_pop_channel(modelspec, d, fit_set_all) # temp append full-population layer as non-free parameters tmodelspec2 = copy.deepcopy(tmodelspec) for i in fit_set_slice: m = copy.deepcopy(modelspec[i]) for k, v in m['phi'].items(): # just applies to wc module? if v.shape[1] >= dim_count: m['phi'][k] = v[:, [d]] else: m['phi'][k] = v tmodelspec2.append(m) # compute residual from prediction by the rest of the pop model trec = rec.copy() trec = ms.evaluate(trec, modelspec) r = trec['resp'].as_continuous() p = trec['pred'].as_continuous().copy() respstd = np.nanstd(r) # std of actual response trec = ms.evaluate(trec, tmodelspec2) p2 = trec['pred'].as_continuous() trec = ms.evaluate(trec, tmodelspec) # residual we're trying to predict with tmodelspec r = r - p + p2 # calculate streamlined nMSE function for single pop channel model # by inverting neuron-specific gains and level shifts a = modelspec[wi]['phi']['coefficients'][:, [d]] b = modelspec[li]['phi']['level'] r -= b # subtract level shift from residual A1 = np.sum(a ** 2) A2 = np.sum(2 * a * r, axis=0, keepdims=True) A3 = np.sum(r**2, axis=0, keepdims=True) def my_nmse(result): ''' hacked from nems.metrics.mse.nmse. optimized nMSE for situation when a single population channel is predicting responses with fixed per-neuron gains and levelshifts A1, A2, A3, respstd defined outside of function result : recording object updated by fitter, prediction response of single pop channel ''' X1 = result['pred'].as_continuous() squared_errors = A1 * (X1**2) - A2 * X1 + A3 mean_sq_err = np.sum(squared_errors) / (r.shape[0]*r.shape[1]) mse = np.sqrt(mean_sq_err) return mse / respstd # import pdb # pdb.set_trace() tmodelspec = analysis_function(trec, tmodelspec, fitter=fitter, metric=my_nmse, fit_kwargs=fit_kwargs) modelspec = _update_pop_channel(tmodelspec, modelspec, d, fit_set_all) return modelspec
def fit_population_channel_fast(rec, modelspec, fit_set_all, fit_set_slice, analysis_function=analysis.fit_basic, metric=metrics.nmse, fitter=scipy_minimize, fit_kwargs={}): # guess at number of subspace dimensions dim_count = modelspec[fit_set_slice[0]]['phi']['coefficients'].shape[1] for d in range(dim_count): # fit each dim separately log.info('Updating dim %d/%d', d+1, dim_count) # create modelspec with single population subspace filter tmodelspec = ms.ModelSpec() for i in fit_set_all: m = copy.deepcopy(modelspec[i]) for k, v in m['phi'].items(): x = v.shape[0] if x >= dim_count: x1 = int(x/dim_count) * d x2 = int(x/dim_count) * (d+1) m['phi'][k] = v[x1:x2] if 'bank_count' in m['fn_kwargs'].keys(): m['fn_kwargs']['bank_count'] = 1 else: # single model-wide parameter, only fit for d==0 if d == 0: m['phi'][k] = v else: m['fn_kwargs'][k] = v # keep fixed for d>0 del m['phi'] del m['prior'] tmodelspec.append(m) # append full-population layer as non-free parameters for i in fit_set_slice: m = copy.deepcopy(modelspec[i]) for k, v in m['phi'].items(): # just applies to wc module? if v.shape[1] >= dim_count: m['fn_kwargs'][k] = v[:, [d]] else: m['fn_kwargs'][k] = v # print(k) # print(m['fn_kwargs'][k]) del m['phi'] del m['prior'] tmodelspec.append(m) # print(tmodelspec[-1]['fn_kwargs']) # compute residual from prediction by the rest of the pop model trec = rec.copy() trec = ms.evaluate(trec, modelspec) r = trec['resp'].as_continuous() p = trec['pred'].as_continuous() trec = ms.evaluate(trec, tmodelspec) p2 = trec['pred'].as_continuous() trec['resp'] = trec['resp']._modified_copy(data=r-p+p2) # import pdb; # pdb.set_trace() tmodelspec = analysis_function(trec, tmodelspec, fitter=fitter, metric=metric, fit_kwargs=fit_kwargs) for i in fit_set_all: for k, v in tmodelspec[i]['phi'].items(): x = modelspec[i]['phi'][k].shape[0] if x >= dim_count: x1 = int(x/dim_count) * d x2 = int(x/dim_count) * (d+1) modelspec[i]['phi'][k][x1:x2] = v else: modelspec[i]['phi'][k] = v # for i in fit_set_slice: # for k, v in tmodelspec[i]['phi'].items(): # if modelspec[i]['phi'][k].shape[0] >= dim_count: # modelspec[i]['phi'][k][d,:] = v # print([modelspec.phi[f] for f in fit_set_all]) return modelspec
def fit_population_slice(rec, modelspec, slice=0, fit_set=None, analysis_function=analysis.fit_basic, metric=metrics.nmse, fitter=scipy_minimize, fit_kwargs={}): """ fits a slice of a population model. modified from prefit_mod_subset slice: int response channel to fit fit_set: list list of mod names to fit """ # preserve input modelspec modelspec = copy.deepcopy(modelspec) slice_count = rec['resp'].shape[0] if slice > slice_count: raise ValueError("Slice %d > slice_count %d", slice, slice_count) if fit_set is None: raise ValueError("fit_set list of module indices must be specified") if type(fit_set[0]) is int: fit_idx = fit_set else: fit_idx = [] for i, m in enumerate(modelspec): for fn in fit_set: if fn in m['fn']: fit_idx.append(i) # identify any excluded modules and take them out of temp modelspec # that will be fit here tmodelspec = ms.ModelSpec() sliceinfo = [] for i, m in enumerate(modelspec): m = copy.deepcopy(m) # need to have phi in place if not m.get('phi'): log.info('Initializing phi for module %d (%s)', i, m['fn']) m = priors.set_mean_phi([m])[0] # Inits phi if i in fit_idx: s = {} for key, value in m['phi'].items(): log.debug('Slicing %d (%s) key %s chan %d for fit', i, m['fn'], key, slice) # keep only sliced channel(s) if 'bank_count' in m['fn_kwargs'].keys(): bank_count = m['fn_kwargs']['bank_count'] filters_per_bank = int(value.shape[0] / bank_count) slices = np.arange(slice*filters_per_bank, (slice+1)*filters_per_bank) m['phi'][key] = value[slices, ...] s[key] = slices m['fn_kwargs']['bank_count'] = 1 elif value.shape[0] == slice_count: m['phi'][key] = value[[slice], ...] s[key] = [slice] else: raise ValueError("Not sure how to slice %s %s", m['fn'], key) # record info about how sliced this module parameter sliceinfo.append(s) tmodelspec.append(m) if len(fit_idx) == 0: log.info('No modules matching fit_set for slice fit') return modelspec exclude_idx = np.setdiff1d(np.arange(0, len(modelspec)), np.array(fit_idx)).tolist() for i in exclude_idx: m = tmodelspec[i] log.debug('Freezing phi for module %d (%s)', i, m['fn']) m['fn_kwargs'].update(m['phi']) m['phi'] = {} tmodelspec[i] = m # generate temp recording with only resposnes of interest temp_rec = rec.copy() slice_chans = [temp_rec['resp'].chans[slice]] temp_rec['resp'] = temp_rec['resp'].extract_channels(slice_chans) # remove initial modules first_idx = fit_idx[0] if first_idx > 0: # print('firstidx {}'.format(first_idx)) temp_rec = ms.evaluate(temp_rec, tmodelspec, stop=first_idx) # temp_rec['stim'] = temp_rec['pred'].copy() # tmodelspec = tmodelspec.copy(lb=first_idx) # tmodelspec[0]['fn_kwargs']['i'] = 'stim' tmodelspec = tmodelspec.copy() tmodelspec.fast_eval_on(rec=temp_rec, subset=fit_idx) first_idx = 0 # print(tmodelspec) # print(temp_rec.signals.keys()) # IS this mask necessary? Does it work? # if 'mask' in temp_rec.signals.keys(): # print("Data len pre-mask: %d" % (temp_rec['mask'].shape[1])) # temp_rec = temp_rec.apply_mask() # print("Data len post-mask: %d" % (temp_rec['mask'].shape[1])) # fit the subset of modules temp_rec = ms.evaluate(temp_rec, tmodelspec) error_before = metric(temp_rec) tmodelspec = analysis_function(temp_rec, tmodelspec, fitter=fitter, metric=metric, fit_kwargs=fit_kwargs) tmodelspec.fast_eval_off() temp_rec = ms.evaluate(temp_rec, tmodelspec) error_after = metric(temp_rec) dError = error_before - error_after if dError < 0: log.info("dError (%.6f - %.6f) = %.6f worse. not updating modelspec" % (error_before, error_after, dError)) else: log.info("dError (%.6f - %.6f) = %.6f better. updating modelspec" % (error_before, error_after, dError)) # reassemble the full modelspec with updated phi values from tmodelspec for i, mod_idx in enumerate(fit_idx): m = copy.deepcopy(modelspec[mod_idx]) # need to have phi in place if not m.get('phi'): log.info('Intializing phi for module %d (%s)', mod_idx, m['fn']) m = priors.set_mean_phi([m])[0] # Inits phi for key, value in tmodelspec[mod_idx - first_idx]['phi'].items(): # print(key) # print(m['phi'][key].shape) # print(sliceinfo[i][key]) # print(value) m['phi'][key][sliceinfo[i][key], :] = value modelspec[mod_idx] = m return modelspec
def fit_population_iteratively( est, modelspec, cost_function=basic_cost, fitter=coordinate_descent, evaluator=ms.evaluate, segmentor=nems.segmentors.use_all_data, mapper=nems.fitters.mappers.simple_vector, metric=lambda data: nems.metrics.api.nmse(data, 'pred', 'resp'), metaname='fit_basic', fit_kwargs={}, module_sets=None, invert=False, tolerances=None, tol_iter=50, fit_iter=10, IsReload=False, **context ): ''' Required Arguments: est A recording object modelspec A modelspec object Optional Arguments: TODO: need to deal with the fact that you can't pass functions in an xforms-frieldly fucntion fitter (CURRENTLY NOT USED?) A function of (sigma, costfn) that tests various points, in fitspace (i.e. sigmas) using the cost function costfn, and hopefully returns a better sigma after some time. mapper (CURRENTLY NOT USED?) A class that has two methods, pack and unpack, which define the mapping between modelspecs and a fitter's fitspace. segmentor (CURRENTLY NOT USED?) An function that selects a subset of the data during the fitting process. This is NOT the same as est/val data splits metric A function of a Recording that returns an error value that is to be minimized. module_sets (CURRENTLY NOT USED?) A nested list specifying which model indices should be fit. Overall iteration will occurr len(module_sets) many times. ex: [[0], [1, 3], [0, 1, 2, 3]] invert (CURRENTLY NOT USED?) Boolean. Causes module_sets to specify the model indices that should *not* be fit. Returns A list containing a single modelspec, which has the best parameters found by this fitter. ''' if IsReload: return {} modelspec = copy.deepcopy(modelspec) data = est.copy() fit_set_all, fit_set_slice = _figure_out_mod_split(modelspec) if tolerances is None: tolerances = [1e-4, 1e-5] # apply mask to remove invalid portions of signals and allow fit to # only evaluate the model on the valid portion of the signals # then delete the mask signal so that it's not reapplied on each fit if 'mask' in data.signals.keys(): log.info("Data len pre-mask: %d", data['mask'].shape[1]) data = data.apply_mask() log.info("Data len post-mask: %d", data['mask'].shape[1]) del data.signals['mask'] start_time = time.time() ms.fit_mode_on(modelspec, data) # modelspec = init_pop_pca(data, modelspec) # print(modelspec) # Ensure that phi exists for all modules; choose prior mean if not found # for i, m in enumerate(modelspec): # if ('phi' not in m.keys()) and ('prior' in m.keys()): # m = nems.priors.set_mean_phi([m])[0] # Inits phi for 1 module # log.debug('Phi not found for module, using mean of prior: {}' # .format(m)) # modelspec[i] = m error = np.inf slice_count = data['resp'].shape[0] step_size = 0.1 if 'nonlinearity' in modelspec[-1]['fn']: skip_nl_first = True tolerances = [tolerances[0]] + tolerances else: skip_nl_first = False for toli, tol in enumerate(tolerances): log.info("Fitting subsets with tol: %.2E fit_iter %d tol_iter %d", tol, fit_iter, tol_iter) cd_kwargs = fit_kwargs.copy() cd_kwargs.update({'tolerance': tol, 'max_iter': fit_iter, 'step_size': step_size}) sp_kwargs = fit_kwargs.copy() sp_kwargs.update({'tolerance': tol, 'max_iter': fit_iter}) if (toli == 0) and skip_nl_first: log.info('skipping nl on first tolerance loop') saved_modelspec = copy.deepcopy(modelspec) saved_fit_set_slice = fit_set_slice.copy() # import pdb; # pdb.set_trace() modelspec.pop_module() fit_set_slice = fit_set_slice[:-1] inner_i = 0 error_reduction = np.inf # big_slice = 0 # big_n = data['resp'].ntimes # big_step = int(big_n/10) # big_slice_size = int(big_n/2) while (error_reduction >= tol) and (inner_i < tol_iter): log.info("(%d) Tol %.2e: Loop %d/%d (max)", toli, tol, inner_i, tol_iter) improved_modelspec = copy.deepcopy(modelspec) cc = 0 slist = list(range(slice_count)) # random.shuffle(slist) for i, m in enumerate(modelspec): if i in fit_set_all: log.info(m['fn'] + ": fitting") else: log.info(m['fn'] + ": frozen") # partially implemented: select temporal subset of data for fitting # on current loop. # data2 = data.copy() # big_slice += 1 # sl = np.zeros(big_n, dtype=bool) # sl[:big_slice_size]=True # sl = np.roll(sl, big_step*big_slice) # log.info('Sampling temporal subset %d (size=%d/%d)', big_step, big_slice_size, big_n) # for s in data2.signals.values(): # e = s._modified_copy(s._data[:,sl]) # data2[e.name] = e # improved_modelspec = init.prefit_mod_subset( # data, improved_modelspec, analysis.fit_basic, # metric=metric, # fit_set=fit_set_all, # fit_kwargs=sp_kwargs) improved_modelspec = fit_population_channel_fast2( data, improved_modelspec, fit_set_all, fit_set_slice, analysis_function=analysis.fit_basic, metric=metric, fitter=scipy_minimize, fit_kwargs=sp_kwargs) for s in slist: log.info('Slice %d set %s' % (s, fit_set_slice)) improved_modelspec = fit_population_slice( data, improved_modelspec, slice=s, fit_set=fit_set_slice, analysis_function=analysis.fit_basic, metric=metric, fitter=scipy_minimize, fit_kwargs=sp_kwargs) # fitter = coordinate_descent, # fit_kwargs = cd_kwargs) cc += 1 # if (cc % 8 == 0) or (cc == slice_count): data = ms.evaluate(data, improved_modelspec) new_error = metric(data) error_reduction = error - new_error error = new_error log.info("tol=%.2E, iter=%d/%d: deltaE=%.6E", tol, inner_i, tol_iter, error_reduction) inner_i += 1 if error_reduction > 0: modelspec = improved_modelspec log.info("Done with tol %.2E (i=%d, max_error_reduction %.7f)", tol, inner_i, error_reduction) if (toli == 0) and skip_nl_first: log.info('Restoring NL module after first tol loop') modelspec.append(saved_modelspec[-1]) fit_set_slice = saved_fit_set_slice if 'double_exponential' in saved_modelspec[-1]['fn']: modelspec = init.init_dexp(data, modelspec) elif 'logistic_sigmoid' in saved_modelspec[-1]['fn']: modelspec = init.init_logsig(data, modelspec) elif 'relu' in saved_modelspec[-1]['fn']: # just keep initialized to zero pass else: raise ValueError("Output NL %s not supported", saved_modelspec[-1]['fn']) # just fit the NL improved_modelspec = copy.deepcopy(modelspec) kwa = cd_kwargs.copy() kwa['max_iter'] *= 2 for s in range(slice_count): log.info('Slice %d set %s' % (s, [fit_set_slice[-1]])) improved_modelspec = fit_population_slice( data, improved_modelspec, slice=s, fit_set=fit_set_slice, analysis_function=analysis.fit_basic, metric=metric, fitter=scipy_minimize, fit_kwargs=sp_kwargs) # fitter = coordinate_descent, # fit_kwargs = cd_kwargs) data = ms.evaluate(data, modelspec) old_error = metric(data) data = ms.evaluate(data, improved_modelspec) new_error = metric(data) log.info('Init NL fit error change %.5f-%.5f = %.5f', old_error, new_error, old_error-new_error) modelspec = improved_modelspec else: step_size *= 0.25 elapsed_time = (time.time() - start_time) # TODO: Should this maybe be moved to a higher level # so it applies to ALL the fittters? ms.fit_mode_off(improved_modelspec) ms.set_modelspec_metadata(improved_modelspec, 'fitter', metaname) ms.set_modelspec_metadata(improved_modelspec, 'fit_time', elapsed_time) return {'modelspec': improved_modelspec.copy()}
def init_dexp(rec, modelspec): """ choose initial values for dexp applied after preceeding fir is initialized """ # preserve input modelspec modelspec = copy.deepcopy(modelspec) target_i = find_module('double_exponential', modelspec) if target_i is None: log.warning("No dexp module was found, can't initialize.") return modelspec if target_i == len(modelspec): fit_portion = modelspec else: fit_portion = modelspec[:target_i] # generate prediction from module preceeding dexp ms.fit_mode_on(fit_portion) rec = ms.evaluate(rec, fit_portion) ms.fit_mode_off(fit_portion) pchans = rec['pred'].shape[0] amp = np.zeros([pchans, 1]) base = np.zeros([pchans, 1]) kappa = np.zeros([pchans, 1]) shift = np.zeros([pchans, 1]) for i in range(pchans): resp = rec['resp'].as_continuous() pred = rec['pred'].as_continuous()[i:(i + 1), :] keepidx = np.isfinite(resp) * np.isfinite(pred) resp = resp[keepidx] pred = pred[keepidx] # choose phi s.t. dexp starts as almost a straight line # phi=[max_out min_out slope mean_in] # meanr = np.nanmean(resp) stdr = np.nanstd(resp) # base = np.max(np.array([meanr - stdr * 4, 0])) base[i, 0] = np.min(resp) # base = meanr - stdr * 3 # amp = np.max(resp) - np.min(resp) amp[i, 0] = stdr * 3 shift[i, 0] = np.mean(pred) # shift = (np.max(pred) + np.min(pred)) / 2 predrange = 2 / (np.max(pred) - np.min(pred) + 1) kappa[i, 0] = np.log(predrange) modelspec[target_i]['phi'] = { 'amplitude': amp, 'base': base, 'kappa': kappa, 'shift': shift } log.info("Init dexp: %s", modelspec[target_i]['phi']) return modelspec