def plot_samples(params, x, n_samples, layer_sizes, act, ker, save=None): " plots samples of " ws = sample_normal(params, n_samples) fnn = bnn_predict(ws, x, layer_sizes, act)[:, :, 0] # [ns, nd] fgp = sample_gpp(x, n_samples, ker) # functions from a standard normal bnn prior wz = sample_normal((np.zeros(ws.shape[1]), np.ones(ws.shape[1])), n_samples) f = bnn_predict(wz, x, layer_sizes, act)[:, :, 0] p.plot_priors(x, (fgp.T, f.T, fnn.T), save)
def plot_samples(params, x, layer_sizes, n_samples=1, act='tanh', save=None): #ws = sample_normal(params, n_samples) ws = sample_full_normal(params, n_samples) fnn = bnn_predict(ws, x, layer_sizes, act)[:, :, 0] # [ns, nd] fgp = sample_gpp(x, n_samples) fig = plt.figure(figsize=(12, 8), facecolor='white') ax = fig.add_subplot(111, frameon=False) bx = fig.add_subplot(111, frameon=False) bx.plot(x.ravel(), fgp.T, color='green') ax.plot(x.ravel(), fnn.T, color='red') if save is not None: plt.savefig(save) plt.show()
def train(n_data=50, n_data_test=100, n_functions=500, nn_arch=[1,15,15,1], hyper_arch=[20], act='rbf', ker='per', lr=0.01, iters=200, exp=1, run=1, feed_x=True, plot=True, save=False): _, num_weights = shapes_and_num(nn_arch) hyper_arch = [2*n_data]+hyper_arch+[num_weights] xs, ys, xys = sample_data(n_functions, n_data, ker=ker) # save_file, args = manage_and_save(inspect.currentframe(),exp,run) save_name = 'none-'+str(n_data)+'nf-'+str(n_functions)+"-"+act+ker save_name = get_save_name(n_data, n_functions, act, ker, nn_arch, hyper_arch) if plot: fig, ax = setup_plot() def objective(params, t): return hyper_loss(params, xs, ys, xys, nn_arch, act) def callback(params, t, g): x, y, xy = sample_data(1, n_data, ker=ker) preds = hyper_predict(params, x, xy, nn_arch, act) # [1, nd] if plot: p.plot_iter(ax, x[0], x[0], y, preds) # cov_compare = np.cov(y.ravel())-np.cov(preds.ravel()) print("ITER {} | OBJ {} COV DIFF {}".format(t, objective(params, t), 1)) var_params = adam(grad(objective), init_random_params(hyper_arch), step_size=lr, num_iters=iters, callback=callback) xs, ys, xys = sample_data(10000, n_data, ker=ker) ws = nn_predict(var_params, xys, act) # [ns, nw] #ws = reparameterize(wse, prior='dropout') fs = bnn_predict(ws, xs, nn_arch, act)[:, :, 0] # [nf, nd] p.plot_weights(ws, save_name) #p.plot_weights(ws, 'post'+save_name) p.plot_weights_function_space(ws, fs, save_name) #p.plot_fs(xs[0], fs[0:3], xs[0], ys[0:3], save_name) return ws, var_params
def sample_random_functions(x, n_samples=10, arch=[1, 1], act='tanh'): _, n_weights = shapes_and_num(arch) w = rs.randn(n_samples, n_weights) return bnn_predict(w, x, arch, act)[:, :, 0] # [ns, nd]
def callback(params, t, g): preds = bnn_predict(params, x, nn_arch, act)[:, :, 0] #[1,nd] #print(preds.shape) if plot: plot_iter(ax, x.ravel(), x.ravel(), y, preds[0]) print("ITER {} | OBJ {}".format(t, objective(params, t)))
def total_loss(weights, x, y, arch, act): f = bnn_predict(weights, x, arch, act)[:, :, 0] return -log_gaussian(f, y, 1)
def hyper_predict(params, x, xy, nn_arch, nn_act): # xy shape is [nf, 2*nd] weights = nn_predict(params, xy, nn_act) # [nf, nw] #weights = reparameterize(weights, prior='dropout') return bnn_predict(weights, x, nn_arch, nn_act)[:, :, 0] # [nf, nd]
def hyper_predict(params, x, y, nn_arch, nn_act): # y shape is [nf, nd] weights = nn_predict(params, y, 'relu') # [nf, nw] return bnn_predict(weights, x, nn_arch, nn_act)[:, :, 0] # [nf, nd]
cd = np.cov(y.ravel()) - np.cov(preds.ravel()) print("ITER {} | OBJ {} COV DIFF {}".format(t, objective(params, t), cd)) var_params = adam(grad(objective), init_random_params(hyper_arch), step_size=0.005, num_iters=200, callback=callback) xtest = np.linspace(-10, 10, n_data_test).reshape(n_data_test, 1) fgps = sample_gpp(x, n_samples=500, kernel=ker) #fgps = sample_function(x, 500) ws = nn_predict(var_params, fgps, "relu") # [ns, nw] fs = bnn_predict(ws, x, nn_arch, act)[:, :, 0] #p.plot_weights_function_space(ws, fs, save_name) moments = get_moments(ws, full_cov=True) #p.plot_heatmap(moments, "heatmap"+save_name) # PLOT HYPERNET #fgp = sample_gpp(x, n_samples=2, kernel=ker) #fnns = hyper_predict(var_params, xtest,fgp, nn_arch, act) #p.plot_fs(xtest, fnns, x, fgp, save_name) #plot_heatmap(moments,"Cov-heatmap"+save_name+'.pdf') #plot_samples(moments, xtest, 5, nn_arch, act=act, ker=ker, save = save_name+'.pdf') #p.plot_dark_contour(ws) #p.plot_weights(moments, num_weights, save_name) moments = fit_one_gmm(ws)