Esempio n. 1
0
def plot_rt_distrib(
    n_cond_rt_ch: np.ndarray,
    ev_cond_dim: np.ndarray,
    abs_cond=True,
    lump_wrong=True,
    dt=consts.DT,
    colors=None,
    alpha=1.,
    alpha_face=0.5,
    smooth_sigma_sec=0.05,
    to_normalize_max=False,
    to_cumsum=False,
    to_exclude_last_frame=True,
    to_skip_zero_trials=False,
    label='',
    # to_exclude_bins_wo_trials=10,
    kw_plot=(),
    fig=None,
    axs=None,
    to_use_sameaxes=True,
):
    """

    :param n_cond_rt_ch:
    :param ev_cond_dim:
    :param abs_cond:
    :param lump_wrong:
    :param dt:
    :param gs:
    :param colors:
    :param alpha:
    :param smooth_sigma_sec:
    :param kw_plot:
    :param axs:
    :return: axs, p_cond01__rt_ch01, p_cond01__rt_ch01_sm, hs
    """
    if colors is None:
        colors = ['red', 'blue']
    elif type(colors) is str:
        colors = [colors] * 2
    else:
        assert len(colors) == 2

    nt = n_cond_rt_ch.shape[1]
    t_all = np.arange(nt) * dt + dt

    out = np.meshgrid(np.unique(ev_cond_dim[:, 0]),
                      np.unique(ev_cond_dim[:, 1]),
                      np.arange(nt),
                      np.arange(2),
                      np.arange(2),
                      indexing='ij')
    cond0, cond1, fr, ch0, ch1 = [v.flatten() for v in out]

    from copy import deepcopy
    n0 = deepcopy(n_cond_rt_ch)
    if to_exclude_last_frame:
        n0[:, -1, :] = 0.
    n0 = n0.flatten()

    def sign_cond(v):
        v1 = np.sign(v)
        v1[v == 0] = 1
        return v1

    if abs_cond:
        ch0 = consts.ch_bool2sign(ch0)
        ch1 = consts.ch_bool2sign(ch1)

        # 1 = correct, -1 = wrong
        ch0 = sign_cond(cond0) * ch0
        ch1 = sign_cond(cond1) * ch1

        cond0 = np.abs(cond0)
        cond1 = np.abs(cond1)

        ch0 = consts.ch_sign2bool(ch0).astype(np.int)
        ch1 = consts.ch_sign2bool(ch1).astype(np.int)
    else:
        raise ValueError()

    if lump_wrong:
        # treat all choices as correct when cond == 0
        ch00 = ch0 | (cond0 == 0)
        ch10 = ch1 | (cond1 == 0)

        ch0 = (ch00 & ch10)
        ch1 = np.ones_like(ch00, dtype=np.int)

    cond_dim = np.stack([cond0, cond1], -1)

    conds = []
    dcond_dim = []
    for cond in cond_dim.T:
        conds1, dcond1 = np.unique(cond, return_inverse=True)
        conds.append(conds1)
        dcond_dim.append(dcond1)
    dcond_dim = np.stack(dcond_dim)

    n_cond01_rt_ch01 = npg.aggregate(
        [*dcond_dim, fr, ch0, ch1], n0, 'sum',
        [*(np.amax(dcond_dim, 1) + 1), nt, consts.N_CH, consts.N_CH])

    p_cond01__rt_ch01 = np2.nan2v(n_cond01_rt_ch01 / n_cond01_rt_ch01.sum(
        (2, 3, 4), keepdims=True))

    n_conds = p_cond01__rt_ch01.shape[:2]

    if axs is None:
        axs = plt2.GridAxes(
            n_conds[1],
            n_conds[0],
            left=0.6,
            right=0.3,
            bottom=0.45,
            top=0.74,
            widths=[1],
            heights=[1],
            wspace=0.04,
            hspace=0.04,
        )

    kw_label = {
        'fontsize': 12,
    }
    pad = 8
    axs[0, 0].set_title('strong\nmotion', pad=pad, **kw_label)
    axs[0, -1].set_title('weak\nmotion', pad=pad, **kw_label)
    axs[0, 0].set_ylabel('strong\ncolor', labelpad=pad, **kw_label)
    axs[-1, 0].set_ylabel('weak\ncolor', labelpad=pad, **kw_label)

    if smooth_sigma_sec > 0:
        from scipy import signal, stats
        sigma_fr = smooth_sigma_sec / dt
        width = np.ceil(sigma_fr * 2.5).astype(np.int)
        kernel = stats.norm.pdf(np.arange(-width, width + 1), 0, sigma_fr)
        kernel = np2.vec_on(kernel, 2, 5)
        p_cond01__rt_ch01_sm = signal.convolve(p_cond01__rt_ch01,
                                               kernel,
                                               mode='same')
    else:
        p_cond01__rt_ch01_sm = p_cond01__rt_ch01.copy()

    if to_cumsum:
        p_cond01__rt_ch01_sm = np.cumsum(p_cond01__rt_ch01_sm, axis=2)

    if to_normalize_max:
        p_cond01__rt_ch01_sm = np2.nan2v(
            p_cond01__rt_ch01_sm /
            np.amax(np.abs(p_cond01__rt_ch01_sm), (2, 3, 4), keepdims=True))

    n_row = n_conds[1]
    n_col = n_conds[0]
    for dcond0 in range(n_conds[0]):
        for dcond1 in range(n_conds[1]):
            row = n_row - 1 - dcond1
            col = n_col - 1 - dcond0

            ax = axs[row, col]  # type: plt.Axes

            for ch0 in [0, 1]:
                for ch1 in [0, 1]:
                    if lump_wrong and ch1 == 0:
                        continue

                    p1 = p_cond01__rt_ch01_sm[dcond0, dcond1, :, ch0, ch1]

                    kw = {
                        'linewidth': 1,
                        'color': colors[ch1],
                        'alpha': alpha,
                        'zorder': 1,
                        **dict(kw_plot)
                    }

                    y = p1 * consts.ch_bool2sign(ch0)

                    p_cond01__rt_ch01_sm[dcond0, dcond1, :, ch0, ch1] = y

                    if to_skip_zero_trials and np.sum(np.abs(y)) < 1e-2:
                        h = None
                    else:
                        h = ax.plot(
                            t_all,
                            y,
                            label=label if ch0 == 1 and ch1 == 1 else None,
                            **kw)
                        ax.fill_between(t_all,
                                        0,
                                        y,
                                        ec='None',
                                        fc=kw['color'],
                                        alpha=alpha_face,
                                        zorder=-1)
                    plt2.box_off(ax=ax)

            ax.axhline(0, color='k', linewidth=0.5)
            ax.set_yticklabels([])
            if row < n_row - 1 or col > 0:
                ax.set_xticklabels([])
                ax.set_xticks([])
                plt2.box_off(['bottom'], ax=ax)
            else:
                ax.set_xlabel('RT (s)')
            # if col > 0:
            ax.set_yticks([])
            ax.set_yticklabels([])
            plt2.box_off(['left'], ax=ax)

            plt2.detach_axis('x', 0, 5, ax=ax)
    if to_use_sameaxes:
        plt2.sameaxes(axs)
    axs[-1, 0].set_xlabel('RT (s)')

    return axs, p_cond01__rt_ch01, p_cond01__rt_ch01_sm, h
Esempio n. 2
0
def plot_fit_combined(
        data: Union[sim2d.Data2DRT, dict] = None,
        pModel_cond_rt_chFlat=None, model=None,
        pModel_dimRel_condDense_chFlat=None,
        # --- in place of data:
        pAll_cond_rt_chFlat=None,
        evAll_cond_dim=None,
        pTrain_cond_rt_chFlat=None,
        evTrain_cond_dim=None,
        pTest_cond_rt_chFlat=None,
        evTest_cond_dim=None,
        dt=None,
        # --- optional
        ev_dimRel_condDense_fr_dim_meanvar=None,
        dt_model=None,
        to_plot_internals=True,
        to_plot_params=True,
        to_plot_choice=True,
        # to_group_irr=False,
        group_dcond_irr=None,
        to_combine_ch_irr_cond=True,
        kw_plot_pred=(),
        kw_plot_pred_ch=(),
        kw_plot_data=(),
        axs=None,
):
    """

    :param data:
    :param pModel_cond_rt_chFlat:
    :param model:
    :param pModel_dimRel_condDense_chFlat:
    :param ev_dimRel_condDense_fr_dim_meanvar:
    :param to_plot_internals:
    :param to_plot_params:
    :param to_group_irr:
    :param to_combine_ch_irr_cond:
    :param kw_plot_pred:
    :param kw_plot_data:
    :param axs:
    :return:
    """
    if data is None:
        if pTrain_cond_rt_chFlat is None:
            pTrain_cond_rt_chFlat = pAll_cond_rt_chFlat
        if evTrain_cond_dim is None:
            evTrain_cond_dim = evAll_cond_dim
        if pTest_cond_rt_chFlat is None:
            pTest_cond_rt_chFlat = pAll_cond_rt_chFlat
        if evTest_cond_dim is None:
            evTest_cond_dim = evAll_cond_dim
    else:
        _, pAll_cond_rt_chFlat, _, _, evAll_cond_dim = \
            data.get_data_by_cond('all')
        _, pTrain_cond_rt_chFlat, _, _, evTrain_cond_dim = data.get_data_by_cond(
            'train_valid', mode_train='easiest')
        _, pTest_cond_rt_chFlat, _, _, evTest_cond_dim = data.get_data_by_cond(
            'test', mode_train='easiest')
        dt = data.dt
    hs = {}

    if model is None:
        assert not to_plot_internals
        assert not to_plot_params

    if dt_model is None:
        if model is None:
            dt_model = dt
        else:
            dt_model = model.dt

    if axs is None:
        if to_plot_params:
            axs = plt2.GridAxes(3, 3)
        else:
            if to_plot_internals:
                axs = plt2.GridAxes(3, 3)
            else:
                if to_plot_choice:
                    axs = plt2.GridAxes(2, 2)
                else:
                    axs = plt2.GridAxes(1, 2)  # TODO: beautify ratios

    rts = []
    hs['rt'] = []
    for dim_rel in range(consts.N_DIM):
        # --- data_pred may not have all conditions, so concatenate the rest
        #  of the conditions so that the color scale is correct. Then also
        #  concatenate p_rt_ch_data_pred1 with zeros so that nothing is
        #  plotted in the concatenated.
        evTest_cond_dim1 = np.concatenate([
            evTest_cond_dim, evAll_cond_dim
        ], axis=0)
        pTest_cond_rt_chFlat1 = np.concatenate([
            pTest_cond_rt_chFlat, np.zeros_like(pAll_cond_rt_chFlat)
        ], axis=0)

        if ev_dimRel_condDense_fr_dim_meanvar is None:
            evModel_cond_dim = evAll_cond_dim
        else:
            if ev_dimRel_condDense_fr_dim_meanvar.ndim == 5:
                evModel_cond_dim = npy(ev_dimRel_condDense_fr_dim_meanvar[
                                           dim_rel][:, 0, :, 0])
            else:
                assert ev_dimRel_condDense_fr_dim_meanvar.ndim == 4
                evModel_cond_dim = npy(ev_dimRel_condDense_fr_dim_meanvar[
                                           dim_rel][:, 0, :])
            pModel_cond_rt_chFlat = npy(pModel_dimRel_condDense_chFlat[dim_rel])

        if to_plot_choice:
            # --- Plot choice
            ax = axs[0, dim_rel]
            plt.sca(ax)

            if to_combine_ch_irr_cond:
                ev_cond_model1, p_rt_ch_model1 = combine_irr_cond(
                    dim_rel, evModel_cond_dim, pModel_cond_rt_chFlat
                )

                sim2d.plot_p_ch_vs_ev(ev_cond_model1, p_rt_ch_model1,
                                      dim_rel=dim_rel, style='pred',
                                      group_dcond_irr=None,
                                      kw_plot=kw_plot_pred_ch,
                                      cmap=lambda n: lambda v: [0., 0., 0.],
                                      )
            else:
                sim2d.plot_p_ch_vs_ev(evModel_cond_dim, pModel_cond_rt_chFlat,
                                      dim_rel=dim_rel, style='pred',
                                      group_dcond_irr=group_dcond_irr,
                                      kw_plot=kw_plot_pred,
                                      cmap=cmaps[dim_rel]
                                      )
            hs, conds_irr = sim2d.plot_p_ch_vs_ev(
                evTest_cond_dim1, pTest_cond_rt_chFlat1,
                dim_rel=dim_rel, style='data_pred',
                group_dcond_irr=group_dcond_irr,
                cmap=cmaps[dim_rel],
                kw_plot=kw_plot_data,
            )
            hs1 = [h[0] for h in hs]
            odim = 1 - dim_rel
            odim_name = consts.DIM_NAMES_LONG[odim]
            legend_odim(conds_irr, hs1, odim_name)
            sim2d.plot_p_ch_vs_ev(evTrain_cond_dim, pTrain_cond_rt_chFlat,
                                  dim_rel=dim_rel, style='data_fit',
                                  group_dcond_irr=group_dcond_irr,
                                  cmap=cmaps[dim_rel],
                                  kw_plot=kw_plot_data
                                  )
            plt2.detach_axis('x', np.amin(evTrain_cond_dim[:, dim_rel]),
                             np.amax(evTrain_cond_dim[:, dim_rel]))
            ax.set_xlabel('')
            ax.set_xticklabels([])
            if dim_rel != 0:
                plt2.box_off(['left'])
                plt.yticks([])

            ax.set_ylabel('P(%s choice)' % consts.CH_NAMES[dim_rel][1])

        # --- Plot RT
        ax = axs[int(to_plot_choice) + 0, dim_rel]
        plt.sca(ax)
        hs1, rts1 = sim2d.plot_rt_vs_ev(
            evModel_cond_dim,
            pModel_cond_rt_chFlat,
            dim_rel=dim_rel, style='pred',
            group_dcond_irr=group_dcond_irr,
            dt=dt_model,
            kw_plot=kw_plot_pred,
            cmap=cmaps[dim_rel]
        )
        hs['rt'].append(hs1)
        rts.append(rts1)

        sim2d.plot_rt_vs_ev(evTest_cond_dim1, pTest_cond_rt_chFlat1,
                            dim_rel=dim_rel, style='data_pred',
                            group_dcond_irr=group_dcond_irr,
                            dt=dt,
                            cmap=cmaps[dim_rel],
                            kw_plot=kw_plot_data
                            )
        sim2d.plot_rt_vs_ev(evTrain_cond_dim, pTrain_cond_rt_chFlat,
                            dim_rel=dim_rel, style='data_fit',
                            group_dcond_irr=group_dcond_irr,
                            dt=dt,
                            cmap=cmaps[dim_rel],
                            kw_plot=kw_plot_data
                            )
        plt2.detach_axis('x', np.amin(evTrain_cond_dim[:, dim_rel]),
                         np.amax(evTrain_cond_dim[:, dim_rel]))
        if dim_rel != 0:
            ax.set_ylabel('')
            plt2.box_off(['left'])
            plt.yticks([])

        ax.set_xlabel(consts.DIM_NAMES_LONG[dim_rel].lower() + ' strength')

        if dim_rel == 0:
            ax.set_ylabel('RT (s)')

        if to_plot_internals:
            for ch1 in range(consts.N_CH):
                ch0 = dim_rel
                ax = axs[3 + ch1, dim_rel]
                plt.sca(ax)

                ch_flat = consts.ch_by_dim2ch_flat(np.array([ch0, ch1]))
                model.tnds[ch_flat].plot_p_tnd()
                ax.set_xlabel('')
                ax.set_xticklabels([])
                ax.set_yticks([0, 1])
                if ch0 > 0:
                    ax.set_yticklabels([])

                ax.set_ylabel(r"$\mathrm{P}(T^\mathrm{n} \mid"
                              " \mathbf{z}=[%d,%d])$"
                              % (ch0, ch1))

            ax = axs[5, dim_rel]
            plt.sca(ax)
            if hasattr(model.dtb, 'dtb1ds'):
                model.dtb.dtb1ds[dim_rel].plot_bound(color='k')

    plt2.sameaxes(axs[-1, :consts.N_DIM], xy='y')

    if to_plot_params:
        ax = axs[0, -1]
        plt.sca(ax)
        model.plot_params()

    return axs, rts, hs
def plot_rt(
        parad,
        to_cumsum=False,
        corners_only=False,
        to_add_legend=True,
        use_easiest_only=0,
        plot_kind='rt_mean',
        model_names=('Ser', 'Par'),
        to_normalize_max=False,
        axs=None,
):
    """

    :param parad:
    :param to_cumsum:
    :param corners_only:
    :param to_add_legend:
    :param use_easiest_only:
    :param plot_kind: 'mean'|'distrib'
    :return:
    """
    subjs = consts.SUBJS[parad]
    n_models = len(model_names)

    p_preds = np.empty([n_models, len(subjs)], dtype=np.object)
    p_datas = np.empty([n_models, len(subjs)], dtype=np.object)
    condss = np.empty([n_models, len(subjs)], dtype=np.object)
    ts = np.empty([n_models, len(subjs)], dtype=np.object)

    p_pred_trains = np.empty([n_models, len(subjs)], dtype=np.object)
    p_data_trains = np.empty([n_models, len(subjs)], dtype=np.object)
    p_pred_tests = np.empty([n_models, len(subjs)], dtype=np.object)
    p_data_tests = np.empty([n_models, len(subjs)], dtype=np.object)

    colors = ('red', 'blue')
    normalize_ev = parad == 'RT'

    for i_model, model_name in enumerate(model_names):
        for i_subj, subj in enumerate(subjs):
            (
                p_pred1,
                p_data1,
                conds1,
                t1,
                cond_ch_incl1,
                p_pred_train,
                p_data_train,
                p_pred_test,
                p_data_test,
            ) = load_fit(subj,
                         parad,
                         model_name,
                         corners_only=corners_only,
                         use_easiest_only=use_easiest_only)
            ix_abs_cond = np.unique(np.abs(conds1[1]), return_inverse=True)[1]
            sign_cond = np.sign(conds1[1])

            if normalize_ev:
                conds1 = conds1 / np.amax(conds1, axis=1, keepdims=True)

            p_preds[i_model, i_subj], p_datas[i_model, i_subj], \
                condss[i_model, i_subj], ts[i_model, i_subj] \
                = p_pred1, p_data1, conds1, t1

            (
                p_pred_trains[i_model, i_subj],
                p_data_trains[i_model, i_subj],
                p_pred_tests[i_model, i_subj],
                p_data_tests[i_model, i_subj],
            ) = p_pred_train, p_data_train, p_pred_test, p_data_test

    siz0 = list(p_preds[0, 0].shape)

    def cell2array(p):
        return np.stack(p.flatten()).reshape(
            [len(model_names), len(subjs)] + siz0)

    p_preds = cell2array(p_preds)
    p_datas = cell2array(p_datas)

    p_pred_trains = cell2array(p_pred_trains)
    p_data_trains = cell2array(p_data_trains)

    p_pred_tests = cell2array(p_pred_tests)
    p_data_tests = cell2array(p_data_tests)

    def pool_subjs(p_datas, p_preds):
        n_data_subj = np.sum(p_datas, (2, 3, 4), keepdims=True)

        # P(ch, rt | cond, subj, model)
        p_preds = np2.nan2v(p_preds / p_preds.sum((3, 4), keepdims=True))
        p_pred_avg = np.sum(p_preds * n_data_subj, 1)
        n_data_sum = np.sum(p_datas, 1)
        return p_pred_avg, n_data_sum

    p_pred_avg, n_data_sum = pool_subjs(p_datas, p_preds)
    p_pred_avg_train, n_data_sum_train = pool_subjs(p_data_trains,
                                                    p_pred_trains)
    p_pred_avg_test, n_data_sum_test = pool_subjs(p_data_tests, p_pred_tests)

    ev_cond_dim = np.stack(condss[0, 0], -1)
    dt = 1 / 75

    if plot_kind == 'rt_mean':
        if axs is None:
            axs = plt2.GridAxes(1,
                                2,
                                top=0.4,
                                left=0.6,
                                right=0.1,
                                bottom=0.65,
                                wspace=0.35,
                                heights=1.7,
                                widths=2.2)
        for i_model, model_name in enumerate(model_names):
            hs = fit2d.plot_fit_combined(
                pAll_cond_rt_chFlat=n_data_sum[i_model],
                evAll_cond_dim=ev_cond_dim,
                pTrain_cond_rt_chFlat=n_data_sum_train[i_model],
                pTest_cond_rt_chFlat=n_data_sum_test[i_model],
                pModel_cond_rt_chFlat=p_pred_avg[i_model],
                dt=dt,
                to_plot_params=False,
                to_plot_internals=False,
                to_plot_choice=False,
                group_dcond_irr=None,
                kw_plot_pred={
                    # 'linestyle': ':',
                    # 'alpha': 0.7,
                    # 'linewidth': 2,
                    # 'linestyle': '--' if model_name == 'Par' else '-',
                },
                kw_plot_data={
                    'markersize': 4,
                },
                axs=axs,
            )[2]

            if to_add_legend:
                n_dim = 2
                for dim in range(n_dim):
                    plt.sca(axs[0, dim])
                    odim = consts.get_odim(dim)
                    conds_irr = np.unique(np.abs(ev_cond_dim[:, odim]))
                    hs1 = [v[0][0] for v in hs['rt'][dim]]
                    hs1 = hs1[len(conds_irr):(len(conds_irr) * 2)]

                    h = fit2d.legend_odim([np.round(v, 3) for v in conds_irr],
                                          hs1,
                                          '',
                                          loc='lower left',
                                          bbox_to_anchor=[1., 0.])
                    h.set_title(consts.DIM_NAMES_LONG[odim] + '\nstrength')
                    plt.setp(h.get_title(), multialignment='center')

        ev_max = np.amax(condss[0, 0], -1)
        for col in [0, 1]:
            plt.sca(axs[-1, col])
            txt = '%s strength' % consts.DIM_NAMES_LONG[col].lower()
            if normalize_ev:
                txt += '\n(a.u.)'
            plt.xlabel(txt)
            xticks = [-ev_max[col], 0, ev_max[col]]
            plt.xticks(xticks, ['%g' % v for v in xticks])

        from matplotlib.ticker import MultipleLocator
        axs[0, 0].yaxis.set_major_locator(MultipleLocator(1))
        print('--')

    elif plot_kind == 'rt_distrib':
        axs, hs = fit2d.plot_rt_distrib_pred_data(
            p_pred_avg,
            n_data_sum[0],
            ev_cond_dim,
            dt,
            to_normalize_max=to_normalize_max,
            to_cumsum=to_cumsum,
            to_skip_zero_trials=True,
            xlim=[0., 5.],
            kw_plot_data={
                'linewidth': 2,
                'linestyle': ':',
                'alpha': 0.75,
            },
            labels=['serial', 'parallel', 'data'],
            colors=colors,
        )
        if to_add_legend:
            plt.sca(axs[0, 0])
            locs = {
                'loc': 'center right',
                'bbox_to_anchor': (1.05, 0., 0., 1.)
            } if to_cumsum else {
                'loc': 'upper right',
                'bbox_to_anchor': (1.05, 1.01)
            }
            plt.legend(**locs,
                       handlelength=0.8,
                       handletextpad=0.5,
                       frameon=False,
                       borderpad=0.)
        print('--')

    else:
        raise ValueError()

    return axs
def main_plot(to_use_easiest_only=0):
    # === Mean RT
    plot_kind = 'rt_mean_row_per_model'
    axs = plt2.GridAxes(4,
                        4,
                        top=0.6,
                        left=0.7,
                        right=0.7,
                        bottom=0.6,
                        wspace=[0.8, 1.6, 0.8],
                        hspace=[0.1, 1.2, 0.1],
                        heights=1.5,
                        widths=1.7)
    inds = ['A', 'B', 'C']
    parads = ['RT', 'unimanual', 'bimanual']
    rowcol_toplefts = [(0, 0), (0, 2), (2, 0)]
    fit_kinds = ['Fit all', 'Fit strongest,\npredict rest']
    models = ['Ser', 'Par']
    model_names = ['serial', 'parallel']
    n_dim = 2

    for rowcol_topleft, parad, ind in zip(rowcol_toplefts, parads, inds):
        for row_shift, model in enumerate(models):
            row = rowcol_topleft[0] + row_shift
            col = rowcol_topleft[1]
            axs1 = axs[row:(row + 1), col:(col + n_dim)]
            plot_rt(parad,
                    to_cumsum=False,
                    corners_only=False,
                    plot_kind='rt_mean',
                    model_names=[model],
                    use_easiest_only=to_use_easiest_only,
                    axs=axs1,
                    to_add_legend=row_shift == 0)

            if row_shift == 0:
                for ax in axs1.flatten():
                    ax.set_xticklabels([])
                    ax.set_xlabel('')
                axs1.suptitle(get_title_parad(parad),
                              pad=0.05,
                              va='bottom',
                              xprop=0.55)
                axs1.suptitle(ind,
                              pad=0.05,
                              xprop=-0.05,
                              va='bottom',
                              fontweight='bold')

            axs1[0, 0].set_ylabel('Response time (s)\n(%s model)' %
                                  model_names[row_shift])

    for ax in axs[2:, 2:].flatten():
        ax.set_visible(False)

    for ext in ['.png', '.pdf']:
        file = locfile.get_file_fig(plot_kind, {
            'parad': parads,
            'easiest_only': to_use_easiest_only,
        },
                                    ext=ext)
        localfile.mkdir4file(file)
        plt.savefig(file)
        print('Saved to %s' % file)

    # === RT distrib
    plot_kind = 'rt_distrib'
    to_normalize_max = False
    for to_cumsum in [True]:  # [False, True]:
        for corners_only in [True]:  # , False]:
            for parad in ['RT', 'unimanual', 'bimanual']:
                axs = plot_rt(
                    parad,
                    to_cumsum=to_cumsum,
                    corners_only=corners_only,
                    plot_kind=plot_kind,
                    use_easiest_only=to_use_easiest_only,
                    to_add_legend=parad == 'RT',
                )
                axs.suptitle(get_title_parad(parad), pad=0.25)
                for ext in ['.png', '.pdf']:
                    file = locfile.get_file_fig(plot_kind, {
                        'parad': parad,
                        'cumsum': to_cumsum,
                        'corner': corners_only,
                        'easiest_only': to_use_easiest_only,
                        'nrmmax': to_normalize_max,
                    },
                                                ext=ext)
                    localfile.mkdir4file(file)
                    plt.savefig(file)
                    print('Saved to %s' % file)

    print('--')
Esempio n. 5
0
def main_fit(
        subjs=None,
        skip_fit_if_absent=False,
        bufdur=None,
        bufdur_best=0.12,
        bufdurs_sim=None,
        bufdurs_fit=None,
        loss_kind='NLL',
        plot_kind='line_sim_fit',
        fix_post=None,
        seed_sims=(0, ),
        err_kind='std',
        base=10,
):
    # if torch.cuda.is_available():
    #     torch.set_default_tensor_type(torch.cuda.FloatTensor)
    torch.set_num_threads(1)
    torch.set_default_dtype(torch.double)

    dict_res_add = {}

    parad = 'VD'
    seed_sims = np.array(seed_sims)

    if subjs is None:
        subjs = consts.SUBJS_VD

    if bufdur is not None:
        bufdurs_sim = list({bufdur_best}.union({bufdur}))
        bufdurs_fit = list({bufdur_best}.union({bufdur}))
    else:
        if bufdurs_sim is None:
            bufdurs_sim = bufdurs0  #
            if bufdurs_fit is None:
                bufdurs_fit = bufdurs0
        else:
            if bufdurs_fit is None:
                bufdurs_fit0 = [0., bufdur_best, 1.2]
                if bufdurs_sim[0] in bufdurs_fit0:
                    bufdurs_fit = copy(bufdurs_fit0)
                else:
                    bufdurs_fit = bufdurs_fit0[:2] + bufdurs_sim + [1.2]

    n_dur_sim = len(bufdurs_sim)
    n_dur_fit = len(bufdurs_fit)

    # DEF: losses[seed, subj, simSerDurPar, fitDurs]
    size_all = [len(seed_sims), len(subjs), n_dur_sim, n_dur_fit]
    losses0 = np.zeros(size_all) + np.nan
    ns = np.zeros(size_all) + np.nan
    ks0 = np.zeros(size_all) + np.nan

    for i_seed, seed_sim in enumerate(seed_sims):
        for i_subj, subj in enumerate(subjs):
            for i_fit, bufdur_fit in enumerate(bufdurs_fit):
                for i_sim, bufdur_sim in enumerate(bufdurs_sim):
                    d, dict_fit_sim, dict_subdir_sim = get_fit_sim(
                        subj,
                        seed_sim,
                        bufdur_sim,
                        bufdur_fit,
                        parad=parad,
                        skip_fit_if_absent=skip_fit_if_absent,
                        fix_post=fix_post,
                    )
                    if d is not None:
                        losses0[i_seed, i_subj, i_sim,
                                i_fit] = d['loss_NLL_test']
                        ns[i_seed, i_subj, i_sim, i_fit] = d['loss_ndata_test']
                        ks0[i_seed, i_subj, i_sim, i_fit] = d['loss_nparam']
    ks = copy(ks0)

    if loss_kind == 'BIC':
        losses = ks * np.log(ns) + 2 * losses0

        # REF: Kass, Raftery 1995 https://doi.org/10.2307%2F2291091
        #   https://en.wikipedia.org/wiki/Bayesian_information_criterion#Gaussian_special_case
        thres_strong = 10. / np.log(base)
    elif loss_kind == 'NLL':
        losses = copy(losses0)
        thres_strong = np.log(100) / np.log(base)
    else:
        raise ValueError('Unsupported loss_kind: %s' % loss_kind)

    #%%
    losses = losses / np.log(base)

    if plot_kind == 'loss_serpar':
        n_row = len(subjs)
        n_col = 1
        axs = plt2.GridAxes(n_row, n_col, left=1, widths=2, bottom=0.75)
        for i_subj, subj in enumerate(subjs):
            plt.sca(axs[i_subj, 0])

            loss_dur_sim_ser_fit = losses[:, i_subj, :, 0].mean(0)
            loss_dur_sim_par_fit = losses[:, i_subj, :, -1].mean(0)

            dloss_serpar = loss_dur_sim_par_fit - loss_dur_sim_ser_fit

            plt.bar(bufdurs_sim, dloss_serpar, color='k')
            plt.axhline(0, color='k', linestyle='-', linewidth=0.5)

            if i_subj < len(subjs) - 1:
                plt2.box_off(['top', 'right', 'bottom'])
                plt.xticks([])
            else:
                plt2.box_off()
                plt.xticks(np.arange(0, 1.4, 0.2), ['0\nser'] + [''] * 2 +
                           ['0.6'] + [''] * 2 + ['1.2\npar'])
                plt2.detach_axis('x', 0, 1.2)
            vdfit.patch_chance_level()

        plt.sca(axs[0, 0])
        plt.title('Support for Serial\n'
                  '($\mathcal{L}_\mathrm{ser} - \mathcal{L}_\mathrm{par}$)')

        plt.sca(axs[-1, 0])
        plt.xlabel('true buffer duration (s)')

        plt2.rowtitle(consts.SUBJS_VD, axs)
    elif plot_kind == 'imshow_ser_buf_par':
        n_row = len(subjs)
        n_col = 1
        axs = plt2.GridAxes(n_row,
                            n_col,
                            left=1,
                            widths=2,
                            heights=2,
                            bottom=0.75)
        for i_subj, subj in enumerate(subjs):
            plt.sca(axs[i_subj, 0])

            dloss = losses[:, i_subj, :, :].mean(0)
            dloss = dloss - np.diag(dloss)[:, None]

            plt.imshow(dloss, cmap='bwr')

            cmax = np.amax(np.abs(dloss))
            plt.clim(-cmax, +cmax)

        plt.sca(axs[0, 0])

        plt.sca(axs[-1, 0])
        plt.xlabel('true buffer duration (s)')
        plt.ylabel('model buffer duration (s)')

        plt2.rowtitle(consts.SUBJS_VD, axs)

    elif plot_kind == 'line_sim_fit':
        bufdur_bests = np.array([get_bufdur_best(subj) for subj in subjs])
        i_bests = np.array([
            list(bufdurs_fit).index(bufdur_best)
            for bufdur_best in bufdur_bests
        ])
        i_subjs = np.arange(len(subjs))
        loss_sim_best_fit_best = losses[:, i_subjs, i_bests, i_bests]
        loss_sim_best_fit_rest = (losses[:, i_subjs, i_bests, :] -
                                  loss_sim_best_fit_best[:, :, None])
        mean_loss_sim_best_fit_rest = np.mean(loss_sim_best_fit_rest, 0)
        if err_kind == 'std':
            err_loss_sim_best_fit_rest = np.std(loss_sim_best_fit_rest, 0)
        elif err_kind == 'sem':
            err_loss_sim_best_fit_rest = np2.sem(loss_sim_best_fit_rest, 0)

        n_dur = len(bufdurs_fit)
        loss_sim_rest_fit_best = np.swapaxes(
            np.stack([
                losses[:, i_subj, np.arange(n_dur), i_best] -
                losses[:, i_subj,
                       np.arange(n_dur),
                       np.arange(n_dur)]
                for i_subj, i_best in zip(i_subjs, i_bests)
            ]), 0, 1)
        mean_loss_sim_rest_fit_best = np.mean(loss_sim_rest_fit_best, 0)
        if err_kind == 'std':
            err_loss_sim_rest_fit_best = np.std(loss_sim_rest_fit_best, 0)
        elif err_kind == 'sem':
            err_loss_sim_rest_fit_best = np2.sem(loss_sim_rest_fit_best, 0)

        dict_res_add['err'] = err_kind

        n_row = 2
        n_col = len(subjs)

        axs = plt2.GridAxes(n_row,
                            n_col,
                            left=1.15,
                            right=0.1,
                            widths=2,
                            heights=1.5,
                            wspace=0.35,
                            hspace=0.5,
                            top=0.25,
                            bottom=0.6)

        for row, (m, s, xlabel, ylabel) in enumerate([
            (mean_loss_sim_best_fit_rest, err_loss_sim_best_fit_rest,
             'model buffer capacity (s)',
             (r'$-\mathrm{log}_{%g}\mathrm{BF}$ given simulated' +
              '\nbest duration data') % base),
            (mean_loss_sim_rest_fit_best, err_loss_sim_rest_fit_best,
             'simulated data buffer capacity (s)',
             (r'$-\mathrm{log}_{%g}\mathrm{BF}$ of the' +
              '\nbest duration model') % base)
        ]):
            for i_subj, subj in enumerate(subjs):
                ax = axs[row, i_subj]
                plt.sca(ax)
                gs1 = axs.gs[row * 2 + 1, i_subj * 2 + 1]

                bax = vdfit.breakaxis(gs1)
                bax.axs[1].errorbar(bufdurs0,
                                    m[i_subj, :],
                                    yerr=s[i_subj, :],
                                    color='k',
                                    marker='.',
                                    linewidth=0.75,
                                    elinewidth=0.5,
                                    markersize=3)
                m1 = copy(m[i_subj, :])
                m1[3:] = np.nan
                s1 = copy(s[i_subj, :])
                s1[3:] = np.nan
                bax.axs[0].errorbar(bufdurs0,
                                    m1,
                                    yerr=s1,
                                    color='k',
                                    marker='.',
                                    linewidth=0.75,
                                    elinewidth=0.5,
                                    markersize=3)

                ax1 = bax.axs[1]  # type: plt.Axes
                plt.sca(ax1)
                vdfit.patch_chance_level(level=thres_strong, signs=[-1, 1])
                plt.axhline(0, color='k', linestyle='--', linewidth=0.5)
                vdfit.beautify_ticks(ax1, )
                vdfit.beautify(ax1)
                ax1.set_yticks([0, 20])
                if i_subj > 0:
                    ax1.set_yticklabels([])
                if row == 0:
                    ax1.set_xticklabels([])

                ax1 = bax.axs[0]  # type: plt.Axes
                plt.sca(ax1)
                ax1.set_yticks([40, 200])
                if i_subj > 0:
                    ax1.set_yticklabels([])

                plt.sca(ax)
                plt2.box_off('all')
                if row == 0:
                    plt.title(consts.SUBJS['VD'][i_subj])

                if i_subj == 0:
                    plt.xlabel(xlabel, labelpad=8 if row == 0 else 20)
                    plt.ylabel(ylabel, labelpad=30)

                plt2.sameaxes(bax.axs, xy='x')

    elif plot_kind == 'bar_sim_fit':
        i_best = bufdurs_fit.index(bufdur_best)
        loss_sim_best_fit_best = losses[0, :, i_best, i_best]
        loss_sim_best_fit_rest = np2.nan2v(losses[0, :, i_best, :] -
                                           loss_sim_best_fit_best[:, None])
        n_dur = len(bufdurs_fit)
        loss_sim_rest_fit_best = np2.nan2v(
            losses[0, :, np.arange(n_dur), i_best] -
            losses[0, :, np.arange(n_dur),
                   np.arange(n_dur)]).T

        n_row = 2
        n_col = len(subjs)

        axs = plt2.GridAxes(n_row,
                            n_col,
                            left=1,
                            widths=3,
                            right=0.25,
                            heights=1,
                            wspace=0.3,
                            hspace=0.6,
                            top=0.25,
                            bottom=0.6)
        for i_subj, subj in enumerate(subjs):

            def plot_bars(gs1, bufdurs, losses1, add_xticklabel=True):
                i_break = np.amax(np.nonzero(np.array(bufdurs) < 0.3)[0])
                bax = brokenaxes(
                    subplot_spec=gs1,
                    xlims=((-1, i_break + 0.5), (i_break + 0.5,
                                                 len(bufdurs) - 0.5)),
                    ylims=((-3, 20), (20, 1250 / 5)),
                    height_ratios=(50 / 100, 500 / (1250 - 100)),
                    hspace=0.15,
                    wspace=0.075,
                    d=0.005,
                )
                bax.bar(np.arange(len(bufdurs)), losses1[i_subj, :], color='k')

                ax11 = bax.axs[3]  # type: plt.Axes
                ax11.set_xticks([bufdurs.index(0.6), bufdurs.index(1.2)])
                if i_subj == 0 and add_xticklabel:
                    ax11.set_xticklabels(['0.6', '1.2'])
                else:
                    ax11.set_xticklabels([])

                ax00 = bax.axs[0]  # type: plt.Axes
                ax00.set_yticks([500, 1000])

                ax10 = bax.axs[2]  # type: plt.Axes
                ax10.set_yticks([0, 50])
                plt.sca(ax10)
                plt2.detach_axis('x', amin=-0.4, amax=i_break + 0.5)
                for ax in [ax10, ax11]:
                    plt.sca(ax)
                    plt.axhline(0, linewidth=0.5, color='k', linestyle='--')
                    for sign in [-1, 1]:
                        plt.axhline(sign * thres_strong,
                                    linewidth=0.5,
                                    color='silver',
                                    linestyle='--')
                ax10.set_xticks([bufdurs.index(0.), bufdurs.index(0.2)])
                if i_subj == 0:
                    if add_xticklabel:
                        ax10.set_xticklabels(['0', '0.2'])
                    else:
                        ax10.set_xticklabels([])
                else:
                    ax10.set_yticklabels([])
                    ax10.set_xticklabels([])
                    ax00.set_yticklabels([])
                return bax

            bax = plot_bars(axs.gs[1, i_subj * 2 + 1],
                            bufdurs_fit,
                            loss_sim_best_fit_rest,
                            add_xticklabel=False)

            ax = axs[0, i_subj]
            plt.sca(ax)
            plt2.box_off('all')
            plt.title(consts.SUBJS['VD'][consts.SUBJS['VD'].index(subj)])
            if i_subj == 0:
                ax.set_ylabel(
                    'misfit to simulated\nbest duration data\n'
                    r'($\Delta$BIC)',
                    labelpad=35)
                ax.set_xlabel('model buffer duration (s)', labelpad=8)

            plot_bars(axs.gs[3, i_subj * 2 + 1],
                      bufdurs_sim,
                      loss_sim_rest_fit_best,
                      add_xticklabel=True)
            ax = axs[1, i_subj]
            plt.sca(ax)
            plt2.box_off('all')
            if i_subj == 0:
                ax.set_ylabel(
                    'misfit of\nbest duration model\n'
                    r'($\Delta$BIC)',
                    labelpad=35)
                ax.set_xlabel('simulated data buffer duration (s)',
                              labelpad=20)

    elif plot_kind == 'bar_ser_buf_par':
        n_row = len(subjs)
        n_col = 1
        axs = plt2.GridAxes(n_row * n_dur_sim,
                            n_col,
                            left=1.25,
                            widths=1.5,
                            right=0.25,
                            heights=0.4,
                            hspace=[0.15] * (n_dur_sim - 1) + [0.5] + [0.15] *
                            (n_dur_sim - 1),
                            top=0.6,
                            bottom=0.5)
        row = -1
        for i_subj, subj in enumerate(subjs):
            for i_sim in range(n_dur_sim):
                row += 1
                plt.sca(axs[row, 0])

                loss1 = losses[:, i_subj, i_sim, :].mean(0)
                dloss = loss1 - loss1[i_sim]

                x = np.arange(n_dur_fit)
                for x1, dloss1 in zip(x, dloss):
                    plt.bar(x1, dloss1, color='r' if dloss1 > 0 else 'b')
                plt.axhline(0, color='k', linewidth=0.5)
                vdfit.patch_chance_level(6.)
                plt.ylim([-100, 100])
                plt.yticks([-100, 0, 100], [''] * 2)
                plt.ylabel('%g' % bufdurs_sim[i_sim],
                           rotation=0,
                           va='center',
                           ha='right')

                if i_sim == 0:
                    plt.title(subj)

                if i_sim < n_dur_sim - 1 or i_subj < len(subjs) - 1:
                    plt2.box_off(['top', 'bottom', 'right'])
                    plt.xticks([])
                else:
                    plt2.box_off(['top', 'right'])

        plt.sca(axs[-1, 0])
        plt.xticks(np.arange(n_dur_sim), ['%g' % v for v in bufdurs_sim])
        plt2.detach_axis('x', 0, n_dur_sim - 1)
        plt.xlabel('model buffer duration (s)', fontsize=10)

        c = axs[-1, 0].get_position().corners()
        plt.figtext(x=0.15,
                    y=np.mean(c[:2, 1]),
                    fontsize=10,
                    s='true\nbuffer\nduration\n(s)',
                    rotation=0,
                    va='center',
                    ha='center')

        plt.figtext(
            x=(1.25 + 1.5 / 2) / (1.25 + 1.5 + 0.25),
            y=0.98,
            s=r'$\mathrm{BIC} - \mathrm{BIC}_\mathrm{true}$',
            ha='center',
            va='top',
            fontsize=12,
        )

    if plot_kind != 'None':
        dict_res = deepcopy(dict_fit_sim)  # noqa
        for k in ['sbj', 'prd']:
            dict_res.pop(k)
        dict_res = {**dict_res, 'los': loss_kind}
        dict_res.update(dict_res_add)
        for ext in ['.png', '.pdf']:
            file = locfile.get_file_fig(plot_kind,
                                        dict_res,
                                        subdir=dict_subdir_sim,
                                        ext=ext)  # noqa
            plt.savefig(file, dpi=300)
            print('Saved to %s' % file)

    #%%
    print('--')
    return losses
Esempio n. 6
0
def main(base=10.):
    ds = load_comp(os.path.join(locfile_in.pth_root, file_model_comp))
    ds['dcost'] = np.array(ds['dcost']) / np.log(base)

    vmax = 160

    m = np.empty(3)
    e = np.empty(3)

    axs = plot_bar_dloss_across_subjs(dlosses=np.array(ds['dcost']),
                                      subj_parad_bis=ds['subj_parad_bi'],
                                      vmax=vmax)
    axs = plt2.GridAxes(nrows=1,
                        ncols=3,
                        heights=axs.heights,
                        widths=[2],
                        left=axs.left,
                        right=axs.right,
                        bottom=axs.bottom)
    plot_bar_dloss_across_subjs(
        dlosses=np.array(ds['dcost']),
        subj_parad_bis=ds['subj_parad_bi'],
        vmax=vmax,
        axs=axs[:, [0]],
        base=base,
    )
    plt.title('Data')

    m[0] = np.mean(ds['dcost'])
    e[0] = np2.sem(ds['dcost'])

    print('--')

    titles = ['Simulated\nSerial', 'Simulated\nParallel']
    for i in range(2):
        ds = load_comp(
            os.path.join(locfile_in.pth_root, files_model_recovery[i]))
        ds['dcost'] = np.array(ds['dcost']) / np.log(base)

        plot_bar_dloss_across_subjs(
            dlosses=ds['dcost'],
            subj_parad_bis=ds['subj_parad_bi'],
            vmax=vmax,
            axs=axs[:, [i + 1]],
            base=base,
        )
        plt.title(titles[i])

        m[i + 1] = np.mean(ds['dcost'])
        e[i + 1] = np2.sem(ds['dcost'])

        plt.sca(axs[0, i + 1])
        plt2.box_off(['left'])
        plt.yticks([])

    for ext in ['.pdf', '.png']:
        file = locfile_out.get_file_fig('model_comp_recovery',
                                        {'easiest_only': use_easiest_only},
                                        ext=ext)
        plt.savefig(file, dpi=300)
        print('Saved to %s' % file)

    # --- Print mean +- SEM to CSV
    csv_file = locfile_out.get_file_csv('model_comp_recovery',
                                        {'easiest_only': use_easiest_only})
    d_csv = np2.dictlist2listdict({
        'data': ['original', 'simulated serial', 'simulated parallel'],
        'mean_dcost':
        m,
        'sem_dcost':
        e,
    })
    with open(csv_file, 'w') as f:
        writer = csv.DictWriter(f, fieldnames=d_csv[0].keys())
        writer.writeheader()
        for d in d_csv:
            writer.writerow(d)
        print('Wrote to %s' % csv_file)

    # --- Mean NLL within each
    axs = plt2.GridAxes(
        1,
        1,
        left=0.85,
        right=0.25,
        heights=[1.5],
        top=0.1,
        bottom=0.9,
    )
    plt.sca(axs[0, 0])
    plt.barh(np.arange(3), m, xerr=e, color='w', edgecolor='k')
    plt.yticks(np.arange(3),
               ['data', 'simulated\nserial', 'simulated\nparallel'])
    plt2.box_off(['top', 'right'])
    vmax = np.amax(np.abs(m) + np.abs(e))
    plt.xlim(np.array([-vmax, vmax]) * 1.1)
    plt2.detach_axis(xy='y', amin=0, amax=2)
    plt2.detach_axis(xy='x', amin=-vmax, amax=vmax)

    axvline_dcost()
    xticks_serial_vs_parallel(vmax, base)

    for ext in ['.pdf', '.png']:
        file = locfile_out.get_file_fig('model_comp_vs_recovery',
                                        {'easiest_only': use_easiest_only},
                                        ext=ext)
        plt.savefig(file, dpi=300)
        print('Saved to %s' % file)

    print('--')
def plot_bar_dloss_across_subjs(
    dlosses,
    elosses=None,
    ix_datas=None,
    subj_parad_bis: Iterable[Tuple[str, str, bool]] = None,
    axs: Union[plt2.GridAxes, plt2.AxesArray] = None,
    vmax=None,
    add_scale=True,
    base=10.,
):
    """

    :param dlosses: [ix_data]
    :param ix_datas:
    :param axs:
    :param subj_parad_bis: [('subj', 'parad', is_bimanual), ...]
    :return: axs
    """

    if subj_parad_bis is None:
        subj_parad_bis = subj_parad_bis0
    if vmax is None:
        vmax = np.amax(np.abs(dlosses))

    # order: eye S1-S3, hand by ID, paired uni-bimanual
    subjs, parads, bis = zip(*subj_parad_bis)
    subjs = np.array(
        ['ID0' + v[-1] if v[:2] == 'ID' and len(v) == 3 else v for v in subjs])
    parads = np.array(parads)
    bis = np.array(bis)

    is_eye = parads == 'RT'
    is_bin = parads == 'binary'
    ix = np.arange(len(subjs))

    def filt_sort(filt):
        ind = [int(subj[1:]) for subj in subjs[filt]]
        return ix[filt][np.argsort(ind)]

    ix = np.concatenate([
        filt_sort(is_eye & ~is_bin),
        np.stack([
            filt_sort(~is_eye & ~bis & ~is_bin),
            filt_sort(~is_eye & bis & ~is_bin)
        ], -1).flatten('C'),
        filt_sort(is_bin)
    ])
    subjs = subjs[ix]
    parads = parads[ix]
    bis = bis[ix]
    is_eye = is_eye[ix]
    dlosses = dlosses[ix]
    subj_parad_bis = subj_parad_bis[ix]

    n_eye = int(np.sum(is_eye))
    n_hand = int(np.sum(~is_eye))

    y = np.empty([n_eye + n_hand])
    y[is_eye] = 1.5 + np.arange(n_eye)
    y[~is_eye] = n_eye - 1 + 1.5 + np.cumsum([1.5, 1.] * (n_hand // 2))
    y_max = np.amax(y) + 1.5

    if axs is None:
        axs = plt2.GridAxes(nrows=1,
                            ncols=1,
                            heights=y_max * 0.2,
                            widths=2,
                            left=1.5,
                            right=0.25,
                            bottom=0.85)
    ax = axs[0, 0]
    plt.sca(ax)

    m = dlosses
    if elosses is None:
        e = np.zeros_like(m)
    else:
        e = elosses

    for y1, m1, e1, parad1, bi1 in zip(y, m, e, parads, bis):
        plt.barh(y1,
                 m1,
                 xerr=e1,
                 color=colors_parad[(parad1, '%s' % bi1)],
                 edgecolor='None')

    if add_scale:
        dy = y[1] - y[0]

    axvline_dcost()

    x_lim = [-vmax * 1.2, vmax * 1.2]
    for ix_big in range(len(y)):
        if np.abs(m[ix_big]) > vmax:
            for i_sign, sign in enumerate([1, -1]):
                plt2.patch_wave(
                    y[ix_big],
                    x_lim[i_sign] * 1.01,
                    ax=ax,
                    color='w',
                    wave_margin=0.15,
                    wave_amplitude=sign * 0.025,
                )

    plt.xlim(x_lim)
    xticks_serial_vs_parallel(vmax, base)
    subj_parad_bi_str = get_subj_parad_bi_str(subj_parad_bis)
    plt.yticks(y, subj_parad_bi_str)
    plt2.detach_axis('y', y[0], y[-1])
    plt2.detach_axis('x', -vmax, vmax)
    plt.ylim([y_max - 1, 1.])

    return axs
Esempio n. 8
0
def main_plot_models_on_columns(i_subjs=(0, 1), axs=None, coef='slope'):
    if coef not in ['slope', 'bias']:
        raise ValueError()

    n_subj = len(i_subjs)

    # ---- Load goodness-of-fit
    def load_gof(i_subj, buffix_dur):
        fix = fix0_pre + [('buffix', buffix_dur)] + fix0_post
        dict_cache, subdir = main_fit(i_subj,
                                      fit_mode='dict_cache_only',
                                      fix=fix)

        file = locfile.get_file('tab',
                                'best_loss',
                                dict_cache,
                                ext='.csv',
                                subdir=subdir)
        import csv, os
        rows = None
        if not os.path.exists(file):
            print('File absent - returning NaN: %s' % file)
            gof = np.nan
            return gof
        with open(file, 'r') as csvfile:
            reader = csv.DictReader(csvfile)
            for row in reader:
                if rows is None:
                    rows = {k: [row[k]] for k in row.keys()}
                else:
                    for k in row.keys():
                        rows[k].append(row[k])
        gof_kind = 'loss_NLL_test'
        igof = rows['name'].index(gof_kind)
        gof = float(rows[' value'][igof][3:])
        return gof

    bufdurs = [0., bufdur_best, 1.2]
    nbufdurs = len(bufdurs)
    gofs = np.zeros([nbufdurs, n_subj])

    for i_subj in range(n_subj):
        for idur, bufdur in enumerate(bufdurs):
            gofs[idur, i_subj] = load_gof(i_subj, bufdur)

    gofs = gofs - np.nanmin(gofs, 0, keepdims=True)

    # ---- Load coefs
    fixs = [
        ('buffer+serial', [
            ('buffix', bufdur_best),
        ]),
        ('parallel', [
            ('buffix', 1.2),
        ]),
        ('serial', [
            ('buffix', 0.),
        ]),
    ]
    model_names = [f[0] for f in fixs]
    n_models = len(model_names)

    models = np.empty([n_models, n_subj], dtype=np.object)
    dict_caches = np.empty([n_models, n_subj], dtype=np.object)
    ds = np.empty([n_models, n_subj], dtype=np.object)
    datas = np.empty([n_subj], dtype=np.object)

    subdir = ','.join(fix0) + '+buffix=' + ','.join(
        [('%1.2f' % v) for v in [0., bufdur_best, 1.2]])

    for i_subj in range(n_subj):
        for i_model, (name, fix1) in enumerate(fixs):
            fix = fix0_pre + fix1 + fix0_post
            model, data, dict_cache, d, _ = main_fit(i_subj,
                                                     fit_mode='load',
                                                     fix=fix)

            models[i_model, i_subj] = model
            datas[i_subj] = data
            dict_caches[i_model, i_subj] = dict_cache
            ds[i_model, i_subj] = d

    # ---- Plot coefs
    kw_plots = {
        'serial': {
            'linestyle': '-'
        },
        'buffer+serial': {
            'linestyle': '-'
        },
        'parallel': {
            'linestyle': '-'
        },
    }

    def beautify(row, col, axs):
        plt.sca(axs[row, col])
        plt.xlim(xmax=1.2 * 1.05)
        plt2.detach_axis('x', 0, 1.2)
        plt2.box_off()

        beautify_ticks(axs[row, col],
                       add_ticklabel=(row == n_row - 1) and (col == 0))

        if col > 0:
            plt.ylabel('')

    import matplotlib as mpl

    for i_subj in range(n_subj):
        fig = plt.figure(figsize=[2.25 * n_models, 5.25])
        n_row = 3
        n_col = n_models

        n_col = len(fixs)
        n_row = 2
        axs = plt2.GridAxes(n_row,
                            n_col,
                            top=.7,
                            left=1.15,
                            right=0.1,
                            bottom=.5,
                            widths=[2],
                            wspace=0.35,
                            heights=[1.5],
                            hspace=[0.2])

        # ---- Plot slopes
        title_model = {
            'buffer+serial': 'parallel followed\nby serial',
            'parallel': 'strictly parallel',
            'serial': 'strictly serial'
        }
        for ii_model, model_name in enumerate(
                # ['buffer+serial', 'serial']):
            ['buffer+serial', 'parallel', 'serial']):
            i_model = model_names.index(model_name)
            fix1 = fixs[i_model]
            model = models[i_model, i_subj]
            data = datas[i_subj]
            name = model_names[i_model]

            plot_coefs_dur_odif_pred_data(
                data,
                model,
                axs=axs[:2, :][:, [ii_model]],
                to_plot_data=True,
                kw_plot_model=kw_plots[model_name],
                coefs_to_plot=[1 if coef == 'slope' else 0],
                add_rowtitle=False)

            for row in range(2):
                plt.sca(axs[row, ii_model])
                plt2.sameaxes(axs[row, :], xy='y')
                beautify(row, ii_model, axs)
                if ii_model > 0:
                    plt.gca().set_yticklabels([])

                if row == 0:
                    plt.title(title_model[model_name])
                    plt.xlabel('')
                else:
                    plt.title('')
                    if ii_model == 0:
                        plt.xlabel('stimulus duration (s)')
                    else:
                        plt.xlabel('')

        bnd = axs[0, 1].get_position().bounds
        subj = consts.SUBJS['VD'][i_subj]
        plt.suptitle(subj, x=bnd[0] + bnd[2] / 2)

        rowtitles = [
            r'Motion sensitivity ($\beta$)',
            r'Color sensitivity ($\beta$)',
        ] if coef == 'slope' else [
            r'Motion bias ($\beta$)',
            r'Color bias ($\beta$)',
        ]
        for row, title in enumerate(rowtitles):
            plt.sca(axs[row, 0])
            plt.ylabel(title)

        dict_cache = dict_caches[0, 0]
        if 'fix' in dict_cache:
            dict_cache.pop('fix')

        dict_cache['sbj'] = subj
        dict_cache['coef'] = coef

        for ext in ['.png', '.pdf']:
            file = locfile.get_file_fig('coefs_dur_odif',
                                        dict_cache,
                                        subdir=subdir,
                                        ext=ext)
            plt.savefig(file, dpi=300)
        print('Saved to %s' % file)
    print('--')

    return models, data
Esempio n. 9
0
def main_plot_across_models(
        i_subjs=i_subjs,
        axs=None,
        models_incl=('buffer+serial', ),
        base=10,
        **kwargs,
):
    n_subj = len(i_subjs)

    # ---- Load goodness-of-fit
    def load_gof(i_subj, buffix_dur):
        fix = fix0_pre + [('buffix', buffix_dur)] + fix0_post
        dict_cache, subdir = main_fit(i_subj,
                                      fit_mode='dict_cache_only',
                                      fix=fix,
                                      **kwargs)

        file = locfile.get_file('tab',
                                'best_loss',
                                dict_cache,
                                ext='.csv',
                                subdir=subdir)
        import csv, os
        rows = None
        if not os.path.exists(file):
            print('File absent - returning NaN: %s' % file)
            gof = np.nan
            return gof
        with open(file, 'r') as csvfile:
            reader = csv.DictReader(csvfile)
            for row in reader:
                if rows is None:
                    rows = {k: [row[k]] for k in row.keys()}
                else:
                    for k in row.keys():
                        rows[k].append(row[k])
        gof_kind = 'loss_NLL_test'
        igof = rows['name'].index(gof_kind)
        gof = float(rows[' value'][igof][3:])
        return gof

    nbufdurs = len(bufdurs0)
    gofs = np.zeros([nbufdurs, n_subj])

    for i_subj in range(n_subj):
        for idur, bufdur in enumerate(bufdurs0):
            gofs[idur, i_subj] = load_gof(i_subj, bufdur)

    gofs = gofs - np.nanmin(gofs, 0, keepdims=True)
    gofs = gofs / np.log(base)

    # ---- Load slopes
    fixs = [
        ('buffer+serial', [
            ('buffix', bufdur_best),
        ]),
        # (
        # 'parallel', [
        #     ('buffix', 1.2),
        # ]),
        # (
        # 'serial', [
        #     ('buffix', 0.),
        # ]),
    ]
    model_names = [f[0] for f in fixs]
    n_models = len(model_names)

    models = np.empty([n_models, n_subj], dtype=np.object)
    dict_caches = np.empty([n_models, n_subj], dtype=np.object)
    ds = np.empty([n_models, n_subj], dtype=np.object)
    datas = np.empty([n_subj], dtype=np.object)

    kw_plots = {
        'serial': {
            'linestyle': '--'
        },
        'buffer+serial': {
            'linestyle': '-'
        },
        'parallel': {
            'linestyle': ':'
        },
    }

    subdir = ','.join(fix0) + '+buffix=' + ','.join(
        [('%1.2f' % v) for v in [0., bufdur_best, 1.2]])

    for i_subj in range(n_subj):
        for i_model, (name, fix1) in enumerate(fixs):
            fix = fix0_pre + fix1 + fix0_post
            # try:
            model, data, dict_cache, d, _ = main_fit(i_subj,
                                                     fit_mode='load',
                                                     fix=fix,
                                                     **kwargs)
            # except RuntimeError:
            #     model = None
            #     data = None
            #     dict_cache = None
            #     d = None

            models[i_model, i_subj] = model
            datas[i_subj] = data
            dict_caches[i_model, i_subj] = dict_cache
            ds[i_model, i_subj] = d

    # ---- Plot goodness-of-fit
    if axs is None:
        n_row = 2 + len(models_incl)
        n_row_gs = n_row + 2
        n_col = n_subj
        axs = plt2.GridAxes(3,
                            2,
                            top=.3,
                            left=1.15,
                            right=0.1,
                            bottom=.5,
                            widths=[2],
                            wspace=0.35,
                            heights=[1.5, 1.5, 1.5],
                            hspace=[0.2, 0.9])

    for i_subj in range(n_subj):
        ax = axs[-1, i_subj]
        plt.sca(ax)
        plt2.box_off('all')

        gs1 = axs.gs[-2, i_subj * 2 + 1]
        bax = breakaxis(gs1)
        ax0 = bax.axs[0]  # type: plt.Axes
        ax1 = bax.axs[1]  # type: plt.Axes

        ax0.plot(bufdurs0[:3],
                 gofs[:3, i_subj],
                 'k.-',
                 linewidth=0.75,
                 markersize=4.5)
        ax1.plot(bufdurs0,
                 gofs[:, i_subj],
                 'k.-',
                 linewidth=0.75,
                 markersize=4.5)

        plt.sca(ax1)
        patch_chance_level(level=np.log(100.) / np.log(base), signs=[-1, 1])
        plt.axhline(0, color='k', linestyle='--', linewidth=0.5)
        beautify(ax1)
        beautify_ticks(ax1, add_ticklabel=True)  # i_subj == 0)

        plt2.sameaxes([ax0, ax1], xy='x')

        ax1.set_yticks([0, 20])
        ax0.set_yticks([40, 200])

        if i_subj == 0:
            plt.sca(ax1)
            plt.xlabel('buffer capacity (s)')

            plt.sca(ax)
            plt.ylabel(r'$-\mathrm{log}_{10}\mathrm{BF}$', labelpad=27)
        else:
            ax0.set_yticklabels([])
            ax1.set_yticklabels([])

        plt.sca(axs[0, i_subj])
        beautify_ticks(axs[0, i_subj], add_ticklabel=False)
        beautify_ticks(axs[1, i_subj], add_ticklabel=True)

    plt.sca(axs[-2, 0])
    plt.xlabel('stimulus duration (s)')

    # ---- Plot slopes
    for i_subj in range(n_subj):
        hss = []
        for model_name in models_incl:
            i_model = model_names.index(model_name)
            fix1 = fixs[i_model]
            model = models[i_model, i_subj]
            data = datas[i_subj]
            name = model_names[i_model]

            if model is not None:
                _, hs = plot_coefs_dur_odif_pred_data(
                    data,
                    model,
                    axs=axs[:2, [i_subj]],
                    # to_plot_data=True,
                    to_plot_data=True,
                    kw_plot_model=kw_plots[model_name],
                    coefs_to_plot=[1],
                    add_rowtitle=False)
                hss.append(hs)

        if i_subj == 0:
            # hss[0] = hs['pred'|'data'][coef, dim, dif]
            hs1 = hss[0]['data']  # type: np.ndarray
            for dim in range(hs1.shape[1] - 1, -1, -1):
                hs = []
                for dif in range(hs1.shape[2]):
                    hs.append(hs1[0, dim, dif])
                odim = 1 - dim
                odim_name = consts.DIM_NAMES_LONG[odim]
                plt.sca(axs[dim, 0])

                plt.legend(hs, [
                    'weak ' + odim_name.lower(), 'strong ' + odim_name.lower()
                ],
                           bbox_to_anchor=[0., 0., 1., 1.],
                           handletextpad=0.35,
                           handlelength=0.5,
                           labelspacing=0.3,
                           borderpad=0.,
                           frameon=False,
                           loc='upper left')

        for row in range(axs.shape[0]):
            beautify(axs[row, i_subj])
            if i_subj > 0:
                plt.ylabel('')

    plt2.sameaxes(axs[-1, :])
    ax = axs[-1, -1]
    ax.set_yticklabels([])

    for i_subj in range(n_subj):
        plt.sca(axs[0, i_subj])
        plt.title(consts.SUBJS['VD'][i_subj])

        for row in range(1, n_row):
            plt.sca(axs[row, i_subj])
            plt.title('')

    for row, title in enumerate([
            r'Motion sensitivity ($\beta$)',
            r'Color sensitivity ($\beta$)',
    ]):
        plt.sca(axs[row, 0])
        plt.ylabel(title)

    dict_cache = dict_caches[0, 0]
    for k in ['fix', 'sbj']:
        if k in dict_cache:
            dict_cache.pop(k)

    for ext in ['.pdf', '.png']:
        file = locfile.get_file_fig('coefs_dur_odif_sbjs',
                                    dict_cache,
                                    subdir=subdir,
                                    ext=ext)
        from lib.pylabyk.cacheutil import mkdir4file
        mkdir4file(file)
        plt.savefig(file, dpi=300)
        print('Saved to %s' % file)

    print('--')

    return gofs, models, data