def demo_visualize_k_effects(kps=[0., 0.01, .1, 2.], kds=[0, 1., 4.], cutoff=0.005, n_samples=550, s_as_triangles=False, seed=1234): x = lowpass_random(n_samples=n_samples, cutoff=cutoff, rng=seed, normalize=True) plt.figure(figsize=(10, 6)) plt.subplots_adjust(wspace=0.01, hspace=0.01, left=0.08, right=.98, top=.92) ax = plt.subplot2grid((len(kps), len(kds)), (0, 0)) for i, kp in enumerate(kps): for j, kd in enumerate(kds): xe = pid_encode(x, kp=kp, kd=kd) h = Herder() xc = [h(xet) for xet in xe] xd = pid_decode(xc, kp=kp, kd=kd) this_ax = plt.subplot2grid((len(kps), len(kds)), (len(kps) - i - 1, j), sharex=ax, sharey=ax) plt.plot(xd, color='C1', label='$\hat x_t$') plt.text( .01, .01, '$\left<|x_t-\hat x_t|\\right>_t={:.2g}, \;\;\; N={}$'.format( np.abs(x - xd).mean(), int(np.sum(np.abs(xc)))), ha='left', va='bottom', transform=this_ax.transAxes, bbox=dict(boxstyle='square', facecolor='w', edgecolor='none', alpha=0.8, pad=0.0)) if s_as_triangles: up_spikes = np.nonzero(xc > 0)[0] down_spikes = np.nonzero(xc < 0)[0] plt.plot(up_spikes, np.zeros(up_spikes.shape), '^', color='k', label='$s_t^+$') plt.plot(down_spikes, np.zeros(down_spikes.shape), 'v', color='r', label='$s_t^-$') else: plt.plot(xc, color='k', label='$s_t$') plt.plot(x, color='C0', label='$x_t$') plt.grid() if i > 0: plt.tick_params('x', labelbottom='off') else: plt.xlabel('$k_d={}$'.format(kd)) if j > 0: plt.tick_params('y', labelleft='off') else: plt.ylabel('$k_p={}$'.format(kp)) ax.set_xlim(0, n_samples) ax.set_ylim(np.min(x) * 1.1, np.max(x) * 1.1) handles, labels = plt.gca().get_legend_handles_labels() plt.legend(handles[::-1], labels[::-1], bbox_to_anchor=(1, 1), bbox_transform=plt.gcf().transFigure, ncol=len(handles[::-1]), loc='upper right') plt.show()
def demo_why_kp_explanation( n_steps=2000, kd=1., kp_values=[0, 0.01, 0.1], x_cutoff=0.03, w_cutoff=0.01, w_fixed=False, seed=1234, ): """ We have time varying signals x, w. See how different choices of kp, kd, and quantization affect our ability to approximate the time-varying quantity x*w. """ rng = np.random.RandomState(seed) x = lowpass_random(n_samples=n_steps, cutoff=x_cutoff, normalize=True, rng=rng) w = lowpass_random( n_samples=n_steps, cutoff=w_cutoff, normalize=True, rng=rng) + 1 if not w_fixed else np.ones(n_steps) xw = x * w plt.figure(figsize=(10, 4)) with vstack_plots(sharex=True, sharey=True, left=0.09, right=.98, spacing=0.02, remove_ticks=False): ax = add_subplot() plt.plot(x, label='$x_t$') plt.plot(w, label='$w_t$') plt.title('In all plots, $k_d={}$'.format(kd), loc='left') plt.grid() for kp in kp_values: s = pid_encode(x, kp=kp, kd=kd, quantization='herd') zprime = pid_decode(s * w, kp=kp, kd=kd) ax_mult = add_subplot() plt.plot(xw, label='$z_t=x_t\cdot w_t$', color='C2') plt.plot( zprime, label= '$\hat z_t = dec_{{k_p k_d}}(Q(enc_{{k_p k_d}}(x_t))\cdot w_t)$' .format(kp), color='C3') plt.ylabel('$k_p={}$'.format(kp)) plt.grid() plt.xlabel('t') ax_mult.set_ylim(-4.5, 4.5) handles, labels = ax.get_legend_handles_labels() handles2, labels2 = ax_mult.get_legend_handles_labels() plt.legend(handles + handles2, labels + labels2, bbox_to_anchor=(.99, .99), bbox_transform=plt.gcf().transFigure, ncol=len(handles + handles2), loc='upper right') plt.show()
def demo_kd_too_large(n_steps=20000, kp=.01, kd=1., kp_scan_range=(.001, .1), kd_scan_range=(.1, 10), n_k_points=32, x_cutoff=0.01, w_cutoff=0.002, w_fixed=False, k_spacing='log', seed=1238): """ We have time varying signals x, w. See how different choices of kp, kd, and quantization affect our ability to approximate the time-varying quantity x*w. """ rng = np.random.RandomState(seed) x = lowpass_random(n_samples=n_steps, cutoff=x_cutoff, normalize=True, rng=rng) w = lowpass_random( n_samples=n_steps, cutoff=w_cutoff, normalize=True, rng=rng) if not w_fixed else np.ones(n_steps) x_w = x * w distance_mat_nonquantized = np.zeros((n_k_points, n_k_points)) distance_mat_quantized = np.zeros((n_k_points, n_k_points)) distance_mat_recon = np.zeros((n_k_points, n_k_points)) n_spikes = np.zeros((n_k_points, n_k_points)) pi = ProgressIndicator(n_k_points**2) kp_values = point_space(kp_scan_range[0], kp_scan_range[1], n_points=n_k_points, spacing=k_spacing) kd_values = point_space(kd_scan_range[0], kd_scan_range[1], n_points=n_k_points, spacing=k_spacing) for i, kpi in enumerate(kp_values): for j, kdj in enumerate(kd_values): pi.print_update(i * n_k_points + j) x_enc = pid_encode(x, kp=kpi, kd=kdj, quantization=None) x_enc_quantized = pid_encode(x, kp=kpi, kd=kdj, quantization='herd') x_enc_w = pid_decode(x_enc * w, kp=kpi, kd=kdj) x_enc_quantized_w_dec = pid_decode(x_enc_quantized * w, kp=kpi, kd=kdj) x_enc_quantized_dec_w = pid_decode(x_enc_quantized, kp=kpi, kd=kdj) * w distance_mat_nonquantized[i, j] = cosine_distance(x_w, x_enc_w) distance_mat_quantized[i, j] = cosine_distance( x_w, x_enc_quantized_w_dec) distance_mat_recon[i, j] = cosine_distance(x_w, x_enc_quantized_dec_w) n_spikes[i, j] = np.abs(x_enc_quantized).sum() x_enc_quantized = pid_encode(x, kp=kp, kd=kd, quantization='herd') x_enc = pid_encode(x, kp=kp, kd=kd, quantization=None) xwq = pid_decode(x_enc_quantized * w, kp=kp, kd=kd) xwn = pid_decode(x_enc * w, kp=kp, kd=kd) return (x, w, x_w, x_enc_quantized, x_enc, xwq, xwn), (distance_mat_nonquantized, distance_mat_quantized, distance_mat_recon, n_spikes), (kp, kd, kd_values, kp_values)
def demo_weight_update_figures(n_samples=1000, seed=1278, kpx=.0015, kdx=.5, kpe=.0015, kde=.5, warmup=500, future_fill_settings=dict( color='lightsteelblue', hatch='//', edgecolor='w'), past_fill_settings=dict(color='lightpink', hatch='\\\\', edgecolor='w'), plot=True): """ What this test shows: (1) The FutureWeightGradCalculator indeed perfectly calculates the product between the two reconstructions (2) The FutureWeightGradCalculator with "true" values plugged in place of reconstructions isn't actually all that great. (3) We implemented the "true" value idea correctly, because we try plugging in the reconstructions it is identical to actually using the reconstructions. """ rng = np.random.RandomState(seed) linewidth = 2 matplotlib.rcParams['hatch.color'] = 'w' matplotlib.rcParams['hatch.linewidth'] = 2. rx = kdx / float(kpx + kdx) re = kde / float(kpe + kde) t = np.arange(n_samples) x = lowpass_random(n_samples + warmup, cutoff=0.0003, rng=rng, normalize=True)[warmup:] e = lowpass_random(n_samples + warmup, cutoff=0.0003, rng=rng, normalize=True)[warmup:] if x.mean() < 0: x = -x if e.mean() < 0: e = -e # x[-int(n_samples/4):]=0 # e[-int(n_samples/4):]=0 xc = pid_encode(x, kp=kpx, kd=kdx, quantization='herd') ec = pid_encode(e, kp=kpe, kd=kde, quantization='herd') xd = pid_decode(xc, kp=kpx, kd=kdx) ed = pid_decode(ec, kp=kpe, kd=kde) w_true = x * e w_recon = xd * ed fig = plt.figure(figsize=(7, 3)) with vstack_plots(grid=False, sharex=False, spacing=0.05, xlabel='t', xlim=(0, n_samples), show_x=False, show_y=False, left=0.01, right=0.98, top=.96, bottom=.08): ix = np.nonzero(ec)[0][1] future_top_e = ed[ix] * re**(np.arange(n_samples - ix)) future_bottom_e = ed[ix - 1] * re**(np.arange(n_samples - ix) + 1) future_top_x = xd[ix] * rx**(np.arange(n_samples - ix)) ix_xlast = np.nonzero(xc[:ix])[0][-1] past_top_x = xd[ix_xlast:ix] past_top_e = ed[ix_xlast:ix] past_top_area = past_top_e * past_top_x future_top_area = future_top_e * future_top_x future_bottom_area = future_bottom_e * future_top_x ax0 = ax = add_subplot() plt.plot(x, label='$x$', linewidth=linewidth) plt.plot(xc, color='k', label='$\\bar x$', linewidth=linewidth + 1) plt.plot(xd, label='$\hat x$', linewidth=linewidth) ax.fill_between(t[ix_xlast:ix], 0., past_top_x, **past_fill_settings) ax.fill_between(t[ix:], 0., future_top_x, **future_fill_settings) plt.legend(loc='upper left') plt.axhline(0, color='k', linewidth=2) # ax.arrow(ix-10, 1, ix, 00, head_width=0.05, head_length=0.1, fc='k', ec='k') # ax.set_ylim(bottom=-.5, top=4) ax1 = ax = add_subplot() plt.plot(e, label='$e$', linewidth=linewidth) plt.plot(ec, color='k', label='$\\bar e$', linewidth=linewidth + 1) plt.plot(ed, label='$\hat e$', linewidth=linewidth) ax.fill_between(t[ix_xlast:ix], 0., past_top_e, **past_fill_settings) ax.fill_between(t[ix:], future_bottom_e, future_top_e, **future_fill_settings) plt.legend(loc='upper left') ax.annotate( 'spike', xy=(ix + 4, 0.4), xytext=(ix + 70, 1.), fontsize=10, fontweight='bold', arrowprops=dict(facecolor='black', shrink=0.05), ) # plt.ylim(-.5, 4) ax2 = ax = add_subplot() plt.plot(w_true, linewidth=linewidth, label='$\\frac{\partial \mathcal{L}}{\partial w}_t=x_t e_t$') plt.plot( w_recon, linewidth=linewidth, label= '$\widehat{\\frac{\partial \mathcal{L}}{\partial w}}_t = \hat x_t \hat e_t$' ) ax.fill_between( t[ix_xlast:ix], 0., past_top_area, label= '$\widehat{\\frac{\partial \mathcal{L}}{\partial w}}_{t,past}$', **past_fill_settings) ax.fill_between( t[ix:], future_bottom_area, future_top_area, label= '$\widehat{\\frac{\partial \mathcal{L}}{\partial w}}_{t,future}$', **future_fill_settings) plt.axhline(0, color='k', linewidth=2) plt.legend(loc='upper left', ncol=2) ax0.set_xlim(0, n_samples * 3 / 4) ax1.set_xlim(0, n_samples * 3 / 4) ax2.set_xlim(0, n_samples * 3 / 4) ax0.set_ylim(-.5, 4) ax1.set_ylim(-.5, 4) # add_subplot() # plt.plot(np.cumsum(x*e)) # plt.plot(w_recon[:, 0, 0]) plt.show()