Esempio n. 1
0
    def callback(params, t, g):
        print("iteration {} Log likelihood {}".format(t, objective(params, t)))

        if plot_during:
            plt.cla()
            x_plot = np.reshape(np.linspace(-8, 8, 400), (400, 1))
            pred_mean, pred_cov = predict(params, X, y, x_plot)  # shapes [N_data], [N_data, N_data]
            std = np.sqrt(np.diag(pred_cov))  # shape [N_data]
            ax.plot(x_plot, pred_mean, 'b')
            ax.fill_between(x_plot.ravel(),
                            pred_mean - 1.96*std,
                            pred_mean + 1.96*std,
                            color=sns.xkcd_rgb["sky blue"])

            # Show sampled functions from posterior.
            sf = mvnorm(pred_mean, pred_cov, size=5)  # shape = [samples, N_data]
            ax.plot(x_plot, sf.T)

            ax.plot(X, y, 'k.')
            ax.set_ylim([-2, 3])
            ax.set_xticks([])
            ax.set_yticks([])
            plt.draw()
            plt.pause(1.0/60.0)
            if t == 1:
                D = X, y[:, None]
                p = sample_f(params, X, y, x_plot, samples)
                plotting.plot_deciles(x_plot.ravel(), p, D, save_dir, plot="gp")
Esempio n. 2
0
    def callback(params, t, g, objective):
        # Sample functions from posterior f ~ p(f|phi) or p(f|varphi)
        N_samples = 5
        plot_inputs = np.linspace(-8, 8, num=400)
        f_bnn = sample_bnn(plot_inputs, N_samples, params)

        # Plot data and functions.
        if plot_during_:
            plt.cla()
            ax.plot(inputs.ravel(), targets.ravel(), 'k.')
            ax.plot(plot_inputs, f_bnn, color='r')
            ax.set_title("fitting to toy data")
            ax.set_ylim([-5, 5])
            plt.draw()
            plt.pause(1.0 / 60.0)

        if t > 25:
            D = (inputs.ravel(), targets.ravel())
            plotting.plot_deciles(plot_inputs,
                                  f_bnn,
                                  D,
                                  save_dir + "iter {}".format(t),
                                  plot="bnn")

        print("Iteration {} | vlb {}".format(t, -objective(params, t)))
Esempio n. 3
0
        def save(params, t):
            D = (inputs.ravel(), targets.ravel())
            x_plot = np.linspace(-8, 8, num=400)
            save_title = "exp-" + str(exp_num) + "iter " + str(t)

            # predictions from posterior of bnn
            p = sample_bnn(x_plot, 5, params)

            save_dir = os.path.join(os.getcwd(), 'plots', 'gpp-bnn', save_title + data)
            plotting.plot_deciles(x_plot, p, D, save_dir, plot="gpp")
Esempio n. 4
0
    def callback(params, t, g):
        plot_inputs = np.linspace(-8, 8, num=400)[:, None]
        f_bnn = sample_bnn(params, plot_inputs, 5, arch, act)

        # Plot data and functions.
        p.plot_iter(ax, inputs, plot_inputs, targets, f_bnn)
        print("ITER {} | LOSS {}".format(t, -loss(params, t)))
        if t > 50:
            D = inputs, targets
            x_plot = np.reshape(np.linspace(-8, 8, 400), (400, 1))
            pred = sample_bnn(params, x_plot, 5, arch, act)
            p.plot_deciles(x_plot.ravel(), pred.T, D, str(t) + "bnnpostfullprior", plot="gpp")
Esempio n. 5
0
def train_bnn(data='expx', n_data=20, n_samples=5, arch=[1,20,20,1],
              prior_params=None, prior_type=None, act='rbf',
              iters=65, lr=0.07, plot=True, save=False):

    if type(data) == str:
        inputs, targets = build_toy_dataset()
    else:
        inputs, targets = data

    if plot: fig, ax = p.setup_plot()

    def loss(params, t):
        return vlb_objective(params, inputs, targets, arch, n_samples, act=act,
                             prior_params=prior_params, prior_type=prior_type)

    def callback(params, t, g):
        plot_inputs = np.linspace(-8, 8, num=400)[:, None]
        f_bnn = sample_bnn(params, plot_inputs, 5, arch, act)

        # Plot data and functions.
        p.plot_iter(ax, inputs, plot_inputs, targets, f_bnn)
        print("ITER {} | LOSS {}".format(t, -loss(params, t)))
        if t > 50:
            D = inputs, targets
            x_plot = np.reshape(np.linspace(-8, 8, 400), (400, 1))
            pred = sample_bnn(params, x_plot, 5, arch, act)
            p.plot_deciles(x_plot.ravel(), pred.T, D, str(t) + "bnnpostfullprior", plot="gpp")

    var_params = adam(grad(loss), init_var_params(arch),
                      step_size=lr, num_iters=iters, callback=callback)


    D = inputs, targets
    x_plot = np.reshape(np.linspace(-8, 8, 400), (400, 1))
    pred = sample_bnn(var_params, x_plot, 5, arch, act)
    p.plot_deciles(x_plot.ravel(), pred.T, D,"bnnpostfullprior", plot="gpp")
Esempio n. 6
0
            ax.set_ylim([-2, 3])
            ax.set_xticks([])
            ax.set_yticks([])
            plt.draw()
            plt.pause(1.0/60.0)
            if t == 1:
                D = X, y[:, None]
                p = sample_f(params, X, y, x_plot, samples)
                plotting.plot_deciles(x_plot.ravel(), p, D, save_dir, plot="gp")


    # Initialize covariance parameters
    rs = npr.RandomState(0)
    init_params = 0.1 * rs.randn(num_params)
    cov_params = adam(grad(objective), init_params,
                      step_size=0.1, num_iters=iters, callback=callback)

    if save_plots:
        D = X, y[:, None]
        x_plot = np.reshape(np.linspace(-8, 8, 400), (400, 1))
        p = sample_f(cov_params, X, y, x_plot, samples)
        print(p.shape)
        plotting.plot_deciles(x_plot.ravel(), p, D, save_dir, plot="gp")







Esempio n. 7
0
                            pred_mean - 1.96 * std,
                            pred_mean + 1.96 * std,
                            color='b')

            # Show sampled functions from posterior.
            sf = mvnorm(pred_mean, pred_cov, size=5)  # [ns, nd]
            ax.plot(x_plot, sf.T)

            ax.plot(X, y, 'k.')
            ax.set_ylim([-2, 3])
            ax.set_xticks([])
            ax.set_yticks([])
            plt.draw()
            plt.pause(1.0 / 60.0)

    # Initialize covariance parameters
    rs = npr.RandomState(0)
    init_params = 0.1 * rs.randn(num_params)
    cov_params = adam(grad(objective),
                      init_params,
                      step_size=0.01,
                      num_iters=iters,
                      callback=callback)

    if save_plots:
        D = X, y[:, None]
        x_plot = np.reshape(np.linspace(-8, 8, 400), (400, 1))
        p = sample_functions(cov_params, X, y, x_plot, samples)
        print(p.shape)
        plotting.plot_deciles(x_plot.ravel(), p, D, "gppost", plot="gp")
Esempio n. 8
0
            print("Iteration {} | vlb {}".format(t, -objective(params, t)))




        callback_vlb = lambda params, t, g: callback(params, t, g, vlb)

        init_var_params = init_bnn_params(num_weights)

        var_params = adam(grad(vlb), init_var_params,
                          step_size=0.1, num_iters=iters_2, callback=callback_vlb)

    # PLOT STUFF BELOW HERE


    if save_plot:
        N_data = 400
        N_samples = 5
        D = (inputs.ravel(), targets.ravel())
        x_plot = np.linspace(-8, 8, num=N_data)
        save_title = "exp-" + str(exp_num)


        # predictions from posterior of bnn
        p = sample_bnn(x_plot, N_samples, var_params)

        save_dir = os.path.join(os.getcwd(), 'plots', 'gpp-bnn', save_title+data)
        plotting.plot_deciles(x_plot, p, D, save_dir, plot="gpp")