def plot_rt_distrib( n_cond_rt_ch: np.ndarray, ev_cond_dim: np.ndarray, abs_cond=True, lump_wrong=True, dt=consts.DT, colors=None, alpha=1., alpha_face=0.5, smooth_sigma_sec=0.05, to_normalize_max=False, to_cumsum=False, to_exclude_last_frame=True, to_skip_zero_trials=False, label='', # to_exclude_bins_wo_trials=10, kw_plot=(), fig=None, axs=None, to_use_sameaxes=True, ): """ :param n_cond_rt_ch: :param ev_cond_dim: :param abs_cond: :param lump_wrong: :param dt: :param gs: :param colors: :param alpha: :param smooth_sigma_sec: :param kw_plot: :param axs: :return: axs, p_cond01__rt_ch01, p_cond01__rt_ch01_sm, hs """ if colors is None: colors = ['red', 'blue'] elif type(colors) is str: colors = [colors] * 2 else: assert len(colors) == 2 nt = n_cond_rt_ch.shape[1] t_all = np.arange(nt) * dt + dt out = np.meshgrid(np.unique(ev_cond_dim[:, 0]), np.unique(ev_cond_dim[:, 1]), np.arange(nt), np.arange(2), np.arange(2), indexing='ij') cond0, cond1, fr, ch0, ch1 = [v.flatten() for v in out] from copy import deepcopy n0 = deepcopy(n_cond_rt_ch) if to_exclude_last_frame: n0[:, -1, :] = 0. n0 = n0.flatten() def sign_cond(v): v1 = np.sign(v) v1[v == 0] = 1 return v1 if abs_cond: ch0 = consts.ch_bool2sign(ch0) ch1 = consts.ch_bool2sign(ch1) # 1 = correct, -1 = wrong ch0 = sign_cond(cond0) * ch0 ch1 = sign_cond(cond1) * ch1 cond0 = np.abs(cond0) cond1 = np.abs(cond1) ch0 = consts.ch_sign2bool(ch0).astype(np.int) ch1 = consts.ch_sign2bool(ch1).astype(np.int) else: raise ValueError() if lump_wrong: # treat all choices as correct when cond == 0 ch00 = ch0 | (cond0 == 0) ch10 = ch1 | (cond1 == 0) ch0 = (ch00 & ch10) ch1 = np.ones_like(ch00, dtype=np.int) cond_dim = np.stack([cond0, cond1], -1) conds = [] dcond_dim = [] for cond in cond_dim.T: conds1, dcond1 = np.unique(cond, return_inverse=True) conds.append(conds1) dcond_dim.append(dcond1) dcond_dim = np.stack(dcond_dim) n_cond01_rt_ch01 = npg.aggregate( [*dcond_dim, fr, ch0, ch1], n0, 'sum', [*(np.amax(dcond_dim, 1) + 1), nt, consts.N_CH, consts.N_CH]) p_cond01__rt_ch01 = np2.nan2v(n_cond01_rt_ch01 / n_cond01_rt_ch01.sum( (2, 3, 4), keepdims=True)) n_conds = p_cond01__rt_ch01.shape[:2] if axs is None: axs = plt2.GridAxes( n_conds[1], n_conds[0], left=0.6, right=0.3, bottom=0.45, top=0.74, widths=[1], heights=[1], wspace=0.04, hspace=0.04, ) kw_label = { 'fontsize': 12, } pad = 8 axs[0, 0].set_title('strong\nmotion', pad=pad, **kw_label) axs[0, -1].set_title('weak\nmotion', pad=pad, **kw_label) axs[0, 0].set_ylabel('strong\ncolor', labelpad=pad, **kw_label) axs[-1, 0].set_ylabel('weak\ncolor', labelpad=pad, **kw_label) if smooth_sigma_sec > 0: from scipy import signal, stats sigma_fr = smooth_sigma_sec / dt width = np.ceil(sigma_fr * 2.5).astype(np.int) kernel = stats.norm.pdf(np.arange(-width, width + 1), 0, sigma_fr) kernel = np2.vec_on(kernel, 2, 5) p_cond01__rt_ch01_sm = signal.convolve(p_cond01__rt_ch01, kernel, mode='same') else: p_cond01__rt_ch01_sm = p_cond01__rt_ch01.copy() if to_cumsum: p_cond01__rt_ch01_sm = np.cumsum(p_cond01__rt_ch01_sm, axis=2) if to_normalize_max: p_cond01__rt_ch01_sm = np2.nan2v( p_cond01__rt_ch01_sm / np.amax(np.abs(p_cond01__rt_ch01_sm), (2, 3, 4), keepdims=True)) n_row = n_conds[1] n_col = n_conds[0] for dcond0 in range(n_conds[0]): for dcond1 in range(n_conds[1]): row = n_row - 1 - dcond1 col = n_col - 1 - dcond0 ax = axs[row, col] # type: plt.Axes for ch0 in [0, 1]: for ch1 in [0, 1]: if lump_wrong and ch1 == 0: continue p1 = p_cond01__rt_ch01_sm[dcond0, dcond1, :, ch0, ch1] kw = { 'linewidth': 1, 'color': colors[ch1], 'alpha': alpha, 'zorder': 1, **dict(kw_plot) } y = p1 * consts.ch_bool2sign(ch0) p_cond01__rt_ch01_sm[dcond0, dcond1, :, ch0, ch1] = y if to_skip_zero_trials and np.sum(np.abs(y)) < 1e-2: h = None else: h = ax.plot( t_all, y, label=label if ch0 == 1 and ch1 == 1 else None, **kw) ax.fill_between(t_all, 0, y, ec='None', fc=kw['color'], alpha=alpha_face, zorder=-1) plt2.box_off(ax=ax) ax.axhline(0, color='k', linewidth=0.5) ax.set_yticklabels([]) if row < n_row - 1 or col > 0: ax.set_xticklabels([]) ax.set_xticks([]) plt2.box_off(['bottom'], ax=ax) else: ax.set_xlabel('RT (s)') # if col > 0: ax.set_yticks([]) ax.set_yticklabels([]) plt2.box_off(['left'], ax=ax) plt2.detach_axis('x', 0, 5, ax=ax) if to_use_sameaxes: plt2.sameaxes(axs) axs[-1, 0].set_xlabel('RT (s)') return axs, p_cond01__rt_ch01, p_cond01__rt_ch01_sm, h
def plot_fit_combined( data: Union[sim2d.Data2DRT, dict] = None, pModel_cond_rt_chFlat=None, model=None, pModel_dimRel_condDense_chFlat=None, # --- in place of data: pAll_cond_rt_chFlat=None, evAll_cond_dim=None, pTrain_cond_rt_chFlat=None, evTrain_cond_dim=None, pTest_cond_rt_chFlat=None, evTest_cond_dim=None, dt=None, # --- optional ev_dimRel_condDense_fr_dim_meanvar=None, dt_model=None, to_plot_internals=True, to_plot_params=True, to_plot_choice=True, # to_group_irr=False, group_dcond_irr=None, to_combine_ch_irr_cond=True, kw_plot_pred=(), kw_plot_pred_ch=(), kw_plot_data=(), axs=None, ): """ :param data: :param pModel_cond_rt_chFlat: :param model: :param pModel_dimRel_condDense_chFlat: :param ev_dimRel_condDense_fr_dim_meanvar: :param to_plot_internals: :param to_plot_params: :param to_group_irr: :param to_combine_ch_irr_cond: :param kw_plot_pred: :param kw_plot_data: :param axs: :return: """ if data is None: if pTrain_cond_rt_chFlat is None: pTrain_cond_rt_chFlat = pAll_cond_rt_chFlat if evTrain_cond_dim is None: evTrain_cond_dim = evAll_cond_dim if pTest_cond_rt_chFlat is None: pTest_cond_rt_chFlat = pAll_cond_rt_chFlat if evTest_cond_dim is None: evTest_cond_dim = evAll_cond_dim else: _, pAll_cond_rt_chFlat, _, _, evAll_cond_dim = \ data.get_data_by_cond('all') _, pTrain_cond_rt_chFlat, _, _, evTrain_cond_dim = data.get_data_by_cond( 'train_valid', mode_train='easiest') _, pTest_cond_rt_chFlat, _, _, evTest_cond_dim = data.get_data_by_cond( 'test', mode_train='easiest') dt = data.dt hs = {} if model is None: assert not to_plot_internals assert not to_plot_params if dt_model is None: if model is None: dt_model = dt else: dt_model = model.dt if axs is None: if to_plot_params: axs = plt2.GridAxes(3, 3) else: if to_plot_internals: axs = plt2.GridAxes(3, 3) else: if to_plot_choice: axs = plt2.GridAxes(2, 2) else: axs = plt2.GridAxes(1, 2) # TODO: beautify ratios rts = [] hs['rt'] = [] for dim_rel in range(consts.N_DIM): # --- data_pred may not have all conditions, so concatenate the rest # of the conditions so that the color scale is correct. Then also # concatenate p_rt_ch_data_pred1 with zeros so that nothing is # plotted in the concatenated. evTest_cond_dim1 = np.concatenate([ evTest_cond_dim, evAll_cond_dim ], axis=0) pTest_cond_rt_chFlat1 = np.concatenate([ pTest_cond_rt_chFlat, np.zeros_like(pAll_cond_rt_chFlat) ], axis=0) if ev_dimRel_condDense_fr_dim_meanvar is None: evModel_cond_dim = evAll_cond_dim else: if ev_dimRel_condDense_fr_dim_meanvar.ndim == 5: evModel_cond_dim = npy(ev_dimRel_condDense_fr_dim_meanvar[ dim_rel][:, 0, :, 0]) else: assert ev_dimRel_condDense_fr_dim_meanvar.ndim == 4 evModel_cond_dim = npy(ev_dimRel_condDense_fr_dim_meanvar[ dim_rel][:, 0, :]) pModel_cond_rt_chFlat = npy(pModel_dimRel_condDense_chFlat[dim_rel]) if to_plot_choice: # --- Plot choice ax = axs[0, dim_rel] plt.sca(ax) if to_combine_ch_irr_cond: ev_cond_model1, p_rt_ch_model1 = combine_irr_cond( dim_rel, evModel_cond_dim, pModel_cond_rt_chFlat ) sim2d.plot_p_ch_vs_ev(ev_cond_model1, p_rt_ch_model1, dim_rel=dim_rel, style='pred', group_dcond_irr=None, kw_plot=kw_plot_pred_ch, cmap=lambda n: lambda v: [0., 0., 0.], ) else: sim2d.plot_p_ch_vs_ev(evModel_cond_dim, pModel_cond_rt_chFlat, dim_rel=dim_rel, style='pred', group_dcond_irr=group_dcond_irr, kw_plot=kw_plot_pred, cmap=cmaps[dim_rel] ) hs, conds_irr = sim2d.plot_p_ch_vs_ev( evTest_cond_dim1, pTest_cond_rt_chFlat1, dim_rel=dim_rel, style='data_pred', group_dcond_irr=group_dcond_irr, cmap=cmaps[dim_rel], kw_plot=kw_plot_data, ) hs1 = [h[0] for h in hs] odim = 1 - dim_rel odim_name = consts.DIM_NAMES_LONG[odim] legend_odim(conds_irr, hs1, odim_name) sim2d.plot_p_ch_vs_ev(evTrain_cond_dim, pTrain_cond_rt_chFlat, dim_rel=dim_rel, style='data_fit', group_dcond_irr=group_dcond_irr, cmap=cmaps[dim_rel], kw_plot=kw_plot_data ) plt2.detach_axis('x', np.amin(evTrain_cond_dim[:, dim_rel]), np.amax(evTrain_cond_dim[:, dim_rel])) ax.set_xlabel('') ax.set_xticklabels([]) if dim_rel != 0: plt2.box_off(['left']) plt.yticks([]) ax.set_ylabel('P(%s choice)' % consts.CH_NAMES[dim_rel][1]) # --- Plot RT ax = axs[int(to_plot_choice) + 0, dim_rel] plt.sca(ax) hs1, rts1 = sim2d.plot_rt_vs_ev( evModel_cond_dim, pModel_cond_rt_chFlat, dim_rel=dim_rel, style='pred', group_dcond_irr=group_dcond_irr, dt=dt_model, kw_plot=kw_plot_pred, cmap=cmaps[dim_rel] ) hs['rt'].append(hs1) rts.append(rts1) sim2d.plot_rt_vs_ev(evTest_cond_dim1, pTest_cond_rt_chFlat1, dim_rel=dim_rel, style='data_pred', group_dcond_irr=group_dcond_irr, dt=dt, cmap=cmaps[dim_rel], kw_plot=kw_plot_data ) sim2d.plot_rt_vs_ev(evTrain_cond_dim, pTrain_cond_rt_chFlat, dim_rel=dim_rel, style='data_fit', group_dcond_irr=group_dcond_irr, dt=dt, cmap=cmaps[dim_rel], kw_plot=kw_plot_data ) plt2.detach_axis('x', np.amin(evTrain_cond_dim[:, dim_rel]), np.amax(evTrain_cond_dim[:, dim_rel])) if dim_rel != 0: ax.set_ylabel('') plt2.box_off(['left']) plt.yticks([]) ax.set_xlabel(consts.DIM_NAMES_LONG[dim_rel].lower() + ' strength') if dim_rel == 0: ax.set_ylabel('RT (s)') if to_plot_internals: for ch1 in range(consts.N_CH): ch0 = dim_rel ax = axs[3 + ch1, dim_rel] plt.sca(ax) ch_flat = consts.ch_by_dim2ch_flat(np.array([ch0, ch1])) model.tnds[ch_flat].plot_p_tnd() ax.set_xlabel('') ax.set_xticklabels([]) ax.set_yticks([0, 1]) if ch0 > 0: ax.set_yticklabels([]) ax.set_ylabel(r"$\mathrm{P}(T^\mathrm{n} \mid" " \mathbf{z}=[%d,%d])$" % (ch0, ch1)) ax = axs[5, dim_rel] plt.sca(ax) if hasattr(model.dtb, 'dtb1ds'): model.dtb.dtb1ds[dim_rel].plot_bound(color='k') plt2.sameaxes(axs[-1, :consts.N_DIM], xy='y') if to_plot_params: ax = axs[0, -1] plt.sca(ax) model.plot_params() return axs, rts, hs
def plot_rt( parad, to_cumsum=False, corners_only=False, to_add_legend=True, use_easiest_only=0, plot_kind='rt_mean', model_names=('Ser', 'Par'), to_normalize_max=False, axs=None, ): """ :param parad: :param to_cumsum: :param corners_only: :param to_add_legend: :param use_easiest_only: :param plot_kind: 'mean'|'distrib' :return: """ subjs = consts.SUBJS[parad] n_models = len(model_names) p_preds = np.empty([n_models, len(subjs)], dtype=np.object) p_datas = np.empty([n_models, len(subjs)], dtype=np.object) condss = np.empty([n_models, len(subjs)], dtype=np.object) ts = np.empty([n_models, len(subjs)], dtype=np.object) p_pred_trains = np.empty([n_models, len(subjs)], dtype=np.object) p_data_trains = np.empty([n_models, len(subjs)], dtype=np.object) p_pred_tests = np.empty([n_models, len(subjs)], dtype=np.object) p_data_tests = np.empty([n_models, len(subjs)], dtype=np.object) colors = ('red', 'blue') normalize_ev = parad == 'RT' for i_model, model_name in enumerate(model_names): for i_subj, subj in enumerate(subjs): ( p_pred1, p_data1, conds1, t1, cond_ch_incl1, p_pred_train, p_data_train, p_pred_test, p_data_test, ) = load_fit(subj, parad, model_name, corners_only=corners_only, use_easiest_only=use_easiest_only) ix_abs_cond = np.unique(np.abs(conds1[1]), return_inverse=True)[1] sign_cond = np.sign(conds1[1]) if normalize_ev: conds1 = conds1 / np.amax(conds1, axis=1, keepdims=True) p_preds[i_model, i_subj], p_datas[i_model, i_subj], \ condss[i_model, i_subj], ts[i_model, i_subj] \ = p_pred1, p_data1, conds1, t1 ( p_pred_trains[i_model, i_subj], p_data_trains[i_model, i_subj], p_pred_tests[i_model, i_subj], p_data_tests[i_model, i_subj], ) = p_pred_train, p_data_train, p_pred_test, p_data_test siz0 = list(p_preds[0, 0].shape) def cell2array(p): return np.stack(p.flatten()).reshape( [len(model_names), len(subjs)] + siz0) p_preds = cell2array(p_preds) p_datas = cell2array(p_datas) p_pred_trains = cell2array(p_pred_trains) p_data_trains = cell2array(p_data_trains) p_pred_tests = cell2array(p_pred_tests) p_data_tests = cell2array(p_data_tests) def pool_subjs(p_datas, p_preds): n_data_subj = np.sum(p_datas, (2, 3, 4), keepdims=True) # P(ch, rt | cond, subj, model) p_preds = np2.nan2v(p_preds / p_preds.sum((3, 4), keepdims=True)) p_pred_avg = np.sum(p_preds * n_data_subj, 1) n_data_sum = np.sum(p_datas, 1) return p_pred_avg, n_data_sum p_pred_avg, n_data_sum = pool_subjs(p_datas, p_preds) p_pred_avg_train, n_data_sum_train = pool_subjs(p_data_trains, p_pred_trains) p_pred_avg_test, n_data_sum_test = pool_subjs(p_data_tests, p_pred_tests) ev_cond_dim = np.stack(condss[0, 0], -1) dt = 1 / 75 if plot_kind == 'rt_mean': if axs is None: axs = plt2.GridAxes(1, 2, top=0.4, left=0.6, right=0.1, bottom=0.65, wspace=0.35, heights=1.7, widths=2.2) for i_model, model_name in enumerate(model_names): hs = fit2d.plot_fit_combined( pAll_cond_rt_chFlat=n_data_sum[i_model], evAll_cond_dim=ev_cond_dim, pTrain_cond_rt_chFlat=n_data_sum_train[i_model], pTest_cond_rt_chFlat=n_data_sum_test[i_model], pModel_cond_rt_chFlat=p_pred_avg[i_model], dt=dt, to_plot_params=False, to_plot_internals=False, to_plot_choice=False, group_dcond_irr=None, kw_plot_pred={ # 'linestyle': ':', # 'alpha': 0.7, # 'linewidth': 2, # 'linestyle': '--' if model_name == 'Par' else '-', }, kw_plot_data={ 'markersize': 4, }, axs=axs, )[2] if to_add_legend: n_dim = 2 for dim in range(n_dim): plt.sca(axs[0, dim]) odim = consts.get_odim(dim) conds_irr = np.unique(np.abs(ev_cond_dim[:, odim])) hs1 = [v[0][0] for v in hs['rt'][dim]] hs1 = hs1[len(conds_irr):(len(conds_irr) * 2)] h = fit2d.legend_odim([np.round(v, 3) for v in conds_irr], hs1, '', loc='lower left', bbox_to_anchor=[1., 0.]) h.set_title(consts.DIM_NAMES_LONG[odim] + '\nstrength') plt.setp(h.get_title(), multialignment='center') ev_max = np.amax(condss[0, 0], -1) for col in [0, 1]: plt.sca(axs[-1, col]) txt = '%s strength' % consts.DIM_NAMES_LONG[col].lower() if normalize_ev: txt += '\n(a.u.)' plt.xlabel(txt) xticks = [-ev_max[col], 0, ev_max[col]] plt.xticks(xticks, ['%g' % v for v in xticks]) from matplotlib.ticker import MultipleLocator axs[0, 0].yaxis.set_major_locator(MultipleLocator(1)) print('--') elif plot_kind == 'rt_distrib': axs, hs = fit2d.plot_rt_distrib_pred_data( p_pred_avg, n_data_sum[0], ev_cond_dim, dt, to_normalize_max=to_normalize_max, to_cumsum=to_cumsum, to_skip_zero_trials=True, xlim=[0., 5.], kw_plot_data={ 'linewidth': 2, 'linestyle': ':', 'alpha': 0.75, }, labels=['serial', 'parallel', 'data'], colors=colors, ) if to_add_legend: plt.sca(axs[0, 0]) locs = { 'loc': 'center right', 'bbox_to_anchor': (1.05, 0., 0., 1.) } if to_cumsum else { 'loc': 'upper right', 'bbox_to_anchor': (1.05, 1.01) } plt.legend(**locs, handlelength=0.8, handletextpad=0.5, frameon=False, borderpad=0.) print('--') else: raise ValueError() return axs
def main_plot(to_use_easiest_only=0): # === Mean RT plot_kind = 'rt_mean_row_per_model' axs = plt2.GridAxes(4, 4, top=0.6, left=0.7, right=0.7, bottom=0.6, wspace=[0.8, 1.6, 0.8], hspace=[0.1, 1.2, 0.1], heights=1.5, widths=1.7) inds = ['A', 'B', 'C'] parads = ['RT', 'unimanual', 'bimanual'] rowcol_toplefts = [(0, 0), (0, 2), (2, 0)] fit_kinds = ['Fit all', 'Fit strongest,\npredict rest'] models = ['Ser', 'Par'] model_names = ['serial', 'parallel'] n_dim = 2 for rowcol_topleft, parad, ind in zip(rowcol_toplefts, parads, inds): for row_shift, model in enumerate(models): row = rowcol_topleft[0] + row_shift col = rowcol_topleft[1] axs1 = axs[row:(row + 1), col:(col + n_dim)] plot_rt(parad, to_cumsum=False, corners_only=False, plot_kind='rt_mean', model_names=[model], use_easiest_only=to_use_easiest_only, axs=axs1, to_add_legend=row_shift == 0) if row_shift == 0: for ax in axs1.flatten(): ax.set_xticklabels([]) ax.set_xlabel('') axs1.suptitle(get_title_parad(parad), pad=0.05, va='bottom', xprop=0.55) axs1.suptitle(ind, pad=0.05, xprop=-0.05, va='bottom', fontweight='bold') axs1[0, 0].set_ylabel('Response time (s)\n(%s model)' % model_names[row_shift]) for ax in axs[2:, 2:].flatten(): ax.set_visible(False) for ext in ['.png', '.pdf']: file = locfile.get_file_fig(plot_kind, { 'parad': parads, 'easiest_only': to_use_easiest_only, }, ext=ext) localfile.mkdir4file(file) plt.savefig(file) print('Saved to %s' % file) # === RT distrib plot_kind = 'rt_distrib' to_normalize_max = False for to_cumsum in [True]: # [False, True]: for corners_only in [True]: # , False]: for parad in ['RT', 'unimanual', 'bimanual']: axs = plot_rt( parad, to_cumsum=to_cumsum, corners_only=corners_only, plot_kind=plot_kind, use_easiest_only=to_use_easiest_only, to_add_legend=parad == 'RT', ) axs.suptitle(get_title_parad(parad), pad=0.25) for ext in ['.png', '.pdf']: file = locfile.get_file_fig(plot_kind, { 'parad': parad, 'cumsum': to_cumsum, 'corner': corners_only, 'easiest_only': to_use_easiest_only, 'nrmmax': to_normalize_max, }, ext=ext) localfile.mkdir4file(file) plt.savefig(file) print('Saved to %s' % file) print('--')
def main_fit( subjs=None, skip_fit_if_absent=False, bufdur=None, bufdur_best=0.12, bufdurs_sim=None, bufdurs_fit=None, loss_kind='NLL', plot_kind='line_sim_fit', fix_post=None, seed_sims=(0, ), err_kind='std', base=10, ): # if torch.cuda.is_available(): # torch.set_default_tensor_type(torch.cuda.FloatTensor) torch.set_num_threads(1) torch.set_default_dtype(torch.double) dict_res_add = {} parad = 'VD' seed_sims = np.array(seed_sims) if subjs is None: subjs = consts.SUBJS_VD if bufdur is not None: bufdurs_sim = list({bufdur_best}.union({bufdur})) bufdurs_fit = list({bufdur_best}.union({bufdur})) else: if bufdurs_sim is None: bufdurs_sim = bufdurs0 # if bufdurs_fit is None: bufdurs_fit = bufdurs0 else: if bufdurs_fit is None: bufdurs_fit0 = [0., bufdur_best, 1.2] if bufdurs_sim[0] in bufdurs_fit0: bufdurs_fit = copy(bufdurs_fit0) else: bufdurs_fit = bufdurs_fit0[:2] + bufdurs_sim + [1.2] n_dur_sim = len(bufdurs_sim) n_dur_fit = len(bufdurs_fit) # DEF: losses[seed, subj, simSerDurPar, fitDurs] size_all = [len(seed_sims), len(subjs), n_dur_sim, n_dur_fit] losses0 = np.zeros(size_all) + np.nan ns = np.zeros(size_all) + np.nan ks0 = np.zeros(size_all) + np.nan for i_seed, seed_sim in enumerate(seed_sims): for i_subj, subj in enumerate(subjs): for i_fit, bufdur_fit in enumerate(bufdurs_fit): for i_sim, bufdur_sim in enumerate(bufdurs_sim): d, dict_fit_sim, dict_subdir_sim = get_fit_sim( subj, seed_sim, bufdur_sim, bufdur_fit, parad=parad, skip_fit_if_absent=skip_fit_if_absent, fix_post=fix_post, ) if d is not None: losses0[i_seed, i_subj, i_sim, i_fit] = d['loss_NLL_test'] ns[i_seed, i_subj, i_sim, i_fit] = d['loss_ndata_test'] ks0[i_seed, i_subj, i_sim, i_fit] = d['loss_nparam'] ks = copy(ks0) if loss_kind == 'BIC': losses = ks * np.log(ns) + 2 * losses0 # REF: Kass, Raftery 1995 https://doi.org/10.2307%2F2291091 # https://en.wikipedia.org/wiki/Bayesian_information_criterion#Gaussian_special_case thres_strong = 10. / np.log(base) elif loss_kind == 'NLL': losses = copy(losses0) thres_strong = np.log(100) / np.log(base) else: raise ValueError('Unsupported loss_kind: %s' % loss_kind) #%% losses = losses / np.log(base) if plot_kind == 'loss_serpar': n_row = len(subjs) n_col = 1 axs = plt2.GridAxes(n_row, n_col, left=1, widths=2, bottom=0.75) for i_subj, subj in enumerate(subjs): plt.sca(axs[i_subj, 0]) loss_dur_sim_ser_fit = losses[:, i_subj, :, 0].mean(0) loss_dur_sim_par_fit = losses[:, i_subj, :, -1].mean(0) dloss_serpar = loss_dur_sim_par_fit - loss_dur_sim_ser_fit plt.bar(bufdurs_sim, dloss_serpar, color='k') plt.axhline(0, color='k', linestyle='-', linewidth=0.5) if i_subj < len(subjs) - 1: plt2.box_off(['top', 'right', 'bottom']) plt.xticks([]) else: plt2.box_off() plt.xticks(np.arange(0, 1.4, 0.2), ['0\nser'] + [''] * 2 + ['0.6'] + [''] * 2 + ['1.2\npar']) plt2.detach_axis('x', 0, 1.2) vdfit.patch_chance_level() plt.sca(axs[0, 0]) plt.title('Support for Serial\n' '($\mathcal{L}_\mathrm{ser} - \mathcal{L}_\mathrm{par}$)') plt.sca(axs[-1, 0]) plt.xlabel('true buffer duration (s)') plt2.rowtitle(consts.SUBJS_VD, axs) elif plot_kind == 'imshow_ser_buf_par': n_row = len(subjs) n_col = 1 axs = plt2.GridAxes(n_row, n_col, left=1, widths=2, heights=2, bottom=0.75) for i_subj, subj in enumerate(subjs): plt.sca(axs[i_subj, 0]) dloss = losses[:, i_subj, :, :].mean(0) dloss = dloss - np.diag(dloss)[:, None] plt.imshow(dloss, cmap='bwr') cmax = np.amax(np.abs(dloss)) plt.clim(-cmax, +cmax) plt.sca(axs[0, 0]) plt.sca(axs[-1, 0]) plt.xlabel('true buffer duration (s)') plt.ylabel('model buffer duration (s)') plt2.rowtitle(consts.SUBJS_VD, axs) elif plot_kind == 'line_sim_fit': bufdur_bests = np.array([get_bufdur_best(subj) for subj in subjs]) i_bests = np.array([ list(bufdurs_fit).index(bufdur_best) for bufdur_best in bufdur_bests ]) i_subjs = np.arange(len(subjs)) loss_sim_best_fit_best = losses[:, i_subjs, i_bests, i_bests] loss_sim_best_fit_rest = (losses[:, i_subjs, i_bests, :] - loss_sim_best_fit_best[:, :, None]) mean_loss_sim_best_fit_rest = np.mean(loss_sim_best_fit_rest, 0) if err_kind == 'std': err_loss_sim_best_fit_rest = np.std(loss_sim_best_fit_rest, 0) elif err_kind == 'sem': err_loss_sim_best_fit_rest = np2.sem(loss_sim_best_fit_rest, 0) n_dur = len(bufdurs_fit) loss_sim_rest_fit_best = np.swapaxes( np.stack([ losses[:, i_subj, np.arange(n_dur), i_best] - losses[:, i_subj, np.arange(n_dur), np.arange(n_dur)] for i_subj, i_best in zip(i_subjs, i_bests) ]), 0, 1) mean_loss_sim_rest_fit_best = np.mean(loss_sim_rest_fit_best, 0) if err_kind == 'std': err_loss_sim_rest_fit_best = np.std(loss_sim_rest_fit_best, 0) elif err_kind == 'sem': err_loss_sim_rest_fit_best = np2.sem(loss_sim_rest_fit_best, 0) dict_res_add['err'] = err_kind n_row = 2 n_col = len(subjs) axs = plt2.GridAxes(n_row, n_col, left=1.15, right=0.1, widths=2, heights=1.5, wspace=0.35, hspace=0.5, top=0.25, bottom=0.6) for row, (m, s, xlabel, ylabel) in enumerate([ (mean_loss_sim_best_fit_rest, err_loss_sim_best_fit_rest, 'model buffer capacity (s)', (r'$-\mathrm{log}_{%g}\mathrm{BF}$ given simulated' + '\nbest duration data') % base), (mean_loss_sim_rest_fit_best, err_loss_sim_rest_fit_best, 'simulated data buffer capacity (s)', (r'$-\mathrm{log}_{%g}\mathrm{BF}$ of the' + '\nbest duration model') % base) ]): for i_subj, subj in enumerate(subjs): ax = axs[row, i_subj] plt.sca(ax) gs1 = axs.gs[row * 2 + 1, i_subj * 2 + 1] bax = vdfit.breakaxis(gs1) bax.axs[1].errorbar(bufdurs0, m[i_subj, :], yerr=s[i_subj, :], color='k', marker='.', linewidth=0.75, elinewidth=0.5, markersize=3) m1 = copy(m[i_subj, :]) m1[3:] = np.nan s1 = copy(s[i_subj, :]) s1[3:] = np.nan bax.axs[0].errorbar(bufdurs0, m1, yerr=s1, color='k', marker='.', linewidth=0.75, elinewidth=0.5, markersize=3) ax1 = bax.axs[1] # type: plt.Axes plt.sca(ax1) vdfit.patch_chance_level(level=thres_strong, signs=[-1, 1]) plt.axhline(0, color='k', linestyle='--', linewidth=0.5) vdfit.beautify_ticks(ax1, ) vdfit.beautify(ax1) ax1.set_yticks([0, 20]) if i_subj > 0: ax1.set_yticklabels([]) if row == 0: ax1.set_xticklabels([]) ax1 = bax.axs[0] # type: plt.Axes plt.sca(ax1) ax1.set_yticks([40, 200]) if i_subj > 0: ax1.set_yticklabels([]) plt.sca(ax) plt2.box_off('all') if row == 0: plt.title(consts.SUBJS['VD'][i_subj]) if i_subj == 0: plt.xlabel(xlabel, labelpad=8 if row == 0 else 20) plt.ylabel(ylabel, labelpad=30) plt2.sameaxes(bax.axs, xy='x') elif plot_kind == 'bar_sim_fit': i_best = bufdurs_fit.index(bufdur_best) loss_sim_best_fit_best = losses[0, :, i_best, i_best] loss_sim_best_fit_rest = np2.nan2v(losses[0, :, i_best, :] - loss_sim_best_fit_best[:, None]) n_dur = len(bufdurs_fit) loss_sim_rest_fit_best = np2.nan2v( losses[0, :, np.arange(n_dur), i_best] - losses[0, :, np.arange(n_dur), np.arange(n_dur)]).T n_row = 2 n_col = len(subjs) axs = plt2.GridAxes(n_row, n_col, left=1, widths=3, right=0.25, heights=1, wspace=0.3, hspace=0.6, top=0.25, bottom=0.6) for i_subj, subj in enumerate(subjs): def plot_bars(gs1, bufdurs, losses1, add_xticklabel=True): i_break = np.amax(np.nonzero(np.array(bufdurs) < 0.3)[0]) bax = brokenaxes( subplot_spec=gs1, xlims=((-1, i_break + 0.5), (i_break + 0.5, len(bufdurs) - 0.5)), ylims=((-3, 20), (20, 1250 / 5)), height_ratios=(50 / 100, 500 / (1250 - 100)), hspace=0.15, wspace=0.075, d=0.005, ) bax.bar(np.arange(len(bufdurs)), losses1[i_subj, :], color='k') ax11 = bax.axs[3] # type: plt.Axes ax11.set_xticks([bufdurs.index(0.6), bufdurs.index(1.2)]) if i_subj == 0 and add_xticklabel: ax11.set_xticklabels(['0.6', '1.2']) else: ax11.set_xticklabels([]) ax00 = bax.axs[0] # type: plt.Axes ax00.set_yticks([500, 1000]) ax10 = bax.axs[2] # type: plt.Axes ax10.set_yticks([0, 50]) plt.sca(ax10) plt2.detach_axis('x', amin=-0.4, amax=i_break + 0.5) for ax in [ax10, ax11]: plt.sca(ax) plt.axhline(0, linewidth=0.5, color='k', linestyle='--') for sign in [-1, 1]: plt.axhline(sign * thres_strong, linewidth=0.5, color='silver', linestyle='--') ax10.set_xticks([bufdurs.index(0.), bufdurs.index(0.2)]) if i_subj == 0: if add_xticklabel: ax10.set_xticklabels(['0', '0.2']) else: ax10.set_xticklabels([]) else: ax10.set_yticklabels([]) ax10.set_xticklabels([]) ax00.set_yticklabels([]) return bax bax = plot_bars(axs.gs[1, i_subj * 2 + 1], bufdurs_fit, loss_sim_best_fit_rest, add_xticklabel=False) ax = axs[0, i_subj] plt.sca(ax) plt2.box_off('all') plt.title(consts.SUBJS['VD'][consts.SUBJS['VD'].index(subj)]) if i_subj == 0: ax.set_ylabel( 'misfit to simulated\nbest duration data\n' r'($\Delta$BIC)', labelpad=35) ax.set_xlabel('model buffer duration (s)', labelpad=8) plot_bars(axs.gs[3, i_subj * 2 + 1], bufdurs_sim, loss_sim_rest_fit_best, add_xticklabel=True) ax = axs[1, i_subj] plt.sca(ax) plt2.box_off('all') if i_subj == 0: ax.set_ylabel( 'misfit of\nbest duration model\n' r'($\Delta$BIC)', labelpad=35) ax.set_xlabel('simulated data buffer duration (s)', labelpad=20) elif plot_kind == 'bar_ser_buf_par': n_row = len(subjs) n_col = 1 axs = plt2.GridAxes(n_row * n_dur_sim, n_col, left=1.25, widths=1.5, right=0.25, heights=0.4, hspace=[0.15] * (n_dur_sim - 1) + [0.5] + [0.15] * (n_dur_sim - 1), top=0.6, bottom=0.5) row = -1 for i_subj, subj in enumerate(subjs): for i_sim in range(n_dur_sim): row += 1 plt.sca(axs[row, 0]) loss1 = losses[:, i_subj, i_sim, :].mean(0) dloss = loss1 - loss1[i_sim] x = np.arange(n_dur_fit) for x1, dloss1 in zip(x, dloss): plt.bar(x1, dloss1, color='r' if dloss1 > 0 else 'b') plt.axhline(0, color='k', linewidth=0.5) vdfit.patch_chance_level(6.) plt.ylim([-100, 100]) plt.yticks([-100, 0, 100], [''] * 2) plt.ylabel('%g' % bufdurs_sim[i_sim], rotation=0, va='center', ha='right') if i_sim == 0: plt.title(subj) if i_sim < n_dur_sim - 1 or i_subj < len(subjs) - 1: plt2.box_off(['top', 'bottom', 'right']) plt.xticks([]) else: plt2.box_off(['top', 'right']) plt.sca(axs[-1, 0]) plt.xticks(np.arange(n_dur_sim), ['%g' % v for v in bufdurs_sim]) plt2.detach_axis('x', 0, n_dur_sim - 1) plt.xlabel('model buffer duration (s)', fontsize=10) c = axs[-1, 0].get_position().corners() plt.figtext(x=0.15, y=np.mean(c[:2, 1]), fontsize=10, s='true\nbuffer\nduration\n(s)', rotation=0, va='center', ha='center') plt.figtext( x=(1.25 + 1.5 / 2) / (1.25 + 1.5 + 0.25), y=0.98, s=r'$\mathrm{BIC} - \mathrm{BIC}_\mathrm{true}$', ha='center', va='top', fontsize=12, ) if plot_kind != 'None': dict_res = deepcopy(dict_fit_sim) # noqa for k in ['sbj', 'prd']: dict_res.pop(k) dict_res = {**dict_res, 'los': loss_kind} dict_res.update(dict_res_add) for ext in ['.png', '.pdf']: file = locfile.get_file_fig(plot_kind, dict_res, subdir=dict_subdir_sim, ext=ext) # noqa plt.savefig(file, dpi=300) print('Saved to %s' % file) #%% print('--') return losses
def main(base=10.): ds = load_comp(os.path.join(locfile_in.pth_root, file_model_comp)) ds['dcost'] = np.array(ds['dcost']) / np.log(base) vmax = 160 m = np.empty(3) e = np.empty(3) axs = plot_bar_dloss_across_subjs(dlosses=np.array(ds['dcost']), subj_parad_bis=ds['subj_parad_bi'], vmax=vmax) axs = plt2.GridAxes(nrows=1, ncols=3, heights=axs.heights, widths=[2], left=axs.left, right=axs.right, bottom=axs.bottom) plot_bar_dloss_across_subjs( dlosses=np.array(ds['dcost']), subj_parad_bis=ds['subj_parad_bi'], vmax=vmax, axs=axs[:, [0]], base=base, ) plt.title('Data') m[0] = np.mean(ds['dcost']) e[0] = np2.sem(ds['dcost']) print('--') titles = ['Simulated\nSerial', 'Simulated\nParallel'] for i in range(2): ds = load_comp( os.path.join(locfile_in.pth_root, files_model_recovery[i])) ds['dcost'] = np.array(ds['dcost']) / np.log(base) plot_bar_dloss_across_subjs( dlosses=ds['dcost'], subj_parad_bis=ds['subj_parad_bi'], vmax=vmax, axs=axs[:, [i + 1]], base=base, ) plt.title(titles[i]) m[i + 1] = np.mean(ds['dcost']) e[i + 1] = np2.sem(ds['dcost']) plt.sca(axs[0, i + 1]) plt2.box_off(['left']) plt.yticks([]) for ext in ['.pdf', '.png']: file = locfile_out.get_file_fig('model_comp_recovery', {'easiest_only': use_easiest_only}, ext=ext) plt.savefig(file, dpi=300) print('Saved to %s' % file) # --- Print mean +- SEM to CSV csv_file = locfile_out.get_file_csv('model_comp_recovery', {'easiest_only': use_easiest_only}) d_csv = np2.dictlist2listdict({ 'data': ['original', 'simulated serial', 'simulated parallel'], 'mean_dcost': m, 'sem_dcost': e, }) with open(csv_file, 'w') as f: writer = csv.DictWriter(f, fieldnames=d_csv[0].keys()) writer.writeheader() for d in d_csv: writer.writerow(d) print('Wrote to %s' % csv_file) # --- Mean NLL within each axs = plt2.GridAxes( 1, 1, left=0.85, right=0.25, heights=[1.5], top=0.1, bottom=0.9, ) plt.sca(axs[0, 0]) plt.barh(np.arange(3), m, xerr=e, color='w', edgecolor='k') plt.yticks(np.arange(3), ['data', 'simulated\nserial', 'simulated\nparallel']) plt2.box_off(['top', 'right']) vmax = np.amax(np.abs(m) + np.abs(e)) plt.xlim(np.array([-vmax, vmax]) * 1.1) plt2.detach_axis(xy='y', amin=0, amax=2) plt2.detach_axis(xy='x', amin=-vmax, amax=vmax) axvline_dcost() xticks_serial_vs_parallel(vmax, base) for ext in ['.pdf', '.png']: file = locfile_out.get_file_fig('model_comp_vs_recovery', {'easiest_only': use_easiest_only}, ext=ext) plt.savefig(file, dpi=300) print('Saved to %s' % file) print('--')
def plot_bar_dloss_across_subjs( dlosses, elosses=None, ix_datas=None, subj_parad_bis: Iterable[Tuple[str, str, bool]] = None, axs: Union[plt2.GridAxes, plt2.AxesArray] = None, vmax=None, add_scale=True, base=10., ): """ :param dlosses: [ix_data] :param ix_datas: :param axs: :param subj_parad_bis: [('subj', 'parad', is_bimanual), ...] :return: axs """ if subj_parad_bis is None: subj_parad_bis = subj_parad_bis0 if vmax is None: vmax = np.amax(np.abs(dlosses)) # order: eye S1-S3, hand by ID, paired uni-bimanual subjs, parads, bis = zip(*subj_parad_bis) subjs = np.array( ['ID0' + v[-1] if v[:2] == 'ID' and len(v) == 3 else v for v in subjs]) parads = np.array(parads) bis = np.array(bis) is_eye = parads == 'RT' is_bin = parads == 'binary' ix = np.arange(len(subjs)) def filt_sort(filt): ind = [int(subj[1:]) for subj in subjs[filt]] return ix[filt][np.argsort(ind)] ix = np.concatenate([ filt_sort(is_eye & ~is_bin), np.stack([ filt_sort(~is_eye & ~bis & ~is_bin), filt_sort(~is_eye & bis & ~is_bin) ], -1).flatten('C'), filt_sort(is_bin) ]) subjs = subjs[ix] parads = parads[ix] bis = bis[ix] is_eye = is_eye[ix] dlosses = dlosses[ix] subj_parad_bis = subj_parad_bis[ix] n_eye = int(np.sum(is_eye)) n_hand = int(np.sum(~is_eye)) y = np.empty([n_eye + n_hand]) y[is_eye] = 1.5 + np.arange(n_eye) y[~is_eye] = n_eye - 1 + 1.5 + np.cumsum([1.5, 1.] * (n_hand // 2)) y_max = np.amax(y) + 1.5 if axs is None: axs = plt2.GridAxes(nrows=1, ncols=1, heights=y_max * 0.2, widths=2, left=1.5, right=0.25, bottom=0.85) ax = axs[0, 0] plt.sca(ax) m = dlosses if elosses is None: e = np.zeros_like(m) else: e = elosses for y1, m1, e1, parad1, bi1 in zip(y, m, e, parads, bis): plt.barh(y1, m1, xerr=e1, color=colors_parad[(parad1, '%s' % bi1)], edgecolor='None') if add_scale: dy = y[1] - y[0] axvline_dcost() x_lim = [-vmax * 1.2, vmax * 1.2] for ix_big in range(len(y)): if np.abs(m[ix_big]) > vmax: for i_sign, sign in enumerate([1, -1]): plt2.patch_wave( y[ix_big], x_lim[i_sign] * 1.01, ax=ax, color='w', wave_margin=0.15, wave_amplitude=sign * 0.025, ) plt.xlim(x_lim) xticks_serial_vs_parallel(vmax, base) subj_parad_bi_str = get_subj_parad_bi_str(subj_parad_bis) plt.yticks(y, subj_parad_bi_str) plt2.detach_axis('y', y[0], y[-1]) plt2.detach_axis('x', -vmax, vmax) plt.ylim([y_max - 1, 1.]) return axs
def main_plot_models_on_columns(i_subjs=(0, 1), axs=None, coef='slope'): if coef not in ['slope', 'bias']: raise ValueError() n_subj = len(i_subjs) # ---- Load goodness-of-fit def load_gof(i_subj, buffix_dur): fix = fix0_pre + [('buffix', buffix_dur)] + fix0_post dict_cache, subdir = main_fit(i_subj, fit_mode='dict_cache_only', fix=fix) file = locfile.get_file('tab', 'best_loss', dict_cache, ext='.csv', subdir=subdir) import csv, os rows = None if not os.path.exists(file): print('File absent - returning NaN: %s' % file) gof = np.nan return gof with open(file, 'r') as csvfile: reader = csv.DictReader(csvfile) for row in reader: if rows is None: rows = {k: [row[k]] for k in row.keys()} else: for k in row.keys(): rows[k].append(row[k]) gof_kind = 'loss_NLL_test' igof = rows['name'].index(gof_kind) gof = float(rows[' value'][igof][3:]) return gof bufdurs = [0., bufdur_best, 1.2] nbufdurs = len(bufdurs) gofs = np.zeros([nbufdurs, n_subj]) for i_subj in range(n_subj): for idur, bufdur in enumerate(bufdurs): gofs[idur, i_subj] = load_gof(i_subj, bufdur) gofs = gofs - np.nanmin(gofs, 0, keepdims=True) # ---- Load coefs fixs = [ ('buffer+serial', [ ('buffix', bufdur_best), ]), ('parallel', [ ('buffix', 1.2), ]), ('serial', [ ('buffix', 0.), ]), ] model_names = [f[0] for f in fixs] n_models = len(model_names) models = np.empty([n_models, n_subj], dtype=np.object) dict_caches = np.empty([n_models, n_subj], dtype=np.object) ds = np.empty([n_models, n_subj], dtype=np.object) datas = np.empty([n_subj], dtype=np.object) subdir = ','.join(fix0) + '+buffix=' + ','.join( [('%1.2f' % v) for v in [0., bufdur_best, 1.2]]) for i_subj in range(n_subj): for i_model, (name, fix1) in enumerate(fixs): fix = fix0_pre + fix1 + fix0_post model, data, dict_cache, d, _ = main_fit(i_subj, fit_mode='load', fix=fix) models[i_model, i_subj] = model datas[i_subj] = data dict_caches[i_model, i_subj] = dict_cache ds[i_model, i_subj] = d # ---- Plot coefs kw_plots = { 'serial': { 'linestyle': '-' }, 'buffer+serial': { 'linestyle': '-' }, 'parallel': { 'linestyle': '-' }, } def beautify(row, col, axs): plt.sca(axs[row, col]) plt.xlim(xmax=1.2 * 1.05) plt2.detach_axis('x', 0, 1.2) plt2.box_off() beautify_ticks(axs[row, col], add_ticklabel=(row == n_row - 1) and (col == 0)) if col > 0: plt.ylabel('') import matplotlib as mpl for i_subj in range(n_subj): fig = plt.figure(figsize=[2.25 * n_models, 5.25]) n_row = 3 n_col = n_models n_col = len(fixs) n_row = 2 axs = plt2.GridAxes(n_row, n_col, top=.7, left=1.15, right=0.1, bottom=.5, widths=[2], wspace=0.35, heights=[1.5], hspace=[0.2]) # ---- Plot slopes title_model = { 'buffer+serial': 'parallel followed\nby serial', 'parallel': 'strictly parallel', 'serial': 'strictly serial' } for ii_model, model_name in enumerate( # ['buffer+serial', 'serial']): ['buffer+serial', 'parallel', 'serial']): i_model = model_names.index(model_name) fix1 = fixs[i_model] model = models[i_model, i_subj] data = datas[i_subj] name = model_names[i_model] plot_coefs_dur_odif_pred_data( data, model, axs=axs[:2, :][:, [ii_model]], to_plot_data=True, kw_plot_model=kw_plots[model_name], coefs_to_plot=[1 if coef == 'slope' else 0], add_rowtitle=False) for row in range(2): plt.sca(axs[row, ii_model]) plt2.sameaxes(axs[row, :], xy='y') beautify(row, ii_model, axs) if ii_model > 0: plt.gca().set_yticklabels([]) if row == 0: plt.title(title_model[model_name]) plt.xlabel('') else: plt.title('') if ii_model == 0: plt.xlabel('stimulus duration (s)') else: plt.xlabel('') bnd = axs[0, 1].get_position().bounds subj = consts.SUBJS['VD'][i_subj] plt.suptitle(subj, x=bnd[0] + bnd[2] / 2) rowtitles = [ r'Motion sensitivity ($\beta$)', r'Color sensitivity ($\beta$)', ] if coef == 'slope' else [ r'Motion bias ($\beta$)', r'Color bias ($\beta$)', ] for row, title in enumerate(rowtitles): plt.sca(axs[row, 0]) plt.ylabel(title) dict_cache = dict_caches[0, 0] if 'fix' in dict_cache: dict_cache.pop('fix') dict_cache['sbj'] = subj dict_cache['coef'] = coef for ext in ['.png', '.pdf']: file = locfile.get_file_fig('coefs_dur_odif', dict_cache, subdir=subdir, ext=ext) plt.savefig(file, dpi=300) print('Saved to %s' % file) print('--') return models, data
def main_plot_across_models( i_subjs=i_subjs, axs=None, models_incl=('buffer+serial', ), base=10, **kwargs, ): n_subj = len(i_subjs) # ---- Load goodness-of-fit def load_gof(i_subj, buffix_dur): fix = fix0_pre + [('buffix', buffix_dur)] + fix0_post dict_cache, subdir = main_fit(i_subj, fit_mode='dict_cache_only', fix=fix, **kwargs) file = locfile.get_file('tab', 'best_loss', dict_cache, ext='.csv', subdir=subdir) import csv, os rows = None if not os.path.exists(file): print('File absent - returning NaN: %s' % file) gof = np.nan return gof with open(file, 'r') as csvfile: reader = csv.DictReader(csvfile) for row in reader: if rows is None: rows = {k: [row[k]] for k in row.keys()} else: for k in row.keys(): rows[k].append(row[k]) gof_kind = 'loss_NLL_test' igof = rows['name'].index(gof_kind) gof = float(rows[' value'][igof][3:]) return gof nbufdurs = len(bufdurs0) gofs = np.zeros([nbufdurs, n_subj]) for i_subj in range(n_subj): for idur, bufdur in enumerate(bufdurs0): gofs[idur, i_subj] = load_gof(i_subj, bufdur) gofs = gofs - np.nanmin(gofs, 0, keepdims=True) gofs = gofs / np.log(base) # ---- Load slopes fixs = [ ('buffer+serial', [ ('buffix', bufdur_best), ]), # ( # 'parallel', [ # ('buffix', 1.2), # ]), # ( # 'serial', [ # ('buffix', 0.), # ]), ] model_names = [f[0] for f in fixs] n_models = len(model_names) models = np.empty([n_models, n_subj], dtype=np.object) dict_caches = np.empty([n_models, n_subj], dtype=np.object) ds = np.empty([n_models, n_subj], dtype=np.object) datas = np.empty([n_subj], dtype=np.object) kw_plots = { 'serial': { 'linestyle': '--' }, 'buffer+serial': { 'linestyle': '-' }, 'parallel': { 'linestyle': ':' }, } subdir = ','.join(fix0) + '+buffix=' + ','.join( [('%1.2f' % v) for v in [0., bufdur_best, 1.2]]) for i_subj in range(n_subj): for i_model, (name, fix1) in enumerate(fixs): fix = fix0_pre + fix1 + fix0_post # try: model, data, dict_cache, d, _ = main_fit(i_subj, fit_mode='load', fix=fix, **kwargs) # except RuntimeError: # model = None # data = None # dict_cache = None # d = None models[i_model, i_subj] = model datas[i_subj] = data dict_caches[i_model, i_subj] = dict_cache ds[i_model, i_subj] = d # ---- Plot goodness-of-fit if axs is None: n_row = 2 + len(models_incl) n_row_gs = n_row + 2 n_col = n_subj axs = plt2.GridAxes(3, 2, top=.3, left=1.15, right=0.1, bottom=.5, widths=[2], wspace=0.35, heights=[1.5, 1.5, 1.5], hspace=[0.2, 0.9]) for i_subj in range(n_subj): ax = axs[-1, i_subj] plt.sca(ax) plt2.box_off('all') gs1 = axs.gs[-2, i_subj * 2 + 1] bax = breakaxis(gs1) ax0 = bax.axs[0] # type: plt.Axes ax1 = bax.axs[1] # type: plt.Axes ax0.plot(bufdurs0[:3], gofs[:3, i_subj], 'k.-', linewidth=0.75, markersize=4.5) ax1.plot(bufdurs0, gofs[:, i_subj], 'k.-', linewidth=0.75, markersize=4.5) plt.sca(ax1) patch_chance_level(level=np.log(100.) / np.log(base), signs=[-1, 1]) plt.axhline(0, color='k', linestyle='--', linewidth=0.5) beautify(ax1) beautify_ticks(ax1, add_ticklabel=True) # i_subj == 0) plt2.sameaxes([ax0, ax1], xy='x') ax1.set_yticks([0, 20]) ax0.set_yticks([40, 200]) if i_subj == 0: plt.sca(ax1) plt.xlabel('buffer capacity (s)') plt.sca(ax) plt.ylabel(r'$-\mathrm{log}_{10}\mathrm{BF}$', labelpad=27) else: ax0.set_yticklabels([]) ax1.set_yticklabels([]) plt.sca(axs[0, i_subj]) beautify_ticks(axs[0, i_subj], add_ticklabel=False) beautify_ticks(axs[1, i_subj], add_ticklabel=True) plt.sca(axs[-2, 0]) plt.xlabel('stimulus duration (s)') # ---- Plot slopes for i_subj in range(n_subj): hss = [] for model_name in models_incl: i_model = model_names.index(model_name) fix1 = fixs[i_model] model = models[i_model, i_subj] data = datas[i_subj] name = model_names[i_model] if model is not None: _, hs = plot_coefs_dur_odif_pred_data( data, model, axs=axs[:2, [i_subj]], # to_plot_data=True, to_plot_data=True, kw_plot_model=kw_plots[model_name], coefs_to_plot=[1], add_rowtitle=False) hss.append(hs) if i_subj == 0: # hss[0] = hs['pred'|'data'][coef, dim, dif] hs1 = hss[0]['data'] # type: np.ndarray for dim in range(hs1.shape[1] - 1, -1, -1): hs = [] for dif in range(hs1.shape[2]): hs.append(hs1[0, dim, dif]) odim = 1 - dim odim_name = consts.DIM_NAMES_LONG[odim] plt.sca(axs[dim, 0]) plt.legend(hs, [ 'weak ' + odim_name.lower(), 'strong ' + odim_name.lower() ], bbox_to_anchor=[0., 0., 1., 1.], handletextpad=0.35, handlelength=0.5, labelspacing=0.3, borderpad=0., frameon=False, loc='upper left') for row in range(axs.shape[0]): beautify(axs[row, i_subj]) if i_subj > 0: plt.ylabel('') plt2.sameaxes(axs[-1, :]) ax = axs[-1, -1] ax.set_yticklabels([]) for i_subj in range(n_subj): plt.sca(axs[0, i_subj]) plt.title(consts.SUBJS['VD'][i_subj]) for row in range(1, n_row): plt.sca(axs[row, i_subj]) plt.title('') for row, title in enumerate([ r'Motion sensitivity ($\beta$)', r'Color sensitivity ($\beta$)', ]): plt.sca(axs[row, 0]) plt.ylabel(title) dict_cache = dict_caches[0, 0] for k in ['fix', 'sbj']: if k in dict_cache: dict_cache.pop(k) for ext in ['.pdf', '.png']: file = locfile.get_file_fig('coefs_dur_odif_sbjs', dict_cache, subdir=subdir, ext=ext) from lib.pylabyk.cacheutil import mkdir4file mkdir4file(file) plt.savefig(file, dpi=300) print('Saved to %s' % file) print('--') return gofs, models, data