Exemplo n.º 1
0
def plot_p_ch_vs_ev(ev_cond,
                    p_ch,
                    style='pred',
                    ax: plt.Axes = None,
                    **kwargs) -> plt.Line2D:
    """
    @param ev_cond: [condition] or [condition, frame]
    @type ev_cond: torch.Tensor
    @param p_ch: [condition, ch] or [condition, rt_frame, ch]
    @type p_ch: torch.Tensor
    @return:
    """
    if ax is None:
        ax = plt.gca()
    if ev_cond.ndim != 1:
        if ev_cond.ndim == 3:
            ev_cond = npt.p2st(ev_cond)[0]
        assert ev_cond.ndim == 2
        ev_cond = ev_cond.mean(1)
    if p_ch.ndim != 2:
        assert p_ch.ndim == 3
        p_ch = p_ch.sum(1)

    kwargs = get_kw_plot(style, **kwargs)

    h = ax.plot(*npys(ev_cond, npt.p2st(npt.sumto1(p_ch, -1))[1]), **kwargs)
    plt2.box_off(ax=ax)
    x_lim = ax.get_xlim()
    plt2.detach_axis('x', amin=x_lim[0], amax=x_lim[1], ax=ax)
    plt2.detach_axis('y', amin=0, amax=1, ax=ax)
    ax.set_yticks([0, 0.5, 1])
    ax.set_yticklabels(['0', '', '1'])
    ax.set_xlabel('evidence')
    ax.set_ylabel(r"$\mathrm{P}(z=1 \mid c)$")
    return h
Exemplo n.º 2
0
def beautify(ax):
    xticks = np.arange(0, 1.2 + 0.1, 0.1)

    plt.sca(ax)
    plt.xlim(xmax=1.2 * 1.05)
    plt2.detach_axis('x', 0, 1.2)
    plt2.box_off()
Exemplo n.º 3
0
    def plot_params(self,
                    named_bounded_params: Sequence[Tuple[
                        str, BoundedParameter]] = None,
                    exclude: Iterable[str] = (),
                    cmap='coolwarm',
                    ax: plt.Axes = None) -> mpl.container.BarContainer:
        if ax is None:
            ax = plt.gca()

        ax = plt.gca()
        names, v, grad, lb, ub, requires_grad = self.get_named_bounded_params(
            named_bounded_params, exclude=exclude)
        max_grad = np.amax(np.abs(grad))
        if max_grad == 0:
            max_grad = 1.
        v01 = (v - lb) / (ub - lb)
        grad01 = (grad + max_grad) / (max_grad * 2)
        n = len(v)

        grad01[~requires_grad] = np.nan
        # (np.amin(grad01) + np.amax(grad01)) / 2

        # ax = plt.gca()  # CHECKED

        for i, (lb1, v1, ub1, g1,
                r1) in enumerate(zip(lb, v, ub, grad, requires_grad)):
            color = 'k' if r1 else 'gray'
            plt.text(0, i, '%1.2g' % lb1, ha='left', va='center', color=color)
            plt.text(1, i, '%1.2g' % ub1, ha='right', va='center', color=color)
            plt.text(
                0.5,
                i,
                '%1.2g %s' %
                (v1, ('(e%1.0f)' % np.log10(np.abs(g1))) if r1 else '(fixed)'),
                ha='center',
                va='center',
                color=color)
        lut = 256
        colors = plt.get_cmap(cmap, lut)(grad01)
        for i, r in enumerate(requires_grad):
            if not r:
                colors[i, :3] = np.array([0.95, 0.95, 0.95])
        h = ax.barh(np.arange(n), v01, left=0, color=colors)
        ax.set_xlim(-0.025, 1)
        ax.set_xticks([])
        ax.set_yticks(np.arange(n))

        names = [v.replace('_', '-') for v in names]
        ax.set_yticklabels(names)
        ax.set_ylim(n - 0.5, -0.5)
        plt2.box_off(['top', 'right', 'bottom'])
        plt2.detach_axis('x', amin=0, amax=1)
        plt2.detach_axis('y', amin=0, amax=n - 1)

        # plt.show()  # CHECKED

        return h
Exemplo n.º 4
0
    def beautify(row, col, axs):
        plt.sca(axs[row, col])
        plt.xlim(xmax=1.2 * 1.05)
        plt2.detach_axis('x', 0, 1.2)
        plt2.box_off()

        beautify_ticks(axs[row, col],
                       add_ticklabel=(row == n_row - 1) and (col == 0))

        if col > 0:
            plt.ylabel('')
Exemplo n.º 5
0
            def plot_bars(gs1, bufdurs, losses1, add_xticklabel=True):
                i_break = np.amax(np.nonzero(np.array(bufdurs) < 0.3)[0])
                bax = brokenaxes(
                    subplot_spec=gs1,
                    xlims=((-1, i_break + 0.5), (i_break + 0.5,
                                                 len(bufdurs) - 0.5)),
                    ylims=((-3, 20), (20, 1250 / 5)),
                    height_ratios=(50 / 100, 500 / (1250 - 100)),
                    hspace=0.15,
                    wspace=0.075,
                    d=0.005,
                )
                bax.bar(np.arange(len(bufdurs)), losses1[i_subj, :], color='k')

                ax11 = bax.axs[3]  # type: plt.Axes
                ax11.set_xticks([bufdurs.index(0.6), bufdurs.index(1.2)])
                if i_subj == 0 and add_xticklabel:
                    ax11.set_xticklabels(['0.6', '1.2'])
                else:
                    ax11.set_xticklabels([])

                ax00 = bax.axs[0]  # type: plt.Axes
                ax00.set_yticks([500, 1000])

                ax10 = bax.axs[2]  # type: plt.Axes
                ax10.set_yticks([0, 50])
                plt.sca(ax10)
                plt2.detach_axis('x', amin=-0.4, amax=i_break + 0.5)
                for ax in [ax10, ax11]:
                    plt.sca(ax)
                    plt.axhline(0, linewidth=0.5, color='k', linestyle='--')
                    for sign in [-1, 1]:
                        plt.axhline(sign * thres_strong,
                                    linewidth=0.5,
                                    color='silver',
                                    linestyle='--')
                ax10.set_xticks([bufdurs.index(0.), bufdurs.index(0.2)])
                if i_subj == 0:
                    if add_xticklabel:
                        ax10.set_xticklabels(['0', '0.2'])
                    else:
                        ax10.set_xticklabels([])
                else:
                    ax10.set_yticklabels([])
                    ax10.set_xticklabels([])
                    ax00.set_yticklabels([])
                return bax
Exemplo n.º 6
0
    def plot_bound(self,
                   t_all: Sequence[float] = None,
                   ax: plt.Axes = None,
                   **kwargs) -> plt.Line2D:
        if ax is None:
            ax = plt.gca()
        if t_all is None:
            t_all = self.t_all

        kwargs = argsutil.kwdefault(kwargs, color='k', linestyle='-')
        h = ax.plot(*npys(t_all, self.get_bound(t_all)), **kwargs)
        ax.set_xlabel('time (s)')
        ax.set_ylabel(r"$b(t)$")
        y_lim = ax.get_ylim()
        y_min = -y_lim[1] * 0.05
        ax.set_ylim(ymin=y_min)
        plt2.detach_axis('y', amin=0)
        plt2.detach_axis('x', amin=0)
        plt2.box_off()
        return h
Exemplo n.º 7
0
    def plot_p_tnd(self,
                   t_all: Sequence[float] = None,
                   ax: plt.Axes = None,
                   **kwargs) -> Union[plt.Line2D, Sequence[plt.Line2D]]:
        if t_all is None:
            t_all = self.t_all
        if ax is None:
            ax = plt.gca()

        p_tnd = self.get_p_tnd(t_all)

        kwargs = argsutil.kwdefault(kwargs, color='k', linestyle='-')
        h = ax.plot(*npys(t_all, p_tnd), **kwargs)
        ax.set_xlabel('time (s)')
        plt2.box_off(ax=ax)
        # ax.set_ylim(ymin=0)
        # ax.set_xlim(xmin=0)
        plt2.detach_axis('y', amin=0, amax=1, ax=ax)
        plt2.detach_axis('x', amin=0, ax=ax)
        return h
Exemplo n.º 8
0
def plot_rt_distrib(
    n_cond_rt_ch: np.ndarray,
    ev_cond_dim: np.ndarray,
    abs_cond=True,
    lump_wrong=True,
    dt=consts.DT,
    colors=None,
    alpha=1.,
    alpha_face=0.5,
    smooth_sigma_sec=0.05,
    to_normalize_max=False,
    to_cumsum=False,
    to_exclude_last_frame=True,
    to_skip_zero_trials=False,
    label='',
    # to_exclude_bins_wo_trials=10,
    kw_plot=(),
    fig=None,
    axs=None,
    to_use_sameaxes=True,
):
    """

    :param n_cond_rt_ch:
    :param ev_cond_dim:
    :param abs_cond:
    :param lump_wrong:
    :param dt:
    :param gs:
    :param colors:
    :param alpha:
    :param smooth_sigma_sec:
    :param kw_plot:
    :param axs:
    :return: axs, p_cond01__rt_ch01, p_cond01__rt_ch01_sm, hs
    """
    if colors is None:
        colors = ['red', 'blue']
    elif type(colors) is str:
        colors = [colors] * 2
    else:
        assert len(colors) == 2

    nt = n_cond_rt_ch.shape[1]
    t_all = np.arange(nt) * dt + dt

    out = np.meshgrid(np.unique(ev_cond_dim[:, 0]),
                      np.unique(ev_cond_dim[:, 1]),
                      np.arange(nt),
                      np.arange(2),
                      np.arange(2),
                      indexing='ij')
    cond0, cond1, fr, ch0, ch1 = [v.flatten() for v in out]

    from copy import deepcopy
    n0 = deepcopy(n_cond_rt_ch)
    if to_exclude_last_frame:
        n0[:, -1, :] = 0.
    n0 = n0.flatten()

    def sign_cond(v):
        v1 = np.sign(v)
        v1[v == 0] = 1
        return v1

    if abs_cond:
        ch0 = consts.ch_bool2sign(ch0)
        ch1 = consts.ch_bool2sign(ch1)

        # 1 = correct, -1 = wrong
        ch0 = sign_cond(cond0) * ch0
        ch1 = sign_cond(cond1) * ch1

        cond0 = np.abs(cond0)
        cond1 = np.abs(cond1)

        ch0 = consts.ch_sign2bool(ch0).astype(np.int)
        ch1 = consts.ch_sign2bool(ch1).astype(np.int)
    else:
        raise ValueError()

    if lump_wrong:
        # treat all choices as correct when cond == 0
        ch00 = ch0 | (cond0 == 0)
        ch10 = ch1 | (cond1 == 0)

        ch0 = (ch00 & ch10)
        ch1 = np.ones_like(ch00, dtype=np.int)

    cond_dim = np.stack([cond0, cond1], -1)

    conds = []
    dcond_dim = []
    for cond in cond_dim.T:
        conds1, dcond1 = np.unique(cond, return_inverse=True)
        conds.append(conds1)
        dcond_dim.append(dcond1)
    dcond_dim = np.stack(dcond_dim)

    n_cond01_rt_ch01 = npg.aggregate(
        [*dcond_dim, fr, ch0, ch1], n0, 'sum',
        [*(np.amax(dcond_dim, 1) + 1), nt, consts.N_CH, consts.N_CH])

    p_cond01__rt_ch01 = np2.nan2v(n_cond01_rt_ch01 / n_cond01_rt_ch01.sum(
        (2, 3, 4), keepdims=True))

    n_conds = p_cond01__rt_ch01.shape[:2]

    if axs is None:
        axs = plt2.GridAxes(
            n_conds[1],
            n_conds[0],
            left=0.6,
            right=0.3,
            bottom=0.45,
            top=0.74,
            widths=[1],
            heights=[1],
            wspace=0.04,
            hspace=0.04,
        )

    kw_label = {
        'fontsize': 12,
    }
    pad = 8
    axs[0, 0].set_title('strong\nmotion', pad=pad, **kw_label)
    axs[0, -1].set_title('weak\nmotion', pad=pad, **kw_label)
    axs[0, 0].set_ylabel('strong\ncolor', labelpad=pad, **kw_label)
    axs[-1, 0].set_ylabel('weak\ncolor', labelpad=pad, **kw_label)

    if smooth_sigma_sec > 0:
        from scipy import signal, stats
        sigma_fr = smooth_sigma_sec / dt
        width = np.ceil(sigma_fr * 2.5).astype(np.int)
        kernel = stats.norm.pdf(np.arange(-width, width + 1), 0, sigma_fr)
        kernel = np2.vec_on(kernel, 2, 5)
        p_cond01__rt_ch01_sm = signal.convolve(p_cond01__rt_ch01,
                                               kernel,
                                               mode='same')
    else:
        p_cond01__rt_ch01_sm = p_cond01__rt_ch01.copy()

    if to_cumsum:
        p_cond01__rt_ch01_sm = np.cumsum(p_cond01__rt_ch01_sm, axis=2)

    if to_normalize_max:
        p_cond01__rt_ch01_sm = np2.nan2v(
            p_cond01__rt_ch01_sm /
            np.amax(np.abs(p_cond01__rt_ch01_sm), (2, 3, 4), keepdims=True))

    n_row = n_conds[1]
    n_col = n_conds[0]
    for dcond0 in range(n_conds[0]):
        for dcond1 in range(n_conds[1]):
            row = n_row - 1 - dcond1
            col = n_col - 1 - dcond0

            ax = axs[row, col]  # type: plt.Axes

            for ch0 in [0, 1]:
                for ch1 in [0, 1]:
                    if lump_wrong and ch1 == 0:
                        continue

                    p1 = p_cond01__rt_ch01_sm[dcond0, dcond1, :, ch0, ch1]

                    kw = {
                        'linewidth': 1,
                        'color': colors[ch1],
                        'alpha': alpha,
                        'zorder': 1,
                        **dict(kw_plot)
                    }

                    y = p1 * consts.ch_bool2sign(ch0)

                    p_cond01__rt_ch01_sm[dcond0, dcond1, :, ch0, ch1] = y

                    if to_skip_zero_trials and np.sum(np.abs(y)) < 1e-2:
                        h = None
                    else:
                        h = ax.plot(
                            t_all,
                            y,
                            label=label if ch0 == 1 and ch1 == 1 else None,
                            **kw)
                        ax.fill_between(t_all,
                                        0,
                                        y,
                                        ec='None',
                                        fc=kw['color'],
                                        alpha=alpha_face,
                                        zorder=-1)
                    plt2.box_off(ax=ax)

            ax.axhline(0, color='k', linewidth=0.5)
            ax.set_yticklabels([])
            if row < n_row - 1 or col > 0:
                ax.set_xticklabels([])
                ax.set_xticks([])
                plt2.box_off(['bottom'], ax=ax)
            else:
                ax.set_xlabel('RT (s)')
            # if col > 0:
            ax.set_yticks([])
            ax.set_yticklabels([])
            plt2.box_off(['left'], ax=ax)

            plt2.detach_axis('x', 0, 5, ax=ax)
    if to_use_sameaxes:
        plt2.sameaxes(axs)
    axs[-1, 0].set_xlabel('RT (s)')

    return axs, p_cond01__rt_ch01, p_cond01__rt_ch01_sm, h
Exemplo n.º 9
0
def plot_p_ch_vs_ev(
        ev_cond: Union[torch.Tensor, np.ndarray],
        n_ch: Union[torch.Tensor, np.ndarray],
        style='pred',
        ax: plt.Axes = None,
        dim_rel=0,
        group_dcond_irr: Iterable[Iterable[int]] = None,
        cmap: Union[str, Callable] = 'cool',
        kw_plot=(),
) -> Iterable[plt.Line2D]:
    """
    @param ev_cond: [condition, dim] or [condition, frame, dim, (mean, var)]
    @type ev_cond: torch.Tensor
    @param n_ch: [condition, ch] or [condition, rt_frame, ch]
    @type n_ch: torch.Tensor
    @return: hs[cond_irr][0] = Line2D, conds_irr
    """
    if ax is None:
        ax = plt.gca()
    if ev_cond.ndim != 2:
        assert ev_cond.ndim == 4
        ev_cond = npt.p2st(ev_cond.mean(1))[0]
    if n_ch.ndim != 2:
        assert n_ch.ndim == 3
        n_ch = n_ch.sum(1)

    ev_cond = npy(ev_cond)
    n_ch = npy(n_ch)
    n_cond_all = n_ch.shape[0]
    ch_rel = np.repeat(np.array(consts.CHS[dim_rel])[None, :], n_cond_all, 0)
    n_ch = n_ch.reshape([-1])
    ch_rel = ch_rel.reshape([-1])

    dim_irr = consts.get_odim(dim_rel)
    conds_rel, dcond_rel = np.unique(ev_cond[:, dim_rel], return_inverse=True)
    conds_irr, dcond_irr = np.unique(np.abs(ev_cond[:, dim_irr]),
                                     return_inverse=True)

    if group_dcond_irr is not None:
        conds_irr, dcond_irr = group_conds(conds_irr, dcond_irr,
                                           group_dcond_irr)

    n_conds = [len(conds_rel), len(conds_irr)]

    n_ch_rel = npg.aggregate(
        np.stack([
            ch_rel,
            np.repeat(dcond_irr[:, None], consts.N_CH_FLAT, 1).flatten(),
            np.repeat(dcond_rel[:, None], consts.N_CH_FLAT, 1).flatten(),
        ]), n_ch, 'sum', [consts.N_CH, n_conds[1], n_conds[0]])
    p_ch_rel = n_ch_rel[1] / n_ch_rel.sum(0)

    hs = []
    for dcond_irr1, p_ch1 in enumerate(p_ch_rel):
        if type(cmap) is str:
            color = plt.get_cmap(cmap, n_conds[1])(dcond_irr1)
        else:
            color = cmap(n_conds[1])(dcond_irr1)
        kw1 = get_kw_plot(style, color=color, **dict(kw_plot))
        h = ax.plot(conds_rel, p_ch1, **kw1)
        hs.append(h)
    plt2.box_off(ax=ax)
    x_lim = ax.get_xlim()
    plt2.detach_axis('x', amin=x_lim[0], amax=x_lim[1], ax=ax)
    plt2.detach_axis('y', amin=0, amax=1, ax=ax)
    ax.set_yticks([0, 0.5, 1])
    ax.set_yticklabels(['0', '', '1'])
    ax.set_xlabel('evidence')
    ax.set_ylabel(r"$\mathrm{P}(z=1 \mid c)$")

    return hs, conds_irr
Exemplo n.º 10
0
def plot_coef_by_dur_vs_odif(coef,
                             se_coef,
                             normalize_bias=True,
                             coef_name='slope',
                             savefig=True,
                             difs=[[2, 1], [0]],
                             t_RDK_durs=np.arange(1, 11) * 0.12,
                             fig=None,
                             horizontal_panels=False):
    """

    @param dat0:
    @param parad:
    @param use_data:
    @param correct_only:
    @param plot_slope: if False, plot bias
    @param i_subj:
    @param savefig:
    @param fig:
    @return:
    """

    if fig is None:
        if horizontal_panels:
            fig = plt.figure(figsize=(7, 2))
        else:
            fig = plt.figure(figsize=(4, 3))
    if horizontal_panels:
        gs = plt.GridSpec(figure=fig,
                          nrows=1,
                          ncols=2,
                          hspace=0.3,
                          left=0.11,
                          right=0.98,
                          top=0.9,
                          bottom=0.15)
    else:
        gs = mpl.gridspec.GridSpec(figure=fig,
                                   nrows=2,
                                   ncols=1,
                                   left=0.22,
                                   right=0.98,
                                   top=0.9,
                                   bottom=0.15)

    coef_names = ['bias', 'slope']
    i_coef = coef_names.index(coef_name)

    units = ['(logit)', '(logit/coh)']
    if normalize_bias:
        coef[0] = -coef[0] / coef[1]
        se_coef[0] = -se_coef[0] / coef[1]
        units[0] = '(coh)'
    unit = units[i_coef]

    y = coef[i_coef]
    se = se_coef[i_coef]
    if len(difs) == 1:
        labels = ['all']
    elif len(difs) == 2:
        labels = ['easy', 'hard']
    elif len(difs) == 3:
        labels = ['easy', 'medium', 'hard']
    else:
        raise ValueError()

    for dim, (slope1, se_slope1) in enumerate(zip(y, se)):
        odim = consts.N_DIM - 1 - dim
        if horizontal_panels:
            plt.subplot(gs[0, dim])
        else:
            plt.subplot(gs[dim, 0])

        # plt.plot(t_RDK_durs, slope[dim])
        for odif, (slope2, se_slope2) in enumerate(zip(slope1, se_slope1)):
            label = labels[odif] + ' ' + consts.DIM_NAMES_SHORT[odim]
            plt.errorbar(t_RDK_durs,
                         slope2,
                         se_slope2,
                         marker='o',
                         label=label)
            plt.ylabel(consts.DIM_NAMES_LONG[dim].lower() + ' ' + coef_name +
                       '\n' + unit)
        plt.axis('auto')
        plt.xlim(xmin=-0.02)
        if coef_name == 'slope':
            plt.ylim(ymin=-0.05)
        plt2.box_off()
        plt2.detach_axis('x')
        if dim == 0 and not horizontal_panels:
            plt.gca().set_xticklabels([])
        plt.legend(handlelength=0.4, frameon=False, loc='upper left')
    plt.xlabel('stimulus duration (s)')
    return gs
Exemplo n.º 11
0
def main_fit(
        subjs=None,
        skip_fit_if_absent=False,
        bufdur=None,
        bufdur_best=0.12,
        bufdurs_sim=None,
        bufdurs_fit=None,
        loss_kind='NLL',
        plot_kind='line_sim_fit',
        fix_post=None,
        seed_sims=(0, ),
        err_kind='std',
        base=10,
):
    # if torch.cuda.is_available():
    #     torch.set_default_tensor_type(torch.cuda.FloatTensor)
    torch.set_num_threads(1)
    torch.set_default_dtype(torch.double)

    dict_res_add = {}

    parad = 'VD'
    seed_sims = np.array(seed_sims)

    if subjs is None:
        subjs = consts.SUBJS_VD

    if bufdur is not None:
        bufdurs_sim = list({bufdur_best}.union({bufdur}))
        bufdurs_fit = list({bufdur_best}.union({bufdur}))
    else:
        if bufdurs_sim is None:
            bufdurs_sim = bufdurs0  #
            if bufdurs_fit is None:
                bufdurs_fit = bufdurs0
        else:
            if bufdurs_fit is None:
                bufdurs_fit0 = [0., bufdur_best, 1.2]
                if bufdurs_sim[0] in bufdurs_fit0:
                    bufdurs_fit = copy(bufdurs_fit0)
                else:
                    bufdurs_fit = bufdurs_fit0[:2] + bufdurs_sim + [1.2]

    n_dur_sim = len(bufdurs_sim)
    n_dur_fit = len(bufdurs_fit)

    # DEF: losses[seed, subj, simSerDurPar, fitDurs]
    size_all = [len(seed_sims), len(subjs), n_dur_sim, n_dur_fit]
    losses0 = np.zeros(size_all) + np.nan
    ns = np.zeros(size_all) + np.nan
    ks0 = np.zeros(size_all) + np.nan

    for i_seed, seed_sim in enumerate(seed_sims):
        for i_subj, subj in enumerate(subjs):
            for i_fit, bufdur_fit in enumerate(bufdurs_fit):
                for i_sim, bufdur_sim in enumerate(bufdurs_sim):
                    d, dict_fit_sim, dict_subdir_sim = get_fit_sim(
                        subj,
                        seed_sim,
                        bufdur_sim,
                        bufdur_fit,
                        parad=parad,
                        skip_fit_if_absent=skip_fit_if_absent,
                        fix_post=fix_post,
                    )
                    if d is not None:
                        losses0[i_seed, i_subj, i_sim,
                                i_fit] = d['loss_NLL_test']
                        ns[i_seed, i_subj, i_sim, i_fit] = d['loss_ndata_test']
                        ks0[i_seed, i_subj, i_sim, i_fit] = d['loss_nparam']
    ks = copy(ks0)

    if loss_kind == 'BIC':
        losses = ks * np.log(ns) + 2 * losses0

        # REF: Kass, Raftery 1995 https://doi.org/10.2307%2F2291091
        #   https://en.wikipedia.org/wiki/Bayesian_information_criterion#Gaussian_special_case
        thres_strong = 10. / np.log(base)
    elif loss_kind == 'NLL':
        losses = copy(losses0)
        thres_strong = np.log(100) / np.log(base)
    else:
        raise ValueError('Unsupported loss_kind: %s' % loss_kind)

    #%%
    losses = losses / np.log(base)

    if plot_kind == 'loss_serpar':
        n_row = len(subjs)
        n_col = 1
        axs = plt2.GridAxes(n_row, n_col, left=1, widths=2, bottom=0.75)
        for i_subj, subj in enumerate(subjs):
            plt.sca(axs[i_subj, 0])

            loss_dur_sim_ser_fit = losses[:, i_subj, :, 0].mean(0)
            loss_dur_sim_par_fit = losses[:, i_subj, :, -1].mean(0)

            dloss_serpar = loss_dur_sim_par_fit - loss_dur_sim_ser_fit

            plt.bar(bufdurs_sim, dloss_serpar, color='k')
            plt.axhline(0, color='k', linestyle='-', linewidth=0.5)

            if i_subj < len(subjs) - 1:
                plt2.box_off(['top', 'right', 'bottom'])
                plt.xticks([])
            else:
                plt2.box_off()
                plt.xticks(np.arange(0, 1.4, 0.2), ['0\nser'] + [''] * 2 +
                           ['0.6'] + [''] * 2 + ['1.2\npar'])
                plt2.detach_axis('x', 0, 1.2)
            vdfit.patch_chance_level()

        plt.sca(axs[0, 0])
        plt.title('Support for Serial\n'
                  '($\mathcal{L}_\mathrm{ser} - \mathcal{L}_\mathrm{par}$)')

        plt.sca(axs[-1, 0])
        plt.xlabel('true buffer duration (s)')

        plt2.rowtitle(consts.SUBJS_VD, axs)
    elif plot_kind == 'imshow_ser_buf_par':
        n_row = len(subjs)
        n_col = 1
        axs = plt2.GridAxes(n_row,
                            n_col,
                            left=1,
                            widths=2,
                            heights=2,
                            bottom=0.75)
        for i_subj, subj in enumerate(subjs):
            plt.sca(axs[i_subj, 0])

            dloss = losses[:, i_subj, :, :].mean(0)
            dloss = dloss - np.diag(dloss)[:, None]

            plt.imshow(dloss, cmap='bwr')

            cmax = np.amax(np.abs(dloss))
            plt.clim(-cmax, +cmax)

        plt.sca(axs[0, 0])

        plt.sca(axs[-1, 0])
        plt.xlabel('true buffer duration (s)')
        plt.ylabel('model buffer duration (s)')

        plt2.rowtitle(consts.SUBJS_VD, axs)

    elif plot_kind == 'line_sim_fit':
        bufdur_bests = np.array([get_bufdur_best(subj) for subj in subjs])
        i_bests = np.array([
            list(bufdurs_fit).index(bufdur_best)
            for bufdur_best in bufdur_bests
        ])
        i_subjs = np.arange(len(subjs))
        loss_sim_best_fit_best = losses[:, i_subjs, i_bests, i_bests]
        loss_sim_best_fit_rest = (losses[:, i_subjs, i_bests, :] -
                                  loss_sim_best_fit_best[:, :, None])
        mean_loss_sim_best_fit_rest = np.mean(loss_sim_best_fit_rest, 0)
        if err_kind == 'std':
            err_loss_sim_best_fit_rest = np.std(loss_sim_best_fit_rest, 0)
        elif err_kind == 'sem':
            err_loss_sim_best_fit_rest = np2.sem(loss_sim_best_fit_rest, 0)

        n_dur = len(bufdurs_fit)
        loss_sim_rest_fit_best = np.swapaxes(
            np.stack([
                losses[:, i_subj, np.arange(n_dur), i_best] -
                losses[:, i_subj,
                       np.arange(n_dur),
                       np.arange(n_dur)]
                for i_subj, i_best in zip(i_subjs, i_bests)
            ]), 0, 1)
        mean_loss_sim_rest_fit_best = np.mean(loss_sim_rest_fit_best, 0)
        if err_kind == 'std':
            err_loss_sim_rest_fit_best = np.std(loss_sim_rest_fit_best, 0)
        elif err_kind == 'sem':
            err_loss_sim_rest_fit_best = np2.sem(loss_sim_rest_fit_best, 0)

        dict_res_add['err'] = err_kind

        n_row = 2
        n_col = len(subjs)

        axs = plt2.GridAxes(n_row,
                            n_col,
                            left=1.15,
                            right=0.1,
                            widths=2,
                            heights=1.5,
                            wspace=0.35,
                            hspace=0.5,
                            top=0.25,
                            bottom=0.6)

        for row, (m, s, xlabel, ylabel) in enumerate([
            (mean_loss_sim_best_fit_rest, err_loss_sim_best_fit_rest,
             'model buffer capacity (s)',
             (r'$-\mathrm{log}_{%g}\mathrm{BF}$ given simulated' +
              '\nbest duration data') % base),
            (mean_loss_sim_rest_fit_best, err_loss_sim_rest_fit_best,
             'simulated data buffer capacity (s)',
             (r'$-\mathrm{log}_{%g}\mathrm{BF}$ of the' +
              '\nbest duration model') % base)
        ]):
            for i_subj, subj in enumerate(subjs):
                ax = axs[row, i_subj]
                plt.sca(ax)
                gs1 = axs.gs[row * 2 + 1, i_subj * 2 + 1]

                bax = vdfit.breakaxis(gs1)
                bax.axs[1].errorbar(bufdurs0,
                                    m[i_subj, :],
                                    yerr=s[i_subj, :],
                                    color='k',
                                    marker='.',
                                    linewidth=0.75,
                                    elinewidth=0.5,
                                    markersize=3)
                m1 = copy(m[i_subj, :])
                m1[3:] = np.nan
                s1 = copy(s[i_subj, :])
                s1[3:] = np.nan
                bax.axs[0].errorbar(bufdurs0,
                                    m1,
                                    yerr=s1,
                                    color='k',
                                    marker='.',
                                    linewidth=0.75,
                                    elinewidth=0.5,
                                    markersize=3)

                ax1 = bax.axs[1]  # type: plt.Axes
                plt.sca(ax1)
                vdfit.patch_chance_level(level=thres_strong, signs=[-1, 1])
                plt.axhline(0, color='k', linestyle='--', linewidth=0.5)
                vdfit.beautify_ticks(ax1, )
                vdfit.beautify(ax1)
                ax1.set_yticks([0, 20])
                if i_subj > 0:
                    ax1.set_yticklabels([])
                if row == 0:
                    ax1.set_xticklabels([])

                ax1 = bax.axs[0]  # type: plt.Axes
                plt.sca(ax1)
                ax1.set_yticks([40, 200])
                if i_subj > 0:
                    ax1.set_yticklabels([])

                plt.sca(ax)
                plt2.box_off('all')
                if row == 0:
                    plt.title(consts.SUBJS['VD'][i_subj])

                if i_subj == 0:
                    plt.xlabel(xlabel, labelpad=8 if row == 0 else 20)
                    plt.ylabel(ylabel, labelpad=30)

                plt2.sameaxes(bax.axs, xy='x')

    elif plot_kind == 'bar_sim_fit':
        i_best = bufdurs_fit.index(bufdur_best)
        loss_sim_best_fit_best = losses[0, :, i_best, i_best]
        loss_sim_best_fit_rest = np2.nan2v(losses[0, :, i_best, :] -
                                           loss_sim_best_fit_best[:, None])
        n_dur = len(bufdurs_fit)
        loss_sim_rest_fit_best = np2.nan2v(
            losses[0, :, np.arange(n_dur), i_best] -
            losses[0, :, np.arange(n_dur),
                   np.arange(n_dur)]).T

        n_row = 2
        n_col = len(subjs)

        axs = plt2.GridAxes(n_row,
                            n_col,
                            left=1,
                            widths=3,
                            right=0.25,
                            heights=1,
                            wspace=0.3,
                            hspace=0.6,
                            top=0.25,
                            bottom=0.6)
        for i_subj, subj in enumerate(subjs):

            def plot_bars(gs1, bufdurs, losses1, add_xticklabel=True):
                i_break = np.amax(np.nonzero(np.array(bufdurs) < 0.3)[0])
                bax = brokenaxes(
                    subplot_spec=gs1,
                    xlims=((-1, i_break + 0.5), (i_break + 0.5,
                                                 len(bufdurs) - 0.5)),
                    ylims=((-3, 20), (20, 1250 / 5)),
                    height_ratios=(50 / 100, 500 / (1250 - 100)),
                    hspace=0.15,
                    wspace=0.075,
                    d=0.005,
                )
                bax.bar(np.arange(len(bufdurs)), losses1[i_subj, :], color='k')

                ax11 = bax.axs[3]  # type: plt.Axes
                ax11.set_xticks([bufdurs.index(0.6), bufdurs.index(1.2)])
                if i_subj == 0 and add_xticklabel:
                    ax11.set_xticklabels(['0.6', '1.2'])
                else:
                    ax11.set_xticklabels([])

                ax00 = bax.axs[0]  # type: plt.Axes
                ax00.set_yticks([500, 1000])

                ax10 = bax.axs[2]  # type: plt.Axes
                ax10.set_yticks([0, 50])
                plt.sca(ax10)
                plt2.detach_axis('x', amin=-0.4, amax=i_break + 0.5)
                for ax in [ax10, ax11]:
                    plt.sca(ax)
                    plt.axhline(0, linewidth=0.5, color='k', linestyle='--')
                    for sign in [-1, 1]:
                        plt.axhline(sign * thres_strong,
                                    linewidth=0.5,
                                    color='silver',
                                    linestyle='--')
                ax10.set_xticks([bufdurs.index(0.), bufdurs.index(0.2)])
                if i_subj == 0:
                    if add_xticklabel:
                        ax10.set_xticklabels(['0', '0.2'])
                    else:
                        ax10.set_xticklabels([])
                else:
                    ax10.set_yticklabels([])
                    ax10.set_xticklabels([])
                    ax00.set_yticklabels([])
                return bax

            bax = plot_bars(axs.gs[1, i_subj * 2 + 1],
                            bufdurs_fit,
                            loss_sim_best_fit_rest,
                            add_xticklabel=False)

            ax = axs[0, i_subj]
            plt.sca(ax)
            plt2.box_off('all')
            plt.title(consts.SUBJS['VD'][consts.SUBJS['VD'].index(subj)])
            if i_subj == 0:
                ax.set_ylabel(
                    'misfit to simulated\nbest duration data\n'
                    r'($\Delta$BIC)',
                    labelpad=35)
                ax.set_xlabel('model buffer duration (s)', labelpad=8)

            plot_bars(axs.gs[3, i_subj * 2 + 1],
                      bufdurs_sim,
                      loss_sim_rest_fit_best,
                      add_xticklabel=True)
            ax = axs[1, i_subj]
            plt.sca(ax)
            plt2.box_off('all')
            if i_subj == 0:
                ax.set_ylabel(
                    'misfit of\nbest duration model\n'
                    r'($\Delta$BIC)',
                    labelpad=35)
                ax.set_xlabel('simulated data buffer duration (s)',
                              labelpad=20)

    elif plot_kind == 'bar_ser_buf_par':
        n_row = len(subjs)
        n_col = 1
        axs = plt2.GridAxes(n_row * n_dur_sim,
                            n_col,
                            left=1.25,
                            widths=1.5,
                            right=0.25,
                            heights=0.4,
                            hspace=[0.15] * (n_dur_sim - 1) + [0.5] + [0.15] *
                            (n_dur_sim - 1),
                            top=0.6,
                            bottom=0.5)
        row = -1
        for i_subj, subj in enumerate(subjs):
            for i_sim in range(n_dur_sim):
                row += 1
                plt.sca(axs[row, 0])

                loss1 = losses[:, i_subj, i_sim, :].mean(0)
                dloss = loss1 - loss1[i_sim]

                x = np.arange(n_dur_fit)
                for x1, dloss1 in zip(x, dloss):
                    plt.bar(x1, dloss1, color='r' if dloss1 > 0 else 'b')
                plt.axhline(0, color='k', linewidth=0.5)
                vdfit.patch_chance_level(6.)
                plt.ylim([-100, 100])
                plt.yticks([-100, 0, 100], [''] * 2)
                plt.ylabel('%g' % bufdurs_sim[i_sim],
                           rotation=0,
                           va='center',
                           ha='right')

                if i_sim == 0:
                    plt.title(subj)

                if i_sim < n_dur_sim - 1 or i_subj < len(subjs) - 1:
                    plt2.box_off(['top', 'bottom', 'right'])
                    plt.xticks([])
                else:
                    plt2.box_off(['top', 'right'])

        plt.sca(axs[-1, 0])
        plt.xticks(np.arange(n_dur_sim), ['%g' % v for v in bufdurs_sim])
        plt2.detach_axis('x', 0, n_dur_sim - 1)
        plt.xlabel('model buffer duration (s)', fontsize=10)

        c = axs[-1, 0].get_position().corners()
        plt.figtext(x=0.15,
                    y=np.mean(c[:2, 1]),
                    fontsize=10,
                    s='true\nbuffer\nduration\n(s)',
                    rotation=0,
                    va='center',
                    ha='center')

        plt.figtext(
            x=(1.25 + 1.5 / 2) / (1.25 + 1.5 + 0.25),
            y=0.98,
            s=r'$\mathrm{BIC} - \mathrm{BIC}_\mathrm{true}$',
            ha='center',
            va='top',
            fontsize=12,
        )

    if plot_kind != 'None':
        dict_res = deepcopy(dict_fit_sim)  # noqa
        for k in ['sbj', 'prd']:
            dict_res.pop(k)
        dict_res = {**dict_res, 'los': loss_kind}
        dict_res.update(dict_res_add)
        for ext in ['.png', '.pdf']:
            file = locfile.get_file_fig(plot_kind,
                                        dict_res,
                                        subdir=dict_subdir_sim,
                                        ext=ext)  # noqa
            plt.savefig(file, dpi=300)
            print('Saved to %s' % file)

    #%%
    print('--')
    return losses
Exemplo n.º 12
0
def plot_coefs_dur_odif(
        n_cond_dur_ch: np.ndarray = None,
        data: Data2DVD = None,
        dif_irrs: Sequence[Union[Iterable[int], int]] = ((0, 1), (2, )),
        ev_cond_dim: np.ndarray = None,
        durs: np.ndarray = None,
        style='data',
        kw_plot=(),
        jitter0=0.,
        coefs_to_plot=(0, 1),
        add_rowtitle=True,
        dim_incl=(0, 1),
        axs=None,
        fig=None,
):
    """

    :param n_cond_dur_ch:
    :param data:
    :param dif_irrs:
    :param ev_cond_dim:
    :param durs:
    :param style:
    :param kw_plot:
    :param jitter0:
    :param coefs_to_plot:
    :param add_rowtitle:
    :param dim_incl:
    :param axs:
    :param fig:
    :return: coef, se_coef, axs, hs[coef, dim, dif]
    """

    if ev_cond_dim is None:
        ev_cond_dim = data.ev_cond_dim
    if durs is None:
        durs = npy(data.durs)
    if n_cond_dur_ch is None:
        n_cond_dur_ch = npy(data.get_data_by_cond('all')[1])

    coef, se_coef = coefdur.get_coefs_mesh_from_histogram(
        n_cond_dur_ch, ev_cond_dim=ev_cond_dim,
        dif_irrs=dif_irrs)[:2]  # type: (np.ndarray, np.ndarray)

    coef_names = ['bias', 'slope']
    n_coef = len(coefs_to_plot)
    n_dim = len(dim_incl)

    if axs is None:
        if fig is None:
            fig = plt.figure(figsize=[6, 1.5 * n_coef])

        n_row = n_dim
        n_col = n_coef
        gs = plt.GridSpec(nrows=n_row,
                          ncols=n_col,
                          figure=fig,
                          left=0.25,
                          right=0.95,
                          bottom=0.15,
                          top=0.9)
        axs = np.empty([n_row, n_col], dtype=np.object)
        for row in range(n_row):
            for col in range(n_col):
                axs[row, col] = plt.subplot(gs[row, col])

    n_dif = len(dif_irrs)
    hs = np.empty(
        [len(coefs_to_plot), len(dim_incl),
         len(dif_irrs)], dtype=np.object)

    for ii_coef, i_coef in enumerate(coefs_to_plot):
        coef_name = coef_names[i_coef]
        for i_dim, dim_rel in enumerate(dim_incl):
            ax = axs[i_dim, ii_coef]  # type: plt.Axes
            plt.sca(ax)

            cmap = consts.CMAP_DIM[dim_rel](n_dif)

            for idif, dif_irr in enumerate(dif_irrs):
                y = coef[i_coef, dim_rel, idif, :]
                e = se_coef[i_coef, dim_rel, idif, :]
                ddur = durs[1] - durs[0]
                jitter = ddur * jitter0 * (idif - (n_dif - 1) / 2)
                kw = consts.get_kw_plot(style,
                                        color=cmap(idif),
                                        for_err=True,
                                        **dict(kw_plot))
                if style.startswith('data'):
                    h = plt.errorbar(durs + jitter, y, yerr=e, **kw)[0]
                else:
                    h = plt.plot(durs + jitter, y, **kw)
                hs[ii_coef, i_dim, idif] = h
            plt2.box_off()

            max_dur = np.amax(npy(durs))
            ax.set_xlim(-0.05 * max_dur, 1.05 * max_dur)
            plt2.detach_axis('x', 0, max_dur)

            if dim_rel == 0:
                ax.set_xticklabels([])
                ax.set_title(coef_name.lower())
            elif i_coef == 0:
                ax.set_xlabel('duration (s)')

    if add_rowtitle:
        plt2.rowtitle(consts.DIM_NAMES_LONG, axes=axs)

    return coef, se_coef, axs, hs
Exemplo n.º 13
0
def main(base=10.):
    ds = load_comp(os.path.join(locfile_in.pth_root, file_model_comp))
    ds['dcost'] = np.array(ds['dcost']) / np.log(base)

    vmax = 160

    m = np.empty(3)
    e = np.empty(3)

    axs = plot_bar_dloss_across_subjs(dlosses=np.array(ds['dcost']),
                                      subj_parad_bis=ds['subj_parad_bi'],
                                      vmax=vmax)
    axs = plt2.GridAxes(nrows=1,
                        ncols=3,
                        heights=axs.heights,
                        widths=[2],
                        left=axs.left,
                        right=axs.right,
                        bottom=axs.bottom)
    plot_bar_dloss_across_subjs(
        dlosses=np.array(ds['dcost']),
        subj_parad_bis=ds['subj_parad_bi'],
        vmax=vmax,
        axs=axs[:, [0]],
        base=base,
    )
    plt.title('Data')

    m[0] = np.mean(ds['dcost'])
    e[0] = np2.sem(ds['dcost'])

    print('--')

    titles = ['Simulated\nSerial', 'Simulated\nParallel']
    for i in range(2):
        ds = load_comp(
            os.path.join(locfile_in.pth_root, files_model_recovery[i]))
        ds['dcost'] = np.array(ds['dcost']) / np.log(base)

        plot_bar_dloss_across_subjs(
            dlosses=ds['dcost'],
            subj_parad_bis=ds['subj_parad_bi'],
            vmax=vmax,
            axs=axs[:, [i + 1]],
            base=base,
        )
        plt.title(titles[i])

        m[i + 1] = np.mean(ds['dcost'])
        e[i + 1] = np2.sem(ds['dcost'])

        plt.sca(axs[0, i + 1])
        plt2.box_off(['left'])
        plt.yticks([])

    for ext in ['.pdf', '.png']:
        file = locfile_out.get_file_fig('model_comp_recovery',
                                        {'easiest_only': use_easiest_only},
                                        ext=ext)
        plt.savefig(file, dpi=300)
        print('Saved to %s' % file)

    # --- Print mean +- SEM to CSV
    csv_file = locfile_out.get_file_csv('model_comp_recovery',
                                        {'easiest_only': use_easiest_only})
    d_csv = np2.dictlist2listdict({
        'data': ['original', 'simulated serial', 'simulated parallel'],
        'mean_dcost':
        m,
        'sem_dcost':
        e,
    })
    with open(csv_file, 'w') as f:
        writer = csv.DictWriter(f, fieldnames=d_csv[0].keys())
        writer.writeheader()
        for d in d_csv:
            writer.writerow(d)
        print('Wrote to %s' % csv_file)

    # --- Mean NLL within each
    axs = plt2.GridAxes(
        1,
        1,
        left=0.85,
        right=0.25,
        heights=[1.5],
        top=0.1,
        bottom=0.9,
    )
    plt.sca(axs[0, 0])
    plt.barh(np.arange(3), m, xerr=e, color='w', edgecolor='k')
    plt.yticks(np.arange(3),
               ['data', 'simulated\nserial', 'simulated\nparallel'])
    plt2.box_off(['top', 'right'])
    vmax = np.amax(np.abs(m) + np.abs(e))
    plt.xlim(np.array([-vmax, vmax]) * 1.1)
    plt2.detach_axis(xy='y', amin=0, amax=2)
    plt2.detach_axis(xy='x', amin=-vmax, amax=vmax)

    axvline_dcost()
    xticks_serial_vs_parallel(vmax, base)

    for ext in ['.pdf', '.png']:
        file = locfile_out.get_file_fig('model_comp_vs_recovery',
                                        {'easiest_only': use_easiest_only},
                                        ext=ext)
        plt.savefig(file, dpi=300)
        print('Saved to %s' % file)

    print('--')
Exemplo n.º 14
0
def plot_rt_vs_ev(
        ev_cond,
        n_cond__rt_ch: Union[torch.Tensor, np.ndarray],
        style='pred',
        pool='mean',
        dt=consts.DT,
        correct_only=True,
        thres_n_trial=10,
        color='k',
        color_ch=('tab:red', 'tab:blue'),
        ax: plt.Axes = None,
        kw_plot=(),
) -> (Sequence[plt.Line2D], List[np.ndarray]):
    """
    @param ev_cond: [condition]
    @type ev_cond: torch.Tensor
    @param n_cond__rt_ch: [condition, frame, ch]
    @type n_cond__rt_ch: torch.Tensor
    @return:
    """
    if ax is None:
        ax = plt.gca()
    if ev_cond.ndim != 1:
        if ev_cond.ndim == 3:
            ev_cond = npt.p2st(ev_cond)[0]
        assert ev_cond.ndim == 2
        ev_cond = ev_cond.mean(1)
    assert n_cond__rt_ch.ndim == 3

    ev_cond = npy(ev_cond)
    n_cond__rt_ch = npy(n_cond__rt_ch)

    def plot_rt_given_cond_ch(ev_cond1, n_rt_given_cond_ch, **kw1):
        # n_rt_given_cond_ch[cond, fr]
        # p_rt_given_cond_ch[cond, fr]
        p_rt_given_cond_ch = np2.sumto1(n_rt_given_cond_ch, 1)

        nt = n_rt_given_cond_ch.shape[1]
        t = np.arange(nt) * dt

        # n_in_cond_ch[cond]
        n_in_cond_ch = n_rt_given_cond_ch.sum(1)
        if pool == 'mean':
            rt_pooled = (t[None, :] * p_rt_given_cond_ch).sum(1)
        elif pool == 'var':
            raise NotImplementedError()
        else:
            raise ValueError()

        if style.startswith('data'):
            rt_pooled[n_in_cond_ch < thres_n_trial] = np.nan

        kw = get_kw_plot(style, **kw1)
        h = ax.plot(ev_cond1, rt_pooled, **kw)
        return h, rt_pooled

    if correct_only:
        hs = []
        rtss = []

        n_cond__rt_ch1 = n_cond__rt_ch.copy()  # type: np.ndarray
        if style.startswith('data'):
            cond0 = ev_cond == 0
            n_cond__rt_ch1[cond0, :, :] = np.sum(n_cond__rt_ch1[cond0, :, :],
                                                 axis=-1,
                                                 keepdims=True)

        # # -- Choose the ch with correct sign (or both chs if cond == 0)
        for ch in range(consts.N_CH):
            cond_sign = np.sign(ev_cond)
            ch_sign = consts.ch_bool2sign(ch)
            cond_accu = cond_sign != -ch_sign

            h, rts = plot_rt_given_cond_ch(ev_cond[cond_accu],
                                           n_cond__rt_ch1[cond_accu, :, ch],
                                           color=color,
                                           **dict(kw_plot))
            hs.append(h)
            rtss.append(rts)
        # rts = np.stack(rtss)
        rts = rtss

    else:
        hs = []
        rtss = []
        for ch in range(consts.N_CH):
            # n_rt_given_cond_ch[cond, fr]

            n_rt_given_cond_ch = n_cond__rt_ch[:, :, ch]

            h, rts = plot_rt_given_cond_ch(ev_cond,
                                           n_rt_given_cond_ch,
                                           color=color_ch[ch])

            hs.append(h)
            rtss.append(rts)
        rts = np.stack(rtss)

    y_lim = ax.get_ylim()
    x_lim = ax.get_xlim()
    plt2.box_off()
    plt2.detach_axis('x', ax=ax, amin=x_lim[0], amax=x_lim[1])
    plt2.detach_axis('y', ax=ax, amin=y_lim[0], amax=y_lim[1])
    ax.set_xlabel('evidence')
    ax.set_ylabel(r"$\mathrm{E}[T^\mathrm{r} \mid c]~(\mathrm{s})$")
    return hs, rts
def plot_bar_dloss_across_subjs(
    dlosses,
    elosses=None,
    ix_datas=None,
    subj_parad_bis: Iterable[Tuple[str, str, bool]] = None,
    axs: Union[plt2.GridAxes, plt2.AxesArray] = None,
    vmax=None,
    add_scale=True,
    base=10.,
):
    """

    :param dlosses: [ix_data]
    :param ix_datas:
    :param axs:
    :param subj_parad_bis: [('subj', 'parad', is_bimanual), ...]
    :return: axs
    """

    if subj_parad_bis is None:
        subj_parad_bis = subj_parad_bis0
    if vmax is None:
        vmax = np.amax(np.abs(dlosses))

    # order: eye S1-S3, hand by ID, paired uni-bimanual
    subjs, parads, bis = zip(*subj_parad_bis)
    subjs = np.array(
        ['ID0' + v[-1] if v[:2] == 'ID' and len(v) == 3 else v for v in subjs])
    parads = np.array(parads)
    bis = np.array(bis)

    is_eye = parads == 'RT'
    is_bin = parads == 'binary'
    ix = np.arange(len(subjs))

    def filt_sort(filt):
        ind = [int(subj[1:]) for subj in subjs[filt]]
        return ix[filt][np.argsort(ind)]

    ix = np.concatenate([
        filt_sort(is_eye & ~is_bin),
        np.stack([
            filt_sort(~is_eye & ~bis & ~is_bin),
            filt_sort(~is_eye & bis & ~is_bin)
        ], -1).flatten('C'),
        filt_sort(is_bin)
    ])
    subjs = subjs[ix]
    parads = parads[ix]
    bis = bis[ix]
    is_eye = is_eye[ix]
    dlosses = dlosses[ix]
    subj_parad_bis = subj_parad_bis[ix]

    n_eye = int(np.sum(is_eye))
    n_hand = int(np.sum(~is_eye))

    y = np.empty([n_eye + n_hand])
    y[is_eye] = 1.5 + np.arange(n_eye)
    y[~is_eye] = n_eye - 1 + 1.5 + np.cumsum([1.5, 1.] * (n_hand // 2))
    y_max = np.amax(y) + 1.5

    if axs is None:
        axs = plt2.GridAxes(nrows=1,
                            ncols=1,
                            heights=y_max * 0.2,
                            widths=2,
                            left=1.5,
                            right=0.25,
                            bottom=0.85)
    ax = axs[0, 0]
    plt.sca(ax)

    m = dlosses
    if elosses is None:
        e = np.zeros_like(m)
    else:
        e = elosses

    for y1, m1, e1, parad1, bi1 in zip(y, m, e, parads, bis):
        plt.barh(y1,
                 m1,
                 xerr=e1,
                 color=colors_parad[(parad1, '%s' % bi1)],
                 edgecolor='None')

    if add_scale:
        dy = y[1] - y[0]

    axvline_dcost()

    x_lim = [-vmax * 1.2, vmax * 1.2]
    for ix_big in range(len(y)):
        if np.abs(m[ix_big]) > vmax:
            for i_sign, sign in enumerate([1, -1]):
                plt2.patch_wave(
                    y[ix_big],
                    x_lim[i_sign] * 1.01,
                    ax=ax,
                    color='w',
                    wave_margin=0.15,
                    wave_amplitude=sign * 0.025,
                )

    plt.xlim(x_lim)
    xticks_serial_vs_parallel(vmax, base)
    subj_parad_bi_str = get_subj_parad_bi_str(subj_parad_bis)
    plt.yticks(y, subj_parad_bi_str)
    plt2.detach_axis('y', y[0], y[-1])
    plt2.detach_axis('x', -vmax, vmax)
    plt.ylim([y_max - 1, 1.])

    return axs
Exemplo n.º 16
0
def plot_fit_combined(
        data: Union[sim2d.Data2DRT, dict] = None,
        pModel_cond_rt_chFlat=None, model=None,
        pModel_dimRel_condDense_chFlat=None,
        # --- in place of data:
        pAll_cond_rt_chFlat=None,
        evAll_cond_dim=None,
        pTrain_cond_rt_chFlat=None,
        evTrain_cond_dim=None,
        pTest_cond_rt_chFlat=None,
        evTest_cond_dim=None,
        dt=None,
        # --- optional
        ev_dimRel_condDense_fr_dim_meanvar=None,
        dt_model=None,
        to_plot_internals=True,
        to_plot_params=True,
        to_plot_choice=True,
        # to_group_irr=False,
        group_dcond_irr=None,
        to_combine_ch_irr_cond=True,
        kw_plot_pred=(),
        kw_plot_pred_ch=(),
        kw_plot_data=(),
        axs=None,
):
    """

    :param data:
    :param pModel_cond_rt_chFlat:
    :param model:
    :param pModel_dimRel_condDense_chFlat:
    :param ev_dimRel_condDense_fr_dim_meanvar:
    :param to_plot_internals:
    :param to_plot_params:
    :param to_group_irr:
    :param to_combine_ch_irr_cond:
    :param kw_plot_pred:
    :param kw_plot_data:
    :param axs:
    :return:
    """
    if data is None:
        if pTrain_cond_rt_chFlat is None:
            pTrain_cond_rt_chFlat = pAll_cond_rt_chFlat
        if evTrain_cond_dim is None:
            evTrain_cond_dim = evAll_cond_dim
        if pTest_cond_rt_chFlat is None:
            pTest_cond_rt_chFlat = pAll_cond_rt_chFlat
        if evTest_cond_dim is None:
            evTest_cond_dim = evAll_cond_dim
    else:
        _, pAll_cond_rt_chFlat, _, _, evAll_cond_dim = \
            data.get_data_by_cond('all')
        _, pTrain_cond_rt_chFlat, _, _, evTrain_cond_dim = data.get_data_by_cond(
            'train_valid', mode_train='easiest')
        _, pTest_cond_rt_chFlat, _, _, evTest_cond_dim = data.get_data_by_cond(
            'test', mode_train='easiest')
        dt = data.dt
    hs = {}

    if model is None:
        assert not to_plot_internals
        assert not to_plot_params

    if dt_model is None:
        if model is None:
            dt_model = dt
        else:
            dt_model = model.dt

    if axs is None:
        if to_plot_params:
            axs = plt2.GridAxes(3, 3)
        else:
            if to_plot_internals:
                axs = plt2.GridAxes(3, 3)
            else:
                if to_plot_choice:
                    axs = plt2.GridAxes(2, 2)
                else:
                    axs = plt2.GridAxes(1, 2)  # TODO: beautify ratios

    rts = []
    hs['rt'] = []
    for dim_rel in range(consts.N_DIM):
        # --- data_pred may not have all conditions, so concatenate the rest
        #  of the conditions so that the color scale is correct. Then also
        #  concatenate p_rt_ch_data_pred1 with zeros so that nothing is
        #  plotted in the concatenated.
        evTest_cond_dim1 = np.concatenate([
            evTest_cond_dim, evAll_cond_dim
        ], axis=0)
        pTest_cond_rt_chFlat1 = np.concatenate([
            pTest_cond_rt_chFlat, np.zeros_like(pAll_cond_rt_chFlat)
        ], axis=0)

        if ev_dimRel_condDense_fr_dim_meanvar is None:
            evModel_cond_dim = evAll_cond_dim
        else:
            if ev_dimRel_condDense_fr_dim_meanvar.ndim == 5:
                evModel_cond_dim = npy(ev_dimRel_condDense_fr_dim_meanvar[
                                           dim_rel][:, 0, :, 0])
            else:
                assert ev_dimRel_condDense_fr_dim_meanvar.ndim == 4
                evModel_cond_dim = npy(ev_dimRel_condDense_fr_dim_meanvar[
                                           dim_rel][:, 0, :])
            pModel_cond_rt_chFlat = npy(pModel_dimRel_condDense_chFlat[dim_rel])

        if to_plot_choice:
            # --- Plot choice
            ax = axs[0, dim_rel]
            plt.sca(ax)

            if to_combine_ch_irr_cond:
                ev_cond_model1, p_rt_ch_model1 = combine_irr_cond(
                    dim_rel, evModel_cond_dim, pModel_cond_rt_chFlat
                )

                sim2d.plot_p_ch_vs_ev(ev_cond_model1, p_rt_ch_model1,
                                      dim_rel=dim_rel, style='pred',
                                      group_dcond_irr=None,
                                      kw_plot=kw_plot_pred_ch,
                                      cmap=lambda n: lambda v: [0., 0., 0.],
                                      )
            else:
                sim2d.plot_p_ch_vs_ev(evModel_cond_dim, pModel_cond_rt_chFlat,
                                      dim_rel=dim_rel, style='pred',
                                      group_dcond_irr=group_dcond_irr,
                                      kw_plot=kw_plot_pred,
                                      cmap=cmaps[dim_rel]
                                      )
            hs, conds_irr = sim2d.plot_p_ch_vs_ev(
                evTest_cond_dim1, pTest_cond_rt_chFlat1,
                dim_rel=dim_rel, style='data_pred',
                group_dcond_irr=group_dcond_irr,
                cmap=cmaps[dim_rel],
                kw_plot=kw_plot_data,
            )
            hs1 = [h[0] for h in hs]
            odim = 1 - dim_rel
            odim_name = consts.DIM_NAMES_LONG[odim]
            legend_odim(conds_irr, hs1, odim_name)
            sim2d.plot_p_ch_vs_ev(evTrain_cond_dim, pTrain_cond_rt_chFlat,
                                  dim_rel=dim_rel, style='data_fit',
                                  group_dcond_irr=group_dcond_irr,
                                  cmap=cmaps[dim_rel],
                                  kw_plot=kw_plot_data
                                  )
            plt2.detach_axis('x', np.amin(evTrain_cond_dim[:, dim_rel]),
                             np.amax(evTrain_cond_dim[:, dim_rel]))
            ax.set_xlabel('')
            ax.set_xticklabels([])
            if dim_rel != 0:
                plt2.box_off(['left'])
                plt.yticks([])

            ax.set_ylabel('P(%s choice)' % consts.CH_NAMES[dim_rel][1])

        # --- Plot RT
        ax = axs[int(to_plot_choice) + 0, dim_rel]
        plt.sca(ax)
        hs1, rts1 = sim2d.plot_rt_vs_ev(
            evModel_cond_dim,
            pModel_cond_rt_chFlat,
            dim_rel=dim_rel, style='pred',
            group_dcond_irr=group_dcond_irr,
            dt=dt_model,
            kw_plot=kw_plot_pred,
            cmap=cmaps[dim_rel]
        )
        hs['rt'].append(hs1)
        rts.append(rts1)

        sim2d.plot_rt_vs_ev(evTest_cond_dim1, pTest_cond_rt_chFlat1,
                            dim_rel=dim_rel, style='data_pred',
                            group_dcond_irr=group_dcond_irr,
                            dt=dt,
                            cmap=cmaps[dim_rel],
                            kw_plot=kw_plot_data
                            )
        sim2d.plot_rt_vs_ev(evTrain_cond_dim, pTrain_cond_rt_chFlat,
                            dim_rel=dim_rel, style='data_fit',
                            group_dcond_irr=group_dcond_irr,
                            dt=dt,
                            cmap=cmaps[dim_rel],
                            kw_plot=kw_plot_data
                            )
        plt2.detach_axis('x', np.amin(evTrain_cond_dim[:, dim_rel]),
                         np.amax(evTrain_cond_dim[:, dim_rel]))
        if dim_rel != 0:
            ax.set_ylabel('')
            plt2.box_off(['left'])
            plt.yticks([])

        ax.set_xlabel(consts.DIM_NAMES_LONG[dim_rel].lower() + ' strength')

        if dim_rel == 0:
            ax.set_ylabel('RT (s)')

        if to_plot_internals:
            for ch1 in range(consts.N_CH):
                ch0 = dim_rel
                ax = axs[3 + ch1, dim_rel]
                plt.sca(ax)

                ch_flat = consts.ch_by_dim2ch_flat(np.array([ch0, ch1]))
                model.tnds[ch_flat].plot_p_tnd()
                ax.set_xlabel('')
                ax.set_xticklabels([])
                ax.set_yticks([0, 1])
                if ch0 > 0:
                    ax.set_yticklabels([])

                ax.set_ylabel(r"$\mathrm{P}(T^\mathrm{n} \mid"
                              " \mathbf{z}=[%d,%d])$"
                              % (ch0, ch1))

            ax = axs[5, dim_rel]
            plt.sca(ax)
            if hasattr(model.dtb, 'dtb1ds'):
                model.dtb.dtb1ds[dim_rel].plot_bound(color='k')

    plt2.sameaxes(axs[-1, :consts.N_DIM], xy='y')

    if to_plot_params:
        ax = axs[0, -1]
        plt.sca(ax)
        model.plot_params()

    return axs, rts, hs
Exemplo n.º 17
0
def plot_rt_distrib_pred_data(
        p_pred_cond_rt_ch,
        n_cond_rt_ch, ev_cond_dim, dt_model, dt_data=None,
        smooth_sigma_sec=0.1,
        to_plot_scale=False,
        to_cumsum=False,
        to_normalize_max=True,
        xlim=None,
        colors=('magenta', 'cyan'),
        kw_plot_pred=(),
        kw_plot_data=(),
        to_skip_zero_trials=False,
        labels=None,
        **kwargs
):
    """

    :param n_cond_rt_ch: [cond, rt, ch] = n_tr(cond, rt, ch)
    :param p_pred_cond_rt_ch: [model, cond, rt, ch] = P(rt, ch | cond, model)
    :param ev_cond_dim:
    :param dt_model:
    :param dt_data:
    :param smooth_sigma_sec:
    :param to_plot_scale:
    :param to_cumsum:
    :param xlim:
    :param kwargs:
    :return:
    """

    axs = None
    ps = []
    ps0 = []
    hss = []

    p_pred_cond_rt_ch = p_pred_cond_rt_ch / np.sum(
        p_pred_cond_rt_ch, (-1, -2), keepdims=True)
    n_preds1 = p_pred_cond_rt_ch * np.sum(
        n_cond_rt_ch, (-1, -2))[None, :, None, None]
    nt = p_pred_cond_rt_ch.shape[-2]
    if dt_data is None:
        dt_data = dt_model
    if labels is None:
        labels = [''] * (len(n_preds1) + 1)

    for i_pred, n_pred in enumerate(n_preds1):
        color = colors[i_pred]
        axs, p0, p1, hs = sim2d.plot_rt_distrib(
            n_pred, ev_cond_dim,
            dt=dt_model,
            axs=axs,
            alpha=1.,
            smooth_sigma_sec=smooth_sigma_sec,
            to_skip_zero_trials=to_skip_zero_trials,
            colors=color,
            alpha_face=0,
            to_normalize_max=to_normalize_max,
            to_cumsum=to_cumsum,
            to_use_sameaxes=False,
            kw_plot={
                'linewidth': 1.5,
                **dict(kw_plot_pred),
            },
            label=labels[i_pred],
            **kwargs,
        )[:4]
        ps.append(p1)
        ps0.append(p0)
        hss.append(hs)

    axs, p0, p1, hs = sim2d.plot_rt_distrib(
        n_cond_rt_ch, ev_cond_dim,
        dt=dt_data,
        axs=axs,
        smooth_sigma_sec=smooth_sigma_sec,
        colors='k',
        alpha_face=0.,
        to_normalize_max=to_normalize_max,  # normalize across preds and data instead
        to_cumsum=to_cumsum,
        to_skip_zero_trials=to_skip_zero_trials,
        kw_plot={
            'linewidth': 0.5,
            **dict(kw_plot_data),
        },
        label=labels[-1],
        **kwargs,
    )
    ps.append(p1)
    ps0.append(p0)
    hss.append(hs)

    ps = np.stack(ps)
    ps0 = np.stack(ps0)

    ps_flat = np.swapaxes(ps, 0, 2).reshape([ps.shape[1] * ps.shape[2], -1])

    for ax in axs.flatten():
        if xlim is None:
            if to_cumsum:
                xlim = [0.5, 4.5]
            else:
                xlim = [0.5, 4.5]

        plt2.detach_axis('x', *xlim, ax=ax)
        ax.set_xlim(xlim[0] - 0.1, xlim[1] + 0.1)

    axs[-1, 0].set_xticks(xlim)
    axs[-1, 0].set_xticklabels(['%g' % v for v in xlim])

    from lib.pylabyk import numpytorch as npt
    t = torch.arange(nt) * dt_model

    mean_rts = []
    for p1 in ps0:
        p11 = npt.sumto1(torch.tensor(p1).sum([-1, -2])[0, 0, :])
        mean_rts.append(npy((torch.tensor(t) * p11).sum()))
    print('mean_rts:')
    print(mean_rts)
    print(mean_rts[1] - mean_rts[0])

    conds = [np.unique(ev_cond_dim[:, i]) for i in [0, 1]]
    p_preds = torch.tensor(n_preds1).reshape([
        2, len(conds[0]), len(conds[1]), nt, 2, 2
    ]) + 1e-12

    if to_plot_scale:
        y = 0.8
        axs[-1, -1].plot(mean_rts[:2], y + np.zeros(2), 'k-', linewidth=0.5)
        x = np.mean(mean_rts[:2])
        plt.text(x, y + 0.1,
                 '%1.0f ms' % (np.abs(mean_rts[1] - mean_rts[0]) * 1e3),
                 ha='center', va='bottom')
    return axs, hss
Exemplo n.º 18
0
def plot_coefs_dur_irrixn(
        n_cond_dur_ch: np.ndarray = None,
        data: Data2DVD = None,
        ev_cond_dim: np.ndarray = None,
        durs: np.ndarray = None,
        style='data',
        kw_plot=(),
        jitter0=0.,
        coefs_to_plot=(2, ),
        axs=None,
        fig=None,
):
    if ev_cond_dim is None:
        ev_cond_dim = data.ev_cond_dim
    if durs is None:
        durs = npy(data.durs)
    if n_cond_dur_ch is None:
        n_cond_dur_ch = npy(data.get_data_by_cond('all')[1])

    coef, se_coef = coefdur.get_coefs_irr_ixn_from_histogram(
        n_cond_dur_ch,
        ev_cond_dim=ev_cond_dim)[:2]  # type: (np.ndarray, np.ndarray)

    coef_names = [
        'bias', 'slope', 'rel x abs(irr)', 'rel x irr', 'abs(irr)', 'irr'
    ]
    n_coef = len(coefs_to_plot)

    if axs is None:
        if fig is None:
            fig = plt.figure(figsize=[6, 1.5 * n_coef])

        n_row = consts.N_DIM
        n_col = n_coef
        gs = plt.GridSpec(nrows=n_row,
                          ncols=n_col,
                          figure=fig,
                          left=0.25,
                          right=0.95,
                          bottom=0.15,
                          top=0.9)
        axs = np.empty([n_row, n_col], dtype=np.object)
        for row in range(n_row):
            for col in range(n_col):
                axs[row, col] = plt.subplot(gs[row, col])

    for ii_coef, i_coef in enumerate(coefs_to_plot):
        coef_name = coef_names[i_coef]
        for dim_rel in range(consts.N_DIM):
            ax = axs[dim_rel, ii_coef]  # type: plt.Axes
            plt.sca(ax)

            y = coef[i_coef, dim_rel, :]
            e = se_coef[i_coef, dim_rel, :]
            jitter = 0.
            kw_plot = {'color': 'k', 'for_err': True, **dict(kw_plot)}
            kw = consts.get_kw_plot(style, **kw_plot)
            if style.startswith('data'):
                plt.errorbar(durs + jitter, y, yerr=e, **kw)
            else:
                plt.plot(durs + jitter, y, **kw)
            plt2.box_off()

            max_dur = np.amax(npy(durs))
            ax.set_xlim(-0.05 * max_dur, 1.05 * max_dur)
            plt2.detach_axis('x', 0, max_dur)  # , ax=ax)

            if dim_rel == 0:
                ax.set_xticklabels([])
                ax.set_title(coef_name.lower())
            elif i_coef == 0:
                ax.set_xlabel('duration (s)')

    plt2.rowtitle(consts.DIM_NAMES_LONG, axes=axs)

    return coef, se_coef, axs