コード例 #1
0
ファイル: dtb_2D_fit_VD.py プロジェクト: yulkang/2D_Decision
def beautify(ax):
    xticks = np.arange(0, 1.2 + 0.1, 0.1)

    plt.sca(ax)
    plt.xlim(xmax=1.2 * 1.05)
    plt2.detach_axis('x', 0, 1.2)
    plt2.box_off()
コード例 #2
0
def axvline_dcost(BF=100., base=10., style='patch'):
    thres = np.log(BF) / np.log(base)
    if style == 'line':
        plt.axvline(0, color='k', linewidth=0.5, linestyle='--', zorder=1)
        for sign in [-1, 1]:
            plt.axvline(sign * thres,
                        color='silver',
                        linewidth=0.5,
                        linestyle='--',
                        zorder=1)
    elif style == 'patch':
        import matplotlib.patches as patches
        ax = plt.gca()
        ax.add_patch(
            patches.Rectangle(
                (-thres, -100),
                thres * 2,
                200,
                edgecolor='None',
                facecolor=[0., 0., 0., 0.4],
                fill=True,
            ))
    else:
        raise ValueError()
    plt2.box_off()
コード例 #3
0
ファイル: dtb_1D_sim.py プロジェクト: yulkang/2D_Decision
def plot_p_ch_vs_ev(ev_cond,
                    p_ch,
                    style='pred',
                    ax: plt.Axes = None,
                    **kwargs) -> plt.Line2D:
    """
    @param ev_cond: [condition] or [condition, frame]
    @type ev_cond: torch.Tensor
    @param p_ch: [condition, ch] or [condition, rt_frame, ch]
    @type p_ch: torch.Tensor
    @return:
    """
    if ax is None:
        ax = plt.gca()
    if ev_cond.ndim != 1:
        if ev_cond.ndim == 3:
            ev_cond = npt.p2st(ev_cond)[0]
        assert ev_cond.ndim == 2
        ev_cond = ev_cond.mean(1)
    if p_ch.ndim != 2:
        assert p_ch.ndim == 3
        p_ch = p_ch.sum(1)

    kwargs = get_kw_plot(style, **kwargs)

    h = ax.plot(*npys(ev_cond, npt.p2st(npt.sumto1(p_ch, -1))[1]), **kwargs)
    plt2.box_off(ax=ax)
    x_lim = ax.get_xlim()
    plt2.detach_axis('x', amin=x_lim[0], amax=x_lim[1], ax=ax)
    plt2.detach_axis('y', amin=0, amax=1, ax=ax)
    ax.set_yticks([0, 0.5, 1])
    ax.set_yticklabels(['0', '', '1'])
    ax.set_xlabel('evidence')
    ax.set_ylabel(r"$\mathrm{P}(z=1 \mid c)$")
    return h
コード例 #4
0
    def plot_params(self,
                    named_bounded_params: Sequence[Tuple[
                        str, BoundedParameter]] = None,
                    exclude: Iterable[str] = (),
                    cmap='coolwarm',
                    ax: plt.Axes = None) -> mpl.container.BarContainer:
        if ax is None:
            ax = plt.gca()

        ax = plt.gca()
        names, v, grad, lb, ub, requires_grad = self.get_named_bounded_params(
            named_bounded_params, exclude=exclude)
        max_grad = np.amax(np.abs(grad))
        if max_grad == 0:
            max_grad = 1.
        v01 = (v - lb) / (ub - lb)
        grad01 = (grad + max_grad) / (max_grad * 2)
        n = len(v)

        grad01[~requires_grad] = np.nan
        # (np.amin(grad01) + np.amax(grad01)) / 2

        # ax = plt.gca()  # CHECKED

        for i, (lb1, v1, ub1, g1,
                r1) in enumerate(zip(lb, v, ub, grad, requires_grad)):
            color = 'k' if r1 else 'gray'
            plt.text(0, i, '%1.2g' % lb1, ha='left', va='center', color=color)
            plt.text(1, i, '%1.2g' % ub1, ha='right', va='center', color=color)
            plt.text(
                0.5,
                i,
                '%1.2g %s' %
                (v1, ('(e%1.0f)' % np.log10(np.abs(g1))) if r1 else '(fixed)'),
                ha='center',
                va='center',
                color=color)
        lut = 256
        colors = plt.get_cmap(cmap, lut)(grad01)
        for i, r in enumerate(requires_grad):
            if not r:
                colors[i, :3] = np.array([0.95, 0.95, 0.95])
        h = ax.barh(np.arange(n), v01, left=0, color=colors)
        ax.set_xlim(-0.025, 1)
        ax.set_xticks([])
        ax.set_yticks(np.arange(n))

        names = [v.replace('_', '-') for v in names]
        ax.set_yticklabels(names)
        ax.set_ylim(n - 0.5, -0.5)
        plt2.box_off(['top', 'right', 'bottom'])
        plt2.detach_axis('x', amin=0, amax=1)
        plt2.detach_axis('y', amin=0, amax=n - 1)

        # plt.show()  # CHECKED

        return h
コード例 #5
0
ファイル: dtb_2D_fit_VD.py プロジェクト: yulkang/2D_Decision
    def beautify(row, col, axs):
        plt.sca(axs[row, col])
        plt.xlim(xmax=1.2 * 1.05)
        plt2.detach_axis('x', 0, 1.2)
        plt2.box_off()

        beautify_ticks(axs[row, col],
                       add_ticklabel=(row == n_row - 1) and (col == 0))

        if col > 0:
            plt.ylabel('')
コード例 #6
0
ファイル: dtb_1D_sim.py プロジェクト: yulkang/2D_Decision
    def plot_bound(self,
                   t_all: Sequence[float] = None,
                   ax: plt.Axes = None,
                   **kwargs) -> plt.Line2D:
        if ax is None:
            ax = plt.gca()
        if t_all is None:
            t_all = self.t_all

        kwargs = argsutil.kwdefault(kwargs, color='k', linestyle='-')
        h = ax.plot(*npys(t_all, self.get_bound(t_all)), **kwargs)
        ax.set_xlabel('time (s)')
        ax.set_ylabel(r"$b(t)$")
        y_lim = ax.get_ylim()
        y_min = -y_lim[1] * 0.05
        ax.set_ylim(ymin=y_min)
        plt2.detach_axis('y', amin=0)
        plt2.detach_axis('x', amin=0)
        plt2.box_off()
        return h
コード例 #7
0
ファイル: dtb_1D_sim.py プロジェクト: yulkang/2D_Decision
    def plot_p_tnd(self,
                   t_all: Sequence[float] = None,
                   ax: plt.Axes = None,
                   **kwargs) -> Union[plt.Line2D, Sequence[plt.Line2D]]:
        if t_all is None:
            t_all = self.t_all
        if ax is None:
            ax = plt.gca()

        p_tnd = self.get_p_tnd(t_all)

        kwargs = argsutil.kwdefault(kwargs, color='k', linestyle='-')
        h = ax.plot(*npys(t_all, p_tnd), **kwargs)
        ax.set_xlabel('time (s)')
        plt2.box_off(ax=ax)
        # ax.set_ylim(ymin=0)
        # ax.set_xlim(xmin=0)
        plt2.detach_axis('y', amin=0, amax=1, ax=ax)
        plt2.detach_axis('x', amin=0, ax=ax)
        return h
コード例 #8
0
ファイル: dtb_2D_sim.py プロジェクト: yulkang/2D_Decision
def plot_p_ch_vs_ev(
        ev_cond: Union[torch.Tensor, np.ndarray],
        n_ch: Union[torch.Tensor, np.ndarray],
        style='pred',
        ax: plt.Axes = None,
        dim_rel=0,
        group_dcond_irr: Iterable[Iterable[int]] = None,
        cmap: Union[str, Callable] = 'cool',
        kw_plot=(),
) -> Iterable[plt.Line2D]:
    """
    @param ev_cond: [condition, dim] or [condition, frame, dim, (mean, var)]
    @type ev_cond: torch.Tensor
    @param n_ch: [condition, ch] or [condition, rt_frame, ch]
    @type n_ch: torch.Tensor
    @return: hs[cond_irr][0] = Line2D, conds_irr
    """
    if ax is None:
        ax = plt.gca()
    if ev_cond.ndim != 2:
        assert ev_cond.ndim == 4
        ev_cond = npt.p2st(ev_cond.mean(1))[0]
    if n_ch.ndim != 2:
        assert n_ch.ndim == 3
        n_ch = n_ch.sum(1)

    ev_cond = npy(ev_cond)
    n_ch = npy(n_ch)
    n_cond_all = n_ch.shape[0]
    ch_rel = np.repeat(np.array(consts.CHS[dim_rel])[None, :], n_cond_all, 0)
    n_ch = n_ch.reshape([-1])
    ch_rel = ch_rel.reshape([-1])

    dim_irr = consts.get_odim(dim_rel)
    conds_rel, dcond_rel = np.unique(ev_cond[:, dim_rel], return_inverse=True)
    conds_irr, dcond_irr = np.unique(np.abs(ev_cond[:, dim_irr]),
                                     return_inverse=True)

    if group_dcond_irr is not None:
        conds_irr, dcond_irr = group_conds(conds_irr, dcond_irr,
                                           group_dcond_irr)

    n_conds = [len(conds_rel), len(conds_irr)]

    n_ch_rel = npg.aggregate(
        np.stack([
            ch_rel,
            np.repeat(dcond_irr[:, None], consts.N_CH_FLAT, 1).flatten(),
            np.repeat(dcond_rel[:, None], consts.N_CH_FLAT, 1).flatten(),
        ]), n_ch, 'sum', [consts.N_CH, n_conds[1], n_conds[0]])
    p_ch_rel = n_ch_rel[1] / n_ch_rel.sum(0)

    hs = []
    for dcond_irr1, p_ch1 in enumerate(p_ch_rel):
        if type(cmap) is str:
            color = plt.get_cmap(cmap, n_conds[1])(dcond_irr1)
        else:
            color = cmap(n_conds[1])(dcond_irr1)
        kw1 = get_kw_plot(style, color=color, **dict(kw_plot))
        h = ax.plot(conds_rel, p_ch1, **kw1)
        hs.append(h)
    plt2.box_off(ax=ax)
    x_lim = ax.get_xlim()
    plt2.detach_axis('x', amin=x_lim[0], amax=x_lim[1], ax=ax)
    plt2.detach_axis('y', amin=0, amax=1, ax=ax)
    ax.set_yticks([0, 0.5, 1])
    ax.set_yticklabels(['0', '', '1'])
    ax.set_xlabel('evidence')
    ax.set_ylabel(r"$\mathrm{P}(z=1 \mid c)$")

    return hs, conds_irr
コード例 #9
0
def plot_coef_by_dur_vs_odif(coef,
                             se_coef,
                             normalize_bias=True,
                             coef_name='slope',
                             savefig=True,
                             difs=[[2, 1], [0]],
                             t_RDK_durs=np.arange(1, 11) * 0.12,
                             fig=None,
                             horizontal_panels=False):
    """

    @param dat0:
    @param parad:
    @param use_data:
    @param correct_only:
    @param plot_slope: if False, plot bias
    @param i_subj:
    @param savefig:
    @param fig:
    @return:
    """

    if fig is None:
        if horizontal_panels:
            fig = plt.figure(figsize=(7, 2))
        else:
            fig = plt.figure(figsize=(4, 3))
    if horizontal_panels:
        gs = plt.GridSpec(figure=fig,
                          nrows=1,
                          ncols=2,
                          hspace=0.3,
                          left=0.11,
                          right=0.98,
                          top=0.9,
                          bottom=0.15)
    else:
        gs = mpl.gridspec.GridSpec(figure=fig,
                                   nrows=2,
                                   ncols=1,
                                   left=0.22,
                                   right=0.98,
                                   top=0.9,
                                   bottom=0.15)

    coef_names = ['bias', 'slope']
    i_coef = coef_names.index(coef_name)

    units = ['(logit)', '(logit/coh)']
    if normalize_bias:
        coef[0] = -coef[0] / coef[1]
        se_coef[0] = -se_coef[0] / coef[1]
        units[0] = '(coh)'
    unit = units[i_coef]

    y = coef[i_coef]
    se = se_coef[i_coef]
    if len(difs) == 1:
        labels = ['all']
    elif len(difs) == 2:
        labels = ['easy', 'hard']
    elif len(difs) == 3:
        labels = ['easy', 'medium', 'hard']
    else:
        raise ValueError()

    for dim, (slope1, se_slope1) in enumerate(zip(y, se)):
        odim = consts.N_DIM - 1 - dim
        if horizontal_panels:
            plt.subplot(gs[0, dim])
        else:
            plt.subplot(gs[dim, 0])

        # plt.plot(t_RDK_durs, slope[dim])
        for odif, (slope2, se_slope2) in enumerate(zip(slope1, se_slope1)):
            label = labels[odif] + ' ' + consts.DIM_NAMES_SHORT[odim]
            plt.errorbar(t_RDK_durs,
                         slope2,
                         se_slope2,
                         marker='o',
                         label=label)
            plt.ylabel(consts.DIM_NAMES_LONG[dim].lower() + ' ' + coef_name +
                       '\n' + unit)
        plt.axis('auto')
        plt.xlim(xmin=-0.02)
        if coef_name == 'slope':
            plt.ylim(ymin=-0.05)
        plt2.box_off()
        plt2.detach_axis('x')
        if dim == 0 and not horizontal_panels:
            plt.gca().set_xticklabels([])
        plt.legend(handlelength=0.4, frameon=False, loc='upper left')
    plt.xlabel('stimulus duration (s)')
    return gs
コード例 #10
0
def main_fit(
        subjs=None,
        skip_fit_if_absent=False,
        bufdur=None,
        bufdur_best=0.12,
        bufdurs_sim=None,
        bufdurs_fit=None,
        loss_kind='NLL',
        plot_kind='line_sim_fit',
        fix_post=None,
        seed_sims=(0, ),
        err_kind='std',
        base=10,
):
    # if torch.cuda.is_available():
    #     torch.set_default_tensor_type(torch.cuda.FloatTensor)
    torch.set_num_threads(1)
    torch.set_default_dtype(torch.double)

    dict_res_add = {}

    parad = 'VD'
    seed_sims = np.array(seed_sims)

    if subjs is None:
        subjs = consts.SUBJS_VD

    if bufdur is not None:
        bufdurs_sim = list({bufdur_best}.union({bufdur}))
        bufdurs_fit = list({bufdur_best}.union({bufdur}))
    else:
        if bufdurs_sim is None:
            bufdurs_sim = bufdurs0  #
            if bufdurs_fit is None:
                bufdurs_fit = bufdurs0
        else:
            if bufdurs_fit is None:
                bufdurs_fit0 = [0., bufdur_best, 1.2]
                if bufdurs_sim[0] in bufdurs_fit0:
                    bufdurs_fit = copy(bufdurs_fit0)
                else:
                    bufdurs_fit = bufdurs_fit0[:2] + bufdurs_sim + [1.2]

    n_dur_sim = len(bufdurs_sim)
    n_dur_fit = len(bufdurs_fit)

    # DEF: losses[seed, subj, simSerDurPar, fitDurs]
    size_all = [len(seed_sims), len(subjs), n_dur_sim, n_dur_fit]
    losses0 = np.zeros(size_all) + np.nan
    ns = np.zeros(size_all) + np.nan
    ks0 = np.zeros(size_all) + np.nan

    for i_seed, seed_sim in enumerate(seed_sims):
        for i_subj, subj in enumerate(subjs):
            for i_fit, bufdur_fit in enumerate(bufdurs_fit):
                for i_sim, bufdur_sim in enumerate(bufdurs_sim):
                    d, dict_fit_sim, dict_subdir_sim = get_fit_sim(
                        subj,
                        seed_sim,
                        bufdur_sim,
                        bufdur_fit,
                        parad=parad,
                        skip_fit_if_absent=skip_fit_if_absent,
                        fix_post=fix_post,
                    )
                    if d is not None:
                        losses0[i_seed, i_subj, i_sim,
                                i_fit] = d['loss_NLL_test']
                        ns[i_seed, i_subj, i_sim, i_fit] = d['loss_ndata_test']
                        ks0[i_seed, i_subj, i_sim, i_fit] = d['loss_nparam']
    ks = copy(ks0)

    if loss_kind == 'BIC':
        losses = ks * np.log(ns) + 2 * losses0

        # REF: Kass, Raftery 1995 https://doi.org/10.2307%2F2291091
        #   https://en.wikipedia.org/wiki/Bayesian_information_criterion#Gaussian_special_case
        thres_strong = 10. / np.log(base)
    elif loss_kind == 'NLL':
        losses = copy(losses0)
        thres_strong = np.log(100) / np.log(base)
    else:
        raise ValueError('Unsupported loss_kind: %s' % loss_kind)

    #%%
    losses = losses / np.log(base)

    if plot_kind == 'loss_serpar':
        n_row = len(subjs)
        n_col = 1
        axs = plt2.GridAxes(n_row, n_col, left=1, widths=2, bottom=0.75)
        for i_subj, subj in enumerate(subjs):
            plt.sca(axs[i_subj, 0])

            loss_dur_sim_ser_fit = losses[:, i_subj, :, 0].mean(0)
            loss_dur_sim_par_fit = losses[:, i_subj, :, -1].mean(0)

            dloss_serpar = loss_dur_sim_par_fit - loss_dur_sim_ser_fit

            plt.bar(bufdurs_sim, dloss_serpar, color='k')
            plt.axhline(0, color='k', linestyle='-', linewidth=0.5)

            if i_subj < len(subjs) - 1:
                plt2.box_off(['top', 'right', 'bottom'])
                plt.xticks([])
            else:
                plt2.box_off()
                plt.xticks(np.arange(0, 1.4, 0.2), ['0\nser'] + [''] * 2 +
                           ['0.6'] + [''] * 2 + ['1.2\npar'])
                plt2.detach_axis('x', 0, 1.2)
            vdfit.patch_chance_level()

        plt.sca(axs[0, 0])
        plt.title('Support for Serial\n'
                  '($\mathcal{L}_\mathrm{ser} - \mathcal{L}_\mathrm{par}$)')

        plt.sca(axs[-1, 0])
        plt.xlabel('true buffer duration (s)')

        plt2.rowtitle(consts.SUBJS_VD, axs)
    elif plot_kind == 'imshow_ser_buf_par':
        n_row = len(subjs)
        n_col = 1
        axs = plt2.GridAxes(n_row,
                            n_col,
                            left=1,
                            widths=2,
                            heights=2,
                            bottom=0.75)
        for i_subj, subj in enumerate(subjs):
            plt.sca(axs[i_subj, 0])

            dloss = losses[:, i_subj, :, :].mean(0)
            dloss = dloss - np.diag(dloss)[:, None]

            plt.imshow(dloss, cmap='bwr')

            cmax = np.amax(np.abs(dloss))
            plt.clim(-cmax, +cmax)

        plt.sca(axs[0, 0])

        plt.sca(axs[-1, 0])
        plt.xlabel('true buffer duration (s)')
        plt.ylabel('model buffer duration (s)')

        plt2.rowtitle(consts.SUBJS_VD, axs)

    elif plot_kind == 'line_sim_fit':
        bufdur_bests = np.array([get_bufdur_best(subj) for subj in subjs])
        i_bests = np.array([
            list(bufdurs_fit).index(bufdur_best)
            for bufdur_best in bufdur_bests
        ])
        i_subjs = np.arange(len(subjs))
        loss_sim_best_fit_best = losses[:, i_subjs, i_bests, i_bests]
        loss_sim_best_fit_rest = (losses[:, i_subjs, i_bests, :] -
                                  loss_sim_best_fit_best[:, :, None])
        mean_loss_sim_best_fit_rest = np.mean(loss_sim_best_fit_rest, 0)
        if err_kind == 'std':
            err_loss_sim_best_fit_rest = np.std(loss_sim_best_fit_rest, 0)
        elif err_kind == 'sem':
            err_loss_sim_best_fit_rest = np2.sem(loss_sim_best_fit_rest, 0)

        n_dur = len(bufdurs_fit)
        loss_sim_rest_fit_best = np.swapaxes(
            np.stack([
                losses[:, i_subj, np.arange(n_dur), i_best] -
                losses[:, i_subj,
                       np.arange(n_dur),
                       np.arange(n_dur)]
                for i_subj, i_best in zip(i_subjs, i_bests)
            ]), 0, 1)
        mean_loss_sim_rest_fit_best = np.mean(loss_sim_rest_fit_best, 0)
        if err_kind == 'std':
            err_loss_sim_rest_fit_best = np.std(loss_sim_rest_fit_best, 0)
        elif err_kind == 'sem':
            err_loss_sim_rest_fit_best = np2.sem(loss_sim_rest_fit_best, 0)

        dict_res_add['err'] = err_kind

        n_row = 2
        n_col = len(subjs)

        axs = plt2.GridAxes(n_row,
                            n_col,
                            left=1.15,
                            right=0.1,
                            widths=2,
                            heights=1.5,
                            wspace=0.35,
                            hspace=0.5,
                            top=0.25,
                            bottom=0.6)

        for row, (m, s, xlabel, ylabel) in enumerate([
            (mean_loss_sim_best_fit_rest, err_loss_sim_best_fit_rest,
             'model buffer capacity (s)',
             (r'$-\mathrm{log}_{%g}\mathrm{BF}$ given simulated' +
              '\nbest duration data') % base),
            (mean_loss_sim_rest_fit_best, err_loss_sim_rest_fit_best,
             'simulated data buffer capacity (s)',
             (r'$-\mathrm{log}_{%g}\mathrm{BF}$ of the' +
              '\nbest duration model') % base)
        ]):
            for i_subj, subj in enumerate(subjs):
                ax = axs[row, i_subj]
                plt.sca(ax)
                gs1 = axs.gs[row * 2 + 1, i_subj * 2 + 1]

                bax = vdfit.breakaxis(gs1)
                bax.axs[1].errorbar(bufdurs0,
                                    m[i_subj, :],
                                    yerr=s[i_subj, :],
                                    color='k',
                                    marker='.',
                                    linewidth=0.75,
                                    elinewidth=0.5,
                                    markersize=3)
                m1 = copy(m[i_subj, :])
                m1[3:] = np.nan
                s1 = copy(s[i_subj, :])
                s1[3:] = np.nan
                bax.axs[0].errorbar(bufdurs0,
                                    m1,
                                    yerr=s1,
                                    color='k',
                                    marker='.',
                                    linewidth=0.75,
                                    elinewidth=0.5,
                                    markersize=3)

                ax1 = bax.axs[1]  # type: plt.Axes
                plt.sca(ax1)
                vdfit.patch_chance_level(level=thres_strong, signs=[-1, 1])
                plt.axhline(0, color='k', linestyle='--', linewidth=0.5)
                vdfit.beautify_ticks(ax1, )
                vdfit.beautify(ax1)
                ax1.set_yticks([0, 20])
                if i_subj > 0:
                    ax1.set_yticklabels([])
                if row == 0:
                    ax1.set_xticklabels([])

                ax1 = bax.axs[0]  # type: plt.Axes
                plt.sca(ax1)
                ax1.set_yticks([40, 200])
                if i_subj > 0:
                    ax1.set_yticklabels([])

                plt.sca(ax)
                plt2.box_off('all')
                if row == 0:
                    plt.title(consts.SUBJS['VD'][i_subj])

                if i_subj == 0:
                    plt.xlabel(xlabel, labelpad=8 if row == 0 else 20)
                    plt.ylabel(ylabel, labelpad=30)

                plt2.sameaxes(bax.axs, xy='x')

    elif plot_kind == 'bar_sim_fit':
        i_best = bufdurs_fit.index(bufdur_best)
        loss_sim_best_fit_best = losses[0, :, i_best, i_best]
        loss_sim_best_fit_rest = np2.nan2v(losses[0, :, i_best, :] -
                                           loss_sim_best_fit_best[:, None])
        n_dur = len(bufdurs_fit)
        loss_sim_rest_fit_best = np2.nan2v(
            losses[0, :, np.arange(n_dur), i_best] -
            losses[0, :, np.arange(n_dur),
                   np.arange(n_dur)]).T

        n_row = 2
        n_col = len(subjs)

        axs = plt2.GridAxes(n_row,
                            n_col,
                            left=1,
                            widths=3,
                            right=0.25,
                            heights=1,
                            wspace=0.3,
                            hspace=0.6,
                            top=0.25,
                            bottom=0.6)
        for i_subj, subj in enumerate(subjs):

            def plot_bars(gs1, bufdurs, losses1, add_xticklabel=True):
                i_break = np.amax(np.nonzero(np.array(bufdurs) < 0.3)[0])
                bax = brokenaxes(
                    subplot_spec=gs1,
                    xlims=((-1, i_break + 0.5), (i_break + 0.5,
                                                 len(bufdurs) - 0.5)),
                    ylims=((-3, 20), (20, 1250 / 5)),
                    height_ratios=(50 / 100, 500 / (1250 - 100)),
                    hspace=0.15,
                    wspace=0.075,
                    d=0.005,
                )
                bax.bar(np.arange(len(bufdurs)), losses1[i_subj, :], color='k')

                ax11 = bax.axs[3]  # type: plt.Axes
                ax11.set_xticks([bufdurs.index(0.6), bufdurs.index(1.2)])
                if i_subj == 0 and add_xticklabel:
                    ax11.set_xticklabels(['0.6', '1.2'])
                else:
                    ax11.set_xticklabels([])

                ax00 = bax.axs[0]  # type: plt.Axes
                ax00.set_yticks([500, 1000])

                ax10 = bax.axs[2]  # type: plt.Axes
                ax10.set_yticks([0, 50])
                plt.sca(ax10)
                plt2.detach_axis('x', amin=-0.4, amax=i_break + 0.5)
                for ax in [ax10, ax11]:
                    plt.sca(ax)
                    plt.axhline(0, linewidth=0.5, color='k', linestyle='--')
                    for sign in [-1, 1]:
                        plt.axhline(sign * thres_strong,
                                    linewidth=0.5,
                                    color='silver',
                                    linestyle='--')
                ax10.set_xticks([bufdurs.index(0.), bufdurs.index(0.2)])
                if i_subj == 0:
                    if add_xticklabel:
                        ax10.set_xticklabels(['0', '0.2'])
                    else:
                        ax10.set_xticklabels([])
                else:
                    ax10.set_yticklabels([])
                    ax10.set_xticklabels([])
                    ax00.set_yticklabels([])
                return bax

            bax = plot_bars(axs.gs[1, i_subj * 2 + 1],
                            bufdurs_fit,
                            loss_sim_best_fit_rest,
                            add_xticklabel=False)

            ax = axs[0, i_subj]
            plt.sca(ax)
            plt2.box_off('all')
            plt.title(consts.SUBJS['VD'][consts.SUBJS['VD'].index(subj)])
            if i_subj == 0:
                ax.set_ylabel(
                    'misfit to simulated\nbest duration data\n'
                    r'($\Delta$BIC)',
                    labelpad=35)
                ax.set_xlabel('model buffer duration (s)', labelpad=8)

            plot_bars(axs.gs[3, i_subj * 2 + 1],
                      bufdurs_sim,
                      loss_sim_rest_fit_best,
                      add_xticklabel=True)
            ax = axs[1, i_subj]
            plt.sca(ax)
            plt2.box_off('all')
            if i_subj == 0:
                ax.set_ylabel(
                    'misfit of\nbest duration model\n'
                    r'($\Delta$BIC)',
                    labelpad=35)
                ax.set_xlabel('simulated data buffer duration (s)',
                              labelpad=20)

    elif plot_kind == 'bar_ser_buf_par':
        n_row = len(subjs)
        n_col = 1
        axs = plt2.GridAxes(n_row * n_dur_sim,
                            n_col,
                            left=1.25,
                            widths=1.5,
                            right=0.25,
                            heights=0.4,
                            hspace=[0.15] * (n_dur_sim - 1) + [0.5] + [0.15] *
                            (n_dur_sim - 1),
                            top=0.6,
                            bottom=0.5)
        row = -1
        for i_subj, subj in enumerate(subjs):
            for i_sim in range(n_dur_sim):
                row += 1
                plt.sca(axs[row, 0])

                loss1 = losses[:, i_subj, i_sim, :].mean(0)
                dloss = loss1 - loss1[i_sim]

                x = np.arange(n_dur_fit)
                for x1, dloss1 in zip(x, dloss):
                    plt.bar(x1, dloss1, color='r' if dloss1 > 0 else 'b')
                plt.axhline(0, color='k', linewidth=0.5)
                vdfit.patch_chance_level(6.)
                plt.ylim([-100, 100])
                plt.yticks([-100, 0, 100], [''] * 2)
                plt.ylabel('%g' % bufdurs_sim[i_sim],
                           rotation=0,
                           va='center',
                           ha='right')

                if i_sim == 0:
                    plt.title(subj)

                if i_sim < n_dur_sim - 1 or i_subj < len(subjs) - 1:
                    plt2.box_off(['top', 'bottom', 'right'])
                    plt.xticks([])
                else:
                    plt2.box_off(['top', 'right'])

        plt.sca(axs[-1, 0])
        plt.xticks(np.arange(n_dur_sim), ['%g' % v for v in bufdurs_sim])
        plt2.detach_axis('x', 0, n_dur_sim - 1)
        plt.xlabel('model buffer duration (s)', fontsize=10)

        c = axs[-1, 0].get_position().corners()
        plt.figtext(x=0.15,
                    y=np.mean(c[:2, 1]),
                    fontsize=10,
                    s='true\nbuffer\nduration\n(s)',
                    rotation=0,
                    va='center',
                    ha='center')

        plt.figtext(
            x=(1.25 + 1.5 / 2) / (1.25 + 1.5 + 0.25),
            y=0.98,
            s=r'$\mathrm{BIC} - \mathrm{BIC}_\mathrm{true}$',
            ha='center',
            va='top',
            fontsize=12,
        )

    if plot_kind != 'None':
        dict_res = deepcopy(dict_fit_sim)  # noqa
        for k in ['sbj', 'prd']:
            dict_res.pop(k)
        dict_res = {**dict_res, 'los': loss_kind}
        dict_res.update(dict_res_add)
        for ext in ['.png', '.pdf']:
            file = locfile.get_file_fig(plot_kind,
                                        dict_res,
                                        subdir=dict_subdir_sim,
                                        ext=ext)  # noqa
            plt.savefig(file, dpi=300)
            print('Saved to %s' % file)

    #%%
    print('--')
    return losses
コード例 #11
0
def plot_ch_ev_by_dur(n_cond_dur_ch: np.ndarray = None,
                      data: Data2DVD = None,
                      dif_irrs: Sequence[Union[Iterable[int],
                                               int]] = ((0, 1), (2, )),
                      ev_cond_dim: np.ndarray = None,
                      durs: np.ndarray = None,
                      dur_prct_groups=((0, 33), (33, 67), (67, 100)),
                      style='data',
                      kw_plot=(),
                      jitter=0.,
                      axs=None,
                      fig=None):
    """
    Panels[dim, irr_dif_group], curves by dur_group
    :param n_cond_dur_ch:
    :param data:
    :param dif_irrs:
    :param ev_cond_dim:
    :param durs:
    :param dur_prct_groups:
    :param style:
    :param kw_plot:
    :param axs: [row, col]
    :return:
    """
    if ev_cond_dim is None:
        ev_cond_dim = data.ev_cond_dim
    if durs is None:
        durs = data.durs
    if n_cond_dur_ch is None:
        n_cond_dur_ch = npy(data.get_data_by_cond('all')[1])

    n_conds_dur_chs = get_n_conds_dur_chs(n_cond_dur_ch)
    conds_dim = [np.unique(cond1) for cond1 in ev_cond_dim.T]
    n_dur = len(dur_prct_groups)
    n_dif = len(dif_irrs)

    if axs is None:
        if fig is None:
            fig = plt.figure(figsize=[6, 4])

        n_row = consts.N_DIM
        n_col = n_dur
        gs = plt.GridSpec(nrows=n_row, ncols=n_col, figure=fig)
        axs = np.empty([n_row, n_col], dtype=np.object)
        for row in range(n_row):
            for col in range(n_col):
                axs[row, col] = plt.subplot(gs[row, col])

    for dim_rel in range(consts.N_DIM):
        for idif, dif_irr in enumerate(dif_irrs):
            for i_dur, dur_prct_group in enumerate(dur_prct_groups):

                ax = axs[dim_rel, i_dur]
                plt.sca(ax)

                conds_rel = conds_dim[dim_rel]
                dim_irr = consts.get_odim(dim_rel)
                _, cond_irr = np.unique(np.abs(conds_dim[dim_irr]),
                                        return_inverse=True)
                incl_irr = np.isin(cond_irr, dif_irr)

                cmap = consts.CMAP_DIM[dim_rel](n_dif)

                ix_dur = np.arange(len(durs))
                dur_incl = (
                    (ix_dur >= np.percentile(ix_dur, dur_prct_group[0]))
                    & (ix_dur <= np.percentile(ix_dur, dur_prct_group[1])))
                if dim_rel == 0:
                    n_cond_dur_ch1 = n_conds_dur_chs[:, incl_irr, :, :, :].sum(
                        (1, -1))
                else:
                    n_cond_dur_ch1 = n_conds_dur_chs[incl_irr, :, :, :, :].sum(
                        (0, -2))
                n_cond_ch1 = n_cond_dur_ch1[:, dur_incl, :].sum(1)
                p_cond__ch1 = n_cond_ch1[:, 1] / n_cond_ch1.sum(1)

                x = conds_rel
                y = p_cond__ch1

                dx = conds_rel[1] - conds_rel[0]
                jitter1 = 0
                kw = consts.get_kw_plot(style,
                                        color=cmap(idif),
                                        **dict(kw_plot))
                if style.startswith('data'):
                    plt.plot(x + jitter1, y, **kw)
                else:
                    plt.plot(x + jitter1, y, **kw)

                plt2.box_off(ax=ax)
                plt2.detach_yaxis(0, 1, ax=ax)

                ax.set_yticks([0, 0.5, 1.])
                if i_dur == 0:
                    ax.set_yticklabels(['0', '', '1'])
                else:
                    ax.set_yticklabels([])
                    ax.set_xticklabels([])

    return axs
コード例 #12
0
def plot_coefs_dur_irrixn(
        n_cond_dur_ch: np.ndarray = None,
        data: Data2DVD = None,
        ev_cond_dim: np.ndarray = None,
        durs: np.ndarray = None,
        style='data',
        kw_plot=(),
        jitter0=0.,
        coefs_to_plot=(2, ),
        axs=None,
        fig=None,
):
    if ev_cond_dim is None:
        ev_cond_dim = data.ev_cond_dim
    if durs is None:
        durs = npy(data.durs)
    if n_cond_dur_ch is None:
        n_cond_dur_ch = npy(data.get_data_by_cond('all')[1])

    coef, se_coef = coefdur.get_coefs_irr_ixn_from_histogram(
        n_cond_dur_ch,
        ev_cond_dim=ev_cond_dim)[:2]  # type: (np.ndarray, np.ndarray)

    coef_names = [
        'bias', 'slope', 'rel x abs(irr)', 'rel x irr', 'abs(irr)', 'irr'
    ]
    n_coef = len(coefs_to_plot)

    if axs is None:
        if fig is None:
            fig = plt.figure(figsize=[6, 1.5 * n_coef])

        n_row = consts.N_DIM
        n_col = n_coef
        gs = plt.GridSpec(nrows=n_row,
                          ncols=n_col,
                          figure=fig,
                          left=0.25,
                          right=0.95,
                          bottom=0.15,
                          top=0.9)
        axs = np.empty([n_row, n_col], dtype=np.object)
        for row in range(n_row):
            for col in range(n_col):
                axs[row, col] = plt.subplot(gs[row, col])

    for ii_coef, i_coef in enumerate(coefs_to_plot):
        coef_name = coef_names[i_coef]
        for dim_rel in range(consts.N_DIM):
            ax = axs[dim_rel, ii_coef]  # type: plt.Axes
            plt.sca(ax)

            y = coef[i_coef, dim_rel, :]
            e = se_coef[i_coef, dim_rel, :]
            jitter = 0.
            kw_plot = {'color': 'k', 'for_err': True, **dict(kw_plot)}
            kw = consts.get_kw_plot(style, **kw_plot)
            if style.startswith('data'):
                plt.errorbar(durs + jitter, y, yerr=e, **kw)
            else:
                plt.plot(durs + jitter, y, **kw)
            plt2.box_off()

            max_dur = np.amax(npy(durs))
            ax.set_xlim(-0.05 * max_dur, 1.05 * max_dur)
            plt2.detach_axis('x', 0, max_dur)  # , ax=ax)

            if dim_rel == 0:
                ax.set_xticklabels([])
                ax.set_title(coef_name.lower())
            elif i_coef == 0:
                ax.set_xlabel('duration (s)')

    plt2.rowtitle(consts.DIM_NAMES_LONG, axes=axs)

    return coef, se_coef, axs
コード例 #13
0
def main(base=10.):
    ds = load_comp(os.path.join(locfile_in.pth_root, file_model_comp))
    ds['dcost'] = np.array(ds['dcost']) / np.log(base)

    vmax = 160

    m = np.empty(3)
    e = np.empty(3)

    axs = plot_bar_dloss_across_subjs(dlosses=np.array(ds['dcost']),
                                      subj_parad_bis=ds['subj_parad_bi'],
                                      vmax=vmax)
    axs = plt2.GridAxes(nrows=1,
                        ncols=3,
                        heights=axs.heights,
                        widths=[2],
                        left=axs.left,
                        right=axs.right,
                        bottom=axs.bottom)
    plot_bar_dloss_across_subjs(
        dlosses=np.array(ds['dcost']),
        subj_parad_bis=ds['subj_parad_bi'],
        vmax=vmax,
        axs=axs[:, [0]],
        base=base,
    )
    plt.title('Data')

    m[0] = np.mean(ds['dcost'])
    e[0] = np2.sem(ds['dcost'])

    print('--')

    titles = ['Simulated\nSerial', 'Simulated\nParallel']
    for i in range(2):
        ds = load_comp(
            os.path.join(locfile_in.pth_root, files_model_recovery[i]))
        ds['dcost'] = np.array(ds['dcost']) / np.log(base)

        plot_bar_dloss_across_subjs(
            dlosses=ds['dcost'],
            subj_parad_bis=ds['subj_parad_bi'],
            vmax=vmax,
            axs=axs[:, [i + 1]],
            base=base,
        )
        plt.title(titles[i])

        m[i + 1] = np.mean(ds['dcost'])
        e[i + 1] = np2.sem(ds['dcost'])

        plt.sca(axs[0, i + 1])
        plt2.box_off(['left'])
        plt.yticks([])

    for ext in ['.pdf', '.png']:
        file = locfile_out.get_file_fig('model_comp_recovery',
                                        {'easiest_only': use_easiest_only},
                                        ext=ext)
        plt.savefig(file, dpi=300)
        print('Saved to %s' % file)

    # --- Print mean +- SEM to CSV
    csv_file = locfile_out.get_file_csv('model_comp_recovery',
                                        {'easiest_only': use_easiest_only})
    d_csv = np2.dictlist2listdict({
        'data': ['original', 'simulated serial', 'simulated parallel'],
        'mean_dcost':
        m,
        'sem_dcost':
        e,
    })
    with open(csv_file, 'w') as f:
        writer = csv.DictWriter(f, fieldnames=d_csv[0].keys())
        writer.writeheader()
        for d in d_csv:
            writer.writerow(d)
        print('Wrote to %s' % csv_file)

    # --- Mean NLL within each
    axs = plt2.GridAxes(
        1,
        1,
        left=0.85,
        right=0.25,
        heights=[1.5],
        top=0.1,
        bottom=0.9,
    )
    plt.sca(axs[0, 0])
    plt.barh(np.arange(3), m, xerr=e, color='w', edgecolor='k')
    plt.yticks(np.arange(3),
               ['data', 'simulated\nserial', 'simulated\nparallel'])
    plt2.box_off(['top', 'right'])
    vmax = np.amax(np.abs(m) + np.abs(e))
    plt.xlim(np.array([-vmax, vmax]) * 1.1)
    plt2.detach_axis(xy='y', amin=0, amax=2)
    plt2.detach_axis(xy='x', amin=-vmax, amax=vmax)

    axvline_dcost()
    xticks_serial_vs_parallel(vmax, base)

    for ext in ['.pdf', '.png']:
        file = locfile_out.get_file_fig('model_comp_vs_recovery',
                                        {'easiest_only': use_easiest_only},
                                        ext=ext)
        plt.savefig(file, dpi=300)
        print('Saved to %s' % file)

    print('--')
コード例 #14
0
ファイル: dtb_1D_sim.py プロジェクト: yulkang/2D_Decision
def plot_rt_vs_ev(
        ev_cond,
        n_cond__rt_ch: Union[torch.Tensor, np.ndarray],
        style='pred',
        pool='mean',
        dt=consts.DT,
        correct_only=True,
        thres_n_trial=10,
        color='k',
        color_ch=('tab:red', 'tab:blue'),
        ax: plt.Axes = None,
        kw_plot=(),
) -> (Sequence[plt.Line2D], List[np.ndarray]):
    """
    @param ev_cond: [condition]
    @type ev_cond: torch.Tensor
    @param n_cond__rt_ch: [condition, frame, ch]
    @type n_cond__rt_ch: torch.Tensor
    @return:
    """
    if ax is None:
        ax = plt.gca()
    if ev_cond.ndim != 1:
        if ev_cond.ndim == 3:
            ev_cond = npt.p2st(ev_cond)[0]
        assert ev_cond.ndim == 2
        ev_cond = ev_cond.mean(1)
    assert n_cond__rt_ch.ndim == 3

    ev_cond = npy(ev_cond)
    n_cond__rt_ch = npy(n_cond__rt_ch)

    def plot_rt_given_cond_ch(ev_cond1, n_rt_given_cond_ch, **kw1):
        # n_rt_given_cond_ch[cond, fr]
        # p_rt_given_cond_ch[cond, fr]
        p_rt_given_cond_ch = np2.sumto1(n_rt_given_cond_ch, 1)

        nt = n_rt_given_cond_ch.shape[1]
        t = np.arange(nt) * dt

        # n_in_cond_ch[cond]
        n_in_cond_ch = n_rt_given_cond_ch.sum(1)
        if pool == 'mean':
            rt_pooled = (t[None, :] * p_rt_given_cond_ch).sum(1)
        elif pool == 'var':
            raise NotImplementedError()
        else:
            raise ValueError()

        if style.startswith('data'):
            rt_pooled[n_in_cond_ch < thres_n_trial] = np.nan

        kw = get_kw_plot(style, **kw1)
        h = ax.plot(ev_cond1, rt_pooled, **kw)
        return h, rt_pooled

    if correct_only:
        hs = []
        rtss = []

        n_cond__rt_ch1 = n_cond__rt_ch.copy()  # type: np.ndarray
        if style.startswith('data'):
            cond0 = ev_cond == 0
            n_cond__rt_ch1[cond0, :, :] = np.sum(n_cond__rt_ch1[cond0, :, :],
                                                 axis=-1,
                                                 keepdims=True)

        # # -- Choose the ch with correct sign (or both chs if cond == 0)
        for ch in range(consts.N_CH):
            cond_sign = np.sign(ev_cond)
            ch_sign = consts.ch_bool2sign(ch)
            cond_accu = cond_sign != -ch_sign

            h, rts = plot_rt_given_cond_ch(ev_cond[cond_accu],
                                           n_cond__rt_ch1[cond_accu, :, ch],
                                           color=color,
                                           **dict(kw_plot))
            hs.append(h)
            rtss.append(rts)
        # rts = np.stack(rtss)
        rts = rtss

    else:
        hs = []
        rtss = []
        for ch in range(consts.N_CH):
            # n_rt_given_cond_ch[cond, fr]

            n_rt_given_cond_ch = n_cond__rt_ch[:, :, ch]

            h, rts = plot_rt_given_cond_ch(ev_cond,
                                           n_rt_given_cond_ch,
                                           color=color_ch[ch])

            hs.append(h)
            rtss.append(rts)
        rts = np.stack(rtss)

    y_lim = ax.get_ylim()
    x_lim = ax.get_xlim()
    plt2.box_off()
    plt2.detach_axis('x', ax=ax, amin=x_lim[0], amax=x_lim[1])
    plt2.detach_axis('y', ax=ax, amin=y_lim[0], amax=y_lim[1])
    ax.set_xlabel('evidence')
    ax.set_ylabel(r"$\mathrm{E}[T^\mathrm{r} \mid c]~(\mathrm{s})$")
    return hs, rts
コード例 #15
0
ファイル: dtb_2D_sim.py プロジェクト: yulkang/2D_Decision
def plot_rt_distrib(
    n_cond_rt_ch: np.ndarray,
    ev_cond_dim: np.ndarray,
    abs_cond=True,
    lump_wrong=True,
    dt=consts.DT,
    colors=None,
    alpha=1.,
    alpha_face=0.5,
    smooth_sigma_sec=0.05,
    to_normalize_max=False,
    to_cumsum=False,
    to_exclude_last_frame=True,
    to_skip_zero_trials=False,
    label='',
    # to_exclude_bins_wo_trials=10,
    kw_plot=(),
    fig=None,
    axs=None,
    to_use_sameaxes=True,
):
    """

    :param n_cond_rt_ch:
    :param ev_cond_dim:
    :param abs_cond:
    :param lump_wrong:
    :param dt:
    :param gs:
    :param colors:
    :param alpha:
    :param smooth_sigma_sec:
    :param kw_plot:
    :param axs:
    :return: axs, p_cond01__rt_ch01, p_cond01__rt_ch01_sm, hs
    """
    if colors is None:
        colors = ['red', 'blue']
    elif type(colors) is str:
        colors = [colors] * 2
    else:
        assert len(colors) == 2

    nt = n_cond_rt_ch.shape[1]
    t_all = np.arange(nt) * dt + dt

    out = np.meshgrid(np.unique(ev_cond_dim[:, 0]),
                      np.unique(ev_cond_dim[:, 1]),
                      np.arange(nt),
                      np.arange(2),
                      np.arange(2),
                      indexing='ij')
    cond0, cond1, fr, ch0, ch1 = [v.flatten() for v in out]

    from copy import deepcopy
    n0 = deepcopy(n_cond_rt_ch)
    if to_exclude_last_frame:
        n0[:, -1, :] = 0.
    n0 = n0.flatten()

    def sign_cond(v):
        v1 = np.sign(v)
        v1[v == 0] = 1
        return v1

    if abs_cond:
        ch0 = consts.ch_bool2sign(ch0)
        ch1 = consts.ch_bool2sign(ch1)

        # 1 = correct, -1 = wrong
        ch0 = sign_cond(cond0) * ch0
        ch1 = sign_cond(cond1) * ch1

        cond0 = np.abs(cond0)
        cond1 = np.abs(cond1)

        ch0 = consts.ch_sign2bool(ch0).astype(np.int)
        ch1 = consts.ch_sign2bool(ch1).astype(np.int)
    else:
        raise ValueError()

    if lump_wrong:
        # treat all choices as correct when cond == 0
        ch00 = ch0 | (cond0 == 0)
        ch10 = ch1 | (cond1 == 0)

        ch0 = (ch00 & ch10)
        ch1 = np.ones_like(ch00, dtype=np.int)

    cond_dim = np.stack([cond0, cond1], -1)

    conds = []
    dcond_dim = []
    for cond in cond_dim.T:
        conds1, dcond1 = np.unique(cond, return_inverse=True)
        conds.append(conds1)
        dcond_dim.append(dcond1)
    dcond_dim = np.stack(dcond_dim)

    n_cond01_rt_ch01 = npg.aggregate(
        [*dcond_dim, fr, ch0, ch1], n0, 'sum',
        [*(np.amax(dcond_dim, 1) + 1), nt, consts.N_CH, consts.N_CH])

    p_cond01__rt_ch01 = np2.nan2v(n_cond01_rt_ch01 / n_cond01_rt_ch01.sum(
        (2, 3, 4), keepdims=True))

    n_conds = p_cond01__rt_ch01.shape[:2]

    if axs is None:
        axs = plt2.GridAxes(
            n_conds[1],
            n_conds[0],
            left=0.6,
            right=0.3,
            bottom=0.45,
            top=0.74,
            widths=[1],
            heights=[1],
            wspace=0.04,
            hspace=0.04,
        )

    kw_label = {
        'fontsize': 12,
    }
    pad = 8
    axs[0, 0].set_title('strong\nmotion', pad=pad, **kw_label)
    axs[0, -1].set_title('weak\nmotion', pad=pad, **kw_label)
    axs[0, 0].set_ylabel('strong\ncolor', labelpad=pad, **kw_label)
    axs[-1, 0].set_ylabel('weak\ncolor', labelpad=pad, **kw_label)

    if smooth_sigma_sec > 0:
        from scipy import signal, stats
        sigma_fr = smooth_sigma_sec / dt
        width = np.ceil(sigma_fr * 2.5).astype(np.int)
        kernel = stats.norm.pdf(np.arange(-width, width + 1), 0, sigma_fr)
        kernel = np2.vec_on(kernel, 2, 5)
        p_cond01__rt_ch01_sm = signal.convolve(p_cond01__rt_ch01,
                                               kernel,
                                               mode='same')
    else:
        p_cond01__rt_ch01_sm = p_cond01__rt_ch01.copy()

    if to_cumsum:
        p_cond01__rt_ch01_sm = np.cumsum(p_cond01__rt_ch01_sm, axis=2)

    if to_normalize_max:
        p_cond01__rt_ch01_sm = np2.nan2v(
            p_cond01__rt_ch01_sm /
            np.amax(np.abs(p_cond01__rt_ch01_sm), (2, 3, 4), keepdims=True))

    n_row = n_conds[1]
    n_col = n_conds[0]
    for dcond0 in range(n_conds[0]):
        for dcond1 in range(n_conds[1]):
            row = n_row - 1 - dcond1
            col = n_col - 1 - dcond0

            ax = axs[row, col]  # type: plt.Axes

            for ch0 in [0, 1]:
                for ch1 in [0, 1]:
                    if lump_wrong and ch1 == 0:
                        continue

                    p1 = p_cond01__rt_ch01_sm[dcond0, dcond1, :, ch0, ch1]

                    kw = {
                        'linewidth': 1,
                        'color': colors[ch1],
                        'alpha': alpha,
                        'zorder': 1,
                        **dict(kw_plot)
                    }

                    y = p1 * consts.ch_bool2sign(ch0)

                    p_cond01__rt_ch01_sm[dcond0, dcond1, :, ch0, ch1] = y

                    if to_skip_zero_trials and np.sum(np.abs(y)) < 1e-2:
                        h = None
                    else:
                        h = ax.plot(
                            t_all,
                            y,
                            label=label if ch0 == 1 and ch1 == 1 else None,
                            **kw)
                        ax.fill_between(t_all,
                                        0,
                                        y,
                                        ec='None',
                                        fc=kw['color'],
                                        alpha=alpha_face,
                                        zorder=-1)
                    plt2.box_off(ax=ax)

            ax.axhline(0, color='k', linewidth=0.5)
            ax.set_yticklabels([])
            if row < n_row - 1 or col > 0:
                ax.set_xticklabels([])
                ax.set_xticks([])
                plt2.box_off(['bottom'], ax=ax)
            else:
                ax.set_xlabel('RT (s)')
            # if col > 0:
            ax.set_yticks([])
            ax.set_yticklabels([])
            plt2.box_off(['left'], ax=ax)

            plt2.detach_axis('x', 0, 5, ax=ax)
    if to_use_sameaxes:
        plt2.sameaxes(axs)
    axs[-1, 0].set_xlabel('RT (s)')

    return axs, p_cond01__rt_ch01, p_cond01__rt_ch01_sm, h
コード例 #16
0
def plot_coefs_dur_odif(
        n_cond_dur_ch: np.ndarray = None,
        data: Data2DVD = None,
        dif_irrs: Sequence[Union[Iterable[int], int]] = ((0, 1), (2, )),
        ev_cond_dim: np.ndarray = None,
        durs: np.ndarray = None,
        style='data',
        kw_plot=(),
        jitter0=0.,
        coefs_to_plot=(0, 1),
        add_rowtitle=True,
        dim_incl=(0, 1),
        axs=None,
        fig=None,
):
    """

    :param n_cond_dur_ch:
    :param data:
    :param dif_irrs:
    :param ev_cond_dim:
    :param durs:
    :param style:
    :param kw_plot:
    :param jitter0:
    :param coefs_to_plot:
    :param add_rowtitle:
    :param dim_incl:
    :param axs:
    :param fig:
    :return: coef, se_coef, axs, hs[coef, dim, dif]
    """

    if ev_cond_dim is None:
        ev_cond_dim = data.ev_cond_dim
    if durs is None:
        durs = npy(data.durs)
    if n_cond_dur_ch is None:
        n_cond_dur_ch = npy(data.get_data_by_cond('all')[1])

    coef, se_coef = coefdur.get_coefs_mesh_from_histogram(
        n_cond_dur_ch, ev_cond_dim=ev_cond_dim,
        dif_irrs=dif_irrs)[:2]  # type: (np.ndarray, np.ndarray)

    coef_names = ['bias', 'slope']
    n_coef = len(coefs_to_plot)
    n_dim = len(dim_incl)

    if axs is None:
        if fig is None:
            fig = plt.figure(figsize=[6, 1.5 * n_coef])

        n_row = n_dim
        n_col = n_coef
        gs = plt.GridSpec(nrows=n_row,
                          ncols=n_col,
                          figure=fig,
                          left=0.25,
                          right=0.95,
                          bottom=0.15,
                          top=0.9)
        axs = np.empty([n_row, n_col], dtype=np.object)
        for row in range(n_row):
            for col in range(n_col):
                axs[row, col] = plt.subplot(gs[row, col])

    n_dif = len(dif_irrs)
    hs = np.empty(
        [len(coefs_to_plot), len(dim_incl),
         len(dif_irrs)], dtype=np.object)

    for ii_coef, i_coef in enumerate(coefs_to_plot):
        coef_name = coef_names[i_coef]
        for i_dim, dim_rel in enumerate(dim_incl):
            ax = axs[i_dim, ii_coef]  # type: plt.Axes
            plt.sca(ax)

            cmap = consts.CMAP_DIM[dim_rel](n_dif)

            for idif, dif_irr in enumerate(dif_irrs):
                y = coef[i_coef, dim_rel, idif, :]
                e = se_coef[i_coef, dim_rel, idif, :]
                ddur = durs[1] - durs[0]
                jitter = ddur * jitter0 * (idif - (n_dif - 1) / 2)
                kw = consts.get_kw_plot(style,
                                        color=cmap(idif),
                                        for_err=True,
                                        **dict(kw_plot))
                if style.startswith('data'):
                    h = plt.errorbar(durs + jitter, y, yerr=e, **kw)[0]
                else:
                    h = plt.plot(durs + jitter, y, **kw)
                hs[ii_coef, i_dim, idif] = h
            plt2.box_off()

            max_dur = np.amax(npy(durs))
            ax.set_xlim(-0.05 * max_dur, 1.05 * max_dur)
            plt2.detach_axis('x', 0, max_dur)

            if dim_rel == 0:
                ax.set_xticklabels([])
                ax.set_title(coef_name.lower())
            elif i_coef == 0:
                ax.set_xlabel('duration (s)')

    if add_rowtitle:
        plt2.rowtitle(consts.DIM_NAMES_LONG, axes=axs)

    return coef, se_coef, axs, hs
コード例 #17
0
def plot_fit_combined(
        data: Union[sim2d.Data2DRT, dict] = None,
        pModel_cond_rt_chFlat=None, model=None,
        pModel_dimRel_condDense_chFlat=None,
        # --- in place of data:
        pAll_cond_rt_chFlat=None,
        evAll_cond_dim=None,
        pTrain_cond_rt_chFlat=None,
        evTrain_cond_dim=None,
        pTest_cond_rt_chFlat=None,
        evTest_cond_dim=None,
        dt=None,
        # --- optional
        ev_dimRel_condDense_fr_dim_meanvar=None,
        dt_model=None,
        to_plot_internals=True,
        to_plot_params=True,
        to_plot_choice=True,
        # to_group_irr=False,
        group_dcond_irr=None,
        to_combine_ch_irr_cond=True,
        kw_plot_pred=(),
        kw_plot_pred_ch=(),
        kw_plot_data=(),
        axs=None,
):
    """

    :param data:
    :param pModel_cond_rt_chFlat:
    :param model:
    :param pModel_dimRel_condDense_chFlat:
    :param ev_dimRel_condDense_fr_dim_meanvar:
    :param to_plot_internals:
    :param to_plot_params:
    :param to_group_irr:
    :param to_combine_ch_irr_cond:
    :param kw_plot_pred:
    :param kw_plot_data:
    :param axs:
    :return:
    """
    if data is None:
        if pTrain_cond_rt_chFlat is None:
            pTrain_cond_rt_chFlat = pAll_cond_rt_chFlat
        if evTrain_cond_dim is None:
            evTrain_cond_dim = evAll_cond_dim
        if pTest_cond_rt_chFlat is None:
            pTest_cond_rt_chFlat = pAll_cond_rt_chFlat
        if evTest_cond_dim is None:
            evTest_cond_dim = evAll_cond_dim
    else:
        _, pAll_cond_rt_chFlat, _, _, evAll_cond_dim = \
            data.get_data_by_cond('all')
        _, pTrain_cond_rt_chFlat, _, _, evTrain_cond_dim = data.get_data_by_cond(
            'train_valid', mode_train='easiest')
        _, pTest_cond_rt_chFlat, _, _, evTest_cond_dim = data.get_data_by_cond(
            'test', mode_train='easiest')
        dt = data.dt
    hs = {}

    if model is None:
        assert not to_plot_internals
        assert not to_plot_params

    if dt_model is None:
        if model is None:
            dt_model = dt
        else:
            dt_model = model.dt

    if axs is None:
        if to_plot_params:
            axs = plt2.GridAxes(3, 3)
        else:
            if to_plot_internals:
                axs = plt2.GridAxes(3, 3)
            else:
                if to_plot_choice:
                    axs = plt2.GridAxes(2, 2)
                else:
                    axs = plt2.GridAxes(1, 2)  # TODO: beautify ratios

    rts = []
    hs['rt'] = []
    for dim_rel in range(consts.N_DIM):
        # --- data_pred may not have all conditions, so concatenate the rest
        #  of the conditions so that the color scale is correct. Then also
        #  concatenate p_rt_ch_data_pred1 with zeros so that nothing is
        #  plotted in the concatenated.
        evTest_cond_dim1 = np.concatenate([
            evTest_cond_dim, evAll_cond_dim
        ], axis=0)
        pTest_cond_rt_chFlat1 = np.concatenate([
            pTest_cond_rt_chFlat, np.zeros_like(pAll_cond_rt_chFlat)
        ], axis=0)

        if ev_dimRel_condDense_fr_dim_meanvar is None:
            evModel_cond_dim = evAll_cond_dim
        else:
            if ev_dimRel_condDense_fr_dim_meanvar.ndim == 5:
                evModel_cond_dim = npy(ev_dimRel_condDense_fr_dim_meanvar[
                                           dim_rel][:, 0, :, 0])
            else:
                assert ev_dimRel_condDense_fr_dim_meanvar.ndim == 4
                evModel_cond_dim = npy(ev_dimRel_condDense_fr_dim_meanvar[
                                           dim_rel][:, 0, :])
            pModel_cond_rt_chFlat = npy(pModel_dimRel_condDense_chFlat[dim_rel])

        if to_plot_choice:
            # --- Plot choice
            ax = axs[0, dim_rel]
            plt.sca(ax)

            if to_combine_ch_irr_cond:
                ev_cond_model1, p_rt_ch_model1 = combine_irr_cond(
                    dim_rel, evModel_cond_dim, pModel_cond_rt_chFlat
                )

                sim2d.plot_p_ch_vs_ev(ev_cond_model1, p_rt_ch_model1,
                                      dim_rel=dim_rel, style='pred',
                                      group_dcond_irr=None,
                                      kw_plot=kw_plot_pred_ch,
                                      cmap=lambda n: lambda v: [0., 0., 0.],
                                      )
            else:
                sim2d.plot_p_ch_vs_ev(evModel_cond_dim, pModel_cond_rt_chFlat,
                                      dim_rel=dim_rel, style='pred',
                                      group_dcond_irr=group_dcond_irr,
                                      kw_plot=kw_plot_pred,
                                      cmap=cmaps[dim_rel]
                                      )
            hs, conds_irr = sim2d.plot_p_ch_vs_ev(
                evTest_cond_dim1, pTest_cond_rt_chFlat1,
                dim_rel=dim_rel, style='data_pred',
                group_dcond_irr=group_dcond_irr,
                cmap=cmaps[dim_rel],
                kw_plot=kw_plot_data,
            )
            hs1 = [h[0] for h in hs]
            odim = 1 - dim_rel
            odim_name = consts.DIM_NAMES_LONG[odim]
            legend_odim(conds_irr, hs1, odim_name)
            sim2d.plot_p_ch_vs_ev(evTrain_cond_dim, pTrain_cond_rt_chFlat,
                                  dim_rel=dim_rel, style='data_fit',
                                  group_dcond_irr=group_dcond_irr,
                                  cmap=cmaps[dim_rel],
                                  kw_plot=kw_plot_data
                                  )
            plt2.detach_axis('x', np.amin(evTrain_cond_dim[:, dim_rel]),
                             np.amax(evTrain_cond_dim[:, dim_rel]))
            ax.set_xlabel('')
            ax.set_xticklabels([])
            if dim_rel != 0:
                plt2.box_off(['left'])
                plt.yticks([])

            ax.set_ylabel('P(%s choice)' % consts.CH_NAMES[dim_rel][1])

        # --- Plot RT
        ax = axs[int(to_plot_choice) + 0, dim_rel]
        plt.sca(ax)
        hs1, rts1 = sim2d.plot_rt_vs_ev(
            evModel_cond_dim,
            pModel_cond_rt_chFlat,
            dim_rel=dim_rel, style='pred',
            group_dcond_irr=group_dcond_irr,
            dt=dt_model,
            kw_plot=kw_plot_pred,
            cmap=cmaps[dim_rel]
        )
        hs['rt'].append(hs1)
        rts.append(rts1)

        sim2d.plot_rt_vs_ev(evTest_cond_dim1, pTest_cond_rt_chFlat1,
                            dim_rel=dim_rel, style='data_pred',
                            group_dcond_irr=group_dcond_irr,
                            dt=dt,
                            cmap=cmaps[dim_rel],
                            kw_plot=kw_plot_data
                            )
        sim2d.plot_rt_vs_ev(evTrain_cond_dim, pTrain_cond_rt_chFlat,
                            dim_rel=dim_rel, style='data_fit',
                            group_dcond_irr=group_dcond_irr,
                            dt=dt,
                            cmap=cmaps[dim_rel],
                            kw_plot=kw_plot_data
                            )
        plt2.detach_axis('x', np.amin(evTrain_cond_dim[:, dim_rel]),
                         np.amax(evTrain_cond_dim[:, dim_rel]))
        if dim_rel != 0:
            ax.set_ylabel('')
            plt2.box_off(['left'])
            plt.yticks([])

        ax.set_xlabel(consts.DIM_NAMES_LONG[dim_rel].lower() + ' strength')

        if dim_rel == 0:
            ax.set_ylabel('RT (s)')

        if to_plot_internals:
            for ch1 in range(consts.N_CH):
                ch0 = dim_rel
                ax = axs[3 + ch1, dim_rel]
                plt.sca(ax)

                ch_flat = consts.ch_by_dim2ch_flat(np.array([ch0, ch1]))
                model.tnds[ch_flat].plot_p_tnd()
                ax.set_xlabel('')
                ax.set_xticklabels([])
                ax.set_yticks([0, 1])
                if ch0 > 0:
                    ax.set_yticklabels([])

                ax.set_ylabel(r"$\mathrm{P}(T^\mathrm{n} \mid"
                              " \mathbf{z}=[%d,%d])$"
                              % (ch0, ch1))

            ax = axs[5, dim_rel]
            plt.sca(ax)
            if hasattr(model.dtb, 'dtb1ds'):
                model.dtb.dtb1ds[dim_rel].plot_bound(color='k')

    plt2.sameaxes(axs[-1, :consts.N_DIM], xy='y')

    if to_plot_params:
        ax = axs[0, -1]
        plt.sca(ax)
        model.plot_params()

    return axs, rts, hs
コード例 #18
0
ファイル: dtb_2D_fit_VD.py プロジェクト: yulkang/2D_Decision
def main_plot_across_models(
        i_subjs=i_subjs,
        axs=None,
        models_incl=('buffer+serial', ),
        base=10,
        **kwargs,
):
    n_subj = len(i_subjs)

    # ---- Load goodness-of-fit
    def load_gof(i_subj, buffix_dur):
        fix = fix0_pre + [('buffix', buffix_dur)] + fix0_post
        dict_cache, subdir = main_fit(i_subj,
                                      fit_mode='dict_cache_only',
                                      fix=fix,
                                      **kwargs)

        file = locfile.get_file('tab',
                                'best_loss',
                                dict_cache,
                                ext='.csv',
                                subdir=subdir)
        import csv, os
        rows = None
        if not os.path.exists(file):
            print('File absent - returning NaN: %s' % file)
            gof = np.nan
            return gof
        with open(file, 'r') as csvfile:
            reader = csv.DictReader(csvfile)
            for row in reader:
                if rows is None:
                    rows = {k: [row[k]] for k in row.keys()}
                else:
                    for k in row.keys():
                        rows[k].append(row[k])
        gof_kind = 'loss_NLL_test'
        igof = rows['name'].index(gof_kind)
        gof = float(rows[' value'][igof][3:])
        return gof

    nbufdurs = len(bufdurs0)
    gofs = np.zeros([nbufdurs, n_subj])

    for i_subj in range(n_subj):
        for idur, bufdur in enumerate(bufdurs0):
            gofs[idur, i_subj] = load_gof(i_subj, bufdur)

    gofs = gofs - np.nanmin(gofs, 0, keepdims=True)
    gofs = gofs / np.log(base)

    # ---- Load slopes
    fixs = [
        ('buffer+serial', [
            ('buffix', bufdur_best),
        ]),
        # (
        # 'parallel', [
        #     ('buffix', 1.2),
        # ]),
        # (
        # 'serial', [
        #     ('buffix', 0.),
        # ]),
    ]
    model_names = [f[0] for f in fixs]
    n_models = len(model_names)

    models = np.empty([n_models, n_subj], dtype=np.object)
    dict_caches = np.empty([n_models, n_subj], dtype=np.object)
    ds = np.empty([n_models, n_subj], dtype=np.object)
    datas = np.empty([n_subj], dtype=np.object)

    kw_plots = {
        'serial': {
            'linestyle': '--'
        },
        'buffer+serial': {
            'linestyle': '-'
        },
        'parallel': {
            'linestyle': ':'
        },
    }

    subdir = ','.join(fix0) + '+buffix=' + ','.join(
        [('%1.2f' % v) for v in [0., bufdur_best, 1.2]])

    for i_subj in range(n_subj):
        for i_model, (name, fix1) in enumerate(fixs):
            fix = fix0_pre + fix1 + fix0_post
            # try:
            model, data, dict_cache, d, _ = main_fit(i_subj,
                                                     fit_mode='load',
                                                     fix=fix,
                                                     **kwargs)
            # except RuntimeError:
            #     model = None
            #     data = None
            #     dict_cache = None
            #     d = None

            models[i_model, i_subj] = model
            datas[i_subj] = data
            dict_caches[i_model, i_subj] = dict_cache
            ds[i_model, i_subj] = d

    # ---- Plot goodness-of-fit
    if axs is None:
        n_row = 2 + len(models_incl)
        n_row_gs = n_row + 2
        n_col = n_subj
        axs = plt2.GridAxes(3,
                            2,
                            top=.3,
                            left=1.15,
                            right=0.1,
                            bottom=.5,
                            widths=[2],
                            wspace=0.35,
                            heights=[1.5, 1.5, 1.5],
                            hspace=[0.2, 0.9])

    for i_subj in range(n_subj):
        ax = axs[-1, i_subj]
        plt.sca(ax)
        plt2.box_off('all')

        gs1 = axs.gs[-2, i_subj * 2 + 1]
        bax = breakaxis(gs1)
        ax0 = bax.axs[0]  # type: plt.Axes
        ax1 = bax.axs[1]  # type: plt.Axes

        ax0.plot(bufdurs0[:3],
                 gofs[:3, i_subj],
                 'k.-',
                 linewidth=0.75,
                 markersize=4.5)
        ax1.plot(bufdurs0,
                 gofs[:, i_subj],
                 'k.-',
                 linewidth=0.75,
                 markersize=4.5)

        plt.sca(ax1)
        patch_chance_level(level=np.log(100.) / np.log(base), signs=[-1, 1])
        plt.axhline(0, color='k', linestyle='--', linewidth=0.5)
        beautify(ax1)
        beautify_ticks(ax1, add_ticklabel=True)  # i_subj == 0)

        plt2.sameaxes([ax0, ax1], xy='x')

        ax1.set_yticks([0, 20])
        ax0.set_yticks([40, 200])

        if i_subj == 0:
            plt.sca(ax1)
            plt.xlabel('buffer capacity (s)')

            plt.sca(ax)
            plt.ylabel(r'$-\mathrm{log}_{10}\mathrm{BF}$', labelpad=27)
        else:
            ax0.set_yticklabels([])
            ax1.set_yticklabels([])

        plt.sca(axs[0, i_subj])
        beautify_ticks(axs[0, i_subj], add_ticklabel=False)
        beautify_ticks(axs[1, i_subj], add_ticklabel=True)

    plt.sca(axs[-2, 0])
    plt.xlabel('stimulus duration (s)')

    # ---- Plot slopes
    for i_subj in range(n_subj):
        hss = []
        for model_name in models_incl:
            i_model = model_names.index(model_name)
            fix1 = fixs[i_model]
            model = models[i_model, i_subj]
            data = datas[i_subj]
            name = model_names[i_model]

            if model is not None:
                _, hs = plot_coefs_dur_odif_pred_data(
                    data,
                    model,
                    axs=axs[:2, [i_subj]],
                    # to_plot_data=True,
                    to_plot_data=True,
                    kw_plot_model=kw_plots[model_name],
                    coefs_to_plot=[1],
                    add_rowtitle=False)
                hss.append(hs)

        if i_subj == 0:
            # hss[0] = hs['pred'|'data'][coef, dim, dif]
            hs1 = hss[0]['data']  # type: np.ndarray
            for dim in range(hs1.shape[1] - 1, -1, -1):
                hs = []
                for dif in range(hs1.shape[2]):
                    hs.append(hs1[0, dim, dif])
                odim = 1 - dim
                odim_name = consts.DIM_NAMES_LONG[odim]
                plt.sca(axs[dim, 0])

                plt.legend(hs, [
                    'weak ' + odim_name.lower(), 'strong ' + odim_name.lower()
                ],
                           bbox_to_anchor=[0., 0., 1., 1.],
                           handletextpad=0.35,
                           handlelength=0.5,
                           labelspacing=0.3,
                           borderpad=0.,
                           frameon=False,
                           loc='upper left')

        for row in range(axs.shape[0]):
            beautify(axs[row, i_subj])
            if i_subj > 0:
                plt.ylabel('')

    plt2.sameaxes(axs[-1, :])
    ax = axs[-1, -1]
    ax.set_yticklabels([])

    for i_subj in range(n_subj):
        plt.sca(axs[0, i_subj])
        plt.title(consts.SUBJS['VD'][i_subj])

        for row in range(1, n_row):
            plt.sca(axs[row, i_subj])
            plt.title('')

    for row, title in enumerate([
            r'Motion sensitivity ($\beta$)',
            r'Color sensitivity ($\beta$)',
    ]):
        plt.sca(axs[row, 0])
        plt.ylabel(title)

    dict_cache = dict_caches[0, 0]
    for k in ['fix', 'sbj']:
        if k in dict_cache:
            dict_cache.pop(k)

    for ext in ['.pdf', '.png']:
        file = locfile.get_file_fig('coefs_dur_odif_sbjs',
                                    dict_cache,
                                    subdir=subdir,
                                    ext=ext)
        from lib.pylabyk.cacheutil import mkdir4file
        mkdir4file(file)
        plt.savefig(file, dpi=300)
        print('Saved to %s' % file)

    print('--')

    return gofs, models, data