def _strf_to_contrast(modelspec, absolute_value=True): ''' Copy prefitted WC and FIR phi values to contrast-based counterparts. ''' modelspec = copy.deepcopy(modelspec) wc_idx, ctwc_idx = find_module('weight_channels', modelspec, find_all_matches=True) fir_idx, ctfir_idx = find_module('fir', modelspec, find_all_matches=True) log.info("Updating contrast phi to match prefitted strf ...") modelspec[ctwc_idx]['phi'] = copy.deepcopy(modelspec[wc_idx]['phi']) modelspec[ctfir_idx]['phi'] = copy.deepcopy(modelspec[fir_idx]['phi']) if absolute_value: for k, v in modelspec[ctwc_idx]['phi'].items(): p = np.abs(v) modelspec[ctwc_idx]['phi'][k] = p for k, v in modelspec[ctfir_idx]['phi'].items(): p = np.abs(v) modelspec[ctfir_idx]['phi'][k] = p return modelspec
def fixed_contrast_strf(modelspec=None, **kwargs): if modelspec is None: pass else: # WARNING: This modifies modelspec in-place mid-evaluation. # Really not sure this is the right way to do this. wc_idx = find_module('weight_channels', modelspec) if 'g' not in modelspec[wc_idx]['id']: _, ctwc_idx = find_module('weight_channels', modelspec, find_all_matches=True) fir_idx, ctfir_idx = find_module('fir', modelspec, find_all_matches=True) modelspec[ctwc_idx]['fn_kwargs'].update( copy.deepcopy(modelspec[wc_idx]['phi'])) modelspec[ctfir_idx]['fn_kwargs'].update( copy.deepcopy(modelspec[fir_idx]['phi'])) modelspec[ctwc_idx]['phi'] = {} modelspec[ctfir_idx]['phi'] = {} for k, v in modelspec[ctwc_idx]['phi']: p = np.abs(v) modelspec[ctwc_idx]['phi'][k] = p for k, v in modelspec[ctfir_idx]['phi']: p = np.abs(v) modelspec[ctfir_idx]['phi'][k] = p return False
def _set_nonlinearity(modelspec): # ctx = get_default_ctx() # est = ctx['est'] # val = ctx['val'] # eresp = est['resp'].as_continuous() # vresp = val['resp'].as_continuous() # min_resp = min(eresp.min(), vresp.min()) # max_resp = max(eresp.max(), vresp.max()) # epred = est['pred'].as_continuous() # vpred = val['pred'].as_continuous() # predrange = 2/(max(epred.max() - epred.min(), # vpred.max() - vpred.min()) + 1) #base = np.array([min_resp]) #amplitude = np.array([max_resp*0.5]) #shift = np.array([0.5*(epred.mean() + vpred.mean())]) #kappa = np.array([np.log(predrange)]) base = np.array([0]) amplitude = np.array([2]) shift = np.array([0.275]) kappa = np.array([2.5]) dexp_idx = find_module('double_exponential', modelspec) if dexp_idx is None: # no dexp, assume dsig for gc instead dexp_idx = find_module('dynamic_sigmoid', modelspec) modelspec[dexp_idx]['phi'].update({ 'base': base, 'amplitude': amplitude, 'shift': shift, 'kappa': kappa }) return modelspec
def _set_LN_phi(modelspec): wc_idx2 = find_module('weight_channels', modelspec) fir_idx2 = find_module('fir', modelspec) modelspec[wc_idx2]['phi'] = { 'mean': np.array([0.4, 0.5]), 'sd': np.array([0.15, 0.15]) } modelspec[fir_idx2]['phi'] = { 'coefficients': np.array([ [0, -.125, -.25, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, .275, .15, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], ]) } return _set_nonlinearity(modelspec)
def add_gc_signal(rec, modelspec, name='GC'): modelspec = copy.deepcopy(modelspec) rec = copy.deepcopy(rec) dsig_idx = find_module('dynamic_sigmoid', modelspec) # if dsig_idx is None: # log.warning("No dsig module was found, can't add GC signal.") # return rec phi = modelspec[dsig_idx]['phi'] phi.update(modelspec[dsig_idx]['fn_kwargs']) pred = rec['pred'].as_continuous() b = phi['base'] + (phi['base_mod'] - phi['base']) * pred a = phi['amplitude'] + (phi['amplitude_mod'] - phi['amplitude']) * pred s = phi['shift'] + (phi['shift_mod'] - phi['shift']) * pred k = phi['kappa'] + (phi['kappa_mod'] - phi['kappa']) * pred array = np.squeeze(np.stack([b, a, s, k], axis=0)) fs = rec['stim'].fs signal = nems.signal.RasterizedSignal(fs, array, name, rec['stim'].recording, chans=['B', 'A', 'S', 'K'], epochs=rec['stim'].epochs) rec[name] = signal return rec
def get_module(ctx, val, key='index', mspec_idx=0, find_all_matches=False): mspec = ctx['modelspecs'][mspec_idx] if key in ['index', 'idx', 'i']: return mspec[val] else: i = find_module(val, mspec, find_all_matches=find_all_matches, key=key) return mspec[i]
def fit_to_simulation(fit_model, simulation_spec): ''' Parameters: ----------- fit_model : str Modelname to fit to the simulation. simulation_spec : NEMS ModelSpec Modelspec to base simulation on. Returns: -------- ctx : dict Xforms context. See nems.xforms. ''' rec = get_default_ctx()['rec'] ctk_idx = find_module('contrast_kernel', simulation_spec) if ctk_idx is not None: simulation_spec[ctk_idx]['fn_kwargs']['evaluate_contrast'] = True new_resp = simulation_spec.evaluate(rec)['pred'] rec['resp'] = new_resp # replace ozgf and ld with ldm modelname = '-'.join(fit_model.split('-')[2:]) xfspec = xhelp.generate_xforms_spec(modelname=modelname) ctx, _ = xforms.evaluate(xfspec, context={'rec': rec}) return ctx
def build_toy_combined_cell(base, amplitude, shift, kappa, u, tau): modelspec = from_keywords(combined.split('_')[1]) modelspec = _set_LN_phi(modelspec) stp_idx = find_module('stp', modelspec) modelspec[stp_idx]['phi'] = {'u': u, 'tau': tau} return _set_gc_phi(modelspec, base, amplitude, shift, kappa)
def dsig_phi_to_prior(modelspec): ''' Sets priors for dynamic_sigmoid equal to the current phi for the same module. Used for random-sample fits - all samples are initialized and pre-fit the same way, and then randomly sampled from the new priors. Parameters ---------- modelspec : list of dictionaries A NEMS modelspec containing, at minimum, a dynamic_sigmoid module Returns ------- modelspec : A copy of the input modelspec with priors updated. ''' modelspec = copy.deepcopy(modelspec) dsig_idx = find_module('dynamic_sigmoid', modelspec) dsig = modelspec[dsig_idx] phi = dsig['phi'] b = phi['base'] a = phi['amplitude'] k = phi['kappa'] s = phi['shift'] p = dsig['prior'] p['base'][1]['beta'] = b p['amplitude'][1]['beta'] = a p['shift'][1]['mean'] = s # Do anything to scale sd? p['kappa'][1]['beta'] = k return modelspec
def _prefit_dsig_only(est, modelspec, analysis_function, fitter, metric=None, fit_kwargs={}): ''' Perform a rough fit that only allows dynamic_sigmoid parameters to vary. ''' dsig_idx = find_module('dynamic_sigmoid', modelspec) # freeze all non-static dynamic sigmoid parameters dynamic_phi = { 'amplitude_mod': False, 'base_mod': False, 'kappa_mod': False, 'shift_mod': False } for p in dynamic_phi: v = modelspec[dsig_idx]['prior'].pop(p, False) if v: modelspec[dsig_idx]['fn_kwargs'][p] = np.nan dynamic_phi[p] = v # Remove ctwc, ctfir, and ctlvl if they exist temp = [] for i, m in enumerate(modelspec.modules): if 'ct' in m['id']: pass else: temp.append(m) temp = ms.ModelSpec(raw=[temp]) temp = prefit_mod_subset(est, temp, analysis_function, fit_set=['dynamic_sigmoid'], fitter=fitter, metric=metric, fit_kwargs=fit_kwargs) # Put ctwc, ctfir, and ctlvl back in where applicable j = 0 for i, m in enumerate(modelspec.modules): if 'ct' in m['id']: pass else: modelspec[i] = temp[j] j += 1 # reset dynamic sigmoid parameters if they were frozen for p, v in dynamic_phi.items(): if v: prior = priors._tuples_to_distributions({p: v})[p] modelspec[dsig_idx]['fn_kwargs'].pop(p, None) modelspec[dsig_idx]['prior'][p] = v modelspec[dsig_idx]['phi'][p] = prior.mean() return modelspec
def _figure_out_mod_split(modelspec): """ determine where to split modelspec for pop vs. slice fit :param modelspec: :return: """ bank_mod = find_module('filter_bank', modelspec, find_all_matches=True) wc_mod = find_module('weight_channels', modelspec, find_all_matches=True) if len(wc_mod) >= 2: fit_set_all = list(range(wc_mod[1])) fit_set_slice = list(range(wc_mod[1], len(modelspec))) elif len(bank_mod) == 1: fit_set_all = list(range(bank_mod[0])) fit_set_slice = list(range(bank_mod[0], len(modelspec))) else: raise ValueError("Can't figure out how to split all and slices") return fit_set_all, fit_set_slice
def build_toy_stp_cell(u, tau): if not isinstance(u, np.ndarray): u = np.array(u) if not isinstance(tau, np.ndarray): tau = np.array(tau) modelspec = from_keywords(stp.split('_')[1]) modelspec = _set_LN_phi(modelspec) stp_idx = find_module('stp', modelspec) modelspec[stp_idx]['phi'] = {'u': u, 'tau': tau} return modelspec
def fir_L2_norm(modelspec): modelspec = copy.deepcopy(modelspec) fir_idx = find_module('fir', modelspec) prior = priors._tuples_to_distributions(modelspec[fir_idx]['prior']) random_coeffs = np.random.rand(*prior['coefficients'].mean().shape) normed = random_coeffs / np.linalg.norm(random_coeffs) # Assumes fir phi hasn't been initialized yet and that coefficients # is the only parameter to set. MAY NOT BE TRUE FOR SOME MODELS. modelspec[fir_idx]['phi'] = {'coefficients': normed} return modelspec
def gd_ratio(cellid, batch, modelname): xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname, eval_model=False) mspec = ctx['modelspec'] dsig_idx = find_module('dynamic_sigmoid', mspec) phi = mspec[dsig_idx]['phi'] return phi['kappa_mod'] / phi['kappa']
def freeze_dsig_statics(modelspec): modelspec = copy.deepcopy(modelspec) dsig_idx = find_module('dynamic_sigmoid', modelspec) if dsig_idx is None: log.warning("No dsig module was found, can't initialize.") return modelspec p = modelspec[dsig_idx]['phi'] frozen_bounds = {k: (v, v) for k, v in p.items()} modelspec[dsig_idx]['bounds'].update(frozen_bounds) return modelspec
def dsig_phi_to_prior(modelspec): ''' Sets priors for dynamic_sigmoid equal to the current phi for the same module. Used for random-sample fits - all samples are initialized and pre-fit the same way, and then randomly sampled from the new priors. Operates on modelspec IN-PLACE. Parameters ---------- modelspec : list of dictionaries A NEMS modelspec containing, at minimum, a dynamic_sigmoid module Returns ------- modelspec : A copy of the input modelspec with priors updated. ''' dsig_idx = find_module('dynamic_sigmoid', modelspec) phi = modelspec[dsig_idx]['phi'] b = phi['base'] a = phi['amplitude'] k = phi['kappa'] s = phi['shift'] b_m = 'base_mod' in phi a_m = 'amplitude_mod' in phi k_m = 'kappa_mod' in phi s_m = 'shift_mod' in phi amp_prior = ('Normal', {'mean': a, 'sd': np.abs(a * 2)}) base_prior = ('Exponential', {'beta': b}) kappa_prior = ('Normal', {'mean': k, 'sd': np.abs(k * 2)}) shift_prior = ('Normal', {'mean': s, 'sd': np.abs(s * 2)}) priors = { 'amplitude': amp_prior, 'base': base_prior, 'kappa': kappa_prior, 'shift': shift_prior } if b_m: priors['base_mod'] = base_prior if a_m: priors['amplitude_mod'] = amp_prior if k_m: priors['kappa_mod'] = kappa_prior if s_m: priors['shift_mod'] = shift_prior modelspec[dsig_idx]['prior'] = priors
def before_and_after_scatter(rec, modelspec, idx, sig_name='pred', compare='resp', smoothing_bins=False, mod_name='Unknown', xlabel1=None, xlabel2=None, ylabel1=None, ylabel2=None): # HACK: shouldn't hardcode 'stim', might be named something else # or not present at all. Need to figure out a better solution # for special case of idx = 0 if idx == 0: # Can't have anything before index 0, so use input stimulus before = rec.copy() before_sig = rec['stim'] before.name = '**stim' else: before = ms.evaluate(rec.copy(), modelspec, start=None, stop=idx) before_sig = before[sig_name] # now evaluate next module step after = ms.evaluate(before.copy(), modelspec, start=idx, stop=idx+1) after_sig = after[sig_name] # compute correlation for pre-module before it's over-written if before[sig_name].shape[0] == 1: corr1 = nm.corrcoef(before, pred_name=sig_name, resp_name=compare) corr2 = nm.corrcoef(after, pred_name=sig_name, resp_name=compare) else: corr1 = 0 corr2 = 0 log.warning('corr coef expects single-dim predictions') compare_to = rec[compare] title1 = '{} vs {} before {}'.format(sig_name, compare, mod_name) title2 = '{} vs {} after {}'.format(sig_name, compare, mod_name) # TODO: These are coming out the same, but that seems unlikely text1 = "r = {0:.5f}".format(corr1) text2 = "r = {0:.5f}".format(corr2) modidx = find_module(mod_name, modelspec) if modidx: module = modelspec[modidx] else: module = None fn1 = partial(plot_scatter, before_sig, compare_to, title=title1, smoothing_bins=smoothing_bins, xlabel=xlabel1, ylabel=ylabel1, text=text1, module=module) fn2 = partial(plot_scatter, after_sig, compare_to, title=title2, smoothing_bins=smoothing_bins, xlabel=xlabel2, ylabel=ylabel2, text=text2) return fn1, fn2
def init_logsig(rec, modelspec): ''' Initialization of priors for logistic_sigmoid, based on process described in methods of Rabinowitz et al. 2014. ''' # preserve input modelspec modelspec = copy.deepcopy(modelspec) logsig_idx = find_module('logistic_sigmoid', modelspec) if logsig_idx is None: log.warning("No logsig module was found, can't initialize.") return modelspec stim = rec['stim'].as_continuous() resp = rec['resp'].as_continuous() # TODO: Maybe need a more sophisticated calculation for this? # Paper isn't very clear on how they calculate "X-bar" and "Y-bar" # They also mention that their stim-resp data is split up into 20 # bins, maybe averaged across trials or something? mean_stim = np.nanmean(stim) min_stim = np.nanmin(stim) max_stim = np.nanmax(stim) stim_range = max_stim - min_stim min_resp = np.nanmin(resp) max_resp = np.nanmax(resp) resp_range = max_resp - min_resp # Rather than setting a hard value for initial phi, # set the prior distributions and let the fitter/analysis # decide how to use it. base = ('Exponential', {'beta': min_resp + 0.05 * (resp_range)}) amplitude = ('Exponential', {'beta': 2 * resp_range}) shift = ('Normal', {'mean': mean_stim, 'sd': stim_range}) kappa = ('Exponential', {'beta': stim_range / stim.shape[1]}) modelspec[logsig_idx]['prior'] = { 'base': base, 'amplitude': amplitude, 'shift': shift, 'kappa': kappa } log.info( "logistic_sigmoid priors initialized to: " "base: %s\namplitude: %s\nshift: %s\nkappa: %s\n", *modelspec[logsig_idx]['prior'].values()) return modelspec
def remove_dsig_bounds(modelspec): dsig_idx = find_module('dynamic_sigmoid', modelspec) if dsig_idx is None: log.warning("No dsig module was found, can't initialize.") return modelspec modelspec = copy.deepcopy(modelspec) modelspec[dsig_idx]['bounds'].update({ 'base': (1e-15, None), 'amplitude': (1e-15, None), 'shift': (None, None), 'kappa': (1e-15, None), 'amplitude_mod': (None, None), 'base_mod': (None, None), 'kappa_mod': (None, None), 'shift_mod': (None, None) }) return modelspec
def _set_gc_phi(modelspec, base, amplitude, shift, kappa): ''' Parameters given as differences, e.g. kappa = -0.5 means set kappa_mod to be 0.5 less than kappa. ''' dsig_idx = find_module('dynamic_sigmoid', modelspec) p = modelspec[dsig_idx]['phi'] b = p['base'] a = p['amplitude'] s = p['shift'] k = p['kappa'] modelspec[dsig_idx]['phi'].update({ 'base_mod': b + base, 'amplitude_mod': a + amplitude, 'shift_mod': s + shift, 'kappa_mod': k + kappa }) return modelspec
def contrast_kernel_heatmap2(rec, modelspec, ax=None, title=None, idx=0, channels=0, xlabel='Lag (s)', ylabel='Channel In', **options): ct_idx = nu.find_module('contrast', modelspec) phi = copy.deepcopy(modelspec[ct_idx]['phi']) fn_kwargs = copy.deepcopy(modelspec[ct_idx]['fn_kwargs']) fs = rec['stim'].fs wc_kwargs = {k: phi[k] for k in ['mean', 'sd']} wc_kwargs['n_chan_in'] = fn_kwargs['n_channels'] fir_kwargs = {k: phi[k] for k in ['tau', 'a', 'b', 's']} fir_kwargs['n_coefs'] = fn_kwargs['n_coefs'] wc_coefs = gaussian_coefficients(**wc_kwargs) fir_coefs = fir_exp_coefficients(**fir_kwargs) if 'offsets' in phi: offsets = phi['offsets'] elif 'offsets' in fn_kwargs: offsets = fn_kwargs['offsets'] else: offsets = None if offsets is not None: fir_coefs = _offset_coefficients(fir_coefs, offsets, fs, pad_bins=True) wc_coefs = np.abs(wc_coefs).T fir_coefs = np.abs(fir_coefs) strf = wc_coefs @ fir_coefs # TODO: This isn't really doing the same operation as an STRF anymore # so it may be better not to plot it this way in the future. _strf_heatmap(strf, wc_coefs, fir_coefs, xlabel=xlabel, ylabel=ylabel, ax=ax, title=title) return ax
def stp_sigmoid_pred_matched(cellid, batch, modelname, LN, include_phi=True): xfspec, ctx = xhelp.load_model_xform(cellid, batch, modelname) ln_spec, ln_ctx = xhelp.load_model_xform(cellid, batch, LN) modelspec = ctx['modelspec'] modelspec.recording = ctx['val'] val = ctx['val'].apply_mask() ln_modelspec = ln_ctx['modelspec'] ln_modelspec.recording = ln_ctx['val'] ln_val = ln_ctx['val'].apply_mask() pred_after_NL = val['pred'].as_continuous().flatten() # with stp val_before_NL = ms.evaluate(ln_val, ln_modelspec, stop=-1) pred_before_NL = val_before_NL['pred'].as_continuous().flatten() # no stp stp_idx = find_module('stp', modelspec) val_before_stp = ms.evaluate(val, modelspec, stop=stp_idx) val_after_stp = ms.evaluate(val, modelspec, stop=stp_idx + 1) pred_before_stp = val_before_stp['pred'].as_continuous().mean( axis=0).flatten() pred_after_stp = val_after_stp['pred'].as_continuous().mean( axis=0).flatten() stp_effect = (pred_after_stp - pred_before_stp) / (pred_after_stp + pred_before_stp) fig = plt.figure() plasma = plt.get_cmap('plasma') plt.scatter(pred_before_NL, pred_after_NL, c=stp_effect, s=2, alpha=0.75, cmap=plasma) plt.title(cellid) plt.xlabel('pred in (no stp)') plt.ylabel('pred out (with stp)') if include_phi: stp_phi = modelspec.phi[stp_idx] phi_string = '\n'.join( ['%s: %.4E' % (k, v) for k, v in stp_phi.items()]) fig.text(0.775, 0.9, phi_string, va='top', ha='left') plt.subplots_adjust(right=0.775, left=0.075) plt.colorbar() return fig
def pca_proj_layer(rec, modelspec, **ctx): from nems.tf.cnnlink_new import fit_tf, fit_tf_init weight_chan_idx = find_module("weight_channels", modelspec, find_all_matches=True) w = weight_chan_idx[-1] coefficients = modelspec.phi[w]['coefficients'].copy() pcs_needed = int(np.ceil(coefficients.shape[1] / 2)) if 'state' in modelspec[w-1]['fn']: w -= 1 try: v = rec.meta['pc_weights'].T[:, :pcs_needed] pc_rec = rec.copy() log.info('Found %d sets of PC weights', pcs_needed) except: pc_rec = resp_to_pc(rec=rec, pc_count=pcs_needed, pc_source='all', overwrite_resp=False, **ctx)['rec'] v = pc_rec.meta['pc_weights'].T[:, :pcs_needed] v = np.concatenate((v, -v), axis=1) pc_modelspec = modelspec.copy() for i in range(w,len(modelspec)): pc_modelspec.pop_module() d = pc_rec.signals['pca'].as_continuous()[:pcs_needed,:] d = np.concatenate((d,-d), axis=0) d = d[:coefficients.shape[1],:] pc_rec['resp'] = pc_rec['resp']._modified_copy(data=d) #_d = fit_tf_init(pc_modelspec, pc_rec, nl_init='skip', use_modelspec_init=True, epoch_name="") _d = fit_tf_init(pc_modelspec, pc_rec, use_modelspec_init=True, epoch_name="") pc_modelspec = _d['modelspec'] modelspec = modelspec.copy() for i in range(w): for k in modelspec.phi[i].keys(): modelspec.phi[i][k] = pc_modelspec.phi[i][k] log.info('modelspec len: %d pc_modelspec len: %d', len(modelspec), len(pc_modelspec)) #pre = modelspec.phi[w]['coefficients'].std() #modelspec.phi[w]['coefficients'] = v[:,:coefficients.shape[1]] #post = modelspec.phi[w]['coefficients'].std() #log.info('Pasted pc weights into N x R = %d x %d weight channels matrix %.3f -> %.3f', v.shape[0], v.shape[1], pre, post) return {'modelspec': modelspec, 'pc_modelspec': pc_modelspec}
def state_gain_plot(modelspec, ax=None, colors=None, clim=None, title=None, **options): state_idx = find_module('state', modelspec) g = modelspec.phi_mean[state_idx]['g'] d = modelspec.phi_mean[state_idx]['d'] ge = modelspec.phi_sem[state_idx]['g'] de = modelspec.phi_sem[state_idx]['d'] MI = modelspec[0]['meta']['state_mod'] state_chans = modelspec[0]['meta']['state_chans'] if ax is not None: plt.sca(ax) else: ax=plt.gca() if d.shape[0] > 1: opt={} for cc in range(d.shape[1]): if colors is not None: opt = {'color': colors[cc]} plt.plot(d[:,cc],'--', **opt) plt.plot(g[:,cc], **opt) else: plt.errorbar(np.arange(len(d[0, :])), d[0, :], de[0, :], color='blue') plt.errorbar(np.arange(len(g[0, :])), g[0, :], ge[0, :], color='red') dz = np.abs(d[0, :] / de[0, :]) gz = np.abs(g[0, :] / ge[0, :]) for i in range(len(gz)): if gz[i] > 2: ax.text(i, g[0, i] + np.sign(g[0, i]) * ge[0, i], state_chans[i], color='red', ha='center', fontsize=6) elif dz[i] > 2: ax.text(i, d[0,i]+np.sign(d[0,i])*de[0,i], state_chans[i], color='blue', ha='center', fontsize=6) #plt.plot(MI) #plt.xticks(np.arange(len(state_chans)), state_chans, fontsize=6) plt.legend(('baseline', 'gain'), frameon=False) plt.plot(np.arange(len(state_chans)),np.zeros(len(state_chans)),'k--', linewidth=0.5) if title: plt.title(title) ax_remove_box(ax)
def single_state_mod_index(rec, modelspec, epoch='REFERENCE', psth_name='pred', state_sig='state', state_chan='pupil'): if type(state_chan) is list: if len(state_chan) == 0: state_chan = rec[state_sig].chans mod_list = [ single_state_mod_index(rec, modelspec, epoch=epoch, psth_name=psth_name, state_sig=state_sig, state_chan=s) for s in state_chan ] return mod_list sidx = find_module('state', modelspec) if sidx is None: raise ValueError("no state signal found") modelspec = copy.deepcopy(modelspec) state_chan_idx = rec[state_sig].chans.index(state_chan) k = np.ones(rec[state_sig].shape[0], dtype=bool) k[0] = False k[state_chan_idx] = False modelspec[sidx]['phi']['d'][:, k] = 0 modelspec[sidx]['phi']['g'][:, k] = 0 newrec = ms.evaluate(rec, modelspec) return np.array( state_mod_index(newrec, epoch=epoch, psth_name=psth_name, state_sig=state_sig, state_chan=state_chan))
def init_dsig(rec, modelspec): ''' Initialization of priors for logistic_sigmoid, based on process described in methods of Rabinowitz et al. 2014. ''' dsig_idx = find_module('dynamic_sigmoid', modelspec) if dsig_idx is None: log.warning("No dsig module was found, can't initialize.") return modelspec modelspec = copy.deepcopy(modelspec) rec = copy.deepcopy(rec) if modelspec[dsig_idx]['fn_kwargs'].get('eq', '') in \ ['dexp', 'd', 'double_exponential']: modelspec = _init_double_exponential(rec, modelspec, dsig_idx) else: modelspec = _init_logistic_sigmoid(rec, modelspec, dsig_idx) return modelspec
def contrast_kernel_heatmap(rec, modelspec, ax=None, title=None, idx=0, channels=0, xlabel='Lag (s)', ylabel='Channel In', **options): ctk_idx = nu.find_module('contrast_kernel', modelspec) phi = copy.deepcopy(modelspec[ctk_idx]['phi']) fn_kwargs = copy.deepcopy(modelspec[ctk_idx]['fn_kwargs']) fs = rec['stim'].fs old = ('auto_copy' in fn_kwargs) if old: fn_kwargs['use_phi'] = True # Remove duplicates from fn_kwargs (phi is more up to date) # to avoid argument collisions removals = [] for k in fn_kwargs: if k in phi: removals.append(k) for k in removals: fn_kwargs.pop(k) strf, wc_coefs, fir_coefs = _get_ctk_coefficients(**fn_kwargs, **phi, fs=fs) _strf_heatmap(strf, wc_coefs, fir_coefs, xlabel=xlabel, ylabel=ylabel, ax=ax, title=title) return ax
def init_relsat(rec, modelspec): modelspec = copy.deepcopy(modelspec) target_i = find_module('saturated_rectifier', modelspec) if target_i is None: log.warning("No relsat module was found, can't initialize.") return modelspec if target_i == len(modelspec): fit_portion = modelspec.modules else: fit_portion = modelspec.modules[:target_i] # generate prediction from module preceeding dexp #rec = ms.evaluate(rec, ms.ModelSpec(fit_portion)).apply_mask() rec = ms.ModelSpec(fit_portion).evaluate(rec).apply_mask() pred = rec['pred'].as_continuous().flatten() resp = rec['resp'].as_continuous().flatten() stdr = np.nanstd(resp) base = np.min(resp) amplitude = min(np.mean(resp) + stdr * 3, np.max(resp)) shift = np.mean(pred) - 1.5 * np.nanstd(pred) kappa = 1 base_prior = ('Exponential', {'beta': base}) amplitude_prior = ('Exponential', {'beta': amplitude}) shift_prior = ('Normal', {'mean': shift, 'sd': shift}) kappa_prior = ('Exponential', {'beta': kappa}) modelspec['prior'] = { 'base': base_prior, 'amplitude': amplitude_prior, 'shift': shift_prior, 'kappa': kappa_prior } return modelspec
def strf_local_lin(rec, modelspec, cursor_time=20, channels=0, **options): rec = rec.copy() tbin = int(cursor_time * rec['resp'].fs) chan_count = rec['stim'].shape[0] firmod = find_module('fir', modelspec) tbin_count = modelspec.phi[firmod]['coefficients'].shape[1] + 2 use_dstrf = True if use_dstrf: index = int(cursor_time * rec['resp'].fs) strf = modelspec.get_dstrf(rec, index=index, width=20, out_channel=channels) else: resp_chan = channels d = rec['stim']._data.copy() strf = np.zeros((chan_count, tbin_count)) _p1 = rec['pred']._data[resp_chan, tbin] eps = np.nanstd(d) / 100 eps = 0.01 #print('eps: {}'.format(eps)) for c in range(chan_count): #eps = np.std(d[c, :])/100 for t in range(tbin_count): _d = d.copy() _d[c, tbin - t] *= 1 + eps rec['stim'] = rec['stim']._modified_copy(data=_d) rec = modelspec.evaluate(rec) _p2 = rec['pred']._data[resp_chan, tbin] strf[c, t] = (_p2 - _p1) / eps print('strf min: {} max: {}'.format(np.min(strf), np.max(strf))) options['clim'] = np.array([-np.max(np.abs(strf)), np.max(np.abs(strf))]) plot_heatmap(strf, cmap=get_setting('FILTER_CMAP'), **options)
def init_dsig(rec, modelspec, nl_mode=2): ''' Initialization of priors for logistic_sigmoid, based on process described in methods of Rabinowitz et al. 2014. ''' dsig_idx = find_module('dynamic_sigmoid', modelspec) if dsig_idx is None: log.warning("No dsig module was found, can't initialize.") return modelspec if modelspec[dsig_idx]['fn_kwargs'].get('eq', '') in \ ['dexp', 'd', 'double_exponential']: modelspec = _init_double_exponential(rec, modelspec, dsig_idx, nl_mode=nl_mode) elif modelspec[dsig_idx]['fn_kwargs'].get('eq', '') in \ ['relsat', 'rs', 'saturated_rectifier']: modelspec = init_relsat(rec, modelspec) else: modelspec = init_logsig(rec, modelspec) return modelspec