Example #1
0
def plot_error_swap_distribs_err(errs,
                                 dist_errs,
                                 axs=None,
                                 fwid=3,
                                 label='',
                                 model_data=None,
                                 color=None,
                                 model_derr=None):
    if axs is None:
        fsize = (2 * fwid, fwid)
        f, axs = plt.subplots(1, 2, figsize=fsize, sharey=True, sharex=True)
    l = axs[0].hist(errs, density=True, color=color)
    if model_data is not None:
        axs[0].hist(model_data.flatten(),
                    histtype='step',
                    density=True,
                    color='k',
                    linestyle='dashed')
    axs[1].hist(dist_errs, label=label, density=True, color=color)
    if model_derr is not None:
        m_derr = u.normalize_periodic_range(model_derr - model_data)
        axs[1].hist(m_derr.flatten(),
                    histtype='step',
                    density=True,
                    color='k',
                    linestyle='dashed')
    axs[1].legend(frameon=False)
    axs[0].set_xlabel('error (rads)')
    axs[0].set_ylabel('density')
    axs[1].set_xlabel('distractor distance (rads)')
    gpl.clean_plot(axs[0], 0)
    gpl.clean_plot(axs[1], 1)
    return axs
Example #2
0
def visualize_simplex_2d(pts,
                         ax=None,
                         ax_labels=None,
                         thr=.5,
                         pt_grey_col=(.7, .7, .7),
                         line_grey_col=(.6, .6, .6),
                         colors=None,
                         bottom_x=.8,
                         bottom_y=-1.1,
                         top_x=.35,
                         top_y=1,
                         legend=False):
    if colors is None:
        colors = (None, ) * pts.shape[1]
    if ax_labels is None:
        ax_labels = ('', ) * pts.shape[1]
    pts_x = pts[:, 1] - pts[:, 0]
    pts_y = pts[:, 2] - (pts[:, 0] + pts[:, 1])
    ax.plot(pts_x, pts_y, 'o', color=pt_grey_col)
    for i in range(pts.shape[1]):
        mask = pts[:, i] > thr
        ax.plot(pts_x[mask],
                pts_y[mask],
                'o',
                color=colors[i],
                label=ax_labels[i])
    ax.plot([-1, 0], [-1, 1], color=line_grey_col)
    ax.plot([0, 1], [1, -1], color=line_grey_col)
    ax.plot([-1, 1], [-1, -1], color=line_grey_col)
    if legend:
        ax.legend(frameon=False)
    gpl.clean_plot(ax, 1)
    gpl.clean_plot_bottom(ax, 0)
    ax.text(bottom_x,
            bottom_y,
            ax_labels[1],
            verticalalignment='top',
            horizontalalignment='center')
    ax.text(-bottom_x,
            bottom_y,
            ax_labels[0],
            verticalalignment='top',
            horizontalalignment='center')
    ax.text(top_x,
            top_y,
            ax_labels[2],
            verticalalignment='top',
            horizontalalignment='center',
            rotation=-60)
    return ax
Example #3
0
def make_color_circle(ax=None, px=1000, r_cent=350, r_wid=100):
    if ax is None:
        f, ax = plt.subplots(1, 1)

    surface = np.zeros((px, px))
    x, y = np.meshgrid(np.arange(px), np.arange(px))
    full = np.stack((x, y), axis=2)
    cent = full - px / 2
    norm = cent / np.sqrt(np.sum(cent**2, axis=2, keepdims=True))
    vec = np.expand_dims([0, 1], axis=(0, 1))
    dist = np.sqrt(np.sum((norm - vec)**2, axis=2))
    sim = np.arcsin(dist / 2)
    sim[x < px / 2] = -sim[x < px / 2]
    sim = sim - np.nanmin(sim)
    sim = sim / np.nanmax(sim)

    r = np.sqrt(np.sum(cent**2, axis=2))
    mask = np.logical_or(r < r_cent - r_wid / 2, r > r_cent + r_wid / 2)
    sim[mask] = np.nan
    ax.imshow(sim, cmap='hsv')
    gpl.clean_plot(ax, 1)
    gpl.clean_plot_bottom(ax, 0)
Example #4
0
def plot_model_probs(*args,
                     plot_keys=('swap_prob', 'guess_prob'),
                     ax=None,
                     sep=.5,
                     comb_func=np.median,
                     colors=None,
                     sub_x=-1,
                     labels=('swaps', 'guesses'),
                     total_label='correct',
                     arg_names=('Elmo', 'Waldorf'),
                     ms=3,
                     monkey_colors=None):
    if colors is None:
        colors = (None, ) * len(args)
    if ax is None:
        f, ax = plt.subplots(1, 1)
    cents = np.arange(0, len(plot_keys))
    n_clusters = len(args)
    swarm_full = {'x': [], 'y': [], 'monkey': []}
    violin_full = {'x': [], 'y': [], 'monkey': []}
    for i, m in enumerate(args):
        offset = (i - n_clusters / 2) * sep
        swarm_data = {'x': [], 'y': []}
        violin_data = {'x': [], 'y': []}
        for j, pk in enumerate(plot_keys):
            pk_sessions = comb_func(m[pk], axis=(0, 1))
            pk_full = m[pk].to_numpy().flatten()
            if j == 0:
                totals = np.zeros_like(pk_sessions)
                totals_full = np.zeros_like(pk_full)
            totals = totals + pk_sessions
            totals_full = totals_full + pk_full
            xs_full = np.ones(len(pk_full)) * (offset + cents[j])
            violin_data['x'] = np.concatenate((violin_data['x'], xs_full))
            violin_data['y'] = np.concatenate((violin_data['y'], pk_full))
            xs = np.ones(len(pk_sessions)) * (offset + cents[j])
            swarm_data['x'] = np.concatenate((swarm_data['x'], xs))
            swarm_data['y'] = np.concatenate((swarm_data['y'], pk_sessions))

        xs_full = np.ones(len(totals_full)) * (offset + sub_x)
        violin_data['x'] = np.concatenate((violin_data['x'], xs_full))
        violin_data['y'] = np.concatenate((violin_data['y'], 1 - totals_full))

        xs = np.ones(len(pk_sessions), dtype=float) * (offset + sub_x)
        swarm_data['x'] = np.concatenate((swarm_data['x'], xs))
        swarm_data['y'] = np.concatenate((swarm_data['y'], 1 - totals))
        swarm_full['x'] = np.concatenate((swarm_full['x'], swarm_data['x']))
        swarm_full['y'] = np.concatenate((swarm_full['y'], swarm_data['y']))
        monkey_list = [arg_names[i]] * len(swarm_data['x'])
        swarm_full['monkey'] = np.concatenate(
            (swarm_full['monkey'], monkey_list))

        violin_full['x'] = np.concatenate((violin_full['x'], violin_data['x']))
        violin_full['y'] = np.concatenate((violin_full['y'], violin_data['y']))
        monkey_list = [arg_names[i]] * len(violin_data['x'])
        violin_full['monkey'] = np.concatenate(
            (violin_full['monkey'], monkey_list))

    # sns.violinplot(data=violin_full, x='x', y='y', hue='monkey',
    #                palette=monkey_colors, ax=ax)
    l = sns.swarmplot(data=swarm_full,
                      x='x',
                      y='y',
                      hue='monkey',
                      palette=monkey_colors,
                      ax=ax,
                      size=ms)
    ax.legend(frameon=False)
    ax.set_ylim([0, 1])
    ax.set_xticks([.5, 2.5, 4.5])
    ax.set_xticklabels((total_label, ) + labels)
    ax.set_ylabel('probability')
    gpl.clean_plot(ax, 0)
    return ax
Example #5
0
def plot_noise_eg(stim_locs,
                  dx,
                  dy,
                  axs,
                  x_max=100,
                  y_max=6,
                  noise_delts=None,
                  n_pts=1000,
                  d_height=5,
                  d_alpha=.6,
                  d_color=(.85, .85, .85),
                  txt_offset=.2,
                  r_c1=(0, 0, 0),
                  r_c2=(0, 0, 0)):
    ts, x_ax, y_ax = axs

    x_pts = np.linspace(-dx, dx, n_pts)
    y_pts = np.linspace(-dy, dy, n_pts)

    ts.vlines(stim_locs, 0, 1, zorder=10)
    if noise_delts is None:
        noise_delts = np.zeros((2, len(stim_locs)))
        noise_delts[0] = sts.norm(0, np.sqrt(dx)).rvs(len(stim_locs))
        noise_delts[1] = sts.norm(0, np.sqrt(dy)).rvs(len(stim_locs))
    x_ax.vlines(stim_locs + noise_delts[0], 0, 1, zorder=10, color=r_c2)
    y_ax.vlines(stim_locs + noise_delts[1], 0, 1, zorder=10, color=r_c1)

    x_distr_pts = sts.norm(0, np.sqrt(dx)).pdf(x_pts)
    x_distr_pts = d_height * x_distr_pts / max(x_distr_pts)
    y_distr_pts = sts.norm(0, np.sqrt(dy)).pdf(y_pts)
    y_distr_pts = d_height * y_distr_pts / max(y_distr_pts)
    mo_x = np.sort(stim_locs + noise_delts[0])
    mo_y = np.sort(stim_locs + noise_delts[1])
    for i, sl in enumerate(stim_locs):
        x_ax.fill_between(x_pts + sl,
                          x_distr_pts,
                          alpha=d_alpha,
                          color=d_color)
        y_ax.fill_between(y_pts + sl,
                          y_distr_pts,
                          alpha=d_alpha,
                          color=d_color)
        label_x = r'$\hat{X}_{' + str(i + 1) + '}$'
        label_y = r'$\hat{Y}_{' + str(i + 1) + '}$'
        x_ax.text(sl + noise_delts[0, i],
                  0 - txt_offset,
                  label_x,
                  ha='center',
                  va='top')
        y_ax.text(sl + noise_delts[1, i],
                  1 + txt_offset,
                  label_y,
                  ha='center',
                  va='bottom')

    ts.set_xlim([0, x_max])
    ts.set_ylim([0, y_max])
    ts.set_xticks([])
    x_ax.set_xlim([0, x_max])
    x_ax.set_ylim([0, y_max])
    x_ax.set_xticks([])
    y_ax.set_xlim([0, x_max])
    y_ax.set_ylim([0, y_max])
    y_ax.set_xticks([])
    ts.set_xlabel('common feature')

    gpl.clean_plot(y_ax, 1)
    gpl.clean_plot(x_ax, 1)
    gpl.clean_plot(ts, 1)
Example #6
0
def plot_stan_model(model,
                    ae_ax,
                    dist_ax,
                    uni_ax,
                    n=4,
                    spacing=np.pi / 4,
                    sz=8):
    m = model[0]
    rb_means = np.mean(m.samples['report_bits'], axis=0)
    db_means = np.mean(m.samples['dist_bits'], axis=0)
    sm_means = np.mean(m.samples['stim_mem'], axis=0)
    ae_prob, _ = da.ae_var_discrete(db_means, n, spacing=spacing, sz=sz)
    unif_prob = da.uniform_prob(sm_means, n)
    dist = da.dr_gaussian(rb_means, n)
    subj_xs = np.random.randn(len(dist))
    x_pos = np.array([0])
    ae_prob_arr = np.expand_dims(ae_prob, 1)
    p = ae_ax.violinplot(ae_prob, positions=x_pos, showextrema=False)
    gpl.plot_trace_werr(x_pos,
                        ae_prob_arr,
                        points=True,
                        ax=ae_ax,
                        error_func=gpl.conf95_interval)

    dist_arr = np.expand_dims(dist, 1)
    p = dist_ax.violinplot(dist, positions=x_pos, showextrema=False)
    gpl.plot_trace_werr(x_pos,
                        dist_arr,
                        points=True,
                        ax=dist_ax,
                        error_func=gpl.conf95_interval)

    up_arr = np.expand_dims(unif_prob, 1)
    p = uni_ax.violinplot(up_arr, positions=x_pos, showextrema=False)
    gpl.plot_trace_werr(x_pos,
                        up_arr,
                        points=True,
                        ax=uni_ax,
                        error_func=gpl.conf95_interval)

    gpl.clean_plot(ae_ax, 0)
    gpl.clean_plot_bottom(ae_ax)
    gpl.clean_plot(dist_ax, 0)
    gpl.clean_plot_bottom(dist_ax)
    gpl.clean_plot(uni_ax, 0)
    gpl.clean_plot_bottom(uni_ax)
    gpl.make_yaxis_scale_bar(ae_ax,
                             anchor=0,
                             magnitude=.2,
                             double=False,
                             label='assignment\nerror rate',
                             text_buff=.95)
    gpl.make_yaxis_scale_bar(dist_ax,
                             anchor=0,
                             magnitude=.5,
                             double=False,
                             label='local distortion\n(MSE)',
                             text_buff=.8)
    gpl.make_yaxis_scale_bar(uni_ax,
                             anchor=0,
                             magnitude=.2,
                             double=False,
                             label='guess rate',
                             text_buff=.7)