def plot_p_ch_vs_ev(ev_cond, p_ch, style='pred', ax: plt.Axes = None, **kwargs) -> plt.Line2D: """ @param ev_cond: [condition] or [condition, frame] @type ev_cond: torch.Tensor @param p_ch: [condition, ch] or [condition, rt_frame, ch] @type p_ch: torch.Tensor @return: """ if ax is None: ax = plt.gca() if ev_cond.ndim != 1: if ev_cond.ndim == 3: ev_cond = npt.p2st(ev_cond)[0] assert ev_cond.ndim == 2 ev_cond = ev_cond.mean(1) if p_ch.ndim != 2: assert p_ch.ndim == 3 p_ch = p_ch.sum(1) kwargs = get_kw_plot(style, **kwargs) h = ax.plot(*npys(ev_cond, npt.p2st(npt.sumto1(p_ch, -1))[1]), **kwargs) plt2.box_off(ax=ax) x_lim = ax.get_xlim() plt2.detach_axis('x', amin=x_lim[0], amax=x_lim[1], ax=ax) plt2.detach_axis('y', amin=0, amax=1, ax=ax) ax.set_yticks([0, 0.5, 1]) ax.set_yticklabels(['0', '', '1']) ax.set_xlabel('evidence') ax.set_ylabel(r"$\mathrm{P}(z=1 \mid c)$") return h
def beautify(ax): xticks = np.arange(0, 1.2 + 0.1, 0.1) plt.sca(ax) plt.xlim(xmax=1.2 * 1.05) plt2.detach_axis('x', 0, 1.2) plt2.box_off()
def plot_params(self, named_bounded_params: Sequence[Tuple[ str, BoundedParameter]] = None, exclude: Iterable[str] = (), cmap='coolwarm', ax: plt.Axes = None) -> mpl.container.BarContainer: if ax is None: ax = plt.gca() ax = plt.gca() names, v, grad, lb, ub, requires_grad = self.get_named_bounded_params( named_bounded_params, exclude=exclude) max_grad = np.amax(np.abs(grad)) if max_grad == 0: max_grad = 1. v01 = (v - lb) / (ub - lb) grad01 = (grad + max_grad) / (max_grad * 2) n = len(v) grad01[~requires_grad] = np.nan # (np.amin(grad01) + np.amax(grad01)) / 2 # ax = plt.gca() # CHECKED for i, (lb1, v1, ub1, g1, r1) in enumerate(zip(lb, v, ub, grad, requires_grad)): color = 'k' if r1 else 'gray' plt.text(0, i, '%1.2g' % lb1, ha='left', va='center', color=color) plt.text(1, i, '%1.2g' % ub1, ha='right', va='center', color=color) plt.text( 0.5, i, '%1.2g %s' % (v1, ('(e%1.0f)' % np.log10(np.abs(g1))) if r1 else '(fixed)'), ha='center', va='center', color=color) lut = 256 colors = plt.get_cmap(cmap, lut)(grad01) for i, r in enumerate(requires_grad): if not r: colors[i, :3] = np.array([0.95, 0.95, 0.95]) h = ax.barh(np.arange(n), v01, left=0, color=colors) ax.set_xlim(-0.025, 1) ax.set_xticks([]) ax.set_yticks(np.arange(n)) names = [v.replace('_', '-') for v in names] ax.set_yticklabels(names) ax.set_ylim(n - 0.5, -0.5) plt2.box_off(['top', 'right', 'bottom']) plt2.detach_axis('x', amin=0, amax=1) plt2.detach_axis('y', amin=0, amax=n - 1) # plt.show() # CHECKED return h
def beautify(row, col, axs): plt.sca(axs[row, col]) plt.xlim(xmax=1.2 * 1.05) plt2.detach_axis('x', 0, 1.2) plt2.box_off() beautify_ticks(axs[row, col], add_ticklabel=(row == n_row - 1) and (col == 0)) if col > 0: plt.ylabel('')
def plot_bars(gs1, bufdurs, losses1, add_xticklabel=True): i_break = np.amax(np.nonzero(np.array(bufdurs) < 0.3)[0]) bax = brokenaxes( subplot_spec=gs1, xlims=((-1, i_break + 0.5), (i_break + 0.5, len(bufdurs) - 0.5)), ylims=((-3, 20), (20, 1250 / 5)), height_ratios=(50 / 100, 500 / (1250 - 100)), hspace=0.15, wspace=0.075, d=0.005, ) bax.bar(np.arange(len(bufdurs)), losses1[i_subj, :], color='k') ax11 = bax.axs[3] # type: plt.Axes ax11.set_xticks([bufdurs.index(0.6), bufdurs.index(1.2)]) if i_subj == 0 and add_xticklabel: ax11.set_xticklabels(['0.6', '1.2']) else: ax11.set_xticklabels([]) ax00 = bax.axs[0] # type: plt.Axes ax00.set_yticks([500, 1000]) ax10 = bax.axs[2] # type: plt.Axes ax10.set_yticks([0, 50]) plt.sca(ax10) plt2.detach_axis('x', amin=-0.4, amax=i_break + 0.5) for ax in [ax10, ax11]: plt.sca(ax) plt.axhline(0, linewidth=0.5, color='k', linestyle='--') for sign in [-1, 1]: plt.axhline(sign * thres_strong, linewidth=0.5, color='silver', linestyle='--') ax10.set_xticks([bufdurs.index(0.), bufdurs.index(0.2)]) if i_subj == 0: if add_xticklabel: ax10.set_xticklabels(['0', '0.2']) else: ax10.set_xticklabels([]) else: ax10.set_yticklabels([]) ax10.set_xticklabels([]) ax00.set_yticklabels([]) return bax
def plot_bound(self, t_all: Sequence[float] = None, ax: plt.Axes = None, **kwargs) -> plt.Line2D: if ax is None: ax = plt.gca() if t_all is None: t_all = self.t_all kwargs = argsutil.kwdefault(kwargs, color='k', linestyle='-') h = ax.plot(*npys(t_all, self.get_bound(t_all)), **kwargs) ax.set_xlabel('time (s)') ax.set_ylabel(r"$b(t)$") y_lim = ax.get_ylim() y_min = -y_lim[1] * 0.05 ax.set_ylim(ymin=y_min) plt2.detach_axis('y', amin=0) plt2.detach_axis('x', amin=0) plt2.box_off() return h
def plot_p_tnd(self, t_all: Sequence[float] = None, ax: plt.Axes = None, **kwargs) -> Union[plt.Line2D, Sequence[plt.Line2D]]: if t_all is None: t_all = self.t_all if ax is None: ax = plt.gca() p_tnd = self.get_p_tnd(t_all) kwargs = argsutil.kwdefault(kwargs, color='k', linestyle='-') h = ax.plot(*npys(t_all, p_tnd), **kwargs) ax.set_xlabel('time (s)') plt2.box_off(ax=ax) # ax.set_ylim(ymin=0) # ax.set_xlim(xmin=0) plt2.detach_axis('y', amin=0, amax=1, ax=ax) plt2.detach_axis('x', amin=0, ax=ax) return h
def plot_rt_distrib( n_cond_rt_ch: np.ndarray, ev_cond_dim: np.ndarray, abs_cond=True, lump_wrong=True, dt=consts.DT, colors=None, alpha=1., alpha_face=0.5, smooth_sigma_sec=0.05, to_normalize_max=False, to_cumsum=False, to_exclude_last_frame=True, to_skip_zero_trials=False, label='', # to_exclude_bins_wo_trials=10, kw_plot=(), fig=None, axs=None, to_use_sameaxes=True, ): """ :param n_cond_rt_ch: :param ev_cond_dim: :param abs_cond: :param lump_wrong: :param dt: :param gs: :param colors: :param alpha: :param smooth_sigma_sec: :param kw_plot: :param axs: :return: axs, p_cond01__rt_ch01, p_cond01__rt_ch01_sm, hs """ if colors is None: colors = ['red', 'blue'] elif type(colors) is str: colors = [colors] * 2 else: assert len(colors) == 2 nt = n_cond_rt_ch.shape[1] t_all = np.arange(nt) * dt + dt out = np.meshgrid(np.unique(ev_cond_dim[:, 0]), np.unique(ev_cond_dim[:, 1]), np.arange(nt), np.arange(2), np.arange(2), indexing='ij') cond0, cond1, fr, ch0, ch1 = [v.flatten() for v in out] from copy import deepcopy n0 = deepcopy(n_cond_rt_ch) if to_exclude_last_frame: n0[:, -1, :] = 0. n0 = n0.flatten() def sign_cond(v): v1 = np.sign(v) v1[v == 0] = 1 return v1 if abs_cond: ch0 = consts.ch_bool2sign(ch0) ch1 = consts.ch_bool2sign(ch1) # 1 = correct, -1 = wrong ch0 = sign_cond(cond0) * ch0 ch1 = sign_cond(cond1) * ch1 cond0 = np.abs(cond0) cond1 = np.abs(cond1) ch0 = consts.ch_sign2bool(ch0).astype(np.int) ch1 = consts.ch_sign2bool(ch1).astype(np.int) else: raise ValueError() if lump_wrong: # treat all choices as correct when cond == 0 ch00 = ch0 | (cond0 == 0) ch10 = ch1 | (cond1 == 0) ch0 = (ch00 & ch10) ch1 = np.ones_like(ch00, dtype=np.int) cond_dim = np.stack([cond0, cond1], -1) conds = [] dcond_dim = [] for cond in cond_dim.T: conds1, dcond1 = np.unique(cond, return_inverse=True) conds.append(conds1) dcond_dim.append(dcond1) dcond_dim = np.stack(dcond_dim) n_cond01_rt_ch01 = npg.aggregate( [*dcond_dim, fr, ch0, ch1], n0, 'sum', [*(np.amax(dcond_dim, 1) + 1), nt, consts.N_CH, consts.N_CH]) p_cond01__rt_ch01 = np2.nan2v(n_cond01_rt_ch01 / n_cond01_rt_ch01.sum( (2, 3, 4), keepdims=True)) n_conds = p_cond01__rt_ch01.shape[:2] if axs is None: axs = plt2.GridAxes( n_conds[1], n_conds[0], left=0.6, right=0.3, bottom=0.45, top=0.74, widths=[1], heights=[1], wspace=0.04, hspace=0.04, ) kw_label = { 'fontsize': 12, } pad = 8 axs[0, 0].set_title('strong\nmotion', pad=pad, **kw_label) axs[0, -1].set_title('weak\nmotion', pad=pad, **kw_label) axs[0, 0].set_ylabel('strong\ncolor', labelpad=pad, **kw_label) axs[-1, 0].set_ylabel('weak\ncolor', labelpad=pad, **kw_label) if smooth_sigma_sec > 0: from scipy import signal, stats sigma_fr = smooth_sigma_sec / dt width = np.ceil(sigma_fr * 2.5).astype(np.int) kernel = stats.norm.pdf(np.arange(-width, width + 1), 0, sigma_fr) kernel = np2.vec_on(kernel, 2, 5) p_cond01__rt_ch01_sm = signal.convolve(p_cond01__rt_ch01, kernel, mode='same') else: p_cond01__rt_ch01_sm = p_cond01__rt_ch01.copy() if to_cumsum: p_cond01__rt_ch01_sm = np.cumsum(p_cond01__rt_ch01_sm, axis=2) if to_normalize_max: p_cond01__rt_ch01_sm = np2.nan2v( p_cond01__rt_ch01_sm / np.amax(np.abs(p_cond01__rt_ch01_sm), (2, 3, 4), keepdims=True)) n_row = n_conds[1] n_col = n_conds[0] for dcond0 in range(n_conds[0]): for dcond1 in range(n_conds[1]): row = n_row - 1 - dcond1 col = n_col - 1 - dcond0 ax = axs[row, col] # type: plt.Axes for ch0 in [0, 1]: for ch1 in [0, 1]: if lump_wrong and ch1 == 0: continue p1 = p_cond01__rt_ch01_sm[dcond0, dcond1, :, ch0, ch1] kw = { 'linewidth': 1, 'color': colors[ch1], 'alpha': alpha, 'zorder': 1, **dict(kw_plot) } y = p1 * consts.ch_bool2sign(ch0) p_cond01__rt_ch01_sm[dcond0, dcond1, :, ch0, ch1] = y if to_skip_zero_trials and np.sum(np.abs(y)) < 1e-2: h = None else: h = ax.plot( t_all, y, label=label if ch0 == 1 and ch1 == 1 else None, **kw) ax.fill_between(t_all, 0, y, ec='None', fc=kw['color'], alpha=alpha_face, zorder=-1) plt2.box_off(ax=ax) ax.axhline(0, color='k', linewidth=0.5) ax.set_yticklabels([]) if row < n_row - 1 or col > 0: ax.set_xticklabels([]) ax.set_xticks([]) plt2.box_off(['bottom'], ax=ax) else: ax.set_xlabel('RT (s)') # if col > 0: ax.set_yticks([]) ax.set_yticklabels([]) plt2.box_off(['left'], ax=ax) plt2.detach_axis('x', 0, 5, ax=ax) if to_use_sameaxes: plt2.sameaxes(axs) axs[-1, 0].set_xlabel('RT (s)') return axs, p_cond01__rt_ch01, p_cond01__rt_ch01_sm, h
def plot_p_ch_vs_ev( ev_cond: Union[torch.Tensor, np.ndarray], n_ch: Union[torch.Tensor, np.ndarray], style='pred', ax: plt.Axes = None, dim_rel=0, group_dcond_irr: Iterable[Iterable[int]] = None, cmap: Union[str, Callable] = 'cool', kw_plot=(), ) -> Iterable[plt.Line2D]: """ @param ev_cond: [condition, dim] or [condition, frame, dim, (mean, var)] @type ev_cond: torch.Tensor @param n_ch: [condition, ch] or [condition, rt_frame, ch] @type n_ch: torch.Tensor @return: hs[cond_irr][0] = Line2D, conds_irr """ if ax is None: ax = plt.gca() if ev_cond.ndim != 2: assert ev_cond.ndim == 4 ev_cond = npt.p2st(ev_cond.mean(1))[0] if n_ch.ndim != 2: assert n_ch.ndim == 3 n_ch = n_ch.sum(1) ev_cond = npy(ev_cond) n_ch = npy(n_ch) n_cond_all = n_ch.shape[0] ch_rel = np.repeat(np.array(consts.CHS[dim_rel])[None, :], n_cond_all, 0) n_ch = n_ch.reshape([-1]) ch_rel = ch_rel.reshape([-1]) dim_irr = consts.get_odim(dim_rel) conds_rel, dcond_rel = np.unique(ev_cond[:, dim_rel], return_inverse=True) conds_irr, dcond_irr = np.unique(np.abs(ev_cond[:, dim_irr]), return_inverse=True) if group_dcond_irr is not None: conds_irr, dcond_irr = group_conds(conds_irr, dcond_irr, group_dcond_irr) n_conds = [len(conds_rel), len(conds_irr)] n_ch_rel = npg.aggregate( np.stack([ ch_rel, np.repeat(dcond_irr[:, None], consts.N_CH_FLAT, 1).flatten(), np.repeat(dcond_rel[:, None], consts.N_CH_FLAT, 1).flatten(), ]), n_ch, 'sum', [consts.N_CH, n_conds[1], n_conds[0]]) p_ch_rel = n_ch_rel[1] / n_ch_rel.sum(0) hs = [] for dcond_irr1, p_ch1 in enumerate(p_ch_rel): if type(cmap) is str: color = plt.get_cmap(cmap, n_conds[1])(dcond_irr1) else: color = cmap(n_conds[1])(dcond_irr1) kw1 = get_kw_plot(style, color=color, **dict(kw_plot)) h = ax.plot(conds_rel, p_ch1, **kw1) hs.append(h) plt2.box_off(ax=ax) x_lim = ax.get_xlim() plt2.detach_axis('x', amin=x_lim[0], amax=x_lim[1], ax=ax) plt2.detach_axis('y', amin=0, amax=1, ax=ax) ax.set_yticks([0, 0.5, 1]) ax.set_yticklabels(['0', '', '1']) ax.set_xlabel('evidence') ax.set_ylabel(r"$\mathrm{P}(z=1 \mid c)$") return hs, conds_irr
def plot_coef_by_dur_vs_odif(coef, se_coef, normalize_bias=True, coef_name='slope', savefig=True, difs=[[2, 1], [0]], t_RDK_durs=np.arange(1, 11) * 0.12, fig=None, horizontal_panels=False): """ @param dat0: @param parad: @param use_data: @param correct_only: @param plot_slope: if False, plot bias @param i_subj: @param savefig: @param fig: @return: """ if fig is None: if horizontal_panels: fig = plt.figure(figsize=(7, 2)) else: fig = plt.figure(figsize=(4, 3)) if horizontal_panels: gs = plt.GridSpec(figure=fig, nrows=1, ncols=2, hspace=0.3, left=0.11, right=0.98, top=0.9, bottom=0.15) else: gs = mpl.gridspec.GridSpec(figure=fig, nrows=2, ncols=1, left=0.22, right=0.98, top=0.9, bottom=0.15) coef_names = ['bias', 'slope'] i_coef = coef_names.index(coef_name) units = ['(logit)', '(logit/coh)'] if normalize_bias: coef[0] = -coef[0] / coef[1] se_coef[0] = -se_coef[0] / coef[1] units[0] = '(coh)' unit = units[i_coef] y = coef[i_coef] se = se_coef[i_coef] if len(difs) == 1: labels = ['all'] elif len(difs) == 2: labels = ['easy', 'hard'] elif len(difs) == 3: labels = ['easy', 'medium', 'hard'] else: raise ValueError() for dim, (slope1, se_slope1) in enumerate(zip(y, se)): odim = consts.N_DIM - 1 - dim if horizontal_panels: plt.subplot(gs[0, dim]) else: plt.subplot(gs[dim, 0]) # plt.plot(t_RDK_durs, slope[dim]) for odif, (slope2, se_slope2) in enumerate(zip(slope1, se_slope1)): label = labels[odif] + ' ' + consts.DIM_NAMES_SHORT[odim] plt.errorbar(t_RDK_durs, slope2, se_slope2, marker='o', label=label) plt.ylabel(consts.DIM_NAMES_LONG[dim].lower() + ' ' + coef_name + '\n' + unit) plt.axis('auto') plt.xlim(xmin=-0.02) if coef_name == 'slope': plt.ylim(ymin=-0.05) plt2.box_off() plt2.detach_axis('x') if dim == 0 and not horizontal_panels: plt.gca().set_xticklabels([]) plt.legend(handlelength=0.4, frameon=False, loc='upper left') plt.xlabel('stimulus duration (s)') return gs
def main_fit( subjs=None, skip_fit_if_absent=False, bufdur=None, bufdur_best=0.12, bufdurs_sim=None, bufdurs_fit=None, loss_kind='NLL', plot_kind='line_sim_fit', fix_post=None, seed_sims=(0, ), err_kind='std', base=10, ): # if torch.cuda.is_available(): # torch.set_default_tensor_type(torch.cuda.FloatTensor) torch.set_num_threads(1) torch.set_default_dtype(torch.double) dict_res_add = {} parad = 'VD' seed_sims = np.array(seed_sims) if subjs is None: subjs = consts.SUBJS_VD if bufdur is not None: bufdurs_sim = list({bufdur_best}.union({bufdur})) bufdurs_fit = list({bufdur_best}.union({bufdur})) else: if bufdurs_sim is None: bufdurs_sim = bufdurs0 # if bufdurs_fit is None: bufdurs_fit = bufdurs0 else: if bufdurs_fit is None: bufdurs_fit0 = [0., bufdur_best, 1.2] if bufdurs_sim[0] in bufdurs_fit0: bufdurs_fit = copy(bufdurs_fit0) else: bufdurs_fit = bufdurs_fit0[:2] + bufdurs_sim + [1.2] n_dur_sim = len(bufdurs_sim) n_dur_fit = len(bufdurs_fit) # DEF: losses[seed, subj, simSerDurPar, fitDurs] size_all = [len(seed_sims), len(subjs), n_dur_sim, n_dur_fit] losses0 = np.zeros(size_all) + np.nan ns = np.zeros(size_all) + np.nan ks0 = np.zeros(size_all) + np.nan for i_seed, seed_sim in enumerate(seed_sims): for i_subj, subj in enumerate(subjs): for i_fit, bufdur_fit in enumerate(bufdurs_fit): for i_sim, bufdur_sim in enumerate(bufdurs_sim): d, dict_fit_sim, dict_subdir_sim = get_fit_sim( subj, seed_sim, bufdur_sim, bufdur_fit, parad=parad, skip_fit_if_absent=skip_fit_if_absent, fix_post=fix_post, ) if d is not None: losses0[i_seed, i_subj, i_sim, i_fit] = d['loss_NLL_test'] ns[i_seed, i_subj, i_sim, i_fit] = d['loss_ndata_test'] ks0[i_seed, i_subj, i_sim, i_fit] = d['loss_nparam'] ks = copy(ks0) if loss_kind == 'BIC': losses = ks * np.log(ns) + 2 * losses0 # REF: Kass, Raftery 1995 https://doi.org/10.2307%2F2291091 # https://en.wikipedia.org/wiki/Bayesian_information_criterion#Gaussian_special_case thres_strong = 10. / np.log(base) elif loss_kind == 'NLL': losses = copy(losses0) thres_strong = np.log(100) / np.log(base) else: raise ValueError('Unsupported loss_kind: %s' % loss_kind) #%% losses = losses / np.log(base) if plot_kind == 'loss_serpar': n_row = len(subjs) n_col = 1 axs = plt2.GridAxes(n_row, n_col, left=1, widths=2, bottom=0.75) for i_subj, subj in enumerate(subjs): plt.sca(axs[i_subj, 0]) loss_dur_sim_ser_fit = losses[:, i_subj, :, 0].mean(0) loss_dur_sim_par_fit = losses[:, i_subj, :, -1].mean(0) dloss_serpar = loss_dur_sim_par_fit - loss_dur_sim_ser_fit plt.bar(bufdurs_sim, dloss_serpar, color='k') plt.axhline(0, color='k', linestyle='-', linewidth=0.5) if i_subj < len(subjs) - 1: plt2.box_off(['top', 'right', 'bottom']) plt.xticks([]) else: plt2.box_off() plt.xticks(np.arange(0, 1.4, 0.2), ['0\nser'] + [''] * 2 + ['0.6'] + [''] * 2 + ['1.2\npar']) plt2.detach_axis('x', 0, 1.2) vdfit.patch_chance_level() plt.sca(axs[0, 0]) plt.title('Support for Serial\n' '($\mathcal{L}_\mathrm{ser} - \mathcal{L}_\mathrm{par}$)') plt.sca(axs[-1, 0]) plt.xlabel('true buffer duration (s)') plt2.rowtitle(consts.SUBJS_VD, axs) elif plot_kind == 'imshow_ser_buf_par': n_row = len(subjs) n_col = 1 axs = plt2.GridAxes(n_row, n_col, left=1, widths=2, heights=2, bottom=0.75) for i_subj, subj in enumerate(subjs): plt.sca(axs[i_subj, 0]) dloss = losses[:, i_subj, :, :].mean(0) dloss = dloss - np.diag(dloss)[:, None] plt.imshow(dloss, cmap='bwr') cmax = np.amax(np.abs(dloss)) plt.clim(-cmax, +cmax) plt.sca(axs[0, 0]) plt.sca(axs[-1, 0]) plt.xlabel('true buffer duration (s)') plt.ylabel('model buffer duration (s)') plt2.rowtitle(consts.SUBJS_VD, axs) elif plot_kind == 'line_sim_fit': bufdur_bests = np.array([get_bufdur_best(subj) for subj in subjs]) i_bests = np.array([ list(bufdurs_fit).index(bufdur_best) for bufdur_best in bufdur_bests ]) i_subjs = np.arange(len(subjs)) loss_sim_best_fit_best = losses[:, i_subjs, i_bests, i_bests] loss_sim_best_fit_rest = (losses[:, i_subjs, i_bests, :] - loss_sim_best_fit_best[:, :, None]) mean_loss_sim_best_fit_rest = np.mean(loss_sim_best_fit_rest, 0) if err_kind == 'std': err_loss_sim_best_fit_rest = np.std(loss_sim_best_fit_rest, 0) elif err_kind == 'sem': err_loss_sim_best_fit_rest = np2.sem(loss_sim_best_fit_rest, 0) n_dur = len(bufdurs_fit) loss_sim_rest_fit_best = np.swapaxes( np.stack([ losses[:, i_subj, np.arange(n_dur), i_best] - losses[:, i_subj, np.arange(n_dur), np.arange(n_dur)] for i_subj, i_best in zip(i_subjs, i_bests) ]), 0, 1) mean_loss_sim_rest_fit_best = np.mean(loss_sim_rest_fit_best, 0) if err_kind == 'std': err_loss_sim_rest_fit_best = np.std(loss_sim_rest_fit_best, 0) elif err_kind == 'sem': err_loss_sim_rest_fit_best = np2.sem(loss_sim_rest_fit_best, 0) dict_res_add['err'] = err_kind n_row = 2 n_col = len(subjs) axs = plt2.GridAxes(n_row, n_col, left=1.15, right=0.1, widths=2, heights=1.5, wspace=0.35, hspace=0.5, top=0.25, bottom=0.6) for row, (m, s, xlabel, ylabel) in enumerate([ (mean_loss_sim_best_fit_rest, err_loss_sim_best_fit_rest, 'model buffer capacity (s)', (r'$-\mathrm{log}_{%g}\mathrm{BF}$ given simulated' + '\nbest duration data') % base), (mean_loss_sim_rest_fit_best, err_loss_sim_rest_fit_best, 'simulated data buffer capacity (s)', (r'$-\mathrm{log}_{%g}\mathrm{BF}$ of the' + '\nbest duration model') % base) ]): for i_subj, subj in enumerate(subjs): ax = axs[row, i_subj] plt.sca(ax) gs1 = axs.gs[row * 2 + 1, i_subj * 2 + 1] bax = vdfit.breakaxis(gs1) bax.axs[1].errorbar(bufdurs0, m[i_subj, :], yerr=s[i_subj, :], color='k', marker='.', linewidth=0.75, elinewidth=0.5, markersize=3) m1 = copy(m[i_subj, :]) m1[3:] = np.nan s1 = copy(s[i_subj, :]) s1[3:] = np.nan bax.axs[0].errorbar(bufdurs0, m1, yerr=s1, color='k', marker='.', linewidth=0.75, elinewidth=0.5, markersize=3) ax1 = bax.axs[1] # type: plt.Axes plt.sca(ax1) vdfit.patch_chance_level(level=thres_strong, signs=[-1, 1]) plt.axhline(0, color='k', linestyle='--', linewidth=0.5) vdfit.beautify_ticks(ax1, ) vdfit.beautify(ax1) ax1.set_yticks([0, 20]) if i_subj > 0: ax1.set_yticklabels([]) if row == 0: ax1.set_xticklabels([]) ax1 = bax.axs[0] # type: plt.Axes plt.sca(ax1) ax1.set_yticks([40, 200]) if i_subj > 0: ax1.set_yticklabels([]) plt.sca(ax) plt2.box_off('all') if row == 0: plt.title(consts.SUBJS['VD'][i_subj]) if i_subj == 0: plt.xlabel(xlabel, labelpad=8 if row == 0 else 20) plt.ylabel(ylabel, labelpad=30) plt2.sameaxes(bax.axs, xy='x') elif plot_kind == 'bar_sim_fit': i_best = bufdurs_fit.index(bufdur_best) loss_sim_best_fit_best = losses[0, :, i_best, i_best] loss_sim_best_fit_rest = np2.nan2v(losses[0, :, i_best, :] - loss_sim_best_fit_best[:, None]) n_dur = len(bufdurs_fit) loss_sim_rest_fit_best = np2.nan2v( losses[0, :, np.arange(n_dur), i_best] - losses[0, :, np.arange(n_dur), np.arange(n_dur)]).T n_row = 2 n_col = len(subjs) axs = plt2.GridAxes(n_row, n_col, left=1, widths=3, right=0.25, heights=1, wspace=0.3, hspace=0.6, top=0.25, bottom=0.6) for i_subj, subj in enumerate(subjs): def plot_bars(gs1, bufdurs, losses1, add_xticklabel=True): i_break = np.amax(np.nonzero(np.array(bufdurs) < 0.3)[0]) bax = brokenaxes( subplot_spec=gs1, xlims=((-1, i_break + 0.5), (i_break + 0.5, len(bufdurs) - 0.5)), ylims=((-3, 20), (20, 1250 / 5)), height_ratios=(50 / 100, 500 / (1250 - 100)), hspace=0.15, wspace=0.075, d=0.005, ) bax.bar(np.arange(len(bufdurs)), losses1[i_subj, :], color='k') ax11 = bax.axs[3] # type: plt.Axes ax11.set_xticks([bufdurs.index(0.6), bufdurs.index(1.2)]) if i_subj == 0 and add_xticklabel: ax11.set_xticklabels(['0.6', '1.2']) else: ax11.set_xticklabels([]) ax00 = bax.axs[0] # type: plt.Axes ax00.set_yticks([500, 1000]) ax10 = bax.axs[2] # type: plt.Axes ax10.set_yticks([0, 50]) plt.sca(ax10) plt2.detach_axis('x', amin=-0.4, amax=i_break + 0.5) for ax in [ax10, ax11]: plt.sca(ax) plt.axhline(0, linewidth=0.5, color='k', linestyle='--') for sign in [-1, 1]: plt.axhline(sign * thres_strong, linewidth=0.5, color='silver', linestyle='--') ax10.set_xticks([bufdurs.index(0.), bufdurs.index(0.2)]) if i_subj == 0: if add_xticklabel: ax10.set_xticklabels(['0', '0.2']) else: ax10.set_xticklabels([]) else: ax10.set_yticklabels([]) ax10.set_xticklabels([]) ax00.set_yticklabels([]) return bax bax = plot_bars(axs.gs[1, i_subj * 2 + 1], bufdurs_fit, loss_sim_best_fit_rest, add_xticklabel=False) ax = axs[0, i_subj] plt.sca(ax) plt2.box_off('all') plt.title(consts.SUBJS['VD'][consts.SUBJS['VD'].index(subj)]) if i_subj == 0: ax.set_ylabel( 'misfit to simulated\nbest duration data\n' r'($\Delta$BIC)', labelpad=35) ax.set_xlabel('model buffer duration (s)', labelpad=8) plot_bars(axs.gs[3, i_subj * 2 + 1], bufdurs_sim, loss_sim_rest_fit_best, add_xticklabel=True) ax = axs[1, i_subj] plt.sca(ax) plt2.box_off('all') if i_subj == 0: ax.set_ylabel( 'misfit of\nbest duration model\n' r'($\Delta$BIC)', labelpad=35) ax.set_xlabel('simulated data buffer duration (s)', labelpad=20) elif plot_kind == 'bar_ser_buf_par': n_row = len(subjs) n_col = 1 axs = plt2.GridAxes(n_row * n_dur_sim, n_col, left=1.25, widths=1.5, right=0.25, heights=0.4, hspace=[0.15] * (n_dur_sim - 1) + [0.5] + [0.15] * (n_dur_sim - 1), top=0.6, bottom=0.5) row = -1 for i_subj, subj in enumerate(subjs): for i_sim in range(n_dur_sim): row += 1 plt.sca(axs[row, 0]) loss1 = losses[:, i_subj, i_sim, :].mean(0) dloss = loss1 - loss1[i_sim] x = np.arange(n_dur_fit) for x1, dloss1 in zip(x, dloss): plt.bar(x1, dloss1, color='r' if dloss1 > 0 else 'b') plt.axhline(0, color='k', linewidth=0.5) vdfit.patch_chance_level(6.) plt.ylim([-100, 100]) plt.yticks([-100, 0, 100], [''] * 2) plt.ylabel('%g' % bufdurs_sim[i_sim], rotation=0, va='center', ha='right') if i_sim == 0: plt.title(subj) if i_sim < n_dur_sim - 1 or i_subj < len(subjs) - 1: plt2.box_off(['top', 'bottom', 'right']) plt.xticks([]) else: plt2.box_off(['top', 'right']) plt.sca(axs[-1, 0]) plt.xticks(np.arange(n_dur_sim), ['%g' % v for v in bufdurs_sim]) plt2.detach_axis('x', 0, n_dur_sim - 1) plt.xlabel('model buffer duration (s)', fontsize=10) c = axs[-1, 0].get_position().corners() plt.figtext(x=0.15, y=np.mean(c[:2, 1]), fontsize=10, s='true\nbuffer\nduration\n(s)', rotation=0, va='center', ha='center') plt.figtext( x=(1.25 + 1.5 / 2) / (1.25 + 1.5 + 0.25), y=0.98, s=r'$\mathrm{BIC} - \mathrm{BIC}_\mathrm{true}$', ha='center', va='top', fontsize=12, ) if plot_kind != 'None': dict_res = deepcopy(dict_fit_sim) # noqa for k in ['sbj', 'prd']: dict_res.pop(k) dict_res = {**dict_res, 'los': loss_kind} dict_res.update(dict_res_add) for ext in ['.png', '.pdf']: file = locfile.get_file_fig(plot_kind, dict_res, subdir=dict_subdir_sim, ext=ext) # noqa plt.savefig(file, dpi=300) print('Saved to %s' % file) #%% print('--') return losses
def plot_coefs_dur_odif( n_cond_dur_ch: np.ndarray = None, data: Data2DVD = None, dif_irrs: Sequence[Union[Iterable[int], int]] = ((0, 1), (2, )), ev_cond_dim: np.ndarray = None, durs: np.ndarray = None, style='data', kw_plot=(), jitter0=0., coefs_to_plot=(0, 1), add_rowtitle=True, dim_incl=(0, 1), axs=None, fig=None, ): """ :param n_cond_dur_ch: :param data: :param dif_irrs: :param ev_cond_dim: :param durs: :param style: :param kw_plot: :param jitter0: :param coefs_to_plot: :param add_rowtitle: :param dim_incl: :param axs: :param fig: :return: coef, se_coef, axs, hs[coef, dim, dif] """ if ev_cond_dim is None: ev_cond_dim = data.ev_cond_dim if durs is None: durs = npy(data.durs) if n_cond_dur_ch is None: n_cond_dur_ch = npy(data.get_data_by_cond('all')[1]) coef, se_coef = coefdur.get_coefs_mesh_from_histogram( n_cond_dur_ch, ev_cond_dim=ev_cond_dim, dif_irrs=dif_irrs)[:2] # type: (np.ndarray, np.ndarray) coef_names = ['bias', 'slope'] n_coef = len(coefs_to_plot) n_dim = len(dim_incl) if axs is None: if fig is None: fig = plt.figure(figsize=[6, 1.5 * n_coef]) n_row = n_dim n_col = n_coef gs = plt.GridSpec(nrows=n_row, ncols=n_col, figure=fig, left=0.25, right=0.95, bottom=0.15, top=0.9) axs = np.empty([n_row, n_col], dtype=np.object) for row in range(n_row): for col in range(n_col): axs[row, col] = plt.subplot(gs[row, col]) n_dif = len(dif_irrs) hs = np.empty( [len(coefs_to_plot), len(dim_incl), len(dif_irrs)], dtype=np.object) for ii_coef, i_coef in enumerate(coefs_to_plot): coef_name = coef_names[i_coef] for i_dim, dim_rel in enumerate(dim_incl): ax = axs[i_dim, ii_coef] # type: plt.Axes plt.sca(ax) cmap = consts.CMAP_DIM[dim_rel](n_dif) for idif, dif_irr in enumerate(dif_irrs): y = coef[i_coef, dim_rel, idif, :] e = se_coef[i_coef, dim_rel, idif, :] ddur = durs[1] - durs[0] jitter = ddur * jitter0 * (idif - (n_dif - 1) / 2) kw = consts.get_kw_plot(style, color=cmap(idif), for_err=True, **dict(kw_plot)) if style.startswith('data'): h = plt.errorbar(durs + jitter, y, yerr=e, **kw)[0] else: h = plt.plot(durs + jitter, y, **kw) hs[ii_coef, i_dim, idif] = h plt2.box_off() max_dur = np.amax(npy(durs)) ax.set_xlim(-0.05 * max_dur, 1.05 * max_dur) plt2.detach_axis('x', 0, max_dur) if dim_rel == 0: ax.set_xticklabels([]) ax.set_title(coef_name.lower()) elif i_coef == 0: ax.set_xlabel('duration (s)') if add_rowtitle: plt2.rowtitle(consts.DIM_NAMES_LONG, axes=axs) return coef, se_coef, axs, hs
def main(base=10.): ds = load_comp(os.path.join(locfile_in.pth_root, file_model_comp)) ds['dcost'] = np.array(ds['dcost']) / np.log(base) vmax = 160 m = np.empty(3) e = np.empty(3) axs = plot_bar_dloss_across_subjs(dlosses=np.array(ds['dcost']), subj_parad_bis=ds['subj_parad_bi'], vmax=vmax) axs = plt2.GridAxes(nrows=1, ncols=3, heights=axs.heights, widths=[2], left=axs.left, right=axs.right, bottom=axs.bottom) plot_bar_dloss_across_subjs( dlosses=np.array(ds['dcost']), subj_parad_bis=ds['subj_parad_bi'], vmax=vmax, axs=axs[:, [0]], base=base, ) plt.title('Data') m[0] = np.mean(ds['dcost']) e[0] = np2.sem(ds['dcost']) print('--') titles = ['Simulated\nSerial', 'Simulated\nParallel'] for i in range(2): ds = load_comp( os.path.join(locfile_in.pth_root, files_model_recovery[i])) ds['dcost'] = np.array(ds['dcost']) / np.log(base) plot_bar_dloss_across_subjs( dlosses=ds['dcost'], subj_parad_bis=ds['subj_parad_bi'], vmax=vmax, axs=axs[:, [i + 1]], base=base, ) plt.title(titles[i]) m[i + 1] = np.mean(ds['dcost']) e[i + 1] = np2.sem(ds['dcost']) plt.sca(axs[0, i + 1]) plt2.box_off(['left']) plt.yticks([]) for ext in ['.pdf', '.png']: file = locfile_out.get_file_fig('model_comp_recovery', {'easiest_only': use_easiest_only}, ext=ext) plt.savefig(file, dpi=300) print('Saved to %s' % file) # --- Print mean +- SEM to CSV csv_file = locfile_out.get_file_csv('model_comp_recovery', {'easiest_only': use_easiest_only}) d_csv = np2.dictlist2listdict({ 'data': ['original', 'simulated serial', 'simulated parallel'], 'mean_dcost': m, 'sem_dcost': e, }) with open(csv_file, 'w') as f: writer = csv.DictWriter(f, fieldnames=d_csv[0].keys()) writer.writeheader() for d in d_csv: writer.writerow(d) print('Wrote to %s' % csv_file) # --- Mean NLL within each axs = plt2.GridAxes( 1, 1, left=0.85, right=0.25, heights=[1.5], top=0.1, bottom=0.9, ) plt.sca(axs[0, 0]) plt.barh(np.arange(3), m, xerr=e, color='w', edgecolor='k') plt.yticks(np.arange(3), ['data', 'simulated\nserial', 'simulated\nparallel']) plt2.box_off(['top', 'right']) vmax = np.amax(np.abs(m) + np.abs(e)) plt.xlim(np.array([-vmax, vmax]) * 1.1) plt2.detach_axis(xy='y', amin=0, amax=2) plt2.detach_axis(xy='x', amin=-vmax, amax=vmax) axvline_dcost() xticks_serial_vs_parallel(vmax, base) for ext in ['.pdf', '.png']: file = locfile_out.get_file_fig('model_comp_vs_recovery', {'easiest_only': use_easiest_only}, ext=ext) plt.savefig(file, dpi=300) print('Saved to %s' % file) print('--')
def plot_rt_vs_ev( ev_cond, n_cond__rt_ch: Union[torch.Tensor, np.ndarray], style='pred', pool='mean', dt=consts.DT, correct_only=True, thres_n_trial=10, color='k', color_ch=('tab:red', 'tab:blue'), ax: plt.Axes = None, kw_plot=(), ) -> (Sequence[plt.Line2D], List[np.ndarray]): """ @param ev_cond: [condition] @type ev_cond: torch.Tensor @param n_cond__rt_ch: [condition, frame, ch] @type n_cond__rt_ch: torch.Tensor @return: """ if ax is None: ax = plt.gca() if ev_cond.ndim != 1: if ev_cond.ndim == 3: ev_cond = npt.p2st(ev_cond)[0] assert ev_cond.ndim == 2 ev_cond = ev_cond.mean(1) assert n_cond__rt_ch.ndim == 3 ev_cond = npy(ev_cond) n_cond__rt_ch = npy(n_cond__rt_ch) def plot_rt_given_cond_ch(ev_cond1, n_rt_given_cond_ch, **kw1): # n_rt_given_cond_ch[cond, fr] # p_rt_given_cond_ch[cond, fr] p_rt_given_cond_ch = np2.sumto1(n_rt_given_cond_ch, 1) nt = n_rt_given_cond_ch.shape[1] t = np.arange(nt) * dt # n_in_cond_ch[cond] n_in_cond_ch = n_rt_given_cond_ch.sum(1) if pool == 'mean': rt_pooled = (t[None, :] * p_rt_given_cond_ch).sum(1) elif pool == 'var': raise NotImplementedError() else: raise ValueError() if style.startswith('data'): rt_pooled[n_in_cond_ch < thres_n_trial] = np.nan kw = get_kw_plot(style, **kw1) h = ax.plot(ev_cond1, rt_pooled, **kw) return h, rt_pooled if correct_only: hs = [] rtss = [] n_cond__rt_ch1 = n_cond__rt_ch.copy() # type: np.ndarray if style.startswith('data'): cond0 = ev_cond == 0 n_cond__rt_ch1[cond0, :, :] = np.sum(n_cond__rt_ch1[cond0, :, :], axis=-1, keepdims=True) # # -- Choose the ch with correct sign (or both chs if cond == 0) for ch in range(consts.N_CH): cond_sign = np.sign(ev_cond) ch_sign = consts.ch_bool2sign(ch) cond_accu = cond_sign != -ch_sign h, rts = plot_rt_given_cond_ch(ev_cond[cond_accu], n_cond__rt_ch1[cond_accu, :, ch], color=color, **dict(kw_plot)) hs.append(h) rtss.append(rts) # rts = np.stack(rtss) rts = rtss else: hs = [] rtss = [] for ch in range(consts.N_CH): # n_rt_given_cond_ch[cond, fr] n_rt_given_cond_ch = n_cond__rt_ch[:, :, ch] h, rts = plot_rt_given_cond_ch(ev_cond, n_rt_given_cond_ch, color=color_ch[ch]) hs.append(h) rtss.append(rts) rts = np.stack(rtss) y_lim = ax.get_ylim() x_lim = ax.get_xlim() plt2.box_off() plt2.detach_axis('x', ax=ax, amin=x_lim[0], amax=x_lim[1]) plt2.detach_axis('y', ax=ax, amin=y_lim[0], amax=y_lim[1]) ax.set_xlabel('evidence') ax.set_ylabel(r"$\mathrm{E}[T^\mathrm{r} \mid c]~(\mathrm{s})$") return hs, rts
def plot_bar_dloss_across_subjs( dlosses, elosses=None, ix_datas=None, subj_parad_bis: Iterable[Tuple[str, str, bool]] = None, axs: Union[plt2.GridAxes, plt2.AxesArray] = None, vmax=None, add_scale=True, base=10., ): """ :param dlosses: [ix_data] :param ix_datas: :param axs: :param subj_parad_bis: [('subj', 'parad', is_bimanual), ...] :return: axs """ if subj_parad_bis is None: subj_parad_bis = subj_parad_bis0 if vmax is None: vmax = np.amax(np.abs(dlosses)) # order: eye S1-S3, hand by ID, paired uni-bimanual subjs, parads, bis = zip(*subj_parad_bis) subjs = np.array( ['ID0' + v[-1] if v[:2] == 'ID' and len(v) == 3 else v for v in subjs]) parads = np.array(parads) bis = np.array(bis) is_eye = parads == 'RT' is_bin = parads == 'binary' ix = np.arange(len(subjs)) def filt_sort(filt): ind = [int(subj[1:]) for subj in subjs[filt]] return ix[filt][np.argsort(ind)] ix = np.concatenate([ filt_sort(is_eye & ~is_bin), np.stack([ filt_sort(~is_eye & ~bis & ~is_bin), filt_sort(~is_eye & bis & ~is_bin) ], -1).flatten('C'), filt_sort(is_bin) ]) subjs = subjs[ix] parads = parads[ix] bis = bis[ix] is_eye = is_eye[ix] dlosses = dlosses[ix] subj_parad_bis = subj_parad_bis[ix] n_eye = int(np.sum(is_eye)) n_hand = int(np.sum(~is_eye)) y = np.empty([n_eye + n_hand]) y[is_eye] = 1.5 + np.arange(n_eye) y[~is_eye] = n_eye - 1 + 1.5 + np.cumsum([1.5, 1.] * (n_hand // 2)) y_max = np.amax(y) + 1.5 if axs is None: axs = plt2.GridAxes(nrows=1, ncols=1, heights=y_max * 0.2, widths=2, left=1.5, right=0.25, bottom=0.85) ax = axs[0, 0] plt.sca(ax) m = dlosses if elosses is None: e = np.zeros_like(m) else: e = elosses for y1, m1, e1, parad1, bi1 in zip(y, m, e, parads, bis): plt.barh(y1, m1, xerr=e1, color=colors_parad[(parad1, '%s' % bi1)], edgecolor='None') if add_scale: dy = y[1] - y[0] axvline_dcost() x_lim = [-vmax * 1.2, vmax * 1.2] for ix_big in range(len(y)): if np.abs(m[ix_big]) > vmax: for i_sign, sign in enumerate([1, -1]): plt2.patch_wave( y[ix_big], x_lim[i_sign] * 1.01, ax=ax, color='w', wave_margin=0.15, wave_amplitude=sign * 0.025, ) plt.xlim(x_lim) xticks_serial_vs_parallel(vmax, base) subj_parad_bi_str = get_subj_parad_bi_str(subj_parad_bis) plt.yticks(y, subj_parad_bi_str) plt2.detach_axis('y', y[0], y[-1]) plt2.detach_axis('x', -vmax, vmax) plt.ylim([y_max - 1, 1.]) return axs
def plot_fit_combined( data: Union[sim2d.Data2DRT, dict] = None, pModel_cond_rt_chFlat=None, model=None, pModel_dimRel_condDense_chFlat=None, # --- in place of data: pAll_cond_rt_chFlat=None, evAll_cond_dim=None, pTrain_cond_rt_chFlat=None, evTrain_cond_dim=None, pTest_cond_rt_chFlat=None, evTest_cond_dim=None, dt=None, # --- optional ev_dimRel_condDense_fr_dim_meanvar=None, dt_model=None, to_plot_internals=True, to_plot_params=True, to_plot_choice=True, # to_group_irr=False, group_dcond_irr=None, to_combine_ch_irr_cond=True, kw_plot_pred=(), kw_plot_pred_ch=(), kw_plot_data=(), axs=None, ): """ :param data: :param pModel_cond_rt_chFlat: :param model: :param pModel_dimRel_condDense_chFlat: :param ev_dimRel_condDense_fr_dim_meanvar: :param to_plot_internals: :param to_plot_params: :param to_group_irr: :param to_combine_ch_irr_cond: :param kw_plot_pred: :param kw_plot_data: :param axs: :return: """ if data is None: if pTrain_cond_rt_chFlat is None: pTrain_cond_rt_chFlat = pAll_cond_rt_chFlat if evTrain_cond_dim is None: evTrain_cond_dim = evAll_cond_dim if pTest_cond_rt_chFlat is None: pTest_cond_rt_chFlat = pAll_cond_rt_chFlat if evTest_cond_dim is None: evTest_cond_dim = evAll_cond_dim else: _, pAll_cond_rt_chFlat, _, _, evAll_cond_dim = \ data.get_data_by_cond('all') _, pTrain_cond_rt_chFlat, _, _, evTrain_cond_dim = data.get_data_by_cond( 'train_valid', mode_train='easiest') _, pTest_cond_rt_chFlat, _, _, evTest_cond_dim = data.get_data_by_cond( 'test', mode_train='easiest') dt = data.dt hs = {} if model is None: assert not to_plot_internals assert not to_plot_params if dt_model is None: if model is None: dt_model = dt else: dt_model = model.dt if axs is None: if to_plot_params: axs = plt2.GridAxes(3, 3) else: if to_plot_internals: axs = plt2.GridAxes(3, 3) else: if to_plot_choice: axs = plt2.GridAxes(2, 2) else: axs = plt2.GridAxes(1, 2) # TODO: beautify ratios rts = [] hs['rt'] = [] for dim_rel in range(consts.N_DIM): # --- data_pred may not have all conditions, so concatenate the rest # of the conditions so that the color scale is correct. Then also # concatenate p_rt_ch_data_pred1 with zeros so that nothing is # plotted in the concatenated. evTest_cond_dim1 = np.concatenate([ evTest_cond_dim, evAll_cond_dim ], axis=0) pTest_cond_rt_chFlat1 = np.concatenate([ pTest_cond_rt_chFlat, np.zeros_like(pAll_cond_rt_chFlat) ], axis=0) if ev_dimRel_condDense_fr_dim_meanvar is None: evModel_cond_dim = evAll_cond_dim else: if ev_dimRel_condDense_fr_dim_meanvar.ndim == 5: evModel_cond_dim = npy(ev_dimRel_condDense_fr_dim_meanvar[ dim_rel][:, 0, :, 0]) else: assert ev_dimRel_condDense_fr_dim_meanvar.ndim == 4 evModel_cond_dim = npy(ev_dimRel_condDense_fr_dim_meanvar[ dim_rel][:, 0, :]) pModel_cond_rt_chFlat = npy(pModel_dimRel_condDense_chFlat[dim_rel]) if to_plot_choice: # --- Plot choice ax = axs[0, dim_rel] plt.sca(ax) if to_combine_ch_irr_cond: ev_cond_model1, p_rt_ch_model1 = combine_irr_cond( dim_rel, evModel_cond_dim, pModel_cond_rt_chFlat ) sim2d.plot_p_ch_vs_ev(ev_cond_model1, p_rt_ch_model1, dim_rel=dim_rel, style='pred', group_dcond_irr=None, kw_plot=kw_plot_pred_ch, cmap=lambda n: lambda v: [0., 0., 0.], ) else: sim2d.plot_p_ch_vs_ev(evModel_cond_dim, pModel_cond_rt_chFlat, dim_rel=dim_rel, style='pred', group_dcond_irr=group_dcond_irr, kw_plot=kw_plot_pred, cmap=cmaps[dim_rel] ) hs, conds_irr = sim2d.plot_p_ch_vs_ev( evTest_cond_dim1, pTest_cond_rt_chFlat1, dim_rel=dim_rel, style='data_pred', group_dcond_irr=group_dcond_irr, cmap=cmaps[dim_rel], kw_plot=kw_plot_data, ) hs1 = [h[0] for h in hs] odim = 1 - dim_rel odim_name = consts.DIM_NAMES_LONG[odim] legend_odim(conds_irr, hs1, odim_name) sim2d.plot_p_ch_vs_ev(evTrain_cond_dim, pTrain_cond_rt_chFlat, dim_rel=dim_rel, style='data_fit', group_dcond_irr=group_dcond_irr, cmap=cmaps[dim_rel], kw_plot=kw_plot_data ) plt2.detach_axis('x', np.amin(evTrain_cond_dim[:, dim_rel]), np.amax(evTrain_cond_dim[:, dim_rel])) ax.set_xlabel('') ax.set_xticklabels([]) if dim_rel != 0: plt2.box_off(['left']) plt.yticks([]) ax.set_ylabel('P(%s choice)' % consts.CH_NAMES[dim_rel][1]) # --- Plot RT ax = axs[int(to_plot_choice) + 0, dim_rel] plt.sca(ax) hs1, rts1 = sim2d.plot_rt_vs_ev( evModel_cond_dim, pModel_cond_rt_chFlat, dim_rel=dim_rel, style='pred', group_dcond_irr=group_dcond_irr, dt=dt_model, kw_plot=kw_plot_pred, cmap=cmaps[dim_rel] ) hs['rt'].append(hs1) rts.append(rts1) sim2d.plot_rt_vs_ev(evTest_cond_dim1, pTest_cond_rt_chFlat1, dim_rel=dim_rel, style='data_pred', group_dcond_irr=group_dcond_irr, dt=dt, cmap=cmaps[dim_rel], kw_plot=kw_plot_data ) sim2d.plot_rt_vs_ev(evTrain_cond_dim, pTrain_cond_rt_chFlat, dim_rel=dim_rel, style='data_fit', group_dcond_irr=group_dcond_irr, dt=dt, cmap=cmaps[dim_rel], kw_plot=kw_plot_data ) plt2.detach_axis('x', np.amin(evTrain_cond_dim[:, dim_rel]), np.amax(evTrain_cond_dim[:, dim_rel])) if dim_rel != 0: ax.set_ylabel('') plt2.box_off(['left']) plt.yticks([]) ax.set_xlabel(consts.DIM_NAMES_LONG[dim_rel].lower() + ' strength') if dim_rel == 0: ax.set_ylabel('RT (s)') if to_plot_internals: for ch1 in range(consts.N_CH): ch0 = dim_rel ax = axs[3 + ch1, dim_rel] plt.sca(ax) ch_flat = consts.ch_by_dim2ch_flat(np.array([ch0, ch1])) model.tnds[ch_flat].plot_p_tnd() ax.set_xlabel('') ax.set_xticklabels([]) ax.set_yticks([0, 1]) if ch0 > 0: ax.set_yticklabels([]) ax.set_ylabel(r"$\mathrm{P}(T^\mathrm{n} \mid" " \mathbf{z}=[%d,%d])$" % (ch0, ch1)) ax = axs[5, dim_rel] plt.sca(ax) if hasattr(model.dtb, 'dtb1ds'): model.dtb.dtb1ds[dim_rel].plot_bound(color='k') plt2.sameaxes(axs[-1, :consts.N_DIM], xy='y') if to_plot_params: ax = axs[0, -1] plt.sca(ax) model.plot_params() return axs, rts, hs
def plot_rt_distrib_pred_data( p_pred_cond_rt_ch, n_cond_rt_ch, ev_cond_dim, dt_model, dt_data=None, smooth_sigma_sec=0.1, to_plot_scale=False, to_cumsum=False, to_normalize_max=True, xlim=None, colors=('magenta', 'cyan'), kw_plot_pred=(), kw_plot_data=(), to_skip_zero_trials=False, labels=None, **kwargs ): """ :param n_cond_rt_ch: [cond, rt, ch] = n_tr(cond, rt, ch) :param p_pred_cond_rt_ch: [model, cond, rt, ch] = P(rt, ch | cond, model) :param ev_cond_dim: :param dt_model: :param dt_data: :param smooth_sigma_sec: :param to_plot_scale: :param to_cumsum: :param xlim: :param kwargs: :return: """ axs = None ps = [] ps0 = [] hss = [] p_pred_cond_rt_ch = p_pred_cond_rt_ch / np.sum( p_pred_cond_rt_ch, (-1, -2), keepdims=True) n_preds1 = p_pred_cond_rt_ch * np.sum( n_cond_rt_ch, (-1, -2))[None, :, None, None] nt = p_pred_cond_rt_ch.shape[-2] if dt_data is None: dt_data = dt_model if labels is None: labels = [''] * (len(n_preds1) + 1) for i_pred, n_pred in enumerate(n_preds1): color = colors[i_pred] axs, p0, p1, hs = sim2d.plot_rt_distrib( n_pred, ev_cond_dim, dt=dt_model, axs=axs, alpha=1., smooth_sigma_sec=smooth_sigma_sec, to_skip_zero_trials=to_skip_zero_trials, colors=color, alpha_face=0, to_normalize_max=to_normalize_max, to_cumsum=to_cumsum, to_use_sameaxes=False, kw_plot={ 'linewidth': 1.5, **dict(kw_plot_pred), }, label=labels[i_pred], **kwargs, )[:4] ps.append(p1) ps0.append(p0) hss.append(hs) axs, p0, p1, hs = sim2d.plot_rt_distrib( n_cond_rt_ch, ev_cond_dim, dt=dt_data, axs=axs, smooth_sigma_sec=smooth_sigma_sec, colors='k', alpha_face=0., to_normalize_max=to_normalize_max, # normalize across preds and data instead to_cumsum=to_cumsum, to_skip_zero_trials=to_skip_zero_trials, kw_plot={ 'linewidth': 0.5, **dict(kw_plot_data), }, label=labels[-1], **kwargs, ) ps.append(p1) ps0.append(p0) hss.append(hs) ps = np.stack(ps) ps0 = np.stack(ps0) ps_flat = np.swapaxes(ps, 0, 2).reshape([ps.shape[1] * ps.shape[2], -1]) for ax in axs.flatten(): if xlim is None: if to_cumsum: xlim = [0.5, 4.5] else: xlim = [0.5, 4.5] plt2.detach_axis('x', *xlim, ax=ax) ax.set_xlim(xlim[0] - 0.1, xlim[1] + 0.1) axs[-1, 0].set_xticks(xlim) axs[-1, 0].set_xticklabels(['%g' % v for v in xlim]) from lib.pylabyk import numpytorch as npt t = torch.arange(nt) * dt_model mean_rts = [] for p1 in ps0: p11 = npt.sumto1(torch.tensor(p1).sum([-1, -2])[0, 0, :]) mean_rts.append(npy((torch.tensor(t) * p11).sum())) print('mean_rts:') print(mean_rts) print(mean_rts[1] - mean_rts[0]) conds = [np.unique(ev_cond_dim[:, i]) for i in [0, 1]] p_preds = torch.tensor(n_preds1).reshape([ 2, len(conds[0]), len(conds[1]), nt, 2, 2 ]) + 1e-12 if to_plot_scale: y = 0.8 axs[-1, -1].plot(mean_rts[:2], y + np.zeros(2), 'k-', linewidth=0.5) x = np.mean(mean_rts[:2]) plt.text(x, y + 0.1, '%1.0f ms' % (np.abs(mean_rts[1] - mean_rts[0]) * 1e3), ha='center', va='bottom') return axs, hss
def plot_coefs_dur_irrixn( n_cond_dur_ch: np.ndarray = None, data: Data2DVD = None, ev_cond_dim: np.ndarray = None, durs: np.ndarray = None, style='data', kw_plot=(), jitter0=0., coefs_to_plot=(2, ), axs=None, fig=None, ): if ev_cond_dim is None: ev_cond_dim = data.ev_cond_dim if durs is None: durs = npy(data.durs) if n_cond_dur_ch is None: n_cond_dur_ch = npy(data.get_data_by_cond('all')[1]) coef, se_coef = coefdur.get_coefs_irr_ixn_from_histogram( n_cond_dur_ch, ev_cond_dim=ev_cond_dim)[:2] # type: (np.ndarray, np.ndarray) coef_names = [ 'bias', 'slope', 'rel x abs(irr)', 'rel x irr', 'abs(irr)', 'irr' ] n_coef = len(coefs_to_plot) if axs is None: if fig is None: fig = plt.figure(figsize=[6, 1.5 * n_coef]) n_row = consts.N_DIM n_col = n_coef gs = plt.GridSpec(nrows=n_row, ncols=n_col, figure=fig, left=0.25, right=0.95, bottom=0.15, top=0.9) axs = np.empty([n_row, n_col], dtype=np.object) for row in range(n_row): for col in range(n_col): axs[row, col] = plt.subplot(gs[row, col]) for ii_coef, i_coef in enumerate(coefs_to_plot): coef_name = coef_names[i_coef] for dim_rel in range(consts.N_DIM): ax = axs[dim_rel, ii_coef] # type: plt.Axes plt.sca(ax) y = coef[i_coef, dim_rel, :] e = se_coef[i_coef, dim_rel, :] jitter = 0. kw_plot = {'color': 'k', 'for_err': True, **dict(kw_plot)} kw = consts.get_kw_plot(style, **kw_plot) if style.startswith('data'): plt.errorbar(durs + jitter, y, yerr=e, **kw) else: plt.plot(durs + jitter, y, **kw) plt2.box_off() max_dur = np.amax(npy(durs)) ax.set_xlim(-0.05 * max_dur, 1.05 * max_dur) plt2.detach_axis('x', 0, max_dur) # , ax=ax) if dim_rel == 0: ax.set_xticklabels([]) ax.set_title(coef_name.lower()) elif i_coef == 0: ax.set_xlabel('duration (s)') plt2.rowtitle(consts.DIM_NAMES_LONG, axes=axs) return coef, se_coef, axs