Esempio n. 1
0
def gpnet_nonconj(args, dataloader, test_x, prior_gp):
    N = len(dataloader.dataset)
    x_dim = 1
    prior_gp.train()

    if args.net == 'tangent':
        kernel = prior_gp.covar_module
        bnn_prev = FirstOrder([x_dim] + [args.n_hidden] * args.n_layer,
                              mvn=False)
        bnn = FirstOrder([x_dim] + [args.n_hidden] * args.n_layer, mvn=True)
    elif args.net == 'deep':
        kernel = prior_gp.covar_module
        bnn_prev = DeepKernel([x_dim] + [args.n_hidden] * args.n_layer,
                              mvn=False)
        bnn = DeepKernel([x_dim] + [args.n_hidden] * args.n_layer, mvn=True)
    elif args.net == 'rf':
        kernel = ScaleKernel(RBFKernel())
        kernel_prev = ScaleKernel(RBFKernel())
        bnn_prev = RFExpansion(x_dim,
                               args.n_hidden,
                               kernel_prev,
                               mvn=False,
                               fix_ls=args.fix_rf_ls,
                               residual=args.residual)
        bnn = RFExpansion(x_dim,
                          args.n_hidden,
                          kernel,
                          fix_ls=args.fix_rf_ls,
                          residual=args.residual)
        bnn_prev.load_state_dict(bnn.state_dict())
    else:
        raise NotImplementedError('Unknown inference net')

    infer_gpnet_optimizer = optim.Adam(bnn.parameters(), lr=args.learning_rate)
    hyper_opt_optimizer = optim.Adam(prior_gp.parameters(), lr=args.hyper_rate)

    x_min, x_max = dataloader.dataset.range
    n = dataloader.batch_size

    bnn.train()
    bnn_prev.train()
    prior_gp.train()

    mb = master_bar(range(1, args.n_iters + 1))

    for t in mb:
        beta = args.beta0 * 1. / (1. + args.gamma * math.sqrt(t - 1))
        dl_bar = progress_bar(dataloader, parent=mb)
        for x, y in dl_bar:
            n = x.size(0)
            x_star = torch.Tensor(args.measurement_size,
                                  x_dim).uniform_(x_min, x_max)
            xx = torch.cat([x, x_star], 0)

            # inference net
            infer_gpnet_optimizer.zero_grad()
            hyper_opt_optimizer.zero_grad()

            qff = bnn(xx)
            qff_mean_prev, K_prox = bnn_prev(xx)
            qf_mean, qf_var = bnn(x, full_cov=False)

            # Eq.(8)
            K_prior = kernel(xx, xx).add_jitter(1e-6)
            pff = MultivariateNormal(torch.zeros(xx.size(0)), K_prior)

            f_term = expected_log_prob(prior_gp.likelihood, qf_mean, qf_var,
                                       y.squeeze(-1))
            f_term = torch.sum(
                expected_log_prob(prior_gp.likelihood, qf_mean, qf_var,
                                  y.squeeze(-1)))
            f_term *= N / x.size(0) * beta

            prior_term = -beta * cross_entropy(qff, pff)

            qff_prev = MultivariateNormal(qff_mean_prev, K_prox)
            prox_term = -(1 - beta) * cross_entropy(qff, qff_prev)

            entropy_term = entropy(qff)

            lower_bound = f_term + prior_term + prox_term + entropy_term
            loss = -lower_bound / n

            loss.backward(retain_graph=True)

            infer_gpnet_optimizer.step()

            # Hyper-parameter update
            Kn_prior = K_prior[:n, :n]
            pf = MultivariateNormal(torch.zeros(n), Kn_prior)
            Kn_prox = K_prox[:n, :n]
            qf_prev_mean = qff_mean_prev[:n]
            qf_prev_var = torch.diagonal(Kn_prox)
            qf_prev = MultivariateNormal(qf_prev_mean, Kn_prior)
            hyper_obj = expected_log_prob(
                prior_gp.likelihood, qf_prev_mean, qf_prev_var,
                y.squeeze(-1)).sum() - kl_div(qf_prev, pf)
            hyper_obj = -hyper_obj
            hyper_obj.backward()
            hyper_opt_optimizer.step()

        bnn_prev.load_state_dict(bnn.state_dict())
        if args.net == 'rf':
            kernel_prev.load_state_dict(kernel.state_dict())
        if t % 50 == 0:
            mb.write("Iter {}/{}, kl_obj = {:.4f}, noise = {:.4f}".format(
                t, args.n_iters, lower_bound.item(),
                prior_gp.likelihood.noise.item()))
    test_x = test_x.to(args.device)
    test_stats = evaluate(bnn, prior_gp.likelihood, test_x,
                          args.net == 'tangent')

    return test_stats
Esempio n. 2
0
def gpnet(args, dataloader, test_x, prior_gp):
    N = len(dataloader.dataset)
    x_dim = 1
    prior_gp.train()

    if args.net == 'tangent':
        kernel = prior_gp.covar_module
        bnn_prev = FirstOrder([x_dim] + [args.n_hidden] * args.n_layer,
                              mvn=False)
        bnn = FirstOrder([x_dim] + [args.n_hidden] * args.n_layer, mvn=True)
    elif args.net == 'deep':
        kernel = prior_gp.covar_module
        bnn_prev = DeepKernel([x_dim] + [args.n_hidden] * args.n_layer,
                              mvn=False)
        bnn = DeepKernel([x_dim] + [args.n_hidden] * args.n_layer, mvn=True)
    elif args.net == 'rf':
        kernel = ScaleKernel(RBFKernel())
        kernel_prev = ScaleKernel(RBFKernel())
        bnn_prev = RFExpansion(x_dim,
                               args.n_hidden,
                               kernel_prev,
                               mvn=False,
                               fix_ls=args.fix_rf_ls,
                               residual=args.residual)
        bnn = RFExpansion(x_dim,
                          args.n_hidden,
                          kernel,
                          fix_ls=args.fix_rf_ls,
                          residual=args.residual)
        bnn_prev.load_state_dict(bnn.state_dict())
    else:
        raise NotImplementedError('Unknown inference net')
    bnn = bnn.to(args.device)
    bnn_prev = bnn_prev.to(args.device)
    prior_gp = prior_gp.to(args.device)

    infer_gpnet_optimizer = optim.Adam(bnn.parameters(), lr=args.learning_rate)
    hyper_opt_optimizer = optim.Adam(prior_gp.parameters(), lr=args.hyper_rate)

    x_min, x_max = dataloader.dataset.range

    bnn.train()
    bnn_prev.train()
    prior_gp.train()

    mb = master_bar(range(1, args.n_iters + 1))

    for t in mb:
        # Hyperparameter selection
        beta = args.beta0 * 1. / (1. + args.gamma * math.sqrt(t - 1))
        dl_bar = progress_bar(dataloader, parent=mb)
        for x, y in dl_bar:
            observed_size = x.size(0)
            x, y = x.to(args.device), y.to(args.device)
            x_star = torch.Tensor(args.measurement_size,
                                  x_dim).uniform_(x_min, x_max).to(args.device)
            # [Batch + Measurement Points x x_dims]
            xx = torch.cat([x, x_star], 0)

            infer_gpnet_optimizer.zero_grad()
            hyper_opt_optimizer.zero_grad()

            # inference net
            # Eq.(6) Prior p(f)
            # \mu_1=0, \Sigma_1
            mean_prior = torch.zeros(observed_size).to(args.device)
            K_prior = kernel(xx, xx).add_jitter(1e-6)

            # q_{\gamma_t}(f_M, f_n) = Normal(mu_2, sigma_2|x_n, x_m)
            # \mu_2, \Sigma_2
            qff_mean_prev, K_prox = bnn_prev(xx)

            # Eq.(8) adapt prior; p(f)^\beta x q(f)^{1 - \beta}
            mean_adapt, K_adapt = product_gaussians(mu1=mean_prior,
                                                    sigma1=K_prior,
                                                    mu2=qff_mean_prev,
                                                    sigma2=K_prox,
                                                    beta=beta)

            # Eq.(8)
            (mean_n, mean_m), (Knn, Knm,
                               Kmm) = split_gaussian(mean_adapt, K_adapt,
                                                     observed_size)

            # Eq.(2) K_{D,D} + noise / (N\beta_t)
            Ky = Knn + torch.eye(observed_size).to(
                args.device) * prior_gp.likelihood.noise / (N / observed_size *
                                                            beta)
            Ky_tril = torch.cholesky(Ky)

            # Eq.(2)
            mean_target = Knm.t().mm(cholesky_solve(y - mean_n,
                                                    Ky_tril)) + mean_m
            mean_target = mean_target.squeeze(-1)
            K_target = gpytorch.add_jitter(
                Kmm - Knm.t().mm(cholesky_solve(Knm, Ky_tril)), 1e-6)
            # \hat{q}_{t+1} (f_M)
            target_pf_star = MultivariateNormal(mean_target, K_target)

            # q_\gamma (f_M)
            qf_star = bnn(x_star)

            # Eq. (11)
            kl_obj = kl_div(qf_star, target_pf_star).sum()

            kl_obj.backward(retain_graph=True)
            infer_gpnet_optimizer.step()

            # Hyper paramter update
            (mean_n_prior, _), (Kn_prior, _,
                                _) = split_gaussian(mean_prior, K_prior,
                                                    observed_size)
            pf = MultivariateNormal(mean_n_prior, Kn_prior)

            (qf_prev_mean, _), (Kn_prox, _,
                                _) = split_gaussian(qff_mean_prev, K_prox,
                                                    observed_size)
            qf_prev = MultivariateNormal(qf_prev_mean, Kn_prox)

            hyper_obj = -(prior_gp.likelihood.expected_log_prob(
                y.squeeze(-1), qf_prev) - kl_div(qf_prev, pf))
            hyper_obj.backward(retain_graph=True)
            hyper_opt_optimizer.step()

            mb.child.comment = "kl_obj = {:.3f}, obs_var={:.3f}".format(
                kl_obj.item(), prior_gp.likelihood.noise.item())

        # update q_{\gamma_t} to q_{\gamma_{t+1}}
        bnn_prev.load_state_dict(bnn.state_dict())
        if args.net == 'rf':
            kernel_prev.load_state_dict(kernel.state_dict())
        if t % 50 == 0:
            mb.write("Iter {}/{}, kl_obj = {:.4f}, noise = {:.4f}".format(
                t, args.n_iters, kl_obj.item(),
                prior_gp.likelihood.noise.item()))

    test_x = test_x.to(args.device)
    test_stats = evaluate(bnn, prior_gp.likelihood, test_x,
                          args.net == 'tangent')
    return test_stats