Esempio n. 1
0
# %% Import packages

import numpy as np
import torch

from eeyore.kernels import NormalKernel
from eeyore.samplers import MetropolisHastings

from bnn_mcmc_examples.examples.mlp.noisy_xor.setting3.dataloaders import training_dataloader
from bnn_mcmc_examples.examples.mlp.noisy_xor.setting3.model import model

# %% Setup proposal variance and proposal kernel for Metropolis-Hastings sampler

proposal_scale = np.sqrt(0.02)

kernel = NormalKernel(
    torch.zeros(model.num_params(), dtype=model.dtype),
    torch.full([model.num_params()], proposal_scale, dtype=model.dtype))

# %% Setup Metropolis-Hastings sampler

sampler = MetropolisHastings(model,
                             theta0=model.prior.sample(),
                             dataloader=training_dataloader,
                             kernel=kernel)
Esempio n. 2
0
 def default_kernel(self, state):
     loc = state['sample']
     scale = torch.ones(self.model.num_params(),
                        dtype=self.model.dtype,
                        device=self.model.device)
     return NormalKernel(loc, scale)
Esempio n. 3
0
hparams = logistic_regression.Hyperparameters(input_size=4, bias=False)
model = logistic_regression.LogisticRegression(
    loss=lambda x, y: binary_cross_entropy(x, y, reduction='sum'),
    hparams=hparams,
    dtype=torch.float32
)
model.prior = Normal(
    torch.zeros(model.num_params(), dtype=model.dtype),
    (3 * torch.ones(model.num_params(), dtype=model.dtype)).sqrt()
)

# %% Setup Metropolis-Hastings sampler

proposal_scale = 0.5
kernel = NormalKernel(
    torch.zeros(model.num_params(), dtype=model.dtype),
    proposal_scale * torch.ones(model.num_params(), dtype=model.dtype)
)
sampler = MetropolisHastings(
    model,
    theta0=model.prior.sample(),
    dataloader=dataloader,
    kernel=kernel
)

# %% Run Metropolis-Hastings sampler

start_time = timer()

sampler.run(num_epochs=11000, num_burnin_epochs=1000)

end_time = timer()
Esempio n. 4
0
 def default_kernel(self, state):
     loc = self.kernel_mean(state)
     scale = torch.full([self.model.num_params()], np.sqrt(self.step), dtype=self.model.dtype, device=self.model.device)
     return NormalKernel(loc, scale)