def beautify(ax): xticks = np.arange(0, 1.2 + 0.1, 0.1) plt.sca(ax) plt.xlim(xmax=1.2 * 1.05) plt2.detach_axis('x', 0, 1.2) plt2.box_off()
def axvline_dcost(BF=100., base=10., style='patch'): thres = np.log(BF) / np.log(base) if style == 'line': plt.axvline(0, color='k', linewidth=0.5, linestyle='--', zorder=1) for sign in [-1, 1]: plt.axvline(sign * thres, color='silver', linewidth=0.5, linestyle='--', zorder=1) elif style == 'patch': import matplotlib.patches as patches ax = plt.gca() ax.add_patch( patches.Rectangle( (-thres, -100), thres * 2, 200, edgecolor='None', facecolor=[0., 0., 0., 0.4], fill=True, )) else: raise ValueError() plt2.box_off()
def plot_p_ch_vs_ev(ev_cond, p_ch, style='pred', ax: plt.Axes = None, **kwargs) -> plt.Line2D: """ @param ev_cond: [condition] or [condition, frame] @type ev_cond: torch.Tensor @param p_ch: [condition, ch] or [condition, rt_frame, ch] @type p_ch: torch.Tensor @return: """ if ax is None: ax = plt.gca() if ev_cond.ndim != 1: if ev_cond.ndim == 3: ev_cond = npt.p2st(ev_cond)[0] assert ev_cond.ndim == 2 ev_cond = ev_cond.mean(1) if p_ch.ndim != 2: assert p_ch.ndim == 3 p_ch = p_ch.sum(1) kwargs = get_kw_plot(style, **kwargs) h = ax.plot(*npys(ev_cond, npt.p2st(npt.sumto1(p_ch, -1))[1]), **kwargs) plt2.box_off(ax=ax) x_lim = ax.get_xlim() plt2.detach_axis('x', amin=x_lim[0], amax=x_lim[1], ax=ax) plt2.detach_axis('y', amin=0, amax=1, ax=ax) ax.set_yticks([0, 0.5, 1]) ax.set_yticklabels(['0', '', '1']) ax.set_xlabel('evidence') ax.set_ylabel(r"$\mathrm{P}(z=1 \mid c)$") return h
def plot_params(self, named_bounded_params: Sequence[Tuple[ str, BoundedParameter]] = None, exclude: Iterable[str] = (), cmap='coolwarm', ax: plt.Axes = None) -> mpl.container.BarContainer: if ax is None: ax = plt.gca() ax = plt.gca() names, v, grad, lb, ub, requires_grad = self.get_named_bounded_params( named_bounded_params, exclude=exclude) max_grad = np.amax(np.abs(grad)) if max_grad == 0: max_grad = 1. v01 = (v - lb) / (ub - lb) grad01 = (grad + max_grad) / (max_grad * 2) n = len(v) grad01[~requires_grad] = np.nan # (np.amin(grad01) + np.amax(grad01)) / 2 # ax = plt.gca() # CHECKED for i, (lb1, v1, ub1, g1, r1) in enumerate(zip(lb, v, ub, grad, requires_grad)): color = 'k' if r1 else 'gray' plt.text(0, i, '%1.2g' % lb1, ha='left', va='center', color=color) plt.text(1, i, '%1.2g' % ub1, ha='right', va='center', color=color) plt.text( 0.5, i, '%1.2g %s' % (v1, ('(e%1.0f)' % np.log10(np.abs(g1))) if r1 else '(fixed)'), ha='center', va='center', color=color) lut = 256 colors = plt.get_cmap(cmap, lut)(grad01) for i, r in enumerate(requires_grad): if not r: colors[i, :3] = np.array([0.95, 0.95, 0.95]) h = ax.barh(np.arange(n), v01, left=0, color=colors) ax.set_xlim(-0.025, 1) ax.set_xticks([]) ax.set_yticks(np.arange(n)) names = [v.replace('_', '-') for v in names] ax.set_yticklabels(names) ax.set_ylim(n - 0.5, -0.5) plt2.box_off(['top', 'right', 'bottom']) plt2.detach_axis('x', amin=0, amax=1) plt2.detach_axis('y', amin=0, amax=n - 1) # plt.show() # CHECKED return h
def beautify(row, col, axs): plt.sca(axs[row, col]) plt.xlim(xmax=1.2 * 1.05) plt2.detach_axis('x', 0, 1.2) plt2.box_off() beautify_ticks(axs[row, col], add_ticklabel=(row == n_row - 1) and (col == 0)) if col > 0: plt.ylabel('')
def plot_bound(self, t_all: Sequence[float] = None, ax: plt.Axes = None, **kwargs) -> plt.Line2D: if ax is None: ax = plt.gca() if t_all is None: t_all = self.t_all kwargs = argsutil.kwdefault(kwargs, color='k', linestyle='-') h = ax.plot(*npys(t_all, self.get_bound(t_all)), **kwargs) ax.set_xlabel('time (s)') ax.set_ylabel(r"$b(t)$") y_lim = ax.get_ylim() y_min = -y_lim[1] * 0.05 ax.set_ylim(ymin=y_min) plt2.detach_axis('y', amin=0) plt2.detach_axis('x', amin=0) plt2.box_off() return h
def plot_p_tnd(self, t_all: Sequence[float] = None, ax: plt.Axes = None, **kwargs) -> Union[plt.Line2D, Sequence[plt.Line2D]]: if t_all is None: t_all = self.t_all if ax is None: ax = plt.gca() p_tnd = self.get_p_tnd(t_all) kwargs = argsutil.kwdefault(kwargs, color='k', linestyle='-') h = ax.plot(*npys(t_all, p_tnd), **kwargs) ax.set_xlabel('time (s)') plt2.box_off(ax=ax) # ax.set_ylim(ymin=0) # ax.set_xlim(xmin=0) plt2.detach_axis('y', amin=0, amax=1, ax=ax) plt2.detach_axis('x', amin=0, ax=ax) return h
def plot_p_ch_vs_ev( ev_cond: Union[torch.Tensor, np.ndarray], n_ch: Union[torch.Tensor, np.ndarray], style='pred', ax: plt.Axes = None, dim_rel=0, group_dcond_irr: Iterable[Iterable[int]] = None, cmap: Union[str, Callable] = 'cool', kw_plot=(), ) -> Iterable[plt.Line2D]: """ @param ev_cond: [condition, dim] or [condition, frame, dim, (mean, var)] @type ev_cond: torch.Tensor @param n_ch: [condition, ch] or [condition, rt_frame, ch] @type n_ch: torch.Tensor @return: hs[cond_irr][0] = Line2D, conds_irr """ if ax is None: ax = plt.gca() if ev_cond.ndim != 2: assert ev_cond.ndim == 4 ev_cond = npt.p2st(ev_cond.mean(1))[0] if n_ch.ndim != 2: assert n_ch.ndim == 3 n_ch = n_ch.sum(1) ev_cond = npy(ev_cond) n_ch = npy(n_ch) n_cond_all = n_ch.shape[0] ch_rel = np.repeat(np.array(consts.CHS[dim_rel])[None, :], n_cond_all, 0) n_ch = n_ch.reshape([-1]) ch_rel = ch_rel.reshape([-1]) dim_irr = consts.get_odim(dim_rel) conds_rel, dcond_rel = np.unique(ev_cond[:, dim_rel], return_inverse=True) conds_irr, dcond_irr = np.unique(np.abs(ev_cond[:, dim_irr]), return_inverse=True) if group_dcond_irr is not None: conds_irr, dcond_irr = group_conds(conds_irr, dcond_irr, group_dcond_irr) n_conds = [len(conds_rel), len(conds_irr)] n_ch_rel = npg.aggregate( np.stack([ ch_rel, np.repeat(dcond_irr[:, None], consts.N_CH_FLAT, 1).flatten(), np.repeat(dcond_rel[:, None], consts.N_CH_FLAT, 1).flatten(), ]), n_ch, 'sum', [consts.N_CH, n_conds[1], n_conds[0]]) p_ch_rel = n_ch_rel[1] / n_ch_rel.sum(0) hs = [] for dcond_irr1, p_ch1 in enumerate(p_ch_rel): if type(cmap) is str: color = plt.get_cmap(cmap, n_conds[1])(dcond_irr1) else: color = cmap(n_conds[1])(dcond_irr1) kw1 = get_kw_plot(style, color=color, **dict(kw_plot)) h = ax.plot(conds_rel, p_ch1, **kw1) hs.append(h) plt2.box_off(ax=ax) x_lim = ax.get_xlim() plt2.detach_axis('x', amin=x_lim[0], amax=x_lim[1], ax=ax) plt2.detach_axis('y', amin=0, amax=1, ax=ax) ax.set_yticks([0, 0.5, 1]) ax.set_yticklabels(['0', '', '1']) ax.set_xlabel('evidence') ax.set_ylabel(r"$\mathrm{P}(z=1 \mid c)$") return hs, conds_irr
def plot_coef_by_dur_vs_odif(coef, se_coef, normalize_bias=True, coef_name='slope', savefig=True, difs=[[2, 1], [0]], t_RDK_durs=np.arange(1, 11) * 0.12, fig=None, horizontal_panels=False): """ @param dat0: @param parad: @param use_data: @param correct_only: @param plot_slope: if False, plot bias @param i_subj: @param savefig: @param fig: @return: """ if fig is None: if horizontal_panels: fig = plt.figure(figsize=(7, 2)) else: fig = plt.figure(figsize=(4, 3)) if horizontal_panels: gs = plt.GridSpec(figure=fig, nrows=1, ncols=2, hspace=0.3, left=0.11, right=0.98, top=0.9, bottom=0.15) else: gs = mpl.gridspec.GridSpec(figure=fig, nrows=2, ncols=1, left=0.22, right=0.98, top=0.9, bottom=0.15) coef_names = ['bias', 'slope'] i_coef = coef_names.index(coef_name) units = ['(logit)', '(logit/coh)'] if normalize_bias: coef[0] = -coef[0] / coef[1] se_coef[0] = -se_coef[0] / coef[1] units[0] = '(coh)' unit = units[i_coef] y = coef[i_coef] se = se_coef[i_coef] if len(difs) == 1: labels = ['all'] elif len(difs) == 2: labels = ['easy', 'hard'] elif len(difs) == 3: labels = ['easy', 'medium', 'hard'] else: raise ValueError() for dim, (slope1, se_slope1) in enumerate(zip(y, se)): odim = consts.N_DIM - 1 - dim if horizontal_panels: plt.subplot(gs[0, dim]) else: plt.subplot(gs[dim, 0]) # plt.plot(t_RDK_durs, slope[dim]) for odif, (slope2, se_slope2) in enumerate(zip(slope1, se_slope1)): label = labels[odif] + ' ' + consts.DIM_NAMES_SHORT[odim] plt.errorbar(t_RDK_durs, slope2, se_slope2, marker='o', label=label) plt.ylabel(consts.DIM_NAMES_LONG[dim].lower() + ' ' + coef_name + '\n' + unit) plt.axis('auto') plt.xlim(xmin=-0.02) if coef_name == 'slope': plt.ylim(ymin=-0.05) plt2.box_off() plt2.detach_axis('x') if dim == 0 and not horizontal_panels: plt.gca().set_xticklabels([]) plt.legend(handlelength=0.4, frameon=False, loc='upper left') plt.xlabel('stimulus duration (s)') return gs
def main_fit( subjs=None, skip_fit_if_absent=False, bufdur=None, bufdur_best=0.12, bufdurs_sim=None, bufdurs_fit=None, loss_kind='NLL', plot_kind='line_sim_fit', fix_post=None, seed_sims=(0, ), err_kind='std', base=10, ): # if torch.cuda.is_available(): # torch.set_default_tensor_type(torch.cuda.FloatTensor) torch.set_num_threads(1) torch.set_default_dtype(torch.double) dict_res_add = {} parad = 'VD' seed_sims = np.array(seed_sims) if subjs is None: subjs = consts.SUBJS_VD if bufdur is not None: bufdurs_sim = list({bufdur_best}.union({bufdur})) bufdurs_fit = list({bufdur_best}.union({bufdur})) else: if bufdurs_sim is None: bufdurs_sim = bufdurs0 # if bufdurs_fit is None: bufdurs_fit = bufdurs0 else: if bufdurs_fit is None: bufdurs_fit0 = [0., bufdur_best, 1.2] if bufdurs_sim[0] in bufdurs_fit0: bufdurs_fit = copy(bufdurs_fit0) else: bufdurs_fit = bufdurs_fit0[:2] + bufdurs_sim + [1.2] n_dur_sim = len(bufdurs_sim) n_dur_fit = len(bufdurs_fit) # DEF: losses[seed, subj, simSerDurPar, fitDurs] size_all = [len(seed_sims), len(subjs), n_dur_sim, n_dur_fit] losses0 = np.zeros(size_all) + np.nan ns = np.zeros(size_all) + np.nan ks0 = np.zeros(size_all) + np.nan for i_seed, seed_sim in enumerate(seed_sims): for i_subj, subj in enumerate(subjs): for i_fit, bufdur_fit in enumerate(bufdurs_fit): for i_sim, bufdur_sim in enumerate(bufdurs_sim): d, dict_fit_sim, dict_subdir_sim = get_fit_sim( subj, seed_sim, bufdur_sim, bufdur_fit, parad=parad, skip_fit_if_absent=skip_fit_if_absent, fix_post=fix_post, ) if d is not None: losses0[i_seed, i_subj, i_sim, i_fit] = d['loss_NLL_test'] ns[i_seed, i_subj, i_sim, i_fit] = d['loss_ndata_test'] ks0[i_seed, i_subj, i_sim, i_fit] = d['loss_nparam'] ks = copy(ks0) if loss_kind == 'BIC': losses = ks * np.log(ns) + 2 * losses0 # REF: Kass, Raftery 1995 https://doi.org/10.2307%2F2291091 # https://en.wikipedia.org/wiki/Bayesian_information_criterion#Gaussian_special_case thres_strong = 10. / np.log(base) elif loss_kind == 'NLL': losses = copy(losses0) thres_strong = np.log(100) / np.log(base) else: raise ValueError('Unsupported loss_kind: %s' % loss_kind) #%% losses = losses / np.log(base) if plot_kind == 'loss_serpar': n_row = len(subjs) n_col = 1 axs = plt2.GridAxes(n_row, n_col, left=1, widths=2, bottom=0.75) for i_subj, subj in enumerate(subjs): plt.sca(axs[i_subj, 0]) loss_dur_sim_ser_fit = losses[:, i_subj, :, 0].mean(0) loss_dur_sim_par_fit = losses[:, i_subj, :, -1].mean(0) dloss_serpar = loss_dur_sim_par_fit - loss_dur_sim_ser_fit plt.bar(bufdurs_sim, dloss_serpar, color='k') plt.axhline(0, color='k', linestyle='-', linewidth=0.5) if i_subj < len(subjs) - 1: plt2.box_off(['top', 'right', 'bottom']) plt.xticks([]) else: plt2.box_off() plt.xticks(np.arange(0, 1.4, 0.2), ['0\nser'] + [''] * 2 + ['0.6'] + [''] * 2 + ['1.2\npar']) plt2.detach_axis('x', 0, 1.2) vdfit.patch_chance_level() plt.sca(axs[0, 0]) plt.title('Support for Serial\n' '($\mathcal{L}_\mathrm{ser} - \mathcal{L}_\mathrm{par}$)') plt.sca(axs[-1, 0]) plt.xlabel('true buffer duration (s)') plt2.rowtitle(consts.SUBJS_VD, axs) elif plot_kind == 'imshow_ser_buf_par': n_row = len(subjs) n_col = 1 axs = plt2.GridAxes(n_row, n_col, left=1, widths=2, heights=2, bottom=0.75) for i_subj, subj in enumerate(subjs): plt.sca(axs[i_subj, 0]) dloss = losses[:, i_subj, :, :].mean(0) dloss = dloss - np.diag(dloss)[:, None] plt.imshow(dloss, cmap='bwr') cmax = np.amax(np.abs(dloss)) plt.clim(-cmax, +cmax) plt.sca(axs[0, 0]) plt.sca(axs[-1, 0]) plt.xlabel('true buffer duration (s)') plt.ylabel('model buffer duration (s)') plt2.rowtitle(consts.SUBJS_VD, axs) elif plot_kind == 'line_sim_fit': bufdur_bests = np.array([get_bufdur_best(subj) for subj in subjs]) i_bests = np.array([ list(bufdurs_fit).index(bufdur_best) for bufdur_best in bufdur_bests ]) i_subjs = np.arange(len(subjs)) loss_sim_best_fit_best = losses[:, i_subjs, i_bests, i_bests] loss_sim_best_fit_rest = (losses[:, i_subjs, i_bests, :] - loss_sim_best_fit_best[:, :, None]) mean_loss_sim_best_fit_rest = np.mean(loss_sim_best_fit_rest, 0) if err_kind == 'std': err_loss_sim_best_fit_rest = np.std(loss_sim_best_fit_rest, 0) elif err_kind == 'sem': err_loss_sim_best_fit_rest = np2.sem(loss_sim_best_fit_rest, 0) n_dur = len(bufdurs_fit) loss_sim_rest_fit_best = np.swapaxes( np.stack([ losses[:, i_subj, np.arange(n_dur), i_best] - losses[:, i_subj, np.arange(n_dur), np.arange(n_dur)] for i_subj, i_best in zip(i_subjs, i_bests) ]), 0, 1) mean_loss_sim_rest_fit_best = np.mean(loss_sim_rest_fit_best, 0) if err_kind == 'std': err_loss_sim_rest_fit_best = np.std(loss_sim_rest_fit_best, 0) elif err_kind == 'sem': err_loss_sim_rest_fit_best = np2.sem(loss_sim_rest_fit_best, 0) dict_res_add['err'] = err_kind n_row = 2 n_col = len(subjs) axs = plt2.GridAxes(n_row, n_col, left=1.15, right=0.1, widths=2, heights=1.5, wspace=0.35, hspace=0.5, top=0.25, bottom=0.6) for row, (m, s, xlabel, ylabel) in enumerate([ (mean_loss_sim_best_fit_rest, err_loss_sim_best_fit_rest, 'model buffer capacity (s)', (r'$-\mathrm{log}_{%g}\mathrm{BF}$ given simulated' + '\nbest duration data') % base), (mean_loss_sim_rest_fit_best, err_loss_sim_rest_fit_best, 'simulated data buffer capacity (s)', (r'$-\mathrm{log}_{%g}\mathrm{BF}$ of the' + '\nbest duration model') % base) ]): for i_subj, subj in enumerate(subjs): ax = axs[row, i_subj] plt.sca(ax) gs1 = axs.gs[row * 2 + 1, i_subj * 2 + 1] bax = vdfit.breakaxis(gs1) bax.axs[1].errorbar(bufdurs0, m[i_subj, :], yerr=s[i_subj, :], color='k', marker='.', linewidth=0.75, elinewidth=0.5, markersize=3) m1 = copy(m[i_subj, :]) m1[3:] = np.nan s1 = copy(s[i_subj, :]) s1[3:] = np.nan bax.axs[0].errorbar(bufdurs0, m1, yerr=s1, color='k', marker='.', linewidth=0.75, elinewidth=0.5, markersize=3) ax1 = bax.axs[1] # type: plt.Axes plt.sca(ax1) vdfit.patch_chance_level(level=thres_strong, signs=[-1, 1]) plt.axhline(0, color='k', linestyle='--', linewidth=0.5) vdfit.beautify_ticks(ax1, ) vdfit.beautify(ax1) ax1.set_yticks([0, 20]) if i_subj > 0: ax1.set_yticklabels([]) if row == 0: ax1.set_xticklabels([]) ax1 = bax.axs[0] # type: plt.Axes plt.sca(ax1) ax1.set_yticks([40, 200]) if i_subj > 0: ax1.set_yticklabels([]) plt.sca(ax) plt2.box_off('all') if row == 0: plt.title(consts.SUBJS['VD'][i_subj]) if i_subj == 0: plt.xlabel(xlabel, labelpad=8 if row == 0 else 20) plt.ylabel(ylabel, labelpad=30) plt2.sameaxes(bax.axs, xy='x') elif plot_kind == 'bar_sim_fit': i_best = bufdurs_fit.index(bufdur_best) loss_sim_best_fit_best = losses[0, :, i_best, i_best] loss_sim_best_fit_rest = np2.nan2v(losses[0, :, i_best, :] - loss_sim_best_fit_best[:, None]) n_dur = len(bufdurs_fit) loss_sim_rest_fit_best = np2.nan2v( losses[0, :, np.arange(n_dur), i_best] - losses[0, :, np.arange(n_dur), np.arange(n_dur)]).T n_row = 2 n_col = len(subjs) axs = plt2.GridAxes(n_row, n_col, left=1, widths=3, right=0.25, heights=1, wspace=0.3, hspace=0.6, top=0.25, bottom=0.6) for i_subj, subj in enumerate(subjs): def plot_bars(gs1, bufdurs, losses1, add_xticklabel=True): i_break = np.amax(np.nonzero(np.array(bufdurs) < 0.3)[0]) bax = brokenaxes( subplot_spec=gs1, xlims=((-1, i_break + 0.5), (i_break + 0.5, len(bufdurs) - 0.5)), ylims=((-3, 20), (20, 1250 / 5)), height_ratios=(50 / 100, 500 / (1250 - 100)), hspace=0.15, wspace=0.075, d=0.005, ) bax.bar(np.arange(len(bufdurs)), losses1[i_subj, :], color='k') ax11 = bax.axs[3] # type: plt.Axes ax11.set_xticks([bufdurs.index(0.6), bufdurs.index(1.2)]) if i_subj == 0 and add_xticklabel: ax11.set_xticklabels(['0.6', '1.2']) else: ax11.set_xticklabels([]) ax00 = bax.axs[0] # type: plt.Axes ax00.set_yticks([500, 1000]) ax10 = bax.axs[2] # type: plt.Axes ax10.set_yticks([0, 50]) plt.sca(ax10) plt2.detach_axis('x', amin=-0.4, amax=i_break + 0.5) for ax in [ax10, ax11]: plt.sca(ax) plt.axhline(0, linewidth=0.5, color='k', linestyle='--') for sign in [-1, 1]: plt.axhline(sign * thres_strong, linewidth=0.5, color='silver', linestyle='--') ax10.set_xticks([bufdurs.index(0.), bufdurs.index(0.2)]) if i_subj == 0: if add_xticklabel: ax10.set_xticklabels(['0', '0.2']) else: ax10.set_xticklabels([]) else: ax10.set_yticklabels([]) ax10.set_xticklabels([]) ax00.set_yticklabels([]) return bax bax = plot_bars(axs.gs[1, i_subj * 2 + 1], bufdurs_fit, loss_sim_best_fit_rest, add_xticklabel=False) ax = axs[0, i_subj] plt.sca(ax) plt2.box_off('all') plt.title(consts.SUBJS['VD'][consts.SUBJS['VD'].index(subj)]) if i_subj == 0: ax.set_ylabel( 'misfit to simulated\nbest duration data\n' r'($\Delta$BIC)', labelpad=35) ax.set_xlabel('model buffer duration (s)', labelpad=8) plot_bars(axs.gs[3, i_subj * 2 + 1], bufdurs_sim, loss_sim_rest_fit_best, add_xticklabel=True) ax = axs[1, i_subj] plt.sca(ax) plt2.box_off('all') if i_subj == 0: ax.set_ylabel( 'misfit of\nbest duration model\n' r'($\Delta$BIC)', labelpad=35) ax.set_xlabel('simulated data buffer duration (s)', labelpad=20) elif plot_kind == 'bar_ser_buf_par': n_row = len(subjs) n_col = 1 axs = plt2.GridAxes(n_row * n_dur_sim, n_col, left=1.25, widths=1.5, right=0.25, heights=0.4, hspace=[0.15] * (n_dur_sim - 1) + [0.5] + [0.15] * (n_dur_sim - 1), top=0.6, bottom=0.5) row = -1 for i_subj, subj in enumerate(subjs): for i_sim in range(n_dur_sim): row += 1 plt.sca(axs[row, 0]) loss1 = losses[:, i_subj, i_sim, :].mean(0) dloss = loss1 - loss1[i_sim] x = np.arange(n_dur_fit) for x1, dloss1 in zip(x, dloss): plt.bar(x1, dloss1, color='r' if dloss1 > 0 else 'b') plt.axhline(0, color='k', linewidth=0.5) vdfit.patch_chance_level(6.) plt.ylim([-100, 100]) plt.yticks([-100, 0, 100], [''] * 2) plt.ylabel('%g' % bufdurs_sim[i_sim], rotation=0, va='center', ha='right') if i_sim == 0: plt.title(subj) if i_sim < n_dur_sim - 1 or i_subj < len(subjs) - 1: plt2.box_off(['top', 'bottom', 'right']) plt.xticks([]) else: plt2.box_off(['top', 'right']) plt.sca(axs[-1, 0]) plt.xticks(np.arange(n_dur_sim), ['%g' % v for v in bufdurs_sim]) plt2.detach_axis('x', 0, n_dur_sim - 1) plt.xlabel('model buffer duration (s)', fontsize=10) c = axs[-1, 0].get_position().corners() plt.figtext(x=0.15, y=np.mean(c[:2, 1]), fontsize=10, s='true\nbuffer\nduration\n(s)', rotation=0, va='center', ha='center') plt.figtext( x=(1.25 + 1.5 / 2) / (1.25 + 1.5 + 0.25), y=0.98, s=r'$\mathrm{BIC} - \mathrm{BIC}_\mathrm{true}$', ha='center', va='top', fontsize=12, ) if plot_kind != 'None': dict_res = deepcopy(dict_fit_sim) # noqa for k in ['sbj', 'prd']: dict_res.pop(k) dict_res = {**dict_res, 'los': loss_kind} dict_res.update(dict_res_add) for ext in ['.png', '.pdf']: file = locfile.get_file_fig(plot_kind, dict_res, subdir=dict_subdir_sim, ext=ext) # noqa plt.savefig(file, dpi=300) print('Saved to %s' % file) #%% print('--') return losses
def plot_ch_ev_by_dur(n_cond_dur_ch: np.ndarray = None, data: Data2DVD = None, dif_irrs: Sequence[Union[Iterable[int], int]] = ((0, 1), (2, )), ev_cond_dim: np.ndarray = None, durs: np.ndarray = None, dur_prct_groups=((0, 33), (33, 67), (67, 100)), style='data', kw_plot=(), jitter=0., axs=None, fig=None): """ Panels[dim, irr_dif_group], curves by dur_group :param n_cond_dur_ch: :param data: :param dif_irrs: :param ev_cond_dim: :param durs: :param dur_prct_groups: :param style: :param kw_plot: :param axs: [row, col] :return: """ if ev_cond_dim is None: ev_cond_dim = data.ev_cond_dim if durs is None: durs = data.durs if n_cond_dur_ch is None: n_cond_dur_ch = npy(data.get_data_by_cond('all')[1]) n_conds_dur_chs = get_n_conds_dur_chs(n_cond_dur_ch) conds_dim = [np.unique(cond1) for cond1 in ev_cond_dim.T] n_dur = len(dur_prct_groups) n_dif = len(dif_irrs) if axs is None: if fig is None: fig = plt.figure(figsize=[6, 4]) n_row = consts.N_DIM n_col = n_dur gs = plt.GridSpec(nrows=n_row, ncols=n_col, figure=fig) axs = np.empty([n_row, n_col], dtype=np.object) for row in range(n_row): for col in range(n_col): axs[row, col] = plt.subplot(gs[row, col]) for dim_rel in range(consts.N_DIM): for idif, dif_irr in enumerate(dif_irrs): for i_dur, dur_prct_group in enumerate(dur_prct_groups): ax = axs[dim_rel, i_dur] plt.sca(ax) conds_rel = conds_dim[dim_rel] dim_irr = consts.get_odim(dim_rel) _, cond_irr = np.unique(np.abs(conds_dim[dim_irr]), return_inverse=True) incl_irr = np.isin(cond_irr, dif_irr) cmap = consts.CMAP_DIM[dim_rel](n_dif) ix_dur = np.arange(len(durs)) dur_incl = ( (ix_dur >= np.percentile(ix_dur, dur_prct_group[0])) & (ix_dur <= np.percentile(ix_dur, dur_prct_group[1]))) if dim_rel == 0: n_cond_dur_ch1 = n_conds_dur_chs[:, incl_irr, :, :, :].sum( (1, -1)) else: n_cond_dur_ch1 = n_conds_dur_chs[incl_irr, :, :, :, :].sum( (0, -2)) n_cond_ch1 = n_cond_dur_ch1[:, dur_incl, :].sum(1) p_cond__ch1 = n_cond_ch1[:, 1] / n_cond_ch1.sum(1) x = conds_rel y = p_cond__ch1 dx = conds_rel[1] - conds_rel[0] jitter1 = 0 kw = consts.get_kw_plot(style, color=cmap(idif), **dict(kw_plot)) if style.startswith('data'): plt.plot(x + jitter1, y, **kw) else: plt.plot(x + jitter1, y, **kw) plt2.box_off(ax=ax) plt2.detach_yaxis(0, 1, ax=ax) ax.set_yticks([0, 0.5, 1.]) if i_dur == 0: ax.set_yticklabels(['0', '', '1']) else: ax.set_yticklabels([]) ax.set_xticklabels([]) return axs
def plot_coefs_dur_irrixn( n_cond_dur_ch: np.ndarray = None, data: Data2DVD = None, ev_cond_dim: np.ndarray = None, durs: np.ndarray = None, style='data', kw_plot=(), jitter0=0., coefs_to_plot=(2, ), axs=None, fig=None, ): if ev_cond_dim is None: ev_cond_dim = data.ev_cond_dim if durs is None: durs = npy(data.durs) if n_cond_dur_ch is None: n_cond_dur_ch = npy(data.get_data_by_cond('all')[1]) coef, se_coef = coefdur.get_coefs_irr_ixn_from_histogram( n_cond_dur_ch, ev_cond_dim=ev_cond_dim)[:2] # type: (np.ndarray, np.ndarray) coef_names = [ 'bias', 'slope', 'rel x abs(irr)', 'rel x irr', 'abs(irr)', 'irr' ] n_coef = len(coefs_to_plot) if axs is None: if fig is None: fig = plt.figure(figsize=[6, 1.5 * n_coef]) n_row = consts.N_DIM n_col = n_coef gs = plt.GridSpec(nrows=n_row, ncols=n_col, figure=fig, left=0.25, right=0.95, bottom=0.15, top=0.9) axs = np.empty([n_row, n_col], dtype=np.object) for row in range(n_row): for col in range(n_col): axs[row, col] = plt.subplot(gs[row, col]) for ii_coef, i_coef in enumerate(coefs_to_plot): coef_name = coef_names[i_coef] for dim_rel in range(consts.N_DIM): ax = axs[dim_rel, ii_coef] # type: plt.Axes plt.sca(ax) y = coef[i_coef, dim_rel, :] e = se_coef[i_coef, dim_rel, :] jitter = 0. kw_plot = {'color': 'k', 'for_err': True, **dict(kw_plot)} kw = consts.get_kw_plot(style, **kw_plot) if style.startswith('data'): plt.errorbar(durs + jitter, y, yerr=e, **kw) else: plt.plot(durs + jitter, y, **kw) plt2.box_off() max_dur = np.amax(npy(durs)) ax.set_xlim(-0.05 * max_dur, 1.05 * max_dur) plt2.detach_axis('x', 0, max_dur) # , ax=ax) if dim_rel == 0: ax.set_xticklabels([]) ax.set_title(coef_name.lower()) elif i_coef == 0: ax.set_xlabel('duration (s)') plt2.rowtitle(consts.DIM_NAMES_LONG, axes=axs) return coef, se_coef, axs
def main(base=10.): ds = load_comp(os.path.join(locfile_in.pth_root, file_model_comp)) ds['dcost'] = np.array(ds['dcost']) / np.log(base) vmax = 160 m = np.empty(3) e = np.empty(3) axs = plot_bar_dloss_across_subjs(dlosses=np.array(ds['dcost']), subj_parad_bis=ds['subj_parad_bi'], vmax=vmax) axs = plt2.GridAxes(nrows=1, ncols=3, heights=axs.heights, widths=[2], left=axs.left, right=axs.right, bottom=axs.bottom) plot_bar_dloss_across_subjs( dlosses=np.array(ds['dcost']), subj_parad_bis=ds['subj_parad_bi'], vmax=vmax, axs=axs[:, [0]], base=base, ) plt.title('Data') m[0] = np.mean(ds['dcost']) e[0] = np2.sem(ds['dcost']) print('--') titles = ['Simulated\nSerial', 'Simulated\nParallel'] for i in range(2): ds = load_comp( os.path.join(locfile_in.pth_root, files_model_recovery[i])) ds['dcost'] = np.array(ds['dcost']) / np.log(base) plot_bar_dloss_across_subjs( dlosses=ds['dcost'], subj_parad_bis=ds['subj_parad_bi'], vmax=vmax, axs=axs[:, [i + 1]], base=base, ) plt.title(titles[i]) m[i + 1] = np.mean(ds['dcost']) e[i + 1] = np2.sem(ds['dcost']) plt.sca(axs[0, i + 1]) plt2.box_off(['left']) plt.yticks([]) for ext in ['.pdf', '.png']: file = locfile_out.get_file_fig('model_comp_recovery', {'easiest_only': use_easiest_only}, ext=ext) plt.savefig(file, dpi=300) print('Saved to %s' % file) # --- Print mean +- SEM to CSV csv_file = locfile_out.get_file_csv('model_comp_recovery', {'easiest_only': use_easiest_only}) d_csv = np2.dictlist2listdict({ 'data': ['original', 'simulated serial', 'simulated parallel'], 'mean_dcost': m, 'sem_dcost': e, }) with open(csv_file, 'w') as f: writer = csv.DictWriter(f, fieldnames=d_csv[0].keys()) writer.writeheader() for d in d_csv: writer.writerow(d) print('Wrote to %s' % csv_file) # --- Mean NLL within each axs = plt2.GridAxes( 1, 1, left=0.85, right=0.25, heights=[1.5], top=0.1, bottom=0.9, ) plt.sca(axs[0, 0]) plt.barh(np.arange(3), m, xerr=e, color='w', edgecolor='k') plt.yticks(np.arange(3), ['data', 'simulated\nserial', 'simulated\nparallel']) plt2.box_off(['top', 'right']) vmax = np.amax(np.abs(m) + np.abs(e)) plt.xlim(np.array([-vmax, vmax]) * 1.1) plt2.detach_axis(xy='y', amin=0, amax=2) plt2.detach_axis(xy='x', amin=-vmax, amax=vmax) axvline_dcost() xticks_serial_vs_parallel(vmax, base) for ext in ['.pdf', '.png']: file = locfile_out.get_file_fig('model_comp_vs_recovery', {'easiest_only': use_easiest_only}, ext=ext) plt.savefig(file, dpi=300) print('Saved to %s' % file) print('--')
def plot_rt_vs_ev( ev_cond, n_cond__rt_ch: Union[torch.Tensor, np.ndarray], style='pred', pool='mean', dt=consts.DT, correct_only=True, thres_n_trial=10, color='k', color_ch=('tab:red', 'tab:blue'), ax: plt.Axes = None, kw_plot=(), ) -> (Sequence[plt.Line2D], List[np.ndarray]): """ @param ev_cond: [condition] @type ev_cond: torch.Tensor @param n_cond__rt_ch: [condition, frame, ch] @type n_cond__rt_ch: torch.Tensor @return: """ if ax is None: ax = plt.gca() if ev_cond.ndim != 1: if ev_cond.ndim == 3: ev_cond = npt.p2st(ev_cond)[0] assert ev_cond.ndim == 2 ev_cond = ev_cond.mean(1) assert n_cond__rt_ch.ndim == 3 ev_cond = npy(ev_cond) n_cond__rt_ch = npy(n_cond__rt_ch) def plot_rt_given_cond_ch(ev_cond1, n_rt_given_cond_ch, **kw1): # n_rt_given_cond_ch[cond, fr] # p_rt_given_cond_ch[cond, fr] p_rt_given_cond_ch = np2.sumto1(n_rt_given_cond_ch, 1) nt = n_rt_given_cond_ch.shape[1] t = np.arange(nt) * dt # n_in_cond_ch[cond] n_in_cond_ch = n_rt_given_cond_ch.sum(1) if pool == 'mean': rt_pooled = (t[None, :] * p_rt_given_cond_ch).sum(1) elif pool == 'var': raise NotImplementedError() else: raise ValueError() if style.startswith('data'): rt_pooled[n_in_cond_ch < thres_n_trial] = np.nan kw = get_kw_plot(style, **kw1) h = ax.plot(ev_cond1, rt_pooled, **kw) return h, rt_pooled if correct_only: hs = [] rtss = [] n_cond__rt_ch1 = n_cond__rt_ch.copy() # type: np.ndarray if style.startswith('data'): cond0 = ev_cond == 0 n_cond__rt_ch1[cond0, :, :] = np.sum(n_cond__rt_ch1[cond0, :, :], axis=-1, keepdims=True) # # -- Choose the ch with correct sign (or both chs if cond == 0) for ch in range(consts.N_CH): cond_sign = np.sign(ev_cond) ch_sign = consts.ch_bool2sign(ch) cond_accu = cond_sign != -ch_sign h, rts = plot_rt_given_cond_ch(ev_cond[cond_accu], n_cond__rt_ch1[cond_accu, :, ch], color=color, **dict(kw_plot)) hs.append(h) rtss.append(rts) # rts = np.stack(rtss) rts = rtss else: hs = [] rtss = [] for ch in range(consts.N_CH): # n_rt_given_cond_ch[cond, fr] n_rt_given_cond_ch = n_cond__rt_ch[:, :, ch] h, rts = plot_rt_given_cond_ch(ev_cond, n_rt_given_cond_ch, color=color_ch[ch]) hs.append(h) rtss.append(rts) rts = np.stack(rtss) y_lim = ax.get_ylim() x_lim = ax.get_xlim() plt2.box_off() plt2.detach_axis('x', ax=ax, amin=x_lim[0], amax=x_lim[1]) plt2.detach_axis('y', ax=ax, amin=y_lim[0], amax=y_lim[1]) ax.set_xlabel('evidence') ax.set_ylabel(r"$\mathrm{E}[T^\mathrm{r} \mid c]~(\mathrm{s})$") return hs, rts
def plot_rt_distrib( n_cond_rt_ch: np.ndarray, ev_cond_dim: np.ndarray, abs_cond=True, lump_wrong=True, dt=consts.DT, colors=None, alpha=1., alpha_face=0.5, smooth_sigma_sec=0.05, to_normalize_max=False, to_cumsum=False, to_exclude_last_frame=True, to_skip_zero_trials=False, label='', # to_exclude_bins_wo_trials=10, kw_plot=(), fig=None, axs=None, to_use_sameaxes=True, ): """ :param n_cond_rt_ch: :param ev_cond_dim: :param abs_cond: :param lump_wrong: :param dt: :param gs: :param colors: :param alpha: :param smooth_sigma_sec: :param kw_plot: :param axs: :return: axs, p_cond01__rt_ch01, p_cond01__rt_ch01_sm, hs """ if colors is None: colors = ['red', 'blue'] elif type(colors) is str: colors = [colors] * 2 else: assert len(colors) == 2 nt = n_cond_rt_ch.shape[1] t_all = np.arange(nt) * dt + dt out = np.meshgrid(np.unique(ev_cond_dim[:, 0]), np.unique(ev_cond_dim[:, 1]), np.arange(nt), np.arange(2), np.arange(2), indexing='ij') cond0, cond1, fr, ch0, ch1 = [v.flatten() for v in out] from copy import deepcopy n0 = deepcopy(n_cond_rt_ch) if to_exclude_last_frame: n0[:, -1, :] = 0. n0 = n0.flatten() def sign_cond(v): v1 = np.sign(v) v1[v == 0] = 1 return v1 if abs_cond: ch0 = consts.ch_bool2sign(ch0) ch1 = consts.ch_bool2sign(ch1) # 1 = correct, -1 = wrong ch0 = sign_cond(cond0) * ch0 ch1 = sign_cond(cond1) * ch1 cond0 = np.abs(cond0) cond1 = np.abs(cond1) ch0 = consts.ch_sign2bool(ch0).astype(np.int) ch1 = consts.ch_sign2bool(ch1).astype(np.int) else: raise ValueError() if lump_wrong: # treat all choices as correct when cond == 0 ch00 = ch0 | (cond0 == 0) ch10 = ch1 | (cond1 == 0) ch0 = (ch00 & ch10) ch1 = np.ones_like(ch00, dtype=np.int) cond_dim = np.stack([cond0, cond1], -1) conds = [] dcond_dim = [] for cond in cond_dim.T: conds1, dcond1 = np.unique(cond, return_inverse=True) conds.append(conds1) dcond_dim.append(dcond1) dcond_dim = np.stack(dcond_dim) n_cond01_rt_ch01 = npg.aggregate( [*dcond_dim, fr, ch0, ch1], n0, 'sum', [*(np.amax(dcond_dim, 1) + 1), nt, consts.N_CH, consts.N_CH]) p_cond01__rt_ch01 = np2.nan2v(n_cond01_rt_ch01 / n_cond01_rt_ch01.sum( (2, 3, 4), keepdims=True)) n_conds = p_cond01__rt_ch01.shape[:2] if axs is None: axs = plt2.GridAxes( n_conds[1], n_conds[0], left=0.6, right=0.3, bottom=0.45, top=0.74, widths=[1], heights=[1], wspace=0.04, hspace=0.04, ) kw_label = { 'fontsize': 12, } pad = 8 axs[0, 0].set_title('strong\nmotion', pad=pad, **kw_label) axs[0, -1].set_title('weak\nmotion', pad=pad, **kw_label) axs[0, 0].set_ylabel('strong\ncolor', labelpad=pad, **kw_label) axs[-1, 0].set_ylabel('weak\ncolor', labelpad=pad, **kw_label) if smooth_sigma_sec > 0: from scipy import signal, stats sigma_fr = smooth_sigma_sec / dt width = np.ceil(sigma_fr * 2.5).astype(np.int) kernel = stats.norm.pdf(np.arange(-width, width + 1), 0, sigma_fr) kernel = np2.vec_on(kernel, 2, 5) p_cond01__rt_ch01_sm = signal.convolve(p_cond01__rt_ch01, kernel, mode='same') else: p_cond01__rt_ch01_sm = p_cond01__rt_ch01.copy() if to_cumsum: p_cond01__rt_ch01_sm = np.cumsum(p_cond01__rt_ch01_sm, axis=2) if to_normalize_max: p_cond01__rt_ch01_sm = np2.nan2v( p_cond01__rt_ch01_sm / np.amax(np.abs(p_cond01__rt_ch01_sm), (2, 3, 4), keepdims=True)) n_row = n_conds[1] n_col = n_conds[0] for dcond0 in range(n_conds[0]): for dcond1 in range(n_conds[1]): row = n_row - 1 - dcond1 col = n_col - 1 - dcond0 ax = axs[row, col] # type: plt.Axes for ch0 in [0, 1]: for ch1 in [0, 1]: if lump_wrong and ch1 == 0: continue p1 = p_cond01__rt_ch01_sm[dcond0, dcond1, :, ch0, ch1] kw = { 'linewidth': 1, 'color': colors[ch1], 'alpha': alpha, 'zorder': 1, **dict(kw_plot) } y = p1 * consts.ch_bool2sign(ch0) p_cond01__rt_ch01_sm[dcond0, dcond1, :, ch0, ch1] = y if to_skip_zero_trials and np.sum(np.abs(y)) < 1e-2: h = None else: h = ax.plot( t_all, y, label=label if ch0 == 1 and ch1 == 1 else None, **kw) ax.fill_between(t_all, 0, y, ec='None', fc=kw['color'], alpha=alpha_face, zorder=-1) plt2.box_off(ax=ax) ax.axhline(0, color='k', linewidth=0.5) ax.set_yticklabels([]) if row < n_row - 1 or col > 0: ax.set_xticklabels([]) ax.set_xticks([]) plt2.box_off(['bottom'], ax=ax) else: ax.set_xlabel('RT (s)') # if col > 0: ax.set_yticks([]) ax.set_yticklabels([]) plt2.box_off(['left'], ax=ax) plt2.detach_axis('x', 0, 5, ax=ax) if to_use_sameaxes: plt2.sameaxes(axs) axs[-1, 0].set_xlabel('RT (s)') return axs, p_cond01__rt_ch01, p_cond01__rt_ch01_sm, h
def plot_coefs_dur_odif( n_cond_dur_ch: np.ndarray = None, data: Data2DVD = None, dif_irrs: Sequence[Union[Iterable[int], int]] = ((0, 1), (2, )), ev_cond_dim: np.ndarray = None, durs: np.ndarray = None, style='data', kw_plot=(), jitter0=0., coefs_to_plot=(0, 1), add_rowtitle=True, dim_incl=(0, 1), axs=None, fig=None, ): """ :param n_cond_dur_ch: :param data: :param dif_irrs: :param ev_cond_dim: :param durs: :param style: :param kw_plot: :param jitter0: :param coefs_to_plot: :param add_rowtitle: :param dim_incl: :param axs: :param fig: :return: coef, se_coef, axs, hs[coef, dim, dif] """ if ev_cond_dim is None: ev_cond_dim = data.ev_cond_dim if durs is None: durs = npy(data.durs) if n_cond_dur_ch is None: n_cond_dur_ch = npy(data.get_data_by_cond('all')[1]) coef, se_coef = coefdur.get_coefs_mesh_from_histogram( n_cond_dur_ch, ev_cond_dim=ev_cond_dim, dif_irrs=dif_irrs)[:2] # type: (np.ndarray, np.ndarray) coef_names = ['bias', 'slope'] n_coef = len(coefs_to_plot) n_dim = len(dim_incl) if axs is None: if fig is None: fig = plt.figure(figsize=[6, 1.5 * n_coef]) n_row = n_dim n_col = n_coef gs = plt.GridSpec(nrows=n_row, ncols=n_col, figure=fig, left=0.25, right=0.95, bottom=0.15, top=0.9) axs = np.empty([n_row, n_col], dtype=np.object) for row in range(n_row): for col in range(n_col): axs[row, col] = plt.subplot(gs[row, col]) n_dif = len(dif_irrs) hs = np.empty( [len(coefs_to_plot), len(dim_incl), len(dif_irrs)], dtype=np.object) for ii_coef, i_coef in enumerate(coefs_to_plot): coef_name = coef_names[i_coef] for i_dim, dim_rel in enumerate(dim_incl): ax = axs[i_dim, ii_coef] # type: plt.Axes plt.sca(ax) cmap = consts.CMAP_DIM[dim_rel](n_dif) for idif, dif_irr in enumerate(dif_irrs): y = coef[i_coef, dim_rel, idif, :] e = se_coef[i_coef, dim_rel, idif, :] ddur = durs[1] - durs[0] jitter = ddur * jitter0 * (idif - (n_dif - 1) / 2) kw = consts.get_kw_plot(style, color=cmap(idif), for_err=True, **dict(kw_plot)) if style.startswith('data'): h = plt.errorbar(durs + jitter, y, yerr=e, **kw)[0] else: h = plt.plot(durs + jitter, y, **kw) hs[ii_coef, i_dim, idif] = h plt2.box_off() max_dur = np.amax(npy(durs)) ax.set_xlim(-0.05 * max_dur, 1.05 * max_dur) plt2.detach_axis('x', 0, max_dur) if dim_rel == 0: ax.set_xticklabels([]) ax.set_title(coef_name.lower()) elif i_coef == 0: ax.set_xlabel('duration (s)') if add_rowtitle: plt2.rowtitle(consts.DIM_NAMES_LONG, axes=axs) return coef, se_coef, axs, hs
def plot_fit_combined( data: Union[sim2d.Data2DRT, dict] = None, pModel_cond_rt_chFlat=None, model=None, pModel_dimRel_condDense_chFlat=None, # --- in place of data: pAll_cond_rt_chFlat=None, evAll_cond_dim=None, pTrain_cond_rt_chFlat=None, evTrain_cond_dim=None, pTest_cond_rt_chFlat=None, evTest_cond_dim=None, dt=None, # --- optional ev_dimRel_condDense_fr_dim_meanvar=None, dt_model=None, to_plot_internals=True, to_plot_params=True, to_plot_choice=True, # to_group_irr=False, group_dcond_irr=None, to_combine_ch_irr_cond=True, kw_plot_pred=(), kw_plot_pred_ch=(), kw_plot_data=(), axs=None, ): """ :param data: :param pModel_cond_rt_chFlat: :param model: :param pModel_dimRel_condDense_chFlat: :param ev_dimRel_condDense_fr_dim_meanvar: :param to_plot_internals: :param to_plot_params: :param to_group_irr: :param to_combine_ch_irr_cond: :param kw_plot_pred: :param kw_plot_data: :param axs: :return: """ if data is None: if pTrain_cond_rt_chFlat is None: pTrain_cond_rt_chFlat = pAll_cond_rt_chFlat if evTrain_cond_dim is None: evTrain_cond_dim = evAll_cond_dim if pTest_cond_rt_chFlat is None: pTest_cond_rt_chFlat = pAll_cond_rt_chFlat if evTest_cond_dim is None: evTest_cond_dim = evAll_cond_dim else: _, pAll_cond_rt_chFlat, _, _, evAll_cond_dim = \ data.get_data_by_cond('all') _, pTrain_cond_rt_chFlat, _, _, evTrain_cond_dim = data.get_data_by_cond( 'train_valid', mode_train='easiest') _, pTest_cond_rt_chFlat, _, _, evTest_cond_dim = data.get_data_by_cond( 'test', mode_train='easiest') dt = data.dt hs = {} if model is None: assert not to_plot_internals assert not to_plot_params if dt_model is None: if model is None: dt_model = dt else: dt_model = model.dt if axs is None: if to_plot_params: axs = plt2.GridAxes(3, 3) else: if to_plot_internals: axs = plt2.GridAxes(3, 3) else: if to_plot_choice: axs = plt2.GridAxes(2, 2) else: axs = plt2.GridAxes(1, 2) # TODO: beautify ratios rts = [] hs['rt'] = [] for dim_rel in range(consts.N_DIM): # --- data_pred may not have all conditions, so concatenate the rest # of the conditions so that the color scale is correct. Then also # concatenate p_rt_ch_data_pred1 with zeros so that nothing is # plotted in the concatenated. evTest_cond_dim1 = np.concatenate([ evTest_cond_dim, evAll_cond_dim ], axis=0) pTest_cond_rt_chFlat1 = np.concatenate([ pTest_cond_rt_chFlat, np.zeros_like(pAll_cond_rt_chFlat) ], axis=0) if ev_dimRel_condDense_fr_dim_meanvar is None: evModel_cond_dim = evAll_cond_dim else: if ev_dimRel_condDense_fr_dim_meanvar.ndim == 5: evModel_cond_dim = npy(ev_dimRel_condDense_fr_dim_meanvar[ dim_rel][:, 0, :, 0]) else: assert ev_dimRel_condDense_fr_dim_meanvar.ndim == 4 evModel_cond_dim = npy(ev_dimRel_condDense_fr_dim_meanvar[ dim_rel][:, 0, :]) pModel_cond_rt_chFlat = npy(pModel_dimRel_condDense_chFlat[dim_rel]) if to_plot_choice: # --- Plot choice ax = axs[0, dim_rel] plt.sca(ax) if to_combine_ch_irr_cond: ev_cond_model1, p_rt_ch_model1 = combine_irr_cond( dim_rel, evModel_cond_dim, pModel_cond_rt_chFlat ) sim2d.plot_p_ch_vs_ev(ev_cond_model1, p_rt_ch_model1, dim_rel=dim_rel, style='pred', group_dcond_irr=None, kw_plot=kw_plot_pred_ch, cmap=lambda n: lambda v: [0., 0., 0.], ) else: sim2d.plot_p_ch_vs_ev(evModel_cond_dim, pModel_cond_rt_chFlat, dim_rel=dim_rel, style='pred', group_dcond_irr=group_dcond_irr, kw_plot=kw_plot_pred, cmap=cmaps[dim_rel] ) hs, conds_irr = sim2d.plot_p_ch_vs_ev( evTest_cond_dim1, pTest_cond_rt_chFlat1, dim_rel=dim_rel, style='data_pred', group_dcond_irr=group_dcond_irr, cmap=cmaps[dim_rel], kw_plot=kw_plot_data, ) hs1 = [h[0] for h in hs] odim = 1 - dim_rel odim_name = consts.DIM_NAMES_LONG[odim] legend_odim(conds_irr, hs1, odim_name) sim2d.plot_p_ch_vs_ev(evTrain_cond_dim, pTrain_cond_rt_chFlat, dim_rel=dim_rel, style='data_fit', group_dcond_irr=group_dcond_irr, cmap=cmaps[dim_rel], kw_plot=kw_plot_data ) plt2.detach_axis('x', np.amin(evTrain_cond_dim[:, dim_rel]), np.amax(evTrain_cond_dim[:, dim_rel])) ax.set_xlabel('') ax.set_xticklabels([]) if dim_rel != 0: plt2.box_off(['left']) plt.yticks([]) ax.set_ylabel('P(%s choice)' % consts.CH_NAMES[dim_rel][1]) # --- Plot RT ax = axs[int(to_plot_choice) + 0, dim_rel] plt.sca(ax) hs1, rts1 = sim2d.plot_rt_vs_ev( evModel_cond_dim, pModel_cond_rt_chFlat, dim_rel=dim_rel, style='pred', group_dcond_irr=group_dcond_irr, dt=dt_model, kw_plot=kw_plot_pred, cmap=cmaps[dim_rel] ) hs['rt'].append(hs1) rts.append(rts1) sim2d.plot_rt_vs_ev(evTest_cond_dim1, pTest_cond_rt_chFlat1, dim_rel=dim_rel, style='data_pred', group_dcond_irr=group_dcond_irr, dt=dt, cmap=cmaps[dim_rel], kw_plot=kw_plot_data ) sim2d.plot_rt_vs_ev(evTrain_cond_dim, pTrain_cond_rt_chFlat, dim_rel=dim_rel, style='data_fit', group_dcond_irr=group_dcond_irr, dt=dt, cmap=cmaps[dim_rel], kw_plot=kw_plot_data ) plt2.detach_axis('x', np.amin(evTrain_cond_dim[:, dim_rel]), np.amax(evTrain_cond_dim[:, dim_rel])) if dim_rel != 0: ax.set_ylabel('') plt2.box_off(['left']) plt.yticks([]) ax.set_xlabel(consts.DIM_NAMES_LONG[dim_rel].lower() + ' strength') if dim_rel == 0: ax.set_ylabel('RT (s)') if to_plot_internals: for ch1 in range(consts.N_CH): ch0 = dim_rel ax = axs[3 + ch1, dim_rel] plt.sca(ax) ch_flat = consts.ch_by_dim2ch_flat(np.array([ch0, ch1])) model.tnds[ch_flat].plot_p_tnd() ax.set_xlabel('') ax.set_xticklabels([]) ax.set_yticks([0, 1]) if ch0 > 0: ax.set_yticklabels([]) ax.set_ylabel(r"$\mathrm{P}(T^\mathrm{n} \mid" " \mathbf{z}=[%d,%d])$" % (ch0, ch1)) ax = axs[5, dim_rel] plt.sca(ax) if hasattr(model.dtb, 'dtb1ds'): model.dtb.dtb1ds[dim_rel].plot_bound(color='k') plt2.sameaxes(axs[-1, :consts.N_DIM], xy='y') if to_plot_params: ax = axs[0, -1] plt.sca(ax) model.plot_params() return axs, rts, hs
def main_plot_across_models( i_subjs=i_subjs, axs=None, models_incl=('buffer+serial', ), base=10, **kwargs, ): n_subj = len(i_subjs) # ---- Load goodness-of-fit def load_gof(i_subj, buffix_dur): fix = fix0_pre + [('buffix', buffix_dur)] + fix0_post dict_cache, subdir = main_fit(i_subj, fit_mode='dict_cache_only', fix=fix, **kwargs) file = locfile.get_file('tab', 'best_loss', dict_cache, ext='.csv', subdir=subdir) import csv, os rows = None if not os.path.exists(file): print('File absent - returning NaN: %s' % file) gof = np.nan return gof with open(file, 'r') as csvfile: reader = csv.DictReader(csvfile) for row in reader: if rows is None: rows = {k: [row[k]] for k in row.keys()} else: for k in row.keys(): rows[k].append(row[k]) gof_kind = 'loss_NLL_test' igof = rows['name'].index(gof_kind) gof = float(rows[' value'][igof][3:]) return gof nbufdurs = len(bufdurs0) gofs = np.zeros([nbufdurs, n_subj]) for i_subj in range(n_subj): for idur, bufdur in enumerate(bufdurs0): gofs[idur, i_subj] = load_gof(i_subj, bufdur) gofs = gofs - np.nanmin(gofs, 0, keepdims=True) gofs = gofs / np.log(base) # ---- Load slopes fixs = [ ('buffer+serial', [ ('buffix', bufdur_best), ]), # ( # 'parallel', [ # ('buffix', 1.2), # ]), # ( # 'serial', [ # ('buffix', 0.), # ]), ] model_names = [f[0] for f in fixs] n_models = len(model_names) models = np.empty([n_models, n_subj], dtype=np.object) dict_caches = np.empty([n_models, n_subj], dtype=np.object) ds = np.empty([n_models, n_subj], dtype=np.object) datas = np.empty([n_subj], dtype=np.object) kw_plots = { 'serial': { 'linestyle': '--' }, 'buffer+serial': { 'linestyle': '-' }, 'parallel': { 'linestyle': ':' }, } subdir = ','.join(fix0) + '+buffix=' + ','.join( [('%1.2f' % v) for v in [0., bufdur_best, 1.2]]) for i_subj in range(n_subj): for i_model, (name, fix1) in enumerate(fixs): fix = fix0_pre + fix1 + fix0_post # try: model, data, dict_cache, d, _ = main_fit(i_subj, fit_mode='load', fix=fix, **kwargs) # except RuntimeError: # model = None # data = None # dict_cache = None # d = None models[i_model, i_subj] = model datas[i_subj] = data dict_caches[i_model, i_subj] = dict_cache ds[i_model, i_subj] = d # ---- Plot goodness-of-fit if axs is None: n_row = 2 + len(models_incl) n_row_gs = n_row + 2 n_col = n_subj axs = plt2.GridAxes(3, 2, top=.3, left=1.15, right=0.1, bottom=.5, widths=[2], wspace=0.35, heights=[1.5, 1.5, 1.5], hspace=[0.2, 0.9]) for i_subj in range(n_subj): ax = axs[-1, i_subj] plt.sca(ax) plt2.box_off('all') gs1 = axs.gs[-2, i_subj * 2 + 1] bax = breakaxis(gs1) ax0 = bax.axs[0] # type: plt.Axes ax1 = bax.axs[1] # type: plt.Axes ax0.plot(bufdurs0[:3], gofs[:3, i_subj], 'k.-', linewidth=0.75, markersize=4.5) ax1.plot(bufdurs0, gofs[:, i_subj], 'k.-', linewidth=0.75, markersize=4.5) plt.sca(ax1) patch_chance_level(level=np.log(100.) / np.log(base), signs=[-1, 1]) plt.axhline(0, color='k', linestyle='--', linewidth=0.5) beautify(ax1) beautify_ticks(ax1, add_ticklabel=True) # i_subj == 0) plt2.sameaxes([ax0, ax1], xy='x') ax1.set_yticks([0, 20]) ax0.set_yticks([40, 200]) if i_subj == 0: plt.sca(ax1) plt.xlabel('buffer capacity (s)') plt.sca(ax) plt.ylabel(r'$-\mathrm{log}_{10}\mathrm{BF}$', labelpad=27) else: ax0.set_yticklabels([]) ax1.set_yticklabels([]) plt.sca(axs[0, i_subj]) beautify_ticks(axs[0, i_subj], add_ticklabel=False) beautify_ticks(axs[1, i_subj], add_ticklabel=True) plt.sca(axs[-2, 0]) plt.xlabel('stimulus duration (s)') # ---- Plot slopes for i_subj in range(n_subj): hss = [] for model_name in models_incl: i_model = model_names.index(model_name) fix1 = fixs[i_model] model = models[i_model, i_subj] data = datas[i_subj] name = model_names[i_model] if model is not None: _, hs = plot_coefs_dur_odif_pred_data( data, model, axs=axs[:2, [i_subj]], # to_plot_data=True, to_plot_data=True, kw_plot_model=kw_plots[model_name], coefs_to_plot=[1], add_rowtitle=False) hss.append(hs) if i_subj == 0: # hss[0] = hs['pred'|'data'][coef, dim, dif] hs1 = hss[0]['data'] # type: np.ndarray for dim in range(hs1.shape[1] - 1, -1, -1): hs = [] for dif in range(hs1.shape[2]): hs.append(hs1[0, dim, dif]) odim = 1 - dim odim_name = consts.DIM_NAMES_LONG[odim] plt.sca(axs[dim, 0]) plt.legend(hs, [ 'weak ' + odim_name.lower(), 'strong ' + odim_name.lower() ], bbox_to_anchor=[0., 0., 1., 1.], handletextpad=0.35, handlelength=0.5, labelspacing=0.3, borderpad=0., frameon=False, loc='upper left') for row in range(axs.shape[0]): beautify(axs[row, i_subj]) if i_subj > 0: plt.ylabel('') plt2.sameaxes(axs[-1, :]) ax = axs[-1, -1] ax.set_yticklabels([]) for i_subj in range(n_subj): plt.sca(axs[0, i_subj]) plt.title(consts.SUBJS['VD'][i_subj]) for row in range(1, n_row): plt.sca(axs[row, i_subj]) plt.title('') for row, title in enumerate([ r'Motion sensitivity ($\beta$)', r'Color sensitivity ($\beta$)', ]): plt.sca(axs[row, 0]) plt.ylabel(title) dict_cache = dict_caches[0, 0] for k in ['fix', 'sbj']: if k in dict_cache: dict_cache.pop(k) for ext in ['.pdf', '.png']: file = locfile.get_file_fig('coefs_dur_odif_sbjs', dict_cache, subdir=subdir, ext=ext) from lib.pylabyk.cacheutil import mkdir4file mkdir4file(file) plt.savefig(file, dpi=300) print('Saved to %s' % file) print('--') return gofs, models, data