コード例 #1
0
ファイル: dtb_1D_sim.py プロジェクト: yulkang/2D_Decision
def plot_p_ch_vs_ev(ev_cond,
                    p_ch,
                    style='pred',
                    ax: plt.Axes = None,
                    **kwargs) -> plt.Line2D:
    """
    @param ev_cond: [condition] or [condition, frame]
    @type ev_cond: torch.Tensor
    @param p_ch: [condition, ch] or [condition, rt_frame, ch]
    @type p_ch: torch.Tensor
    @return:
    """
    if ax is None:
        ax = plt.gca()
    if ev_cond.ndim != 1:
        if ev_cond.ndim == 3:
            ev_cond = npt.p2st(ev_cond)[0]
        assert ev_cond.ndim == 2
        ev_cond = ev_cond.mean(1)
    if p_ch.ndim != 2:
        assert p_ch.ndim == 3
        p_ch = p_ch.sum(1)

    kwargs = get_kw_plot(style, **kwargs)

    h = ax.plot(*npys(ev_cond, npt.p2st(npt.sumto1(p_ch, -1))[1]), **kwargs)
    plt2.box_off(ax=ax)
    x_lim = ax.get_xlim()
    plt2.detach_axis('x', amin=x_lim[0], amax=x_lim[1], ax=ax)
    plt2.detach_axis('y', amin=0, amax=1, ax=ax)
    ax.set_yticks([0, 0.5, 1])
    ax.set_yticklabels(['0', '', '1'])
    ax.set_xlabel('evidence')
    ax.set_ylabel(r"$\mathrm{P}(z=1 \mid c)$")
    return h
コード例 #2
0
ファイル: dtb_1D_sim.py プロジェクト: yulkang/2D_Decision
    def plot_rt_given_cond_ch(ev_cond1, n_rt_given_cond_ch, **kw1):
        # n_rt_given_cond_ch[cond, fr]
        # p_rt_given_cond_ch[cond, fr]
        p_rt_given_cond_ch = np2.sumto1(n_rt_given_cond_ch, 1)

        nt = n_rt_given_cond_ch.shape[1]
        t = np.arange(nt) * dt

        # n_in_cond_ch[cond]
        n_in_cond_ch = n_rt_given_cond_ch.sum(1)
        if pool == 'mean':
            rt_pooled = (t[None, :] * p_rt_given_cond_ch).sum(1)
        elif pool == 'var':
            raise NotImplementedError()
        else:
            raise ValueError()

        if style.startswith('data'):
            rt_pooled[n_in_cond_ch < thres_n_trial] = np.nan

        kw = get_kw_plot(style, **kw1)
        h = ax.plot(ev_cond1, rt_pooled, **kw)
        return h, rt_pooled
コード例 #3
0
ファイル: dtb_2D_sim.py プロジェクト: yulkang/2D_Decision
def plot_p_ch_vs_ev(
        ev_cond: Union[torch.Tensor, np.ndarray],
        n_ch: Union[torch.Tensor, np.ndarray],
        style='pred',
        ax: plt.Axes = None,
        dim_rel=0,
        group_dcond_irr: Iterable[Iterable[int]] = None,
        cmap: Union[str, Callable] = 'cool',
        kw_plot=(),
) -> Iterable[plt.Line2D]:
    """
    @param ev_cond: [condition, dim] or [condition, frame, dim, (mean, var)]
    @type ev_cond: torch.Tensor
    @param n_ch: [condition, ch] or [condition, rt_frame, ch]
    @type n_ch: torch.Tensor
    @return: hs[cond_irr][0] = Line2D, conds_irr
    """
    if ax is None:
        ax = plt.gca()
    if ev_cond.ndim != 2:
        assert ev_cond.ndim == 4
        ev_cond = npt.p2st(ev_cond.mean(1))[0]
    if n_ch.ndim != 2:
        assert n_ch.ndim == 3
        n_ch = n_ch.sum(1)

    ev_cond = npy(ev_cond)
    n_ch = npy(n_ch)
    n_cond_all = n_ch.shape[0]
    ch_rel = np.repeat(np.array(consts.CHS[dim_rel])[None, :], n_cond_all, 0)
    n_ch = n_ch.reshape([-1])
    ch_rel = ch_rel.reshape([-1])

    dim_irr = consts.get_odim(dim_rel)
    conds_rel, dcond_rel = np.unique(ev_cond[:, dim_rel], return_inverse=True)
    conds_irr, dcond_irr = np.unique(np.abs(ev_cond[:, dim_irr]),
                                     return_inverse=True)

    if group_dcond_irr is not None:
        conds_irr, dcond_irr = group_conds(conds_irr, dcond_irr,
                                           group_dcond_irr)

    n_conds = [len(conds_rel), len(conds_irr)]

    n_ch_rel = npg.aggregate(
        np.stack([
            ch_rel,
            np.repeat(dcond_irr[:, None], consts.N_CH_FLAT, 1).flatten(),
            np.repeat(dcond_rel[:, None], consts.N_CH_FLAT, 1).flatten(),
        ]), n_ch, 'sum', [consts.N_CH, n_conds[1], n_conds[0]])
    p_ch_rel = n_ch_rel[1] / n_ch_rel.sum(0)

    hs = []
    for dcond_irr1, p_ch1 in enumerate(p_ch_rel):
        if type(cmap) is str:
            color = plt.get_cmap(cmap, n_conds[1])(dcond_irr1)
        else:
            color = cmap(n_conds[1])(dcond_irr1)
        kw1 = get_kw_plot(style, color=color, **dict(kw_plot))
        h = ax.plot(conds_rel, p_ch1, **kw1)
        hs.append(h)
    plt2.box_off(ax=ax)
    x_lim = ax.get_xlim()
    plt2.detach_axis('x', amin=x_lim[0], amax=x_lim[1], ax=ax)
    plt2.detach_axis('y', amin=0, amax=1, ax=ax)
    ax.set_yticks([0, 0.5, 1])
    ax.set_yticklabels(['0', '', '1'])
    ax.set_xlabel('evidence')
    ax.set_ylabel(r"$\mathrm{P}(z=1 \mid c)$")

    return hs, conds_irr
コード例 #4
0
def plot_ch_ev_by_dur(n_cond_dur_ch: np.ndarray = None,
                      data: Data2DVD = None,
                      dif_irrs: Sequence[Union[Iterable[int],
                                               int]] = ((0, 1), (2, )),
                      ev_cond_dim: np.ndarray = None,
                      durs: np.ndarray = None,
                      dur_prct_groups=((0, 33), (33, 67), (67, 100)),
                      style='data',
                      kw_plot=(),
                      jitter=0.,
                      axs=None,
                      fig=None):
    """
    Panels[dim, irr_dif_group], curves by dur_group
    :param n_cond_dur_ch:
    :param data:
    :param dif_irrs:
    :param ev_cond_dim:
    :param durs:
    :param dur_prct_groups:
    :param style:
    :param kw_plot:
    :param axs: [row, col]
    :return:
    """
    if ev_cond_dim is None:
        ev_cond_dim = data.ev_cond_dim
    if durs is None:
        durs = data.durs
    if n_cond_dur_ch is None:
        n_cond_dur_ch = npy(data.get_data_by_cond('all')[1])

    n_conds_dur_chs = get_n_conds_dur_chs(n_cond_dur_ch)
    conds_dim = [np.unique(cond1) for cond1 in ev_cond_dim.T]
    n_dur = len(dur_prct_groups)
    n_dif = len(dif_irrs)

    if axs is None:
        if fig is None:
            fig = plt.figure(figsize=[6, 4])

        n_row = consts.N_DIM
        n_col = n_dur
        gs = plt.GridSpec(nrows=n_row, ncols=n_col, figure=fig)
        axs = np.empty([n_row, n_col], dtype=np.object)
        for row in range(n_row):
            for col in range(n_col):
                axs[row, col] = plt.subplot(gs[row, col])

    for dim_rel in range(consts.N_DIM):
        for idif, dif_irr in enumerate(dif_irrs):
            for i_dur, dur_prct_group in enumerate(dur_prct_groups):

                ax = axs[dim_rel, i_dur]
                plt.sca(ax)

                conds_rel = conds_dim[dim_rel]
                dim_irr = consts.get_odim(dim_rel)
                _, cond_irr = np.unique(np.abs(conds_dim[dim_irr]),
                                        return_inverse=True)
                incl_irr = np.isin(cond_irr, dif_irr)

                cmap = consts.CMAP_DIM[dim_rel](n_dif)

                ix_dur = np.arange(len(durs))
                dur_incl = (
                    (ix_dur >= np.percentile(ix_dur, dur_prct_group[0]))
                    & (ix_dur <= np.percentile(ix_dur, dur_prct_group[1])))
                if dim_rel == 0:
                    n_cond_dur_ch1 = n_conds_dur_chs[:, incl_irr, :, :, :].sum(
                        (1, -1))
                else:
                    n_cond_dur_ch1 = n_conds_dur_chs[incl_irr, :, :, :, :].sum(
                        (0, -2))
                n_cond_ch1 = n_cond_dur_ch1[:, dur_incl, :].sum(1)
                p_cond__ch1 = n_cond_ch1[:, 1] / n_cond_ch1.sum(1)

                x = conds_rel
                y = p_cond__ch1

                dx = conds_rel[1] - conds_rel[0]
                jitter1 = 0
                kw = consts.get_kw_plot(style,
                                        color=cmap(idif),
                                        **dict(kw_plot))
                if style.startswith('data'):
                    plt.plot(x + jitter1, y, **kw)
                else:
                    plt.plot(x + jitter1, y, **kw)

                plt2.box_off(ax=ax)
                plt2.detach_yaxis(0, 1, ax=ax)

                ax.set_yticks([0, 0.5, 1.])
                if i_dur == 0:
                    ax.set_yticklabels(['0', '', '1'])
                else:
                    ax.set_yticklabels([])
                    ax.set_xticklabels([])

    return axs
コード例 #5
0
def plot_coefs_dur_irrixn(
        n_cond_dur_ch: np.ndarray = None,
        data: Data2DVD = None,
        ev_cond_dim: np.ndarray = None,
        durs: np.ndarray = None,
        style='data',
        kw_plot=(),
        jitter0=0.,
        coefs_to_plot=(2, ),
        axs=None,
        fig=None,
):
    if ev_cond_dim is None:
        ev_cond_dim = data.ev_cond_dim
    if durs is None:
        durs = npy(data.durs)
    if n_cond_dur_ch is None:
        n_cond_dur_ch = npy(data.get_data_by_cond('all')[1])

    coef, se_coef = coefdur.get_coefs_irr_ixn_from_histogram(
        n_cond_dur_ch,
        ev_cond_dim=ev_cond_dim)[:2]  # type: (np.ndarray, np.ndarray)

    coef_names = [
        'bias', 'slope', 'rel x abs(irr)', 'rel x irr', 'abs(irr)', 'irr'
    ]
    n_coef = len(coefs_to_plot)

    if axs is None:
        if fig is None:
            fig = plt.figure(figsize=[6, 1.5 * n_coef])

        n_row = consts.N_DIM
        n_col = n_coef
        gs = plt.GridSpec(nrows=n_row,
                          ncols=n_col,
                          figure=fig,
                          left=0.25,
                          right=0.95,
                          bottom=0.15,
                          top=0.9)
        axs = np.empty([n_row, n_col], dtype=np.object)
        for row in range(n_row):
            for col in range(n_col):
                axs[row, col] = plt.subplot(gs[row, col])

    for ii_coef, i_coef in enumerate(coefs_to_plot):
        coef_name = coef_names[i_coef]
        for dim_rel in range(consts.N_DIM):
            ax = axs[dim_rel, ii_coef]  # type: plt.Axes
            plt.sca(ax)

            y = coef[i_coef, dim_rel, :]
            e = se_coef[i_coef, dim_rel, :]
            jitter = 0.
            kw_plot = {'color': 'k', 'for_err': True, **dict(kw_plot)}
            kw = consts.get_kw_plot(style, **kw_plot)
            if style.startswith('data'):
                plt.errorbar(durs + jitter, y, yerr=e, **kw)
            else:
                plt.plot(durs + jitter, y, **kw)
            plt2.box_off()

            max_dur = np.amax(npy(durs))
            ax.set_xlim(-0.05 * max_dur, 1.05 * max_dur)
            plt2.detach_axis('x', 0, max_dur)  # , ax=ax)

            if dim_rel == 0:
                ax.set_xticklabels([])
                ax.set_title(coef_name.lower())
            elif i_coef == 0:
                ax.set_xlabel('duration (s)')

    plt2.rowtitle(consts.DIM_NAMES_LONG, axes=axs)

    return coef, se_coef, axs
コード例 #6
0
def plot_coefs_dur_odif(
        n_cond_dur_ch: np.ndarray = None,
        data: Data2DVD = None,
        dif_irrs: Sequence[Union[Iterable[int], int]] = ((0, 1), (2, )),
        ev_cond_dim: np.ndarray = None,
        durs: np.ndarray = None,
        style='data',
        kw_plot=(),
        jitter0=0.,
        coefs_to_plot=(0, 1),
        add_rowtitle=True,
        dim_incl=(0, 1),
        axs=None,
        fig=None,
):
    """

    :param n_cond_dur_ch:
    :param data:
    :param dif_irrs:
    :param ev_cond_dim:
    :param durs:
    :param style:
    :param kw_plot:
    :param jitter0:
    :param coefs_to_plot:
    :param add_rowtitle:
    :param dim_incl:
    :param axs:
    :param fig:
    :return: coef, se_coef, axs, hs[coef, dim, dif]
    """

    if ev_cond_dim is None:
        ev_cond_dim = data.ev_cond_dim
    if durs is None:
        durs = npy(data.durs)
    if n_cond_dur_ch is None:
        n_cond_dur_ch = npy(data.get_data_by_cond('all')[1])

    coef, se_coef = coefdur.get_coefs_mesh_from_histogram(
        n_cond_dur_ch, ev_cond_dim=ev_cond_dim,
        dif_irrs=dif_irrs)[:2]  # type: (np.ndarray, np.ndarray)

    coef_names = ['bias', 'slope']
    n_coef = len(coefs_to_plot)
    n_dim = len(dim_incl)

    if axs is None:
        if fig is None:
            fig = plt.figure(figsize=[6, 1.5 * n_coef])

        n_row = n_dim
        n_col = n_coef
        gs = plt.GridSpec(nrows=n_row,
                          ncols=n_col,
                          figure=fig,
                          left=0.25,
                          right=0.95,
                          bottom=0.15,
                          top=0.9)
        axs = np.empty([n_row, n_col], dtype=np.object)
        for row in range(n_row):
            for col in range(n_col):
                axs[row, col] = plt.subplot(gs[row, col])

    n_dif = len(dif_irrs)
    hs = np.empty(
        [len(coefs_to_plot), len(dim_incl),
         len(dif_irrs)], dtype=np.object)

    for ii_coef, i_coef in enumerate(coefs_to_plot):
        coef_name = coef_names[i_coef]
        for i_dim, dim_rel in enumerate(dim_incl):
            ax = axs[i_dim, ii_coef]  # type: plt.Axes
            plt.sca(ax)

            cmap = consts.CMAP_DIM[dim_rel](n_dif)

            for idif, dif_irr in enumerate(dif_irrs):
                y = coef[i_coef, dim_rel, idif, :]
                e = se_coef[i_coef, dim_rel, idif, :]
                ddur = durs[1] - durs[0]
                jitter = ddur * jitter0 * (idif - (n_dif - 1) / 2)
                kw = consts.get_kw_plot(style,
                                        color=cmap(idif),
                                        for_err=True,
                                        **dict(kw_plot))
                if style.startswith('data'):
                    h = plt.errorbar(durs + jitter, y, yerr=e, **kw)[0]
                else:
                    h = plt.plot(durs + jitter, y, **kw)
                hs[ii_coef, i_dim, idif] = h
            plt2.box_off()

            max_dur = np.amax(npy(durs))
            ax.set_xlim(-0.05 * max_dur, 1.05 * max_dur)
            plt2.detach_axis('x', 0, max_dur)

            if dim_rel == 0:
                ax.set_xticklabels([])
                ax.set_title(coef_name.lower())
            elif i_coef == 0:
                ax.set_xlabel('duration (s)')

    if add_rowtitle:
        plt2.rowtitle(consts.DIM_NAMES_LONG, axes=axs)

    return coef, se_coef, axs, hs