class StandardLogisticDistribution:
    def __init__(self, data_dim=28 * 28, device='cpu'):
        self.m = TransformedDistribution(
            Uniform(torch.zeros(data_dim, device=device),
                    torch.ones(data_dim, device=device)), [
                        SigmoidTransform().inv,
                        AffineTransform(torch.zeros(data_dim, device=device),
                                        torch.ones(data_dim, device=device))
                    ])

    def log_pdf(self, z):
        return self.m.log_prob(z).sum(dim=1)

    def sample(self):
        return self.m.sample()
Ejemplo n.º 2
0
def test_logistic():
    base_distribution = Uniform(0, 1)
    transforms = [SigmoidTransform().inv, AffineTransform(loc=torch.tensor([2.]), scale=torch.tensor([1.]))]
    model = TransformedDistribution(base_distribution, transforms)
    transform = Logistic(2., 1.)

    x = model.sample((4,)).reshape(-1, 1)
    assert torch.all(transform.log_prob(x)- model.log_prob(x).view(-1) < 1e-4)

    x = transform.sample(4)
    assert x.shape == (4, 1)
    assert torch.all(transform.log_prob(x)- model.log_prob(x).view(-1) < 1e-4)

    x = transform.sample(1)
    assert x.shape == (1, 1)
    assert torch.all(transform.log_prob(x)- model.log_prob(x).view(-1) < 1e-4)

    transform.get_parameters()
Ejemplo n.º 3
0
class BNN_SGDMC(nn.Module, BNN):
    def __init__(self,
                 dim,
                 act=nn.ReLU(),
                 num_hiddens=[50],
                 nout=1,
                 conf=dict()):
        nn.Module.__init__(self)
        BNN.__init__(self)
        self.dim = dim
        self.act = act
        self.num_hiddens = num_hiddens
        self.nout = nout
        self.steps_burnin = conf.get('steps_burnin', 2500)
        self.steps = conf.get('steps', 2500)
        self.keep_every = conf.get('keep_every', 50)
        self.batch_size = conf.get('batch_size', 32)
        self.warm_start = conf.get('warm_start', False)

        self.lr_weight = np.float32(conf.get('lr_weight', 1e-3))
        self.lr_noise = np.float32(conf.get('lr_noise', 1e-3))
        self.lr_lambda = np.float32(conf.get('lr_lambda', 1e-3))
        self.alpha_w = torch.as_tensor(1. * conf.get('alpha_w', 6.))
        self.beta_w = torch.as_tensor(1. * conf.get('beta_w', 6.))
        self.alpha_n = torch.as_tensor(1. * conf.get('alpha_n', 6.))
        self.beta_n = torch.as_tensor(1. * conf.get('beta_n', 6.))
        self.noise_level = conf.get('noise_level', None)
        if self.noise_level is not None:
            prec = 1 / self.noise_level**2
            prec_var = (prec * 0.25)**2
            self.beta_n = torch.as_tensor(prec / prec_var)
            self.alpha_n = torch.as_tensor(prec * self.beta_n)
            print("Reset alpha_n = %g, beta_n = %g" %
                  (self.alpha_n, self.beta_n))

        self.prior_log_lambda = TransformedDistribution(
            Gamma(self.alpha_w, self.beta_w),
            ExpTransform().inv)  # log of gamma distribution
        self.prior_log_precision = TransformedDistribution(
            Gamma(self.alpha_n, self.beta_n),
            ExpTransform().inv)

        self.log_lambda = nn.Parameter(torch.tensor(0.))
        self.log_precs = nn.Parameter(torch.zeros(self.nout))
        self.nn = NN(dim, self.act, self.num_hiddens, self.nout)

        self.init_nn()

    def init_nn(self):
        self.log_lambda.data = self.prior_log_lambda.sample()
        self.log_precs.data = self.prior_log_precision.sample((self.nout, ))
        for layer in self.nn.nn:
            if isinstance(layer, nn.Linear):
                layer.weight.data = torch.distributions.Normal(
                    0, 1 / self.log_lambda.exp().sqrt()).sample(
                        layer.weight.shape)
                layer.bias.data = torch.zeros(layer.bias.shape)

    def log_prior(self):
        log_p = self.prior_log_lambda.log_prob(self.log_lambda).sum()
        log_p += self.prior_log_precision.log_prob(self.log_precs).sum()

        lambd = self.log_lambda.exp()
        for n, p in self.nn.nn.named_parameters():
            if "weight" in n:
                log_p += -0.5 * lambd * torch.sum(p**2) + 0.5 * p.numel() * (
                    self.log_lambda - np.log(2 * np.pi))
        return log_p

    def log_lik(self, X, y):
        y = y.view(-1, self.nout)
        nout = self.nn(X).view(-1, self.nout)
        precs = self.log_precs.exp()
        log_lik = -0.5 * precs * (
            y - nout)**2 + 0.5 * self.log_precs - 0.5 * np.log(2 * np.pi)
        return log_lik.sum()

    def sgld_steps(self, num_steps, num_train):
        step_cnt = 0
        loss = 0.
        while (step_cnt < num_steps):
            for bx, by in self.loader:
                log_prior = self.log_prior()
                log_lik = self.log_lik(bx, by)
                loss = -1 * (log_lik * (num_train / bx.shape[0]) + log_prior)
                self.opt.zero_grad()
                loss.backward()
                self.opt.step()
                self.scheduler.step()
                step_cnt += 1
                if step_cnt >= num_steps:
                    break
        return loss

    def train(self, X, y):
        y = y.view(-1, self.nout)
        num_train = X.shape[0]
        params = [{
            'params': self.nn.nn.parameters(),
            'lr': self.lr_weight
        }, {
            'params': self.log_precs,
            'lr': self.lr_noise
        }, {
            'params': self.log_lambda,
            'lr': self.lr_lambda
        }]
        # self.opt       = aSGHMC(params, num_burn_in_steps = self.steps_burnin)
        # self.scheduler = optim.lr_scheduler.LambdaLR(self.opt, lambda iter : np.float32(1.))

        self.opt = pSGLD(params)
        self.scheduler = optim.lr_scheduler.LambdaLR(
            self.opt, lambda iter: np.float32((1 + iter)**-0.33))

        self.loader = DataLoader(TensorDataset(X, y),
                                 batch_size=self.batch_size,
                                 shuffle=True)
        step_cnt = 0
        self.nns = []
        self.lrs = []
        if not self.warm_start:
            self.init_nn()

        _ = self.sgld_steps(self.steps_burnin, num_train)  # burn-in

        while (step_cnt < self.steps):
            loss = self.sgld_steps(self.keep_every, num_train)
            step_cnt += self.keep_every
            prec = self.log_precs.exp().mean()
            wstd = 1 / self.log_lambda.exp().sqrt()
            print('Step %4d, loss = %8.2f, precision = %g, weight_std = %g' %
                  (step_cnt, loss, prec, wstd),
                  flush=True)
            self.nns.append(deepcopy(self.nn))
        print('Number of samples: %d' % len(self.nns))

    def sample(self, num_samples=1):
        assert (num_samples <= len(self.nns))
        return np.random.permutation(self.nns)[:num_samples]

    def sample_predict(self, nns, input):
        num_samples = len(nns)
        num_x = input.shape[0]
        pred = torch.empty(num_samples, num_x, self.nout)
        for i in range(num_samples):
            pred[i] = nns[i](input)
        return pred

    def report(self):
        print(self.nn.nn)