Exemplo n.º 1
0
def show_alignment_plot(record):

    duck = record.get_result()
    alignments = duck[:, 'alignments'].break_in().to_array().T

    with vstack_plots(grid=True, spacing=0.1, bottom_pad=0.1, left_pad=0.1):

        add_subplot()
        plt.plot(duck[:, 'epoch'],
                 duck[:, 'test_init_error'],
                 label='Init Eq. Prop: $s^f$',
                 color='C2')
        plt.plot(duck[:, 'epoch'],
                 duck[:, 'test_neg_error'],
                 label='Init Eq. Prop: $s^-$',
                 color='C1')
        plt.ylabel('Classification Test Error')
        plt.ylim(0, 10)
        plt.legend()
        add_subplot()

        n_layers = len(alignments)
        colors = list(get_color_cycle_map('jet', n_layers + 3))
        for i, (alignments, c) in enumerate(zip(alignments[:-1], colors)):
            plt.plot(
                duck[:, 'epoch'],
                alignments,
                label=
                f'$S(\\nabla_{{\phi_{i+1}}} L_{i+1}, \\nabla_{{\phi_{i+1}}} L_{{{i+2}:{n_layers}}})$',
                color=c)

        plt.xlabel('Epoch')
        plt.ylabel('Gradient Alignment')
        plt.legend()
    plt.show()
Exemplo n.º 2
0
def figure_alignment_during_training(orientation='h'):

    from init_eqprop.demo_energy_eq_prop_AH_grad_alignment import ex_large_alignment

    result_local = ex_large_alignment.get_latest_record(
        if_none='run').get_result()  # type: Duck
    result_global = ex_large_alignment.get_variant(
        local_loss=False).get_latest_record(
            if_none='run').get_result()  # type: Duck

    alignments = result_local[:, 'alignments'].break_in().to_array()[:, :-1]

    length = min(len(result_local), len(result_global))
    plt.figure(figsize=(8, 6) if orientation == 'v' else (10, 2.3))

    context = \
        vstack_plots(grid=True, xlabel='Epochs', left_pad=0.15, bottom_pad=0.1, right_pad=0.1, spacing=0.05) if orientation =='v' else \
        hstack_plots(grid=True, xlabel='Epochs', left_pad=0.1, bottom_pad=0.2, right_pad=0.1, spacing=0.4, sharey=False, show_y=True)

    with context:
        add_subplot()
        plt.plot(result_local[:length, 'epoch'],
                 result_local[:length, 'test_init_error'],
                 label='Local $s^f$',
                 color='C0')
        plt.plot(result_local[:length, 'epoch'],
                 result_local[:length, 'test_neg_error'],
                 label='Local $s^-$',
                 color=modify_color('C0', modifier=None),
                 linestyle=':')
        plt.plot(result_global[:length, 'epoch'],
                 result_global[:length, 'test_init_error'],
                 label='Global $s^f$',
                 color='C1')
        plt.plot(result_global[:length, 'epoch'],
                 result_global[:length, 'test_neg_error'],
                 label='Global $s^-$',
                 color=modify_color('C1', modifier=None),
                 linestyle=':')
        plt.ylim(0, 10)
        plt.legend(loc='Upper Right')
        plt.ylabel('Test Classification Error')

        add_subplot()
        for i, c in zip(
                range(alignments.shape[1]),
                get_color_cycle_map('jet', length=alignments.shape[1] + 4)):
            plt.plot(
                result_local[:length, 'epoch'],
                alignments[:length, i],
                color=c,
                label=
                f'$S(\\nabla_{{\phi_{i+1}}} L_{i+1}, \\nabla_{{\phi_{i+1}}} L_{{{i+2}:{alignments.shape[1]+1}}})$'
            )
        plt.ylabel('Alignment')
        plt.legend()
    plt.show()
Exemplo n.º 3
0
def demo_why_kp_explanation(
        n_steps=2000,
        kd=1.,
        kp_values = [0, 0.01, 0.1],
        x_cutoff = 0.03,
        w_cutoff = 0.01,
        w_fixed = False,
        seed = 1234,
        ):
    """
    We have time varying signals x, w.  See how different choices of kp, kd, and quantization affect our
    ability to approximate the time-varying quantity x*w.
    """

    rng = np.random.RandomState(seed)

    x = lowpass_random(n_samples=n_steps, cutoff=x_cutoff, normalize=True, rng=rng)
    w = lowpass_random(n_samples=n_steps, cutoff=w_cutoff, normalize=True, rng=rng)+1 if not w_fixed else np.ones(n_steps)
    xw = x*w

    plt.figure(figsize=(10, 4))
    with vstack_plots(sharex=True, sharey=True, left=0.09, right=.98, spacing=0.02, remove_ticks=False):
        ax=add_subplot()
        plt.plot(x, label='$x_t$')
        plt.plot(w, label='$w_t$')
        plt.title('In all plots, $k_d={}$'.format(kd), loc='left')
        plt.grid()

        for kp in kp_values:
            s = pid_encode(x, kp=kp, kd=kd, quantization='herd')
            zprime = pid_decode(s*w, kp=kp, kd=kd)
            ax_mult=add_subplot()
            plt.plot(xw, label = '$z_t=x_t\cdot w_t$', color='C2')
            plt.plot(zprime, label='$\hat z_t = dec_{{k_p k_d}}(Q(enc_{{k_p k_d}}(x_t))\cdot w_t)$'.format(kp), color='C3')
            plt.ylabel('$k_p={}$'.format(kp))
            # plt.tick_params(axis='y', labelleft='off')
            # plt.ylim(-4.5, 4.5)
            plt.grid()
        # plt.plot(xw, label = '$z_t$', color='k', linewidth=2)
        plt.xlabel('t')
# plt.legend()


    # ax.set_ylim(-2.7, 2.7)
    ax_mult.set_ylim(-4.5, 4.5)
    handles, labels = ax.get_legend_handles_labels()
    handles2, labels2 = ax_mult.get_legend_handles_labels()
    # plt.legend(handles[::-1], labels[::-1],bbox_to_anchor=(1, 1), bbox_transform=plt.gcf().transFigure, ncol=len(handles[::-1]))
    plt.legend(handles+handles2, labels+labels2,bbox_to_anchor=(.99, .99), bbox_transform=plt.gcf().transFigure, ncol=len(handles+handles2), loc='upper right')


    plt.show()
def demo_create_signal_figure(
    w=[-.7, .8, .5],
    eps=0.2,
    lambda_schedule='1/t**.75',
    eps_schedule='1/t**.4',
    n_samples=200,
    seed=1247,
    input_convergence_speed=3,
    scale=0.3,
):
    rng = np.random.RandomState(seed)

    varying_sig = lowpass_random(n_samples=n_samples,
                                 cutoff=0.03,
                                 n_dim=len(w),
                                 normalize=(-scale, scale),
                                 rng=rng)
    frac = 1 - np.linspace(1, 0, n_samples)[:, None]**input_convergence_speed
    x = np.clip((1 - frac) * varying_sig + frac * scale * rng.rand(3), 0, 1)
    true_z = [
        s for s in [0] for xt in x
        for s in [np.clip((1 - eps) * s + eps * xt.dot(w), 0, 1)]
    ]

    # Alternative try 2
    eps_stepper = create_step_sizer(eps_schedule)  # type: IStepSizer
    lambda_stepper = create_step_sizer(lambda_schedule)  # type: IStepSizer
    encoder = PredictiveEncoder(lambda_stepper=lambda_stepper,
                                quantizer=SigmaDeltaQuantizer())
    decoder = PredictiveDecoder(lambda_stepper=lambda_stepper)
    q = np.array(
        [qt for enc in [encoder] for xt in x for enc, qt in [enc(xt)]])
    inputs = [qt.dot(w) for qt in q]
    sig, epsilons, lambdaas = unzip([
        (s, eps, dec.lambda_stepper(inp)[1])
        for s, dec, eps_func in [[0, decoder, eps_stepper]] for inp in inputs
        for dec, decoded_input in [dec(inp)]
        for eps_func, eps in [eps_func(decoded_input)]
        for s in [np.clip((1 - eps) * s + eps * decoded_input, 0, 1)]
    ])

    fig = plt.figure(figsize=(3, 4.5))
    set_figure_border_size(0.02, bottom=0.1)

    with vstack_plots(spacing=0.1, xlabel='t', bottom_pad=0.1):

        ax = add_subplot()

        sep = np.max(x) * 1.1
        plot_stacked_signals(x, sep=sep, labels=False)
        plt.gca().set_prop_cycle(None)

        event_raster_plot(
            events=[np.nonzero(q[:, i])[0] for i in range(len(w))],
            sep=sep,
            s=100)
        ax.legend(labels=[f'$s_{i}$' for i in range(1,
                                                    len(w) + 1)],
                  loc='lower left')

        ax = add_subplot()
        stemlight(inputs, ax=ax, label='$u_j$', color='k')
        # plt.plot(inputs, label='$u_j$', color='k')
        ax.axhline(0, color='k')
        ax.tick_params(axis='y', labelleft='off')
        ax.legend(loc='upper left')
        # plt.grid()

        ax = add_subplot()
        # ax.plot([eps_func(t) for t in range(n_samples)], label='$\epsilon$', color='k')
        ax.plot(epsilons, label='$\epsilon$', color='k')
        ax.plot(lambdaas, label='$\lambda$', color='b')
        ax.axhline(0, color='k')
        ax.legend(loc='upper right')
        ax.tick_params(axis='y', labelleft='off')

        ax = add_subplot()
        ax.plot(true_z, label='$s_j$ (real)', color='r')
        ax.plot(sig, label='$s_j$ (binary)', color='k')
        ax.legend(loc='lower right')
        ax.tick_params(axis='y', labelleft='off')
        ax.axhline(0, color='k')

    plt.show()
Exemplo n.º 5
0
def demo_pd_stdp_equivalence(kp=0.03, kd=2., p=0.05, n_samples = 2000):

    r = kd/float(kp+kd)

    rng = np.random.RandomState(1236)
    x_spikes = rng.choice((-1, 0, 1), size=n_samples, p=[p/2, 1-p, p/2])
    e_spikes = rng.choice((-1, 0, 1), size=n_samples, p=[p/2, 1-p, p/2])

    t_k = np.linspace(-5, 5, 201)

    x_spikes[n_samples*3/4:] = 0
    e_spikes[n_samples*3/4:] = 0


    x_spikes[n_samples*7/8]=1


    t = np.arange(-500,  501)
    k = r**t * (t>=0)

    x_hat = np.convolve(x_spikes, k, mode='same')/(kp+kd)
    e_hat = np.convolve(e_spikes, k, mode='same')/(kp+kd)

    x_hat = pid_decode(x_spikes, kp=kp, kd=kd)
    e_hat = pid_decode(e_spikes, kp=kp, kd=kd)


    # kbb = k+(r**(-t) * (t<0))
    kbb = r**(np.abs(t))

    x_conv_kbb = np.convolve(x_spikes, kbb, mode='same')
    sig_future =(x_hat*e_spikes + x_spikes*e_hat - x_spikes*e_spikes/(kp+kd)) * (kp+kd)/float(kp**2 + 2*kp*kd)
    sig_stdp = x_conv_kbb*e_spikes * 1./float(kp**2 + 2*kp*kd)

    sig_past = pd_weight_grads(xc=x_spikes[:, None], ec=e_spikes[:, None], kp=kp, kd=kd)[:, 0, 0]


    plt.figure(figsize=(10, 6))
    with vstack_plots(grid=False, xlabel='t', show_x = False, show_y=False, spacing=0.05, left=0.1, right=0.93, top=0.95):

        add_subplot()
        plt.plot(x_spikes, label='$\\bar x_t$')
        plt.plot(x_hat, label='$\hat x_t$')
        plt.axhline(0, linewidth=2, color='k')
        plt.legend(loc = 'lower right')
        plt.ylim(-2, 2)
        plt.ylabel('Presynaptic\nSignal')
        # plt.ylabel('$\\bar x_t$')

        add_subplot()
        plt.plot(e_spikes, label='$\\bar e_t$')
        plt.plot(e_hat, label='$\hat e_t$')
        plt.axhline(0, linewidth=2, color='k')
        plt.legend(loc = 'lower right')
        plt.ylim(-2, 2)
        plt.ylabel('Postsynaptic\nSignal')
        # plt.ylabel('$\\bar e_t$')

        # add_subplot()
        # plt.plot(x_hat, label='xk', marker='.')
        # plt.plot(e_hat-2*np.max(np.abs(x_hat)), label='yk', marker='.')
        # add_subplot()
        # plt.plot(x_conv_kbb, marker='.', label='x * kbb')
        # add_subplot()
        # plt.plot(x_conv_kbb*e_spikes, label='$(x * kbb) \odot y$')
        # plt.plot(sig_future, label = 'xk*y + x*yk')
        # sig_kbb = -x_conv_kbb*ym + x_conv_kbb*yp
        # plt.plot(sig_kbb, label='new', linestyle='--')
        # sig_stdp= -x_conv_kstdp*ym + x_conv_kstdp*yp
        # sig_stdp = x_conv_kstdp * y
        # plt.plot(sig_stdp, label='new', linestyle='--')
        # plt.legend()

        add_subplot()
        plt.plot(-np.cumsum(x_hat*e_hat), label='recon')
        plt.plot(-np.cumsum(sig_stdp), label='STDP')
        plt.plot(-sig_past, label='past')
        plt.plot(-np.cumsum(sig_future), label='future')
        plt.axhline(0, linewidth=2, color='k')
        plt.ylabel('$\sum_{\\tau=0}^t \Delta w_\\tau$')

        # plt.plot(np.cumsum(sig_stdp))

        plt.legend(loc = 'lower right')


    # add_subplot(layout='v')
    # plt.plot(xkky, label='xkky', linestyle='--')
    # plt.legend()

    # plt.plot(kk)
    # plt.plot(k)

    plt.show()
Exemplo n.º 6
0
def demo_why_kp_explanation(
    n_steps=2000,
    kd=1.,
    kp_values=[0, 0.01, 0.1],
    x_cutoff=0.03,
    w_cutoff=0.01,
    w_fixed=False,
    seed=1234,
):
    """
    We have time varying signals x, w.  See how different choices of kp, kd, and quantization affect our
    ability to approximate the time-varying quantity x*w.
    """
    rng = np.random.RandomState(seed)

    x = lowpass_random(n_samples=n_steps,
                       cutoff=x_cutoff,
                       normalize=True,
                       rng=rng)
    w = lowpass_random(
        n_samples=n_steps, cutoff=w_cutoff, normalize=True,
        rng=rng) + 1 if not w_fixed else np.ones(n_steps)
    xw = x * w

    plt.figure(figsize=(10, 4))
    with vstack_plots(sharex=True,
                      sharey=True,
                      left=0.09,
                      right=.98,
                      spacing=0.02,
                      remove_ticks=False):
        ax = add_subplot()
        plt.plot(x, label='$x_t$')
        plt.plot(w, label='$w_t$')
        plt.title('In all plots, $k_d={}$'.format(kd), loc='left')
        plt.grid()

        for kp in kp_values:
            s = pid_encode(x, kp=kp, kd=kd, quantization='herd')
            zprime = pid_decode(s * w, kp=kp, kd=kd)
            ax_mult = add_subplot()
            plt.plot(xw, label='$z_t=x_t\cdot w_t$', color='C2')
            plt.plot(
                zprime,
                label=
                '$\hat z_t = dec_{{k_p k_d}}(Q(enc_{{k_p k_d}}(x_t))\cdot w_t)$'
                .format(kp),
                color='C3')
            plt.ylabel('$k_p={}$'.format(kp))
            plt.grid()
        plt.xlabel('t')

    ax_mult.set_ylim(-4.5, 4.5)
    handles, labels = ax.get_legend_handles_labels()
    handles2, labels2 = ax_mult.get_legend_handles_labels()

    plt.legend(handles + handles2,
               labels + labels2,
               bbox_to_anchor=(.99, .99),
               bbox_transform=plt.gcf().transFigure,
               ncol=len(handles + handles2),
               loc='upper right')

    plt.show()
Exemplo n.º 7
0
def demo_pd_stdp_equivalence(kp=0.03, kd=2., p=0.05, n_samples=2000):

    r = kd / float(kp + kd)

    rng = np.random.RandomState(1236)
    x_spikes = rng.choice((-1, 0, 1), size=n_samples, p=[p / 2, 1 - p, p / 2])
    e_spikes = rng.choice((-1, 0, 1), size=n_samples, p=[p / 2, 1 - p, p / 2])

    t_k = np.linspace(-5, 5, 201)

    x_spikes[n_samples * 3 / 4:] = 0
    e_spikes[n_samples * 3 / 4:] = 0

    x_spikes[n_samples * 7 / 8] = 1

    t = np.arange(-500, 501)
    k = r**t * (t >= 0)

    x_hat = np.convolve(x_spikes, k, mode='same') / (kp + kd)
    e_hat = np.convolve(e_spikes, k, mode='same') / (kp + kd)

    x_hat = pid_decode(x_spikes, kp=kp, kd=kd)
    e_hat = pid_decode(e_spikes, kp=kp, kd=kd)

    # kbb = k+(r**(-t) * (t<0))
    kbb = r**(np.abs(t))

    x_conv_kbb = np.convolve(x_spikes, kbb, mode='same')
    sig_future = (x_hat * e_spikes + x_spikes * e_hat - x_spikes * e_spikes /
                  (kp + kd)) * (kp + kd) / float(kp**2 + 2 * kp * kd)
    sig_stdp = x_conv_kbb * e_spikes * 1. / float(kp**2 + 2 * kp * kd)

    sig_past = pd_weight_grads(xc=x_spikes[:, None],
                               ec=e_spikes[:, None],
                               kp=kp,
                               kd=kd)[:, 0, 0]

    plt.figure(figsize=(10, 6))
    with vstack_plots(grid=False,
                      xlabel='t',
                      show_x=False,
                      show_y=False,
                      spacing=0.05,
                      left=0.1,
                      right=0.93,
                      top=0.95):

        add_subplot()
        plt.plot(x_spikes, label='$\\bar x_t$')
        plt.plot(x_hat, label='$\hat x_t$')
        plt.axhline(0, linewidth=2, color='k')
        plt.legend(loc='lower right')
        plt.ylim(-2, 2)
        plt.ylabel('Presynaptic\nSignal')
        # plt.ylabel('$\\bar x_t$')

        add_subplot()
        plt.plot(e_spikes, label='$\\bar e_t$')
        plt.plot(e_hat, label='$\hat e_t$')
        plt.axhline(0, linewidth=2, color='k')
        plt.legend(loc='lower right')
        plt.ylim(-2, 2)
        plt.ylabel('Postsynaptic\nSignal')
        # plt.ylabel('$\\bar e_t$')

        # add_subplot()
        # plt.plot(x_hat, label='xk', marker='.')
        # plt.plot(e_hat-2*np.max(np.abs(x_hat)), label='yk', marker='.')
        # add_subplot()
        # plt.plot(x_conv_kbb, marker='.', label='x * kbb')
        # add_subplot()
        # plt.plot(x_conv_kbb*e_spikes, label='$(x * kbb) \odot y$')
        # plt.plot(sig_future, label = 'xk*y + x*yk')
        # sig_kbb = -x_conv_kbb*ym + x_conv_kbb*yp
        # plt.plot(sig_kbb, label='new', linestyle='--')
        # sig_stdp= -x_conv_kstdp*ym + x_conv_kstdp*yp
        # sig_stdp = x_conv_kstdp * y
        # plt.plot(sig_stdp, label='new', linestyle='--')
        # plt.legend()

        add_subplot()
        plt.plot(-np.cumsum(x_hat * e_hat), label='recon')
        plt.plot(-np.cumsum(sig_stdp), label='STDP')
        plt.plot(-sig_past, label='past')
        plt.plot(-np.cumsum(sig_future), label='future')
        plt.axhline(0, linewidth=2, color='k')
        plt.ylabel('$\sum_{\\tau=0}^t \Delta w_\\tau$')

        # plt.plot(np.cumsum(sig_stdp))

        plt.legend(loc='lower right')

    # add_subplot(layout='v')
    # plt.plot(xkky, label='xkky', linestyle='--')
    # plt.legend()

    # plt.plot(kk)
    # plt.plot(k)

    plt.show()
def demo_weight_update_figures(
        n_samples = 1000,
        seed=1278,
        kpx=.0015,
        kdx=.5,
        kpe=.0015,
        kde=.5,
        warmup=500,
        future_fill_settings = dict(color='lightsteelblue', hatch='//', edgecolor='w'),
        past_fill_settings = dict(color='lightpink', hatch='\\\\', edgecolor='w'),
        plot=True
        ):
    """
    What this test shows:
    (1) The FutureWeightGradCalculator indeed perfectly calculates the product between the two reconstructions
    (2) The FutureWeightGradCalculator with "true" values plugged in place of reconstructions isn't actually all that great.
    (3) We implemented the "true" value idea correctly, because we try plugging in the reconstructions it is identical to actually using the reconstructions.
    """
    rng = np.random.RandomState(seed)

    linewidth=2

    matplotlib.rcParams['hatch.color'] = 'w'
    matplotlib.rcParams['hatch.linewidth'] = 2.

    rx = kdx/float(kpx+kdx)
    re = kde/float(kpe+kde)

    t = np.arange(n_samples)
    x = lowpass_random(n_samples+warmup, cutoff=0.0003, rng=rng, normalize=True)[warmup:]
    e = lowpass_random(n_samples+warmup, cutoff=0.0003, rng=rng, normalize=True)[warmup:]
    if x.mean()<0:
        x=-x
    if e.mean()<0:
        e=-e
    # x[-int(n_samples/4):]=0
    # e[-int(n_samples/4):]=0
    xc = pid_encode(x, kp=kpx, kd=kdx, quantization='herd')
    ec = pid_encode(e, kp=kpe, kd=kde, quantization='herd')

    xd = pid_decode(xc, kp=kpx, kd=kdx)
    ed = pid_decode(ec, kp=kpe, kd=kde)
    w_true = x*e
    w_recon = xd*ed

    fig = plt.figure(figsize=(7, 3))
    with vstack_plots(grid=False, sharex=False, spacing=0.05, xlabel='t', xlim=(0, n_samples), show_x=False, show_y=False, left=0.01, right=0.98, top=.96, bottom=.08):

        ix = np.nonzero(ec)[0][1]

        future_top_e = ed[ix]*re**(np.arange(n_samples-ix))
        future_bottom_e = ed[ix-1]*re**(np.arange(n_samples-ix)+1)
        future_top_x = xd[ix]*rx**(np.arange(n_samples-ix))

        ix_xlast = np.nonzero(xc[:ix])[0][-1]

        past_top_x = xd[ix_xlast:ix]
        past_top_e = ed[ix_xlast:ix]

        past_top_area = past_top_e*past_top_x

        future_top_area = future_top_e*future_top_x
        future_bottom_area = future_bottom_e*future_top_x

        ax0=ax=add_subplot()
        plt.plot(x, label='$x$', linewidth=linewidth)
        plt.plot(xc, color='k', label='$\\bar x$', linewidth=linewidth+1)
        plt.plot(xd, label='$\hat x$', linewidth=linewidth)
        ax.fill_between(t[ix_xlast:ix], 0., past_top_x, **past_fill_settings)
        ax.fill_between(t[ix:], 0., future_top_x, **future_fill_settings)
        plt.legend(loc='upper left')
        plt.axhline(0, color='k', linewidth=2)

        # ax.arrow(ix-10, 1, ix, 00, head_width=0.05, head_length=0.1, fc='k', ec='k')
        # ax.set_ylim(bottom=-.5, top=4)

        ax1=ax=add_subplot()
        plt.plot(e, label='$e$', linewidth=linewidth)
        plt.plot(ec, color='k', label='$\\bar e$', linewidth=linewidth+1)
        plt.plot(ed, label='$\hat e$', linewidth=linewidth)
        ax.fill_between(t[ix_xlast:ix], 0., past_top_e, **past_fill_settings)
        ax.fill_between(t[ix:], future_bottom_e, future_top_e, **future_fill_settings)
        plt.legend(loc='upper left')
        ax.annotate('spike', xy=(ix+4, 0.4), xytext=(ix+70, 1.), fontsize=10, fontweight='bold',
            arrowprops=dict(facecolor='black', shrink=0.05),
            )


        # plt.ylim(-.5, 4)

        ax2=ax=add_subplot()
        plt.plot(w_true, linewidth=linewidth, label='$\\frac{\partial \mathcal{L}}{\partial w}_t=x_t e_t$')
        plt.plot(w_recon, linewidth=linewidth, label='$\widehat{\\frac{\partial \mathcal{L}}{\partial w}}_t = \hat x_t \hat e_t$')
        ax.fill_between(t[ix_xlast:ix], 0., past_top_area, label='$\widehat{\\frac{\partial \mathcal{L}}{\partial w}}_{t,past}$', **past_fill_settings)
        ax.fill_between(t[ix:], future_bottom_area, future_top_area, label='$\widehat{\\frac{\partial \mathcal{L}}{\partial w}}_{t,future}$', **future_fill_settings)
        plt.axhline(0, color='k', linewidth=2)
        plt.legend(loc='upper left', ncol=2)

    ax0.set_xlim(0, n_samples*3/4)
    ax1.set_xlim(0, n_samples*3/4)
    ax2.set_xlim(0, n_samples*3/4)
    ax0.set_ylim(-.5, 4)
    ax1.set_ylim(-.5, 4)



        # add_subplot()
        # plt.plot(np.cumsum(x*e))
        # plt.plot(w_recon[:, 0, 0])

    plt.show()
Exemplo n.º 9
0
def demo_weight_update_figures(n_samples=1000,
                               seed=1278,
                               kpx=.0015,
                               kdx=.5,
                               kpe=.0015,
                               kde=.5,
                               warmup=500,
                               future_fill_settings=dict(
                                   color='lightsteelblue',
                                   hatch='//',
                                   edgecolor='w'),
                               past_fill_settings=dict(color='lightpink',
                                                       hatch='\\\\',
                                                       edgecolor='w'),
                               plot=True):
    """
    What this test shows:
    (1) The FutureWeightGradCalculator indeed perfectly calculates the product between the two reconstructions
    (2) The FutureWeightGradCalculator with "true" values plugged in place of reconstructions isn't actually all that great.
    (3) We implemented the "true" value idea correctly, because we try plugging in the reconstructions it is identical to actually using the reconstructions.
    """
    rng = np.random.RandomState(seed)

    linewidth = 2

    matplotlib.rcParams['hatch.color'] = 'w'
    matplotlib.rcParams['hatch.linewidth'] = 2.

    rx = kdx / float(kpx + kdx)
    re = kde / float(kpe + kde)

    t = np.arange(n_samples)
    x = lowpass_random(n_samples + warmup,
                       cutoff=0.0003,
                       rng=rng,
                       normalize=True)[warmup:]
    e = lowpass_random(n_samples + warmup,
                       cutoff=0.0003,
                       rng=rng,
                       normalize=True)[warmup:]
    if x.mean() < 0:
        x = -x
    if e.mean() < 0:
        e = -e
    # x[-int(n_samples/4):]=0
    # e[-int(n_samples/4):]=0
    xc = pid_encode(x, kp=kpx, kd=kdx, quantization='herd')
    ec = pid_encode(e, kp=kpe, kd=kde, quantization='herd')

    xd = pid_decode(xc, kp=kpx, kd=kdx)
    ed = pid_decode(ec, kp=kpe, kd=kde)
    w_true = x * e
    w_recon = xd * ed

    fig = plt.figure(figsize=(7, 3))
    with vstack_plots(grid=False,
                      sharex=False,
                      spacing=0.05,
                      xlabel='t',
                      xlim=(0, n_samples),
                      show_x=False,
                      show_y=False,
                      left=0.01,
                      right=0.98,
                      top=.96,
                      bottom=.08):

        ix = np.nonzero(ec)[0][1]

        future_top_e = ed[ix] * re**(np.arange(n_samples - ix))
        future_bottom_e = ed[ix - 1] * re**(np.arange(n_samples - ix) + 1)
        future_top_x = xd[ix] * rx**(np.arange(n_samples - ix))

        ix_xlast = np.nonzero(xc[:ix])[0][-1]

        past_top_x = xd[ix_xlast:ix]
        past_top_e = ed[ix_xlast:ix]

        past_top_area = past_top_e * past_top_x

        future_top_area = future_top_e * future_top_x
        future_bottom_area = future_bottom_e * future_top_x

        ax0 = ax = add_subplot()
        plt.plot(x, label='$x$', linewidth=linewidth)
        plt.plot(xc, color='k', label='$\\bar x$', linewidth=linewidth + 1)
        plt.plot(xd, label='$\hat x$', linewidth=linewidth)
        ax.fill_between(t[ix_xlast:ix], 0., past_top_x, **past_fill_settings)
        ax.fill_between(t[ix:], 0., future_top_x, **future_fill_settings)
        plt.legend(loc='upper left')
        plt.axhline(0, color='k', linewidth=2)

        # ax.arrow(ix-10, 1, ix, 00, head_width=0.05, head_length=0.1, fc='k', ec='k')
        # ax.set_ylim(bottom=-.5, top=4)

        ax1 = ax = add_subplot()
        plt.plot(e, label='$e$', linewidth=linewidth)
        plt.plot(ec, color='k', label='$\\bar e$', linewidth=linewidth + 1)
        plt.plot(ed, label='$\hat e$', linewidth=linewidth)
        ax.fill_between(t[ix_xlast:ix], 0., past_top_e, **past_fill_settings)
        ax.fill_between(t[ix:], future_bottom_e, future_top_e,
                        **future_fill_settings)
        plt.legend(loc='upper left')
        ax.annotate(
            'spike',
            xy=(ix + 4, 0.4),
            xytext=(ix + 70, 1.),
            fontsize=10,
            fontweight='bold',
            arrowprops=dict(facecolor='black', shrink=0.05),
        )

        # plt.ylim(-.5, 4)

        ax2 = ax = add_subplot()
        plt.plot(w_true,
                 linewidth=linewidth,
                 label='$\\frac{\partial \mathcal{L}}{\partial w}_t=x_t e_t$')
        plt.plot(
            w_recon,
            linewidth=linewidth,
            label=
            '$\widehat{\\frac{\partial \mathcal{L}}{\partial w}}_t = \hat x_t \hat e_t$'
        )
        ax.fill_between(
            t[ix_xlast:ix],
            0.,
            past_top_area,
            label=
            '$\widehat{\\frac{\partial \mathcal{L}}{\partial w}}_{t,past}$',
            **past_fill_settings)
        ax.fill_between(
            t[ix:],
            future_bottom_area,
            future_top_area,
            label=
            '$\widehat{\\frac{\partial \mathcal{L}}{\partial w}}_{t,future}$',
            **future_fill_settings)
        plt.axhline(0, color='k', linewidth=2)
        plt.legend(loc='upper left', ncol=2)

    ax0.set_xlim(0, n_samples * 3 / 4)
    ax1.set_xlim(0, n_samples * 3 / 4)
    ax2.set_xlim(0, n_samples * 3 / 4)
    ax0.set_ylim(-.5, 4)
    ax1.set_ylim(-.5, 4)

    # add_subplot()
    # plt.plot(np.cumsum(x*e))
    # plt.plot(w_recon[:, 0, 0])

    plt.show()