Beispiel #1
0
def demo_visualize_k_effects(kps=[0., 0.01, .1, 2.],
                             kds=[0, 1., 4.],
                             cutoff=0.005,
                             n_samples=550,
                             s_as_triangles=False,
                             seed=1234):

    x = lowpass_random(n_samples=n_samples,
                       cutoff=cutoff,
                       rng=seed,
                       normalize=True)

    plt.figure(figsize=(10, 6))
    plt.subplots_adjust(wspace=0.01,
                        hspace=0.01,
                        left=0.08,
                        right=.98,
                        top=.92)
    ax = plt.subplot2grid((len(kps), len(kds)), (0, 0))
    for i, kp in enumerate(kps):
        for j, kd in enumerate(kds):
            xe = pid_encode(x, kp=kp, kd=kd)
            h = Herder()
            xc = [h(xet) for xet in xe]
            xd = pid_decode(xc, kp=kp, kd=kd)
            this_ax = plt.subplot2grid((len(kps), len(kds)),
                                       (len(kps) - i - 1, j),
                                       sharex=ax,
                                       sharey=ax)
            plt.plot(xd, color='C1', label='$\hat x_t$')

            plt.text(
                .01,
                .01,
                '$\left<|x_t-\hat x_t|\\right>_t={:.2g}, \;\;\;  N={}$'.format(
                    np.abs(x - xd).mean(), int(np.sum(np.abs(xc)))),
                ha='left',
                va='bottom',
                transform=this_ax.transAxes,
                bbox=dict(boxstyle='square',
                          facecolor='w',
                          edgecolor='none',
                          alpha=0.8,
                          pad=0.0))

            if s_as_triangles:
                up_spikes = np.nonzero(xc > 0)[0]
                down_spikes = np.nonzero(xc < 0)[0]
                plt.plot(up_spikes,
                         np.zeros(up_spikes.shape),
                         '^',
                         color='k',
                         label='$s_t^+$')
                plt.plot(down_spikes,
                         np.zeros(down_spikes.shape),
                         'v',
                         color='r',
                         label='$s_t^-$')
            else:
                plt.plot(xc, color='k', label='$s_t$')
            plt.plot(x, color='C0', label='$x_t$')
            plt.grid()
            if i > 0:
                plt.tick_params('x', labelbottom='off')
            else:
                plt.xlabel('$k_d={}$'.format(kd))
            if j > 0:
                plt.tick_params('y', labelleft='off')
            else:
                plt.ylabel('$k_p={}$'.format(kp))

    ax.set_xlim(0, n_samples)
    ax.set_ylim(np.min(x) * 1.1, np.max(x) * 1.1)
    handles, labels = plt.gca().get_legend_handles_labels()
    plt.legend(handles[::-1],
               labels[::-1],
               bbox_to_anchor=(1, 1),
               bbox_transform=plt.gcf().transFigure,
               ncol=len(handles[::-1]),
               loc='upper right')
    plt.show()
Beispiel #2
0
def demo_why_kp_explanation(
    n_steps=2000,
    kd=1.,
    kp_values=[0, 0.01, 0.1],
    x_cutoff=0.03,
    w_cutoff=0.01,
    w_fixed=False,
    seed=1234,
):
    """
    We have time varying signals x, w.  See how different choices of kp, kd, and quantization affect our
    ability to approximate the time-varying quantity x*w.
    """
    rng = np.random.RandomState(seed)

    x = lowpass_random(n_samples=n_steps,
                       cutoff=x_cutoff,
                       normalize=True,
                       rng=rng)
    w = lowpass_random(
        n_samples=n_steps, cutoff=w_cutoff, normalize=True,
        rng=rng) + 1 if not w_fixed else np.ones(n_steps)
    xw = x * w

    plt.figure(figsize=(10, 4))
    with vstack_plots(sharex=True,
                      sharey=True,
                      left=0.09,
                      right=.98,
                      spacing=0.02,
                      remove_ticks=False):
        ax = add_subplot()
        plt.plot(x, label='$x_t$')
        plt.plot(w, label='$w_t$')
        plt.title('In all plots, $k_d={}$'.format(kd), loc='left')
        plt.grid()

        for kp in kp_values:
            s = pid_encode(x, kp=kp, kd=kd, quantization='herd')
            zprime = pid_decode(s * w, kp=kp, kd=kd)
            ax_mult = add_subplot()
            plt.plot(xw, label='$z_t=x_t\cdot w_t$', color='C2')
            plt.plot(
                zprime,
                label=
                '$\hat z_t = dec_{{k_p k_d}}(Q(enc_{{k_p k_d}}(x_t))\cdot w_t)$'
                .format(kp),
                color='C3')
            plt.ylabel('$k_p={}$'.format(kp))
            plt.grid()
        plt.xlabel('t')

    ax_mult.set_ylim(-4.5, 4.5)
    handles, labels = ax.get_legend_handles_labels()
    handles2, labels2 = ax_mult.get_legend_handles_labels()

    plt.legend(handles + handles2,
               labels + labels2,
               bbox_to_anchor=(.99, .99),
               bbox_transform=plt.gcf().transFigure,
               ncol=len(handles + handles2),
               loc='upper right')

    plt.show()
Beispiel #3
0
def demo_kd_too_large(n_steps=20000,
                      kp=.01,
                      kd=1.,
                      kp_scan_range=(.001, .1),
                      kd_scan_range=(.1, 10),
                      n_k_points=32,
                      x_cutoff=0.01,
                      w_cutoff=0.002,
                      w_fixed=False,
                      k_spacing='log',
                      seed=1238):
    """
    We have time varying signals x, w.  See how different choices of kp, kd, and quantization affect our
    ability to approximate the time-varying quantity x*w.
    """

    rng = np.random.RandomState(seed)

    x = lowpass_random(n_samples=n_steps,
                       cutoff=x_cutoff,
                       normalize=True,
                       rng=rng)
    w = lowpass_random(
        n_samples=n_steps, cutoff=w_cutoff, normalize=True,
        rng=rng) if not w_fixed else np.ones(n_steps)
    x_w = x * w

    distance_mat_nonquantized = np.zeros((n_k_points, n_k_points))
    distance_mat_quantized = np.zeros((n_k_points, n_k_points))
    distance_mat_recon = np.zeros((n_k_points, n_k_points))
    n_spikes = np.zeros((n_k_points, n_k_points))

    pi = ProgressIndicator(n_k_points**2)

    kp_values = point_space(kp_scan_range[0],
                            kp_scan_range[1],
                            n_points=n_k_points,
                            spacing=k_spacing)
    kd_values = point_space(kd_scan_range[0],
                            kd_scan_range[1],
                            n_points=n_k_points,
                            spacing=k_spacing)

    for i, kpi in enumerate(kp_values):
        for j, kdj in enumerate(kd_values):
            pi.print_update(i * n_k_points + j)
            x_enc = pid_encode(x, kp=kpi, kd=kdj, quantization=None)
            x_enc_quantized = pid_encode(x,
                                         kp=kpi,
                                         kd=kdj,
                                         quantization='herd')
            x_enc_w = pid_decode(x_enc * w, kp=kpi, kd=kdj)
            x_enc_quantized_w_dec = pid_decode(x_enc_quantized * w,
                                               kp=kpi,
                                               kd=kdj)
            x_enc_quantized_dec_w = pid_decode(x_enc_quantized, kp=kpi,
                                               kd=kdj) * w
            distance_mat_nonquantized[i, j] = cosine_distance(x_w, x_enc_w)
            distance_mat_quantized[i, j] = cosine_distance(
                x_w, x_enc_quantized_w_dec)
            distance_mat_recon[i, j] = cosine_distance(x_w,
                                                       x_enc_quantized_dec_w)
            n_spikes[i, j] = np.abs(x_enc_quantized).sum()

    x_enc_quantized = pid_encode(x, kp=kp, kd=kd, quantization='herd')
    x_enc = pid_encode(x, kp=kp, kd=kd, quantization=None)
    xwq = pid_decode(x_enc_quantized * w, kp=kp, kd=kd)
    xwn = pid_decode(x_enc * w, kp=kp, kd=kd)

    return (x, w, x_w, x_enc_quantized, x_enc, xwq,
            xwn), (distance_mat_nonquantized, distance_mat_quantized,
                   distance_mat_recon, n_spikes), (kp, kd, kd_values,
                                                   kp_values)
Beispiel #4
0
def demo_weight_update_figures(n_samples=1000,
                               seed=1278,
                               kpx=.0015,
                               kdx=.5,
                               kpe=.0015,
                               kde=.5,
                               warmup=500,
                               future_fill_settings=dict(
                                   color='lightsteelblue',
                                   hatch='//',
                                   edgecolor='w'),
                               past_fill_settings=dict(color='lightpink',
                                                       hatch='\\\\',
                                                       edgecolor='w'),
                               plot=True):
    """
    What this test shows:
    (1) The FutureWeightGradCalculator indeed perfectly calculates the product between the two reconstructions
    (2) The FutureWeightGradCalculator with "true" values plugged in place of reconstructions isn't actually all that great.
    (3) We implemented the "true" value idea correctly, because we try plugging in the reconstructions it is identical to actually using the reconstructions.
    """
    rng = np.random.RandomState(seed)

    linewidth = 2

    matplotlib.rcParams['hatch.color'] = 'w'
    matplotlib.rcParams['hatch.linewidth'] = 2.

    rx = kdx / float(kpx + kdx)
    re = kde / float(kpe + kde)

    t = np.arange(n_samples)
    x = lowpass_random(n_samples + warmup,
                       cutoff=0.0003,
                       rng=rng,
                       normalize=True)[warmup:]
    e = lowpass_random(n_samples + warmup,
                       cutoff=0.0003,
                       rng=rng,
                       normalize=True)[warmup:]
    if x.mean() < 0:
        x = -x
    if e.mean() < 0:
        e = -e
    # x[-int(n_samples/4):]=0
    # e[-int(n_samples/4):]=0
    xc = pid_encode(x, kp=kpx, kd=kdx, quantization='herd')
    ec = pid_encode(e, kp=kpe, kd=kde, quantization='herd')

    xd = pid_decode(xc, kp=kpx, kd=kdx)
    ed = pid_decode(ec, kp=kpe, kd=kde)
    w_true = x * e
    w_recon = xd * ed

    fig = plt.figure(figsize=(7, 3))
    with vstack_plots(grid=False,
                      sharex=False,
                      spacing=0.05,
                      xlabel='t',
                      xlim=(0, n_samples),
                      show_x=False,
                      show_y=False,
                      left=0.01,
                      right=0.98,
                      top=.96,
                      bottom=.08):

        ix = np.nonzero(ec)[0][1]

        future_top_e = ed[ix] * re**(np.arange(n_samples - ix))
        future_bottom_e = ed[ix - 1] * re**(np.arange(n_samples - ix) + 1)
        future_top_x = xd[ix] * rx**(np.arange(n_samples - ix))

        ix_xlast = np.nonzero(xc[:ix])[0][-1]

        past_top_x = xd[ix_xlast:ix]
        past_top_e = ed[ix_xlast:ix]

        past_top_area = past_top_e * past_top_x

        future_top_area = future_top_e * future_top_x
        future_bottom_area = future_bottom_e * future_top_x

        ax0 = ax = add_subplot()
        plt.plot(x, label='$x$', linewidth=linewidth)
        plt.plot(xc, color='k', label='$\\bar x$', linewidth=linewidth + 1)
        plt.plot(xd, label='$\hat x$', linewidth=linewidth)
        ax.fill_between(t[ix_xlast:ix], 0., past_top_x, **past_fill_settings)
        ax.fill_between(t[ix:], 0., future_top_x, **future_fill_settings)
        plt.legend(loc='upper left')
        plt.axhline(0, color='k', linewidth=2)

        # ax.arrow(ix-10, 1, ix, 00, head_width=0.05, head_length=0.1, fc='k', ec='k')
        # ax.set_ylim(bottom=-.5, top=4)

        ax1 = ax = add_subplot()
        plt.plot(e, label='$e$', linewidth=linewidth)
        plt.plot(ec, color='k', label='$\\bar e$', linewidth=linewidth + 1)
        plt.plot(ed, label='$\hat e$', linewidth=linewidth)
        ax.fill_between(t[ix_xlast:ix], 0., past_top_e, **past_fill_settings)
        ax.fill_between(t[ix:], future_bottom_e, future_top_e,
                        **future_fill_settings)
        plt.legend(loc='upper left')
        ax.annotate(
            'spike',
            xy=(ix + 4, 0.4),
            xytext=(ix + 70, 1.),
            fontsize=10,
            fontweight='bold',
            arrowprops=dict(facecolor='black', shrink=0.05),
        )

        # plt.ylim(-.5, 4)

        ax2 = ax = add_subplot()
        plt.plot(w_true,
                 linewidth=linewidth,
                 label='$\\frac{\partial \mathcal{L}}{\partial w}_t=x_t e_t$')
        plt.plot(
            w_recon,
            linewidth=linewidth,
            label=
            '$\widehat{\\frac{\partial \mathcal{L}}{\partial w}}_t = \hat x_t \hat e_t$'
        )
        ax.fill_between(
            t[ix_xlast:ix],
            0.,
            past_top_area,
            label=
            '$\widehat{\\frac{\partial \mathcal{L}}{\partial w}}_{t,past}$',
            **past_fill_settings)
        ax.fill_between(
            t[ix:],
            future_bottom_area,
            future_top_area,
            label=
            '$\widehat{\\frac{\partial \mathcal{L}}{\partial w}}_{t,future}$',
            **future_fill_settings)
        plt.axhline(0, color='k', linewidth=2)
        plt.legend(loc='upper left', ncol=2)

    ax0.set_xlim(0, n_samples * 3 / 4)
    ax1.set_xlim(0, n_samples * 3 / 4)
    ax2.set_xlim(0, n_samples * 3 / 4)
    ax0.set_ylim(-.5, 4)
    ax1.set_ylim(-.5, 4)

    # add_subplot()
    # plt.plot(np.cumsum(x*e))
    # plt.plot(w_recon[:, 0, 0])

    plt.show()