def plot_unabs( self, unabs_td_ev: np.ndarray, p_td_ch: np.ndarray = None, prcts=(25., 50., 75.), ): nt = unabs_td_ev.shape[0] t = np.arange(nt) * self.dt ev = self.ev_bin b = npy(self.get_bound(torch.tensor(t))) b_max = np.amax(b) unabs_ev = unabs_td_ev.sum(0) iev_unabs_max = np.amax(np.nonzero(unabs_ev)[0]) iev_unabs_min = np.amin(np.nonzero(unabs_ev)[0]) # ev_unabs_max = ev[iev_unabs_max] dev = ev[1] - ev[0] iev_bound_max = int(np.clip(iev_unabs_max + 2, 0, len(ev) - 1)) iev_bound_min = int(np.clip(iev_unabs_min - 2, 0, len(ev) - 1)) ev_max = ev[iev_bound_max] + dev / 2 ev_min = ev[iev_bound_min] - dev / 2 extent = [t[0], t[-1], ev_min, ev_max] unabs_plot = unabs_td_ev.copy()[:, iev_bound_min:iev_bound_max] if p_td_ch is not None: unabs_plot[:, 0] += p_td_ch[:, 0] unabs_plot[:, -1] += p_td_ch[:, 1] unabs_plot = np2.sumto1(unabs_plot, 1) plt.imshow(npy(unabs_plot.T), extent=[npy(v) for v in extent]) plt.plot(t, b, 'w-') plt.plot(t, -b, 'w-')
def simulate_data(self, pPred_cond_rt_ch: torch.Tensor, seed=0, rt_only=False): torch.random.manual_seed(seed) if rt_only: dcond_tr = self.dcond_tr chSim_tr_dim = self.ch_tr_dim chSimFlat_tr = consts.ch_by_dim2ch_flat(chSim_tr_dim) pPred_tr_rt = pPred_cond_rt_ch[dcond_tr, :, chSimFlat_tr] rtSim_tr = npy( npt.categrnd(probs=npt.sumto1(pPred_tr_rt, -1)) * self.dt) else: dcond_tr = self.dcond_tr pPred_tr_rt_ch = pPred_cond_rt_ch[dcond_tr, :, :] n_tr, nt, n_ch = pPred_tr_rt_ch.shape chSim_tr_rt_ch = npy( npt.categrnd(probs=pPred_tr_rt_ch.reshape([n_tr, -1]))) rtSim_tr = npy((chSim_tr_rt_ch // n_ch) * self.dt) chSim_tr = npy(chSim_tr_rt_ch % n_ch) chs = np.array(consts.CHS) chSim_tr_dim = np.stack( [chs[dim][chSim_tr] for dim in range(consts.N_DIM)], -1) self.update_data(ch_tr_dim=chSim_tr_dim, rt_tr=rtSim_tr)
def get_named_bounded_params( self, named_bounded_params: Dict[str, BoundedParameter] = None, exclude: Iterable[str] = () ) -> (Iterable[str], np.ndarray, np.ndarray, np.ndarray, np.ndarray): """ :param named_bounded_params: :param exclude: :return: names, v, grad, lb, ub """ if named_bounded_params is None: d = odict([ (k, v) for k, v in self.named_modules() if (isinstance(v, OverriddenParameter) # # BoundedParameter) and k not in exclude) ]) else: d = named_bounded_params names = [] v = [] lb = [] ub = [] grad = [] requires_grad = [] for name, param in d.items(): v0 = param.v.flatten() if param._param.grad is None: g0 = torch.zeros_like(v0) else: g0 = param._param.grad.flatten() l0 = npt.tensor(param.lb).expand_as(param.v).flatten() u0 = npt.tensor(param.ub).expand_as(param.v).flatten() for i, (v1, g1, l1, u1) in enumerate(zip(v0, g0, l0, u0)): v.append(npy(v1)) grad.append(npy(g1)) lb.append(npy(l1)) ub.append(npy(u1)) requires_grad.append(npy(param._param.requires_grad)) if v0.numel() > 1: names.append(name + '%d' % i) else: names.append(name) v = np.stack(v) lb = np.stack(lb) ub = np.stack(ub) grad = -np.stack(grad) # minimizing; so take negative requires_grad = np.stack(requires_grad) return names, v, grad, lb, ub, requires_grad
def plot_coefs_dur_odif_pred_data(data, model, kw_plot_model=(), to_plot_data=True, axs=None, coefs_to_plot=(0, 1), dim_incl=(0, 1), **kwargs): with torch.no_grad(): ev_cond_fr_dim_meanvar, n_cond_dur_ch, durs0 = data.get_data_by_cond( 'all')[:3] dt1 = model.dt durs = torch.arange(durs0[0] - dt1 * 0. - 1e-2 * 0, durs0[-1] + dt1, dt1 * 3).clone() out0 = model(ev_cond_fr_dim_meanvar, durs) out1 = npy(out0) n_cond_dur_ch = npy(data.get_data_by_cond('all')[1]) if axs is None: fig = plt.figure(figsize=[6, 4]) kw_plot_model = {'zorder': -1, **kw_plot_model} axs, hs_pred = vd2d.plot_coefs_dur_odif( out1, data=data, durs=npy(durs), kw_plot=kw_plot_model, style='pred', axs=axs, # fig=fig, coefs_to_plot=coefs_to_plot, dim_incl=dim_incl, jitter0=0., **kwargs)[2:4] if to_plot_data: axs, hs_data = vd2d.plot_coefs_dur_odif( npy(n_cond_dur_ch), data=data, style='data', axs=axs, # fig=fig, jitter0=0., coefs_to_plot=coefs_to_plot, dim_incl=dim_incl, **kwargs)[2:4] else: hs_data = None hs = {'pred': hs_pred, 'data': hs_data} return axs, hs
def plot_ch_ev_by_dur1(model, d): fig = plt.figure('ch_ev_by_dur', figsize=[6, 4]) axs = None axs = plot_ch_ev_by_dur(npy(d['out_train_valid']), data, style='pred', axs=axs, fig=fig) axs = plot_ch_ev_by_dur(npy(d['target_train_valid']), data, style='data', axs=axs, fig=fig) return fig, d
def plot_coefs_dur_odif1(model, d): fig = plt.figure('coefs_dur_odif', figsize=[6, 4]) axs = None axs = plot_coefs_dur_odif(npy(d['out_train_valid']), data, style='pred', axs=axs, fig=fig)[2] axs = plot_coefs_dur_odif(npy(d['target_train_valid']), data, style='data', axs=axs, fig=fig)[2] return fig, d
def plot_unabs(model, inp): ev_cond_fr_dim_meanvar, durs = inp with torch.no_grad(): # # unabs[dim, td, cond, ev] # _, unabs = model(*inp, return_unabs=True) p_tds = [] unabss = [] for dim_rel in range(consts.N_DIM): ev1 = ev_cond_fr_dim_meanvar[:, :, dim_rel, :] dtb1d = model.dtb.dtb.dtb1ds[dim_rel] # type: sim1d.Dtb1D p_td, unabs = dtb1d.forward(ev1, return_unabs=True) unabs = unabs.permute([1, 0, 2]) nt = p_td.shape[1] p_td = npy(p_td).reshape([6, 6, nt, 2]) unabs = npy(unabs).reshape([6, 6, nt, -1]) if dim_rel == 0: p_td = p_td.mean(1) unabs = unabs.mean(1) else: p_td = p_td.mean(0) unabs = unabs.mean(0) p_tds.append(p_td) unabss.append(unabs) dtb1ds = model.dtb.dtb.dtb1ds # unabss[dim, cond, td, ev] unabss = np.array(unabss) # p_tds[dim, cond, td, ch] p_tds = np.array(p_tds) n_conds = 6 gs = plt.GridSpec(nrows=n_conds, ncols=consts.N_DIM) for dim_rel, (unabs, p_td, dtb1d) in enumerate(zip(unabss, p_tds, dtb1ds)): for cond, (unabs1, p_td1) in enumerate(zip(unabs, p_td)): ax = plt.subplot(gs[cond, dim_rel]) # type: plt.Axes # noqa dtb1d.plot_unabs(unabs1, p_td1) if dim_rel > 0 or cond < len(unabs) - 1: ax.set_xticklabels([]) ax.set_yticklabels([])
def plot_p_ch_by_dur(self, p_dim_cond_dur_ch0, cond_irr=2, ch=1): p_dim_conds_dur_ch = npy( p_dim_cond_dur_ch0.reshape(p_dim_cond_dur_ch0.shape[:1] + ( 6, 6, ) + (10, 2))) n_cond1 = p_dim_conds_dur_ch.shape[1] cmap = plt2.cool2_rev(n_cond1) for i in range(n_cond1): plt.plot(p_dim_conds_dur_ch[0, i, cond_irr, :, ch].T, color=cmap(i))
def choose_correct_ch(n_cond__rt_ch): """ :param n_cond__rt_ch: [condition, frame, ch] :return: n_cond__rt_correct_ch[cond, frame] """ n_cond__rt_ch = npy(n_cond__rt_ch) n_cond__ch = n_cond__rt_ch.sum(1) n_cond = n_cond__rt_ch.shape[0] correct_ch = np.argmax(n_cond__ch, axis=1) return n_cond__rt_ch[np.arange(n_cond), :, correct_ch]
def plot_rt_distrib1(model, d, data_mode='train_valid'): data = d['data_' + data_mode] out = d['out_' + data_mode] target = d['target_' + data_mode] fig = plt.figure('rtdstr', figsize=[4, 4]) ev_cond_dim = npy(data[:, :, :, 0].sum(1)) axs = plot_rt_distrib(npy(out), ev_cond_dim, alpha_face=0., colors=['b', 'b'], fig=fig)[0] axs = plot_rt_distrib( npy(target), ev_cond_dim, alpha_face=0., colors=['k', 'k'], axs=axs, )[0] return fig, d
def update_data(self, ch_tr_dim: np.ndarray = None, rt_tr: np.ndarray = None, ev_tr_dim: np.ndarray = None): if ch_tr_dim is None: ch_tr_dim = self.ch_tr_dim else: self.ch_tr_dim = ch_tr_dim if rt_tr is None: rt_tr = self.rt_tr else: self.rt_tr = rt_tr if ev_tr_dim is None: ev_tr_dim = self.ev_tr_dim else: self.ev_tr_dim = ev_tr_dim self.ev_cond_dim, self.dcond_tr = self.dat2p_dat( npy(ch_tr_dim), npy(rt_tr), ev_tr_dim)[2:4]
def simulate_data(self, pPred_cond_dur_ch: torch.Tensor, seed=0): dcond_tr = self.dcond_tr ddur_tr = self.ddur_tr pPred_tr_ch = pPred_cond_dur_ch[dcond_tr, ddur_tr, :] torch.random.manual_seed(seed) chSim_tr_ch = npt.categrnd(probs=pPred_tr_ch) chs = np.array(consts.CHS) chSim_tr_dim = np.stack( [chs[dim][npy(chSim_tr_ch)] for dim in range(consts.N_DIM)], -1) self.update_data(ch_tr_dim=chSim_tr_dim)
def plot_coefs_dur_irr_ixn_pred_data( data, model, kw_plot_model=(), to_plot_data=True, axs=None, coefs_to_plot=(2, ), ): with torch.no_grad(): ev_cond_fr_dim_meanvar, n_cond_dur_ch, durs0 = data.get_data_by_cond( 'all')[:3] dt1 = model.dt durs = torch.arange(durs0[0] - dt1 * 0. - 1e-2, durs0[-1] + dt1, dt1 * 1).clone() out0 = model(ev_cond_fr_dim_meanvar, durs) out1 = npy(out0) if axs is None: fig = plt.figure(figsize=[6, 4]) axs = vd2d.plot_coefs_dur_irrixn( out1, data=data, durs=npy(durs), kw_plot=kw_plot_model, style='pred', axs=axs, # fig=fig, jitter0=0., coefs_to_plot=coefs_to_plot)[2] if to_plot_data: axs = vd2d.plot_coefs_dur_irrixn( npy(n_cond_dur_ch), data=data, style='data', axs=axs, # fig=fig, jitter0=0., coefs_to_plot=coefs_to_plot)[2] return axs #, fig
def dat2p_dat( self, ch_tr_dim: np.ndarray, dur_tr: np.ndarray, ev_tr_dim: np.ndarray ) -> (torch.Tensor, torch.Tensor, np.ndarray, np.ndarray, np.ndarray, np.ndarray): """ :param ch_tr_dim: [tr, dim] :param dur_tr: [tr] :param ev_tr_dim: [tr, dim] :return: n_cond_dur_ch[cond, dur, ch], ev_cond_fr_dim_meanvar[dcond, fr, dim, (mean, var)], ev_cond_dim[dcond, dim], dcond_tr[tr], durs[dur], ddur_tr[tr] """ nt0 = self.nt0 dt0 = self.dt0 n_ch_flat = self.n_ch subsample_factor = self.subsample_factor nt = int(nt0 // subsample_factor) durs, ddur_tr = np.unique(dur_tr, return_inverse=True) ddur_tr = ddur_tr.astype(np.int) n_dur = len(durs) durs = torch.tensor(durs) ddur_tr = torch.tensor(ddur_tr, dtype=torch.long) ch_tr_flat = consts.ch_by_dim2ch_flat(ch_tr_dim) ev_cond_dim, dcond_tr = np.unique(ev_tr_dim, return_inverse=True, axis=0) n_cond_flat = len(ev_cond_dim) ev_cond_fr_dim = torch.tensor(ev_cond_dim)[:, None, :].expand( [-1, nt, -1]) ev_cond_fr_dim_meanvar = torch.stack( [ev_cond_fr_dim, torch.zeros_like(ev_cond_fr_dim)], -1) n_cond_dur_ch = npt.tensor( npg.aggregate(np.stack([dcond_tr, npy(ddur_tr), ch_tr_flat]), 1., 'sum', [n_cond_flat, n_dur, n_ch_flat])) return n_cond_dur_ch, ev_cond_fr_dim_meanvar, ev_cond_dim, dcond_tr, \ durs, ddur_tr
def fun_loss(p_cond__rt_ch_pred: torch.Tensor, n_cond__rt_ch_data: torch.Tensor, ignore_hard_RT=False, conds: Union[torch.Tensor, np.ndarray] = None, **kwargs) -> torch.Tensor: """ :param conds: [cond, dim] """ if ignore_hard_RT: conds = npy(conds) ix_conds_to_ignore_rt = np.any(conds == np.amax(np.abs(conds), axis=0), axis=1) else: ix_conds_to_ignore_rt = None return sim1d.fun_loss(p_cond__rt_ch_pred, n_cond__rt_ch_data, ix_conds_to_ignore_rt=ix_conds_to_ignore_rt, **kwargs)
def get_fit_sim( subj: str, seed_sim: int, bufdur_sim: float, bufdur_fit: float, parad='VD', skip_fit_if_absent=False, fix_post=None, ) -> (dict, dict, dict): """ :param subj: :param seed_sim: :param bufdur_sim: :param bufdur_fit: :param parad: :param fix_post: :return: d, dict_fit_sim, dict_subdir_sim """ if fix_post is None: fix_post = ('basym0_fix', 'bhalf042_lb01', 'diffix', 'lps0') # --- Load fit to simulated data def remove_buffix(fix_strs): return [s for s in fix_strs if not s.startswith('(buffix')] _, dict_cache, dict_subdir = vdfit.init_model(subj=subj, bufdur=bufdur_sim, parad=parad, fix_post=fix_post) _, dict_subdir_sim = vdfit.get_subdir(fix_strs=remove_buffix( dict_subdir['fix']), **dict_subdir) dict_sim = { **dict_cache, 'fix': remove_buffix(dict_cache['fix']), 'bufdur_sim': bufdur_sim, 'seed_sim': seed_sim } dict_fit_sim = {**dict_sim, 'bufdur_fit': bufdur_fit} cache_fit_sim = locfile.get_cache('fit_sim', dict_fit_sim, subdir=dict_subdir_sim) if cache_fit_sim.exists(): best_state, d = cache_fit_sim.getdict(['best_state', 'd']) elif skip_fit_if_absent: return None, dict_fit_sim, dict_subdir_sim else: # --- Load model fit to real data _, data, dict_cache, d, subdir = vdfit.load_fit( subj=subj, bufdur=bufdur_sim, fix_post=fix_post, skip_fit_if_absent=skip_fit_if_absent) # --- Simulate new data and save data_sim = deepcopy(data) # type: vd2d.Data2DVD cache_data_sim = locfile.get_cache('data_sim', dict_sim, subdir=dict_subdir_sim) if cache_data_sim.exists(): data_sim.update_data( ch_tr_dim=cache_data_sim.getdict(['chSim_tr_dim'])[0]) else: ch_tr_dim_bef = data_sim.ch_tr_dim.copy() data_bef = npy(data_sim.n_cond_dur_ch).copy() data_sim.simulate_data(pPred_cond_dur_ch=d['out_all'], seed=seed_sim) ch_tr_dim_aft = data_sim.ch_tr_dim.copy() data_aft = npy(data_sim.n_cond_dur_ch).copy() print('Proportion of trials with the same choice:') print(np.mean(ch_tr_dim_bef == ch_tr_dim_aft)) cache_data_sim.set({'chSim_tr_dim': data_sim.ch_tr_dim}) cache_data_sim.save() del cache_data_sim # --- Fit simulated data model, dict_cache, dict_subdir = vdfit.init_model(subj=subj, bufdur=bufdur_fit, fix_post=fix_post) _, best_state, d, plotfuns = vd2d.fit_dtb( model, data_sim, comment='+' + argsutil.dict2fname(dict_fit_sim), max_epoch=vdfit.max_epoch0, ) dtb.dtb_1D_sim.save_fit_results(model=model, best_state=best_state, d=d, plotfuns=plotfuns, locfile=locfile, dict_cache=dict_fit_sim, subdir=dict_subdir_sim) cache_fit_sim.set({ 'best_state': best_state, 'd': {k: v for k, v in d.items() if k.startswith('loss_')} }) cache_fit_sim.save() del cache_fit_sim return d, dict_fit_sim, dict_subdir_sim
def plot_rt_distrib_pred_data( p_pred_cond_rt_ch, n_cond_rt_ch, ev_cond_dim, dt_model, dt_data=None, smooth_sigma_sec=0.1, to_plot_scale=False, to_cumsum=False, to_normalize_max=True, xlim=None, colors=('magenta', 'cyan'), kw_plot_pred=(), kw_plot_data=(), to_skip_zero_trials=False, labels=None, **kwargs ): """ :param n_cond_rt_ch: [cond, rt, ch] = n_tr(cond, rt, ch) :param p_pred_cond_rt_ch: [model, cond, rt, ch] = P(rt, ch | cond, model) :param ev_cond_dim: :param dt_model: :param dt_data: :param smooth_sigma_sec: :param to_plot_scale: :param to_cumsum: :param xlim: :param kwargs: :return: """ axs = None ps = [] ps0 = [] hss = [] p_pred_cond_rt_ch = p_pred_cond_rt_ch / np.sum( p_pred_cond_rt_ch, (-1, -2), keepdims=True) n_preds1 = p_pred_cond_rt_ch * np.sum( n_cond_rt_ch, (-1, -2))[None, :, None, None] nt = p_pred_cond_rt_ch.shape[-2] if dt_data is None: dt_data = dt_model if labels is None: labels = [''] * (len(n_preds1) + 1) for i_pred, n_pred in enumerate(n_preds1): color = colors[i_pred] axs, p0, p1, hs = sim2d.plot_rt_distrib( n_pred, ev_cond_dim, dt=dt_model, axs=axs, alpha=1., smooth_sigma_sec=smooth_sigma_sec, to_skip_zero_trials=to_skip_zero_trials, colors=color, alpha_face=0, to_normalize_max=to_normalize_max, to_cumsum=to_cumsum, to_use_sameaxes=False, kw_plot={ 'linewidth': 1.5, **dict(kw_plot_pred), }, label=labels[i_pred], **kwargs, )[:4] ps.append(p1) ps0.append(p0) hss.append(hs) axs, p0, p1, hs = sim2d.plot_rt_distrib( n_cond_rt_ch, ev_cond_dim, dt=dt_data, axs=axs, smooth_sigma_sec=smooth_sigma_sec, colors='k', alpha_face=0., to_normalize_max=to_normalize_max, # normalize across preds and data instead to_cumsum=to_cumsum, to_skip_zero_trials=to_skip_zero_trials, kw_plot={ 'linewidth': 0.5, **dict(kw_plot_data), }, label=labels[-1], **kwargs, ) ps.append(p1) ps0.append(p0) hss.append(hs) ps = np.stack(ps) ps0 = np.stack(ps0) ps_flat = np.swapaxes(ps, 0, 2).reshape([ps.shape[1] * ps.shape[2], -1]) for ax in axs.flatten(): if xlim is None: if to_cumsum: xlim = [0.5, 4.5] else: xlim = [0.5, 4.5] plt2.detach_axis('x', *xlim, ax=ax) ax.set_xlim(xlim[0] - 0.1, xlim[1] + 0.1) axs[-1, 0].set_xticks(xlim) axs[-1, 0].set_xticklabels(['%g' % v for v in xlim]) from lib.pylabyk import numpytorch as npt t = torch.arange(nt) * dt_model mean_rts = [] for p1 in ps0: p11 = npt.sumto1(torch.tensor(p1).sum([-1, -2])[0, 0, :]) mean_rts.append(npy((torch.tensor(t) * p11).sum())) print('mean_rts:') print(mean_rts) print(mean_rts[1] - mean_rts[0]) conds = [np.unique(ev_cond_dim[:, i]) for i in [0, 1]] p_preds = torch.tensor(n_preds1).reshape([ 2, len(conds[0]), len(conds[1]), nt, 2, 2 ]) + 1e-12 if to_plot_scale: y = 0.8 axs[-1, -1].plot(mean_rts[:2], y + np.zeros(2), 'k-', linewidth=0.5) x = np.mean(mean_rts[:2]) plt.text(x, y + 0.1, '%1.0f ms' % (np.abs(mean_rts[1] - mean_rts[0]) * 1e3), ha='center', va='bottom') return axs, hss
def plot_rt_vs_ev( ev_cond: Union[torch.Tensor, np.ndarray], n_cond__rt_ch: Union[torch.Tensor, np.ndarray], dim_rel=0, group_dcond_irr: Iterable[Iterable[int]] = None, correct_only=True, cmap: Union[str, Callable] = 'cool', kw_plot=(), **kwargs ) -> Tuple[Sequence[Sequence[plt.Line2D]], Sequence[Sequence[float]]]: """ @param ev_cond: [condition] @type ev_cond: torch.Tensor @param n_cond__rt_ch: [condition, frame, ch] @type n_cond__rt_ch: torch.Tensor @return: """ if ev_cond.ndim != 2: assert ev_cond.ndim == 4 ev_cond = npt.p2st(ev_cond.mean(1))[0] assert n_cond__rt_ch.ndim == 3 n_cond__rt_ch = n_cond__rt_ch.copy() ev_cond = npy(ev_cond) n_cond__rt_ch = npy(n_cond__rt_ch) nt = n_cond__rt_ch.shape[1] n_ch = n_cond__rt_ch.shape[2] n_cond_all = n_cond__rt_ch.shape[0] dim_irr = consts.get_odim(dim_rel) conds_rel, dcond_rel = np.unique(ev_cond[:, dim_rel], return_inverse=True) conds_irr0, dcond_irr0 = np.unique(ev_cond[:, dim_irr], return_inverse=True) conds_irr, dcond_irr = np.unique(np.abs(ev_cond[:, dim_irr]), return_inverse=True) # --- Exclude wrong choice on either dim n_cond__rt_ch0 = n_cond__rt_ch.copy() if correct_only: ch_signs = consts.ch_bool2sign(consts.CHS_ARRAY) for dim, ch_signs1 in enumerate(ch_signs): for i_ch, ch_sign1 in enumerate(ch_signs1): for i, cond1 in enumerate(ev_cond[:, dim]): if np.sign(cond1) == -ch_sign1: n_cond__rt_ch[i, :, i_ch] = 0. if group_dcond_irr is not None: conds_irr, dcond_irr = group_conds(conds_irr, dcond_irr, group_dcond_irr) n_conds = [len(conds_rel), len(conds_irr)] if correct_only: # simply split at zero hs = [] rts = [] for dcond_irr1 in range(n_conds[1]): for ch1 in range(consts.N_CH): ch_sign = consts.ch_bool2sign(ch1) incl = dcond_irr == dcond_irr1 ev_cond1 = conds_rel n_cond__rt_ch1 = np.empty([n_conds[0], nt, n_ch]) for dcond_rel1 in range(n_conds[0]): incl1 = incl & (dcond_rel1 == dcond_rel) n_cond__rt_ch1[dcond_rel1] = n_cond__rt_ch[incl1].sum(0) if type(cmap) is str: color = plt.get_cmap(cmap, n_conds[1])(dcond_irr1) else: color = cmap(n_conds[1])(dcond_irr1) chs = np.array(consts.CHS) n_cond__rt_ch11 = np.zeros(n_cond__rt_ch1.shape[:2] + (consts.N_CH, )) # # -- Pool across ch_irr for ch_rel in range(consts.N_CH): incl = chs[dim_rel] == ch_rel n_cond__rt_ch11[:, :, ch_rel] = n_cond__rt_ch1[:, :, incl].sum(-1) hs1, rts1 = sim1d.plot_rt_vs_ev( ev_cond1, n_cond__rt_ch11, color=color, correct_only=correct_only, kw_plot=kw_plot, **kwargs, ) hs.append(hs1) rts.append(rts1) else: raise NotImplementedError() return hs, rts
def fit_dtb(model: FitVD2D, data: Data2DVD, n_fold_valid=1, mode_train='all', to_debug=False, max_epoch=500, **kwargs) -> (float, dict): """ Provide functions fun_data() and plot_*() to ykt.optimize(). See ykt.optimize() for details about fun_data and plot_* :param model: :param data: :param n_fold_valid: :param mode_train: 'all'|'easiest' - which conditions to use in training :param to_debug: :param kwargs: fed to ykt.optimize() :return: best_loss, best_state """ def fun_data1(mode='all', fold_valid=0, epoch=0, n_fold_valid=1): """ :param mode: :param fold_valid: :param epoch: :return: (ev_cond_fr_dim_meanvar, durs), n_cond_dur_ch """ return fun_data(data=data, mode=mode, fold_valid=fold_valid, epoch=epoch, n_fold_valid=n_fold_valid, mode_train=mode_train, to_debug=to_debug) kw_optim = argsutil.kwdefault(argsutil.kwdef( {'n_fold_valid': n_fold_valid}, kwargs), filename_suffix='', optimizer_kind='Adam', learning_rate=.5, patience=100, max_epoch=max_epoch, reduce_lr_after=25, reset_lr_after=50, thres_patience=1e-4, to_print_grad=False) def plot_coefs_dur_odif1(model, d): fig = plt.figure('coefs_dur_odif', figsize=[6, 4]) axs = None axs = plot_coefs_dur_odif(npy(d['out_train_valid']), data, style='pred', axs=axs, fig=fig)[2] axs = plot_coefs_dur_odif(npy(d['target_train_valid']), data, style='data', axs=axs, fig=fig)[2] return fig, d def plot_ch_ev_by_dur1(model, d): fig = plt.figure('ch_ev_by_dur', figsize=[6, 4]) axs = None axs = plot_ch_ev_by_dur(npy(d['out_train_valid']), data, style='pred', axs=axs, fig=fig) axs = plot_ch_ev_by_dur(npy(d['target_train_valid']), data, style='data', axs=axs, fig=fig) return fig, d def plot_unabs1(model, d): fig = plt.figure('unabs', figsize=[3, 7]) plot_unabs(model, d['data_train_valid']) return fig, d plotfuns = [ ('coefs_dur_odif', plot_coefs_dur_odif1), ('ch_ev_by_dur', plot_ch_ev_by_dur1), # ('bound', sim2d.plot_bound), ('unabs', plot_unabs1), ('params', sim2d.plot_params) ] best_loss, best_state, d = ykt.optimize(model, fun_data1, fun_loss, plotfuns=plotfuns, **kw_optim)[:3] with torch.no_grad(): for data_mode in ['train_valid', 'test', 'all']: inp, target = fun_data1(data_mode) out = model(inp) for loss_kind in ['CE', 'NLL', 'BIC']: if loss_kind == 'CE': loss = fun_loss(out, target, to_average=True, base_n_bin=True) elif loss_kind in ['NLL', 'BIC']: loss = fun_loss(out, target, to_average=False, base_n_bin=False) if loss_kind == 'BIC': n = npy(target.sum()) k = np.sum([ v.numel() if v.requires_grad else 0 for v in model.parameters() ]) loss = loss * 2 + k * np.log(n) d['loss_ndata_%s' % data_mode] = n d['loss_nparam'] = k d['loss_%s_%s' % (loss_kind, data_mode)] = loss return best_loss, best_state, d, plotfuns
def plot_rt_vs_ev( ev_cond, n_cond__rt_ch: Union[torch.Tensor, np.ndarray], style='pred', pool='mean', dt=consts.DT, correct_only=True, thres_n_trial=10, color='k', color_ch=('tab:red', 'tab:blue'), ax: plt.Axes = None, kw_plot=(), ) -> (Sequence[plt.Line2D], List[np.ndarray]): """ @param ev_cond: [condition] @type ev_cond: torch.Tensor @param n_cond__rt_ch: [condition, frame, ch] @type n_cond__rt_ch: torch.Tensor @return: """ if ax is None: ax = plt.gca() if ev_cond.ndim != 1: if ev_cond.ndim == 3: ev_cond = npt.p2st(ev_cond)[0] assert ev_cond.ndim == 2 ev_cond = ev_cond.mean(1) assert n_cond__rt_ch.ndim == 3 ev_cond = npy(ev_cond) n_cond__rt_ch = npy(n_cond__rt_ch) def plot_rt_given_cond_ch(ev_cond1, n_rt_given_cond_ch, **kw1): # n_rt_given_cond_ch[cond, fr] # p_rt_given_cond_ch[cond, fr] p_rt_given_cond_ch = np2.sumto1(n_rt_given_cond_ch, 1) nt = n_rt_given_cond_ch.shape[1] t = np.arange(nt) * dt # n_in_cond_ch[cond] n_in_cond_ch = n_rt_given_cond_ch.sum(1) if pool == 'mean': rt_pooled = (t[None, :] * p_rt_given_cond_ch).sum(1) elif pool == 'var': raise NotImplementedError() else: raise ValueError() if style.startswith('data'): rt_pooled[n_in_cond_ch < thres_n_trial] = np.nan kw = get_kw_plot(style, **kw1) h = ax.plot(ev_cond1, rt_pooled, **kw) return h, rt_pooled if correct_only: hs = [] rtss = [] n_cond__rt_ch1 = n_cond__rt_ch.copy() # type: np.ndarray if style.startswith('data'): cond0 = ev_cond == 0 n_cond__rt_ch1[cond0, :, :] = np.sum(n_cond__rt_ch1[cond0, :, :], axis=-1, keepdims=True) # # -- Choose the ch with correct sign (or both chs if cond == 0) for ch in range(consts.N_CH): cond_sign = np.sign(ev_cond) ch_sign = consts.ch_bool2sign(ch) cond_accu = cond_sign != -ch_sign h, rts = plot_rt_given_cond_ch(ev_cond[cond_accu], n_cond__rt_ch1[cond_accu, :, ch], color=color, **dict(kw_plot)) hs.append(h) rtss.append(rts) # rts = np.stack(rtss) rts = rtss else: hs = [] rtss = [] for ch in range(consts.N_CH): # n_rt_given_cond_ch[cond, fr] n_rt_given_cond_ch = n_cond__rt_ch[:, :, ch] h, rts = plot_rt_given_cond_ch(ev_cond, n_rt_given_cond_ch, color=color_ch[ch]) hs.append(h) rtss.append(rts) rts = np.stack(rtss) y_lim = ax.get_ylim() x_lim = ax.get_xlim() plt2.box_off() plt2.detach_axis('x', ax=ax, amin=x_lim[0], amax=x_lim[1]) plt2.detach_axis('y', ax=ax, amin=y_lim[0], amax=y_lim[1]) ax.set_xlabel('evidence') ax.set_ylabel(r"$\mathrm{E}[T^\mathrm{r} \mid c]~(\mathrm{s})$") return hs, rts
def plot_p_ch_vs_ev( ev_cond: Union[torch.Tensor, np.ndarray], n_ch: Union[torch.Tensor, np.ndarray], style='pred', ax: plt.Axes = None, dim_rel=0, group_dcond_irr: Iterable[Iterable[int]] = None, cmap: Union[str, Callable] = 'cool', kw_plot=(), ) -> Iterable[plt.Line2D]: """ @param ev_cond: [condition, dim] or [condition, frame, dim, (mean, var)] @type ev_cond: torch.Tensor @param n_ch: [condition, ch] or [condition, rt_frame, ch] @type n_ch: torch.Tensor @return: hs[cond_irr][0] = Line2D, conds_irr """ if ax is None: ax = plt.gca() if ev_cond.ndim != 2: assert ev_cond.ndim == 4 ev_cond = npt.p2st(ev_cond.mean(1))[0] if n_ch.ndim != 2: assert n_ch.ndim == 3 n_ch = n_ch.sum(1) ev_cond = npy(ev_cond) n_ch = npy(n_ch) n_cond_all = n_ch.shape[0] ch_rel = np.repeat(np.array(consts.CHS[dim_rel])[None, :], n_cond_all, 0) n_ch = n_ch.reshape([-1]) ch_rel = ch_rel.reshape([-1]) dim_irr = consts.get_odim(dim_rel) conds_rel, dcond_rel = np.unique(ev_cond[:, dim_rel], return_inverse=True) conds_irr, dcond_irr = np.unique(np.abs(ev_cond[:, dim_irr]), return_inverse=True) if group_dcond_irr is not None: conds_irr, dcond_irr = group_conds(conds_irr, dcond_irr, group_dcond_irr) n_conds = [len(conds_rel), len(conds_irr)] n_ch_rel = npg.aggregate( np.stack([ ch_rel, np.repeat(dcond_irr[:, None], consts.N_CH_FLAT, 1).flatten(), np.repeat(dcond_rel[:, None], consts.N_CH_FLAT, 1).flatten(), ]), n_ch, 'sum', [consts.N_CH, n_conds[1], n_conds[0]]) p_ch_rel = n_ch_rel[1] / n_ch_rel.sum(0) hs = [] for dcond_irr1, p_ch1 in enumerate(p_ch_rel): if type(cmap) is str: color = plt.get_cmap(cmap, n_conds[1])(dcond_irr1) else: color = cmap(n_conds[1])(dcond_irr1) kw1 = get_kw_plot(style, color=color, **dict(kw_plot)) h = ax.plot(conds_rel, p_ch1, **kw1) hs.append(h) plt2.box_off(ax=ax) x_lim = ax.get_xlim() plt2.detach_axis('x', amin=x_lim[0], amax=x_lim[1], ax=ax) plt2.detach_axis('y', amin=0, amax=1, ax=ax) ax.set_yticks([0, 0.5, 1]) ax.set_yticklabels(['0', '', '1']) ax.set_xlabel('evidence') ax.set_ylabel(r"$\mathrm{P}(z=1 \mid c)$") return hs, conds_irr
def plot_coefs_dur_odif( n_cond_dur_ch: np.ndarray = None, data: Data2DVD = None, dif_irrs: Sequence[Union[Iterable[int], int]] = ((0, 1), (2, )), ev_cond_dim: np.ndarray = None, durs: np.ndarray = None, style='data', kw_plot=(), jitter0=0., coefs_to_plot=(0, 1), add_rowtitle=True, dim_incl=(0, 1), axs=None, fig=None, ): """ :param n_cond_dur_ch: :param data: :param dif_irrs: :param ev_cond_dim: :param durs: :param style: :param kw_plot: :param jitter0: :param coefs_to_plot: :param add_rowtitle: :param dim_incl: :param axs: :param fig: :return: coef, se_coef, axs, hs[coef, dim, dif] """ if ev_cond_dim is None: ev_cond_dim = data.ev_cond_dim if durs is None: durs = npy(data.durs) if n_cond_dur_ch is None: n_cond_dur_ch = npy(data.get_data_by_cond('all')[1]) coef, se_coef = coefdur.get_coefs_mesh_from_histogram( n_cond_dur_ch, ev_cond_dim=ev_cond_dim, dif_irrs=dif_irrs)[:2] # type: (np.ndarray, np.ndarray) coef_names = ['bias', 'slope'] n_coef = len(coefs_to_plot) n_dim = len(dim_incl) if axs is None: if fig is None: fig = plt.figure(figsize=[6, 1.5 * n_coef]) n_row = n_dim n_col = n_coef gs = plt.GridSpec(nrows=n_row, ncols=n_col, figure=fig, left=0.25, right=0.95, bottom=0.15, top=0.9) axs = np.empty([n_row, n_col], dtype=np.object) for row in range(n_row): for col in range(n_col): axs[row, col] = plt.subplot(gs[row, col]) n_dif = len(dif_irrs) hs = np.empty( [len(coefs_to_plot), len(dim_incl), len(dif_irrs)], dtype=np.object) for ii_coef, i_coef in enumerate(coefs_to_plot): coef_name = coef_names[i_coef] for i_dim, dim_rel in enumerate(dim_incl): ax = axs[i_dim, ii_coef] # type: plt.Axes plt.sca(ax) cmap = consts.CMAP_DIM[dim_rel](n_dif) for idif, dif_irr in enumerate(dif_irrs): y = coef[i_coef, dim_rel, idif, :] e = se_coef[i_coef, dim_rel, idif, :] ddur = durs[1] - durs[0] jitter = ddur * jitter0 * (idif - (n_dif - 1) / 2) kw = consts.get_kw_plot(style, color=cmap(idif), for_err=True, **dict(kw_plot)) if style.startswith('data'): h = plt.errorbar(durs + jitter, y, yerr=e, **kw)[0] else: h = plt.plot(durs + jitter, y, **kw) hs[ii_coef, i_dim, idif] = h plt2.box_off() max_dur = np.amax(npy(durs)) ax.set_xlim(-0.05 * max_dur, 1.05 * max_dur) plt2.detach_axis('x', 0, max_dur) if dim_rel == 0: ax.set_xticklabels([]) ax.set_title(coef_name.lower()) elif i_coef == 0: ax.set_xlabel('duration (s)') if add_rowtitle: plt2.rowtitle(consts.DIM_NAMES_LONG, axes=axs) return coef, se_coef, axs, hs
def plot_fit_combined( data: Union[sim2d.Data2DRT, dict] = None, pModel_cond_rt_chFlat=None, model=None, pModel_dimRel_condDense_chFlat=None, # --- in place of data: pAll_cond_rt_chFlat=None, evAll_cond_dim=None, pTrain_cond_rt_chFlat=None, evTrain_cond_dim=None, pTest_cond_rt_chFlat=None, evTest_cond_dim=None, dt=None, # --- optional ev_dimRel_condDense_fr_dim_meanvar=None, dt_model=None, to_plot_internals=True, to_plot_params=True, to_plot_choice=True, # to_group_irr=False, group_dcond_irr=None, to_combine_ch_irr_cond=True, kw_plot_pred=(), kw_plot_pred_ch=(), kw_plot_data=(), axs=None, ): """ :param data: :param pModel_cond_rt_chFlat: :param model: :param pModel_dimRel_condDense_chFlat: :param ev_dimRel_condDense_fr_dim_meanvar: :param to_plot_internals: :param to_plot_params: :param to_group_irr: :param to_combine_ch_irr_cond: :param kw_plot_pred: :param kw_plot_data: :param axs: :return: """ if data is None: if pTrain_cond_rt_chFlat is None: pTrain_cond_rt_chFlat = pAll_cond_rt_chFlat if evTrain_cond_dim is None: evTrain_cond_dim = evAll_cond_dim if pTest_cond_rt_chFlat is None: pTest_cond_rt_chFlat = pAll_cond_rt_chFlat if evTest_cond_dim is None: evTest_cond_dim = evAll_cond_dim else: _, pAll_cond_rt_chFlat, _, _, evAll_cond_dim = \ data.get_data_by_cond('all') _, pTrain_cond_rt_chFlat, _, _, evTrain_cond_dim = data.get_data_by_cond( 'train_valid', mode_train='easiest') _, pTest_cond_rt_chFlat, _, _, evTest_cond_dim = data.get_data_by_cond( 'test', mode_train='easiest') dt = data.dt hs = {} if model is None: assert not to_plot_internals assert not to_plot_params if dt_model is None: if model is None: dt_model = dt else: dt_model = model.dt if axs is None: if to_plot_params: axs = plt2.GridAxes(3, 3) else: if to_plot_internals: axs = plt2.GridAxes(3, 3) else: if to_plot_choice: axs = plt2.GridAxes(2, 2) else: axs = plt2.GridAxes(1, 2) # TODO: beautify ratios rts = [] hs['rt'] = [] for dim_rel in range(consts.N_DIM): # --- data_pred may not have all conditions, so concatenate the rest # of the conditions so that the color scale is correct. Then also # concatenate p_rt_ch_data_pred1 with zeros so that nothing is # plotted in the concatenated. evTest_cond_dim1 = np.concatenate([ evTest_cond_dim, evAll_cond_dim ], axis=0) pTest_cond_rt_chFlat1 = np.concatenate([ pTest_cond_rt_chFlat, np.zeros_like(pAll_cond_rt_chFlat) ], axis=0) if ev_dimRel_condDense_fr_dim_meanvar is None: evModel_cond_dim = evAll_cond_dim else: if ev_dimRel_condDense_fr_dim_meanvar.ndim == 5: evModel_cond_dim = npy(ev_dimRel_condDense_fr_dim_meanvar[ dim_rel][:, 0, :, 0]) else: assert ev_dimRel_condDense_fr_dim_meanvar.ndim == 4 evModel_cond_dim = npy(ev_dimRel_condDense_fr_dim_meanvar[ dim_rel][:, 0, :]) pModel_cond_rt_chFlat = npy(pModel_dimRel_condDense_chFlat[dim_rel]) if to_plot_choice: # --- Plot choice ax = axs[0, dim_rel] plt.sca(ax) if to_combine_ch_irr_cond: ev_cond_model1, p_rt_ch_model1 = combine_irr_cond( dim_rel, evModel_cond_dim, pModel_cond_rt_chFlat ) sim2d.plot_p_ch_vs_ev(ev_cond_model1, p_rt_ch_model1, dim_rel=dim_rel, style='pred', group_dcond_irr=None, kw_plot=kw_plot_pred_ch, cmap=lambda n: lambda v: [0., 0., 0.], ) else: sim2d.plot_p_ch_vs_ev(evModel_cond_dim, pModel_cond_rt_chFlat, dim_rel=dim_rel, style='pred', group_dcond_irr=group_dcond_irr, kw_plot=kw_plot_pred, cmap=cmaps[dim_rel] ) hs, conds_irr = sim2d.plot_p_ch_vs_ev( evTest_cond_dim1, pTest_cond_rt_chFlat1, dim_rel=dim_rel, style='data_pred', group_dcond_irr=group_dcond_irr, cmap=cmaps[dim_rel], kw_plot=kw_plot_data, ) hs1 = [h[0] for h in hs] odim = 1 - dim_rel odim_name = consts.DIM_NAMES_LONG[odim] legend_odim(conds_irr, hs1, odim_name) sim2d.plot_p_ch_vs_ev(evTrain_cond_dim, pTrain_cond_rt_chFlat, dim_rel=dim_rel, style='data_fit', group_dcond_irr=group_dcond_irr, cmap=cmaps[dim_rel], kw_plot=kw_plot_data ) plt2.detach_axis('x', np.amin(evTrain_cond_dim[:, dim_rel]), np.amax(evTrain_cond_dim[:, dim_rel])) ax.set_xlabel('') ax.set_xticklabels([]) if dim_rel != 0: plt2.box_off(['left']) plt.yticks([]) ax.set_ylabel('P(%s choice)' % consts.CH_NAMES[dim_rel][1]) # --- Plot RT ax = axs[int(to_plot_choice) + 0, dim_rel] plt.sca(ax) hs1, rts1 = sim2d.plot_rt_vs_ev( evModel_cond_dim, pModel_cond_rt_chFlat, dim_rel=dim_rel, style='pred', group_dcond_irr=group_dcond_irr, dt=dt_model, kw_plot=kw_plot_pred, cmap=cmaps[dim_rel] ) hs['rt'].append(hs1) rts.append(rts1) sim2d.plot_rt_vs_ev(evTest_cond_dim1, pTest_cond_rt_chFlat1, dim_rel=dim_rel, style='data_pred', group_dcond_irr=group_dcond_irr, dt=dt, cmap=cmaps[dim_rel], kw_plot=kw_plot_data ) sim2d.plot_rt_vs_ev(evTrain_cond_dim, pTrain_cond_rt_chFlat, dim_rel=dim_rel, style='data_fit', group_dcond_irr=group_dcond_irr, dt=dt, cmap=cmaps[dim_rel], kw_plot=kw_plot_data ) plt2.detach_axis('x', np.amin(evTrain_cond_dim[:, dim_rel]), np.amax(evTrain_cond_dim[:, dim_rel])) if dim_rel != 0: ax.set_ylabel('') plt2.box_off(['left']) plt.yticks([]) ax.set_xlabel(consts.DIM_NAMES_LONG[dim_rel].lower() + ' strength') if dim_rel == 0: ax.set_ylabel('RT (s)') if to_plot_internals: for ch1 in range(consts.N_CH): ch0 = dim_rel ax = axs[3 + ch1, dim_rel] plt.sca(ax) ch_flat = consts.ch_by_dim2ch_flat(np.array([ch0, ch1])) model.tnds[ch_flat].plot_p_tnd() ax.set_xlabel('') ax.set_xticklabels([]) ax.set_yticks([0, 1]) if ch0 > 0: ax.set_yticklabels([]) ax.set_ylabel(r"$\mathrm{P}(T^\mathrm{n} \mid" " \mathbf{z}=[%d,%d])$" % (ch0, ch1)) ax = axs[5, dim_rel] plt.sca(ax) if hasattr(model.dtb, 'dtb1ds'): model.dtb.dtb1ds[dim_rel].plot_bound(color='k') plt2.sameaxes(axs[-1, :consts.N_DIM], xy='y') if to_plot_params: ax = axs[0, -1] plt.sca(ax) model.plot_params() return axs, rts, hs
def plot_coefs_dur_irrixn( n_cond_dur_ch: np.ndarray = None, data: Data2DVD = None, ev_cond_dim: np.ndarray = None, durs: np.ndarray = None, style='data', kw_plot=(), jitter0=0., coefs_to_plot=(2, ), axs=None, fig=None, ): if ev_cond_dim is None: ev_cond_dim = data.ev_cond_dim if durs is None: durs = npy(data.durs) if n_cond_dur_ch is None: n_cond_dur_ch = npy(data.get_data_by_cond('all')[1]) coef, se_coef = coefdur.get_coefs_irr_ixn_from_histogram( n_cond_dur_ch, ev_cond_dim=ev_cond_dim)[:2] # type: (np.ndarray, np.ndarray) coef_names = [ 'bias', 'slope', 'rel x abs(irr)', 'rel x irr', 'abs(irr)', 'irr' ] n_coef = len(coefs_to_plot) if axs is None: if fig is None: fig = plt.figure(figsize=[6, 1.5 * n_coef]) n_row = consts.N_DIM n_col = n_coef gs = plt.GridSpec(nrows=n_row, ncols=n_col, figure=fig, left=0.25, right=0.95, bottom=0.15, top=0.9) axs = np.empty([n_row, n_col], dtype=np.object) for row in range(n_row): for col in range(n_col): axs[row, col] = plt.subplot(gs[row, col]) for ii_coef, i_coef in enumerate(coefs_to_plot): coef_name = coef_names[i_coef] for dim_rel in range(consts.N_DIM): ax = axs[dim_rel, ii_coef] # type: plt.Axes plt.sca(ax) y = coef[i_coef, dim_rel, :] e = se_coef[i_coef, dim_rel, :] jitter = 0. kw_plot = {'color': 'k', 'for_err': True, **dict(kw_plot)} kw = consts.get_kw_plot(style, **kw_plot) if style.startswith('data'): plt.errorbar(durs + jitter, y, yerr=e, **kw) else: plt.plot(durs + jitter, y, **kw) plt2.box_off() max_dur = np.amax(npy(durs)) ax.set_xlim(-0.05 * max_dur, 1.05 * max_dur) plt2.detach_axis('x', 0, max_dur) # , ax=ax) if dim_rel == 0: ax.set_xticklabels([]) ax.set_title(coef_name.lower()) elif i_coef == 0: ax.set_xlabel('duration (s)') plt2.rowtitle(consts.DIM_NAMES_LONG, axes=axs) return coef, se_coef, axs
def optimize( model: ModelType, fun_data: FunDataType, fun_loss: FunLossType, plotfuns: PlotFunsType, optimizer_kind='Adam', max_epoch=100, patience=20, # How many epochs to wait before quitting thres_patience=0.001, # How much should it improve wi patience learning_rate=.5, reduce_lr_by=0.5, reduced_lr_on_epoch=0, reduce_lr_after=50, reset_lr_after=100, to_plot_progress=True, show_progress_every=5, # number of epochs to_print_grad=True, n_fold_valid=1, epoch_to_check=None, # CHECKED comment='', **kwargs # to ignore unnecessary kwargs ) -> (float, dict, dict, List[float], List[float]): """ :param model: :param fun_data: (mode='all'|'train'|'valid'|'train_valid'|'test', fold_valid=0, epoch=0, n_fold_valid=1) -> (data, target) :param fun_loss: (out, target) -> loss :param plotfuns: [(str, fun)] where fun takes dict d with keys 'data_*', 'target_*', 'out_*', 'loss_*', where * = 'train', 'valid', etc. :param optimizer_kind: :param max_epoch: :param patience: :param thres_patience: :param learning_rate: :param reduce_lr_by: :param reduced_lr_on_epoch: :param reduce_lr_after: :param to_plot_progress: :param show_progress_every: :param to_print_grad: :param n_fold_valid: :param kwargs: :return: loss_test, best_state, d, losses_train, losses_valid where d contains 'data_*', 'target_*', 'out_*', and 'loss_*', where * is 'train_valid', 'test', and 'all'. """ def get_optimizer(model, lr): if optimizer_kind == 'SGD': return optim.SGD(model.parameters(), lr=lr) elif optimizer_kind == 'Adam': return optim.Adam(model.parameters(), lr=lr) elif optimizer_kind == 'LBFGS': return optim.LBFGS(model.parameters(), lr=lr) else: raise NotImplementedError() learning_rate0 = learning_rate optimizer = get_optimizer(model, learning_rate) best_loss_epoch = 0 best_loss_valid = np.inf best_state = model.state_dict() best_losses = [] # CHECKED storing and loading states state0 = None loss0 = None data0 = None target0 = None out0 = None outs0 = None def array2str(v): return ', '.join(['%1.2g' % v1 for v1 in v.flatten()[:10]]) def print_targ_out(target0, out0, outs0, loss0): print('target:\n' + array2str(target0)) print('outs:\n' + '\n'.join( ['[%s]' % array2str(v) for v in outs0])) print('out:\n' + array2str(out0)) print('loss: ' + '%g' % loss0) def fun_outs(model, data): p_bef_lapse0 = model.dtb(*data)[0].detach().clone() p_aft_lapse0 = model.lapse(p_bef_lapse0).detach().clone() return [ p_bef_lapse0, p_aft_lapse0 ] def are_all_equal(outs, outs0): for i, (out1, out0) in enumerate(zip(outs, outs0)): if (out1 != out0).any(): warnings.warn( 'output %d different! max diff = %g' % (i, (out1 - out0).abs().max())) print('--') # losses_train[epoch] = average cross-validated loss for the epoch losses_train = [] losses_valid = [] if to_plot_progress: writer = SummaryWriter(comment=comment) t_st = time.time() epoch = 0 try: for epoch in range(max([max_epoch, 1])): losses_fold_train = [] losses_fold_valid = [] for i_fold in range(n_fold_valid): # NOTE: Core part data_train, target_train = fun_data('train', i_fold, epoch, n_fold_valid) model.train() if optimizer_kind == 'LBFGS': def closure(): optimizer.zero_grad() out_train = model(data_train) loss = fun_loss(out_train, target_train) loss.backward() return loss if max_epoch > 0: optimizer.step(closure) out_train = model(data_train) loss_train1 = fun_loss(out_train, target_train) raise NotImplementedError( 'Restoring best state is not implemented yet' ) else: optimizer.zero_grad() out_train = model(data_train) loss_train1 = fun_loss(out_train, target_train) # DEBUGGED: optimizer.step() must not be taken before # storing best_loss or best_state losses_fold_train.append(loss_train1) if n_fold_valid == 1: out_valid = npt.tensor(npy(out_train)) loss_valid1 = npt.tensor(npy(loss_train1)) data_valid = data_train target_valid = target_train # DEBUGGED: Unless directly assigned, target_valid != # target_train when n_fold_valid = 1, which doesn't make # sense. Suggests a bug in fun_data when n_fold = 1 else: model.eval() data_valid, target_valid = fun_data('valid', i_fold, epoch, n_fold_valid) out_valid = model(data_valid) loss_valid1 = fun_loss(out_valid, target_valid) model.train() losses_fold_valid.append(loss_valid1) loss_train = torch.mean(torch.stack(losses_fold_train)) loss_valid = torch.mean(torch.stack(losses_fold_valid)) losses_train.append(npy(loss_train)) losses_valid.append(npy(loss_valid)) if to_plot_progress: writer.add_scalar( 'loss_train', loss_train, global_step=epoch ) writer.add_scalar( 'loss_valid', loss_valid, global_step=epoch ) # --- Store best loss # NOTE: storing losses/states must happen BEFORE taking a step! if loss_valid < best_loss_valid: # is_best = True best_loss_epoch = deepcopy(epoch) best_loss_valid = npt.tensor(npy(loss_valid)) best_state = model.state_dict() best_losses.append(best_loss_valid) # CHECKED storing and loading state if epoch == epoch_to_check: loss0 = loss_valid.detach().clone() state0 = model.state_dict() data0 = deepcopy(data_valid) target0 = deepcopy(target_valid) out0 = out_valid.detach().clone() outs0 = fun_outs(model, data0) loss001 = fun_loss(out0, target0) # CHECKED: loss001 must equal loss0 print('loss001 - loss0: %g' % (loss001 - loss0)) print_targ_out(target0, out0, outs0, loss0) print('--') def print_loss(): t_el = time.time() - t_st print('%1.0f sec/%d epochs = %1.1f sec/epoch, Ltrain: %f, ' 'Lvalid: %f, LR: %g, best: %f, epochB: %d' % (t_el, epoch + 1, t_el / (epoch + 1), loss_train, loss_valid, learning_rate, best_loss_valid, best_loss_epoch)) if epoch % show_progress_every == 0: model.train() data_train_valid, target_train_valid = fun_data( 'train_valid', i_fold, epoch, n_fold_valid ) out_train_valid = model(data_train_valid) loss_train_valid = fun_loss(out_train_valid, target_train_valid) print_loss() if to_plot_progress: d = { 'data_train': data_train, 'data_valid': data_valid, 'data_train_valid': data_train_valid, 'out_train': out_train.detach(), 'out_valid': out_valid.detach(), 'out_train_valid': out_train_valid.detach(), 'target_train': target_train.detach(), 'target_valid': target_valid.detach(), 'target_train_valid': target_train_valid.detach(), 'loss_train': loss_train.detach(), 'loss_valid': loss_valid.detach(), 'loss_train_valid': loss_train_valid.detach() } for k, f in odict(plotfuns).items(): fig, d = f(model, d) if fig is not None: writer.add_figure(k, fig, global_step=epoch) # --- Learning rate reduction and patience # if epoch == reduced_lr_on_epoch + reset_lr_after # if epoch == reduced_lr_on_epoch + reduce_lr_after and ( # best_loss_valid # > best_losses[-reduce_lr_after] - thres_patience # ): if epoch > 0 and epoch % reset_lr_after == 0: learning_rate = learning_rate0 elif epoch > 0 and epoch % reduce_lr_after == 0: learning_rate *= reduce_lr_by optimizer = get_optimizer(model, learning_rate) reduced_lr_on_epoch = epoch if epoch >= patience and ( best_loss_valid > best_losses[-patience] - thres_patience ): print('Ran out of patience!') if to_print_grad: print_grad(model) break # --- Take a step if optimizer_kind != 'LBFGS': # steps are not taken above for n_fold_valid == 1, so take a # step here, after storing the best state loss_train.backward() if to_print_grad and epoch == 0: print_grad(model) if max_epoch > 0: optimizer.step() except Exception as ex: from lib.pylabyk.cacheutil import is_keyboard_interrupt if not is_keyboard_interrupt(ex): raise ex print('fit interrupted by user at epoch %d' % epoch) from lib.pylabyk.localfile import LocalFile, datetime4filename localfile = LocalFile() cache = localfile.get_cache('model_data_target') data_train_valid, target_train_valid = fun_data( 'all', 0, 0, n_fold_valid) cache.set({ 'model': model, 'data_train_valid': data_train_valid, 'target_train_valid': target_train_valid }) cache.save() print_loss() if to_plot_progress: writer.close() if epoch_to_check is not None: # Must print the same output as previous call to print_targ_out print_targ_out(target0, out0, outs0, loss0) model.load_state_dict(state0) state1 = model.state_dict() for (key0, param0), (key1, param1) in zip( state0.items(), state1.items() ): # type: ((str, torch.Tensor), (str, torch.Tensor)) if (param0 != param1).any(): with torch.no_grad(): warnings.warn( 'Strange! loaded %s = %s\n' '!= stored %s = %s\n' 'loaded - stored = %s' % (key1, param1, key0, param0, param1 - param0)) data, target = fun_data('valid', 0, epoch_to_check, n_fold_valid) if not torch.is_tensor(data): p_unequal = torch.tensor([ (v1 != v0).double().mean() for v1, v0 in zip(data, data0) ]) if (p_unequal > 0).any(): print('Strange! loaded data != stored data0\n' 'Proportion: %s' % p_unequal) else: print('All loaded data == stored data') elif (data != data0).any(): print('Strange! loaded data != stored data0') else: print('All loaded data == stored data') if (target != target0).any(): print('Strange! loaded target != stored target0') else: print('All loaded target == stored target') print_targ_out(target0, out0, outs0, loss0) # with torch.no_grad(): # out01 = model(data0) # loss01 = fun_loss(out01, target0) model.train() # with torch.no_grad(): # CHECKED # outs1 = fun_outs(model, data) # are_all_equal(outs1, outs0) out1 = model(data) if (out0 != out1).any(): warnings.warn( 'Strange! out from loaded params != stored out\n' 'Max abs(loaded - stored): %g' % (out1 - out0).abs().max()) print('--') else: print('out from loaded params = stored out') loss01 = fun_loss(out0, target0) print_targ_out(target0, out0, outs0, loss01) if loss0 != loss01: warnings.warn( 'Strange! loss1 = %g simply computed again with out0, ' 'target0\n' '!= stored loss0 = %g\n' 'loaded - stored: %g\n' 'Therefore, fun_loss, out0, or target0 has changed!' % (loss01, loss0, loss01 - loss0)) print('--') else: print('loss0 == loss01, simply computed again with out0, target0') loss1 = fun_loss(out1, target) if loss0 != loss1: warnings.warn( 'Strange! loss1 = %g from loaded params\n' '!= stored loss0 = %g\n' 'loaded - stored: %g' % (loss1, loss0, loss1 - loss0)) print('--') else: print('loss1 = %g = loss0 = %g' % (loss1, loss0)) loss10 = fun_loss(out1, target0) if loss0 != loss1: warnings.warn( 'Strange! loss10 = %g from loaded params and stored ' 'target0\n' '!= stored loss0 = %g\n' 'loaded - stored: %g' % (loss10, loss0, loss10 - loss0)) print('--') else: print('loss10 = %g = loss10 = %g' % (loss1, loss0)) print('--') model.load_state_dict(best_state) d = {} for mode in ['train_valid', 'valid', 'test', 'all']: data, target = fun_data(mode, 0, 0, n_fold_valid) out = model(data) loss = fun_loss(out, target) d.update({ 'data_' + mode: data, 'target_' + mode: target, 'out_' + mode: npt.tensor(npy(out)), 'loss_' + mode: npt.tensor(npy(loss)) }) if d['loss_valid'] != best_loss_valid: print('d[loss_valid] = %g from loaded best_state \n' '!= best_loss_valid = %g\n' 'd[loss_valid] - best_loss_valid = %g' % (d['loss_valid'], best_loss_valid, d['loss_valid'] - best_loss_valid)) print('--') if isinstance(model, OverriddenParameter): print(model.__str__()) elif isinstance(model, BoundedModule): pprint(model._parameters_incl_bounded) else: pprint(model.state_dict()) return d['loss_test'], best_state, d, losses_train, losses_valid
def plot_ch_ev_by_dur(n_cond_dur_ch: np.ndarray = None, data: Data2DVD = None, dif_irrs: Sequence[Union[Iterable[int], int]] = ((0, 1), (2, )), ev_cond_dim: np.ndarray = None, durs: np.ndarray = None, dur_prct_groups=((0, 33), (33, 67), (67, 100)), style='data', kw_plot=(), jitter=0., axs=None, fig=None): """ Panels[dim, irr_dif_group], curves by dur_group :param n_cond_dur_ch: :param data: :param dif_irrs: :param ev_cond_dim: :param durs: :param dur_prct_groups: :param style: :param kw_plot: :param axs: [row, col] :return: """ if ev_cond_dim is None: ev_cond_dim = data.ev_cond_dim if durs is None: durs = data.durs if n_cond_dur_ch is None: n_cond_dur_ch = npy(data.get_data_by_cond('all')[1]) n_conds_dur_chs = get_n_conds_dur_chs(n_cond_dur_ch) conds_dim = [np.unique(cond1) for cond1 in ev_cond_dim.T] n_dur = len(dur_prct_groups) n_dif = len(dif_irrs) if axs is None: if fig is None: fig = plt.figure(figsize=[6, 4]) n_row = consts.N_DIM n_col = n_dur gs = plt.GridSpec(nrows=n_row, ncols=n_col, figure=fig) axs = np.empty([n_row, n_col], dtype=np.object) for row in range(n_row): for col in range(n_col): axs[row, col] = plt.subplot(gs[row, col]) for dim_rel in range(consts.N_DIM): for idif, dif_irr in enumerate(dif_irrs): for i_dur, dur_prct_group in enumerate(dur_prct_groups): ax = axs[dim_rel, i_dur] plt.sca(ax) conds_rel = conds_dim[dim_rel] dim_irr = consts.get_odim(dim_rel) _, cond_irr = np.unique(np.abs(conds_dim[dim_irr]), return_inverse=True) incl_irr = np.isin(cond_irr, dif_irr) cmap = consts.CMAP_DIM[dim_rel](n_dif) ix_dur = np.arange(len(durs)) dur_incl = ( (ix_dur >= np.percentile(ix_dur, dur_prct_group[0])) & (ix_dur <= np.percentile(ix_dur, dur_prct_group[1]))) if dim_rel == 0: n_cond_dur_ch1 = n_conds_dur_chs[:, incl_irr, :, :, :].sum( (1, -1)) else: n_cond_dur_ch1 = n_conds_dur_chs[incl_irr, :, :, :, :].sum( (0, -2)) n_cond_ch1 = n_cond_dur_ch1[:, dur_incl, :].sum(1) p_cond__ch1 = n_cond_ch1[:, 1] / n_cond_ch1.sum(1) x = conds_rel y = p_cond__ch1 dx = conds_rel[1] - conds_rel[0] jitter1 = 0 kw = consts.get_kw_plot(style, color=cmap(idif), **dict(kw_plot)) if style.startswith('data'): plt.plot(x + jitter1, y, **kw) else: plt.plot(x + jitter1, y, **kw) plt2.box_off(ax=ax) plt2.detach_yaxis(0, 1, ax=ax) ax.set_yticks([0, 0.5, 1.]) if i_dur == 0: ax.set_yticklabels(['0', '', '1']) else: ax.set_yticklabels([]) ax.set_xticklabels([]) return axs