Example #1
0
for i in range(100):
    tictoc.tic()
    _ = vb.update(maxiter_nc=300,
                  lr_nc=0.01,
                  n_samples_nc=500,
                  n_samples_sn=300,
                  n_iter_sn=300)

    if ((i + 1) % 10) == 0:
        try:
            vb.update_bmcmodel(acquisition=acquisition, vreg=1e-3)
        except:
            continue
        samples_dict[i + 1] = (vb.samples.numpy(), vb.evaluations.numpy())
#        vb.update_full()
    vb.cutweights(1e-3)
    #%% Save trackings
    elapsed = tictoc.toc(printing=False)
    dmeans.append(np.linalg.norm(vb.currentq_mean.cpu() - true_mean, 2))
    dcovs.append(np.linalg.norm(vb.currentq_cov.cpu() - true_cov, 2))
    elbo_list.append(vb.evidence_lower_bound(nsamples=10000).cpu().numpy())
    step_list.append(i + 1)
    time_list.append(elapsed)
    vbp = vb.current_logq(delta_x.reshape(
        -1, 1)).cpu().flatten().numpy().astype(float)
    vbp_list.append(vbp)
    prediction_np = (vb.bmcmodel.prediction(delta_x.reshape(-1,1),cov="none")*vb.evals_std + vb.evals_mean).\
                    numpy().astype(float).flatten()
    prediction_list.append(prediction_np)
    #    vb.save_distribution("%s/distrib%i"%(folder_name,i+1))
    print(vb_utils.kl_vb_bmc(vb, 1000))
Example #2
0
dim=2 #Dimension of problem
samples = torch.randn(20,dim) #Initial samples
mu0 = torch.zeros(dim) #Initial mean
cov0 = 20.0*torch.ones(dim) #Initial covariance
acquisition = "prospective" #Acquisition function

#Initialize algorithm
vb = VariationalBoosting(dim,logjoint,samples,mu0,cov0)
vb.optimize_bmc_model() #Optimize GP model
vb.update_full() #Fit first component

#Training loop
for i in range(100):
    _ = vb.update() #Choose new boosting component
    vb.update_bmcmodel(acquisition=acquisition) #Choose new evaluation
    vb.cutweights(1e-3) #Weights prunning
    if ((i+1)%20) == 0:
        vb.update_full(cutoff=1e-3) #Joint parameter updating

vb.save_distribution("finaldistrib") #Save distribution
#%%
import math
distrib = torch.load("finaldistrib")
nplot = 21
x,y = torch.linspace(-6,6,nplot),torch.linspace(-6,6,nplot)
X,Y = torch.meshgrid(x,y)
Z1 = logjoint(torch.stack([X,Y],dim=-1).reshape(-1,2)).reshape(*X.shape)-\
        2*math.log(math.pi)
Z2 = distrib.logprob(torch.stack([X,Y],dim=-1).reshape(-1,2)).reshape(*X.shape)
#%%
import matplotlib.pyplot as plt