class BLinear(nn.Module):
    """Bayesian Linear layer, default prior is Gaussian for weights and bias"""

    def __init__(self, in_features, out_features):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        # Weight parameters
        self.weight_mu = nn.Parameter(
            torch.Tensor(out_features, in_features).uniform_(-0.2, 0.2))
        self.weight_rho = nn.Parameter(
            torch.Tensor(out_features, in_features).uniform_(1e-1, 2))
        # variational posterior for the weights
        self.weight = Gaussian(self.weight_mu, self.weight_rho)

        # Bias parameters
        self.bias_mu = nn.Parameter(
            torch.Tensor(out_features).uniform_(-0.2, 0.2))
        self.bias_rho = nn.Parameter(
            torch.Tensor(out_features).uniform_(1e-1, 2))
        # variational posterior for the bias
        self.bias = Gaussian(self.bias_mu, self.bias_rho)

        # Prior distributions
        self.weight_prior = Gaussian(torch.Tensor([0.]), torch.Tensor([1.]))
        self.bias_prior = Gaussian(torch.Tensor([0.]), torch.Tensor([1.]))

        # initialize log_prior and log_posterior as 0
        self.log_prior = 0
        self.log_variational_posterior = 0

    def forward(self, input, sample=False, calculate_log_probs=False):
        # 1. Sample weights and bias from variational posterior
        if self.training or sample:
            weight = self.weight.sample()
            bias = self.bias.sample()
        else:
            weight = self.weight.mu
            bias = self.bias.mu

        # 2. Update log_prior and log_posterior according to current approximation
        if self.training or calculate_log_probs:
            self.log_prior = self.weight_prior.log_prob(
                weight) + self.bias_prior.log_prob(bias)
            self.log_variational_posterior = self.weight.log_prob(
                weight) + self.bias.log_prob(bias)
        else:
            self.log_prior, self.log_variational_posterior = 0, 0

        # 3. Do a forward pass through the layer
        return F.linear(input, weight, bias)
Esempio n. 2
0
class VAE(nn.Module):
    def __init__(self,img_params,model_params,latent_params):
        super(VAE, self).__init__()
        image_dim  = img_params['image_dim']
        image_size = img_params['image_size']

        n_downsample = model_params['n_downsample']
        dim      = model_params['dim']
        n_res    = model_params['n_res']
        norm     = model_params['norm']
        activ    = model_params['activ']
        pad_type = model_params['pad_type']
        n_mlp    = model_params['n_mlp']
        mlp_dim  = model_params['mlp_dim']

        self.latent_dim = latent_params['latent_dim']
        self.prior = Gaussian(self.latent_dim)

        self.encoder = Encoder(n_downsample,n_res,n_mlp,image_size,image_dim,dim,mlp_dim,
                               self.latent_dim,norm,activ,pad_type)

        conv_inp_size = image_size // (2**n_downsample)
        self.decoder = Decoder(n_downsample,n_res,n_mlp,self.latent_dim,mlp_dim,conv_inp_size,
                               dim,image_dim,norm,activ,pad_type)

    def forward(self,x):
        latent_distr = self.encoder(x)
        latent_distr = self.prior.activate(latent_distr)
        samples = self.prior.sample(latent_distr)
        return self.decoder(samples),latent_distr,samples
Esempio n. 3
0
class TestGaussian(unittest.TestCase):
    
    def setUp(self):
        self.g1f = Gaussian(1, False)
        self.g2f = Gaussian(2, False)
        self.g4t = Gaussian(4, True)
        self.param1f = np.array([0,2])
        self.param2f = np.array([np.arange(6), 2*np.arange(6)])
        self.param4t = np.array([np.arange(8), -np.arange(8)])
       
    def test_reparam(self):
        t1f = {"g_mu":np.array([0]), "g_Sig":np.array([4]), "g_Siginv":np.array([0.25])}
        theta1f = self.g1f.reparam(self.param1f)
        self.assertDictEqual(t1f, theta1f)
        
        t2f = {"g_mu":np.array([[0,1],[0,2]]), "g_Sig":np.array([[[13,23],[23,41]],[[52,92],[92,164]]]), "g_Siginv":np.array([[[10.25,-5.75],[-5.75,3.25]],[[2.5625,-1.4375],[-1.4375,0.8125]]])}
        theta2f = self.g2f.reparam(self.param2f)
        for key in t2f.keys():
            with self.subTest(key=key):
                self.assertTrue(np.all(np.round(t2f[key],2)==np.round(theta2f[key],2)))
        
        t4t = {"g_mu":np.array([[0,1,2,3],[0,-1,-2,-3]]), "g_Sig":np.array([np.exp([4,5,6,7]),np.exp([-4,-5,-6,-7])])}
        theta4t = self.g4t.reparam(self.param4t)
        for key in t4t.keys():
            with self.subTest(key=key):
                self.assertTrue(np.all(t4t[key]==theta4t[key]))
    
    def test_logpdf(self):
        X1f = np.random.randn(5)
        p1f = norm.logpdf(X1f, 0, 2)
        logp1f = self.g1f.logpdf(self.param1f, X1f)
        self.assertTrue(np.all(np.round(p1f[:,np.newaxis],5)==np.round(logp1f,5)))
        
        X2f = np.random.randn(2)
        p2f1 = multivariate_normal.logpdf(X2f, np.array([0,1]), np.array([[13,23],[23,41]]))
        p2f2 = multivariate_normal.logpdf(X2f, np.array([0,2]), np.array([[52,92],[92,164]]))
        p2f = np.array([p2f1, p2f2])
        logp2f = self.g2f.logpdf(self.param2f, X2f)
        self.assertTrue(np.all(np.round(p2f[np.newaxis,:],5)==np.round(logp2f,5)))
        
        X4t = np.random.randn(8).reshape((2,4))
        p4t1 = multivariate_normal.logpdf(X4t, np.array([0,1,2,3]),np.diag(np.exp(np.array([4,5,6,7]))))
        p4t2 = multivariate_normal.logpdf(X4t, np.array([0,-1,-2,-3]),np.diag(np.exp(np.array([-4,-5,-6,-7]))))
        p4t = np.hstack((p4t1[:,np.newaxis], p4t2[:,np.newaxis]))
        logp4t = self.g4t.logpdf(self.param4t, X4t)
        self.assertTrue(np.all(np.around(p4t,5)==np.around(logp4t,5)))
    
    def test_sample(self):
        np.random.seed(1)
        x1 = 0 + 2* np.random.randn(5)
        x2 = np.array([0,1]) + np.dot(np.random.randn(1,2), np.array([[2,3],[4,5]]))
        x4 = np.array([0,1,2,3]) + np.dot(np.random.randn(10,4), np.diag(np.exp(np.array([4,5,6,7])/2)))
        
        np.random.seed(1)
        samp1 = self.g1f.sample(self.param1f, 5)
        samp2 = self.g2f.sample(self.param2f[0,:], 1)
        samp4 = self.g4t.sample(self.param4t[0,:], 10)
        
        self.assertTrue(np.all(np.round(x1[:,np.newaxis],3)==np.round(samp1,3)))
        self.assertTrue(np.all(np.round(x2,3)==np.round(samp2,3)))
        self.assertTrue(np.all(np.round(x4,3)==np.round(samp4,3)))
    
    def test_cross_sample(self):
        var1f = 2 / np.array([1/4+1/25])
        mu1f = np.array([0])
        cov2f = 2 * np.linalg.inv(np.array([[10.25,-5.75],[-5.75,3.25]])+np.array([[2.5625,-1.4375],[-1.4375,0.8125]]))
        mu2f = 0.5*np.dot(cov2f, np.dot(np.array([[10.25,-5.75],[-5.75,3.25]]),np.array([0,1])) + np.dot(np.array([[2.5625,-1.4375],[-1.4375,0.8125]]),np.array([0,2])))
        cov4t = 2 /(np.exp(np.arange(-4,-8,-1)) + np.exp(np.arange(4,8)))
        mu4t = 0.5*(cov4t*(np.arange(4)*np.exp(np.arange(-4,-8,-1))+np.arange(0,-4,-1)*np.exp(np.arange(4,8))))               
        np.random.seed(1)
        x1 = np.random.multivariate_normal(mu1f, np.atleast_2d(var1f), 5)
        np.random.seed(2)
        x2 = np.random.multivariate_normal(mu2f, cov2f, 1)
        np.random.seed(0)
        x4 = mu4t + np.sqrt(cov4t)*np.random.randn(10, 4)
        
        np.random.seed(1)
        samp1 = self.g1f.cross_sample(self.param1f, -2.5*self.param1f, 5)
        np.random.seed(2)
        samp2 = self.g2f.cross_sample(self.param2f[0,:], self.param2f[1,:], 1)
        np.random.seed(0)
        samp4 = self.g4t.cross_sample(self.param4t[0,:], self.param4t[1,:], 10)
        
        self.assertTrue(np.all(np.round(x1,3)==np.round(samp1,3)))
        self.assertTrue(np.all(np.round(x2,3)==np.round(samp2,3)))
        self.assertTrue(np.all(np.round(x4,3)==np.round(samp4,3)))
    
    def test_log_sqrt_pair_integral(self):
        l1 = -0.5*np.log(10/(2*4))
        difmu2 = np.array([0,1])
        s = np.array([[13,23],[23,41]])
        S = np.array([[52,92],[92,164]])
        S2 = np.array([[32.5,57.5],[57.5,102.5]])
        l21 = -0.125*(difmu2*np.linalg.solve(S2,difmu2)).sum()- 0.5*np.linalg.slogdet(S2)[1] + 0.25*np.linalg.slogdet(s)[1] + 0.25*np.linalg.slogdet(S)[1]
        l2 = np.array([0,l21])
        difmu4 = np.array([[0,-0.5,-1,-1.5],[0,2.5,5,7.5]])
        lsig = self.param4t[0, 4:]
        Lsig = self.param4t[:,4:]*1.5
        lSig2 = np.log(0.5)+np.logaddexp(lsig, Lsig)
        l4 = -0.125*np.sum(np.exp(-lSig2)*difmu4**2, axis=1) - 0.5*np.sum(lSig2, axis=1) + 0.25*np.sum(lsig) + 0.25*np.sum(Lsig, axis=1)
        
        la1f = self.g1f.log_sqrt_pair_integral(self.param1f, -2*self.param1f)
        la2f = self.g2f.log_sqrt_pair_integral(self.param2f[0,:], self.param2f)
        la4t = self.g4t.log_sqrt_pair_integral(self.param4t[0,:], 1.5*self.param4t)
        
        self.assertTrue(np.all(np.round(l1,5)==np.round(la1f,5)))
        self.assertTrue(np.all(np.round(l2,5)==np.round(la2f,5)))
        self.assertTrue(np.all(np.round(l4,5)==np.round(la4t,5)))
    
    def test_params_init(self):
        np.random.seed(1)
        m0 = np.random.multivariate_normal(np.zeros(2), 4*np.eye(2))
        prm0f = np.concatenate((m0, np.array([1,0,0,1])))
        np.random.seed(2)
        m0 = np.random.multivariate_normal(np.zeros(2), 4*np.eye(2))
        prm0t = np.concatenate((m0, np.zeros(2)))
        np.random.seed(4)
        mu = np.array([[0,1,2,3],[0,-1,-2,-3]])
        k = np.random.choice(2, p=np.array([0.5,0.5]))
        lsig = np.array([[4,5,6,7],[-4,-5,-6,-7]])
        mu0 = mu[k]+np.random.randn(4)*np.sqrt(10)*np.exp(lsig[k,:])
        LSig = np.random.randn(4)+lsig[k] 
        prm4t = np.hstack((mu0, LSig))
        
        g2t = Gaussian(2, True)
        np.random.seed(1)
        par0f = self.g2f.params_init(np.empty((0,6)), np.empty((0,1)), 4)
        np.random.seed(2)
        par0t = g2t.params_init(np.empty((0,6)), np.empty((0,1)), 4)
        np.random.seed(4)
        par4t = self.g4t.params_init(self.param4t, np.array([0.1, 0.9]), 10)
        
        self.assertTrue(np.all(np.round(par0f,3)==np.round(prm0f,3)))
        self.assertTrue(np.all(np.round(par0t,3)==np.round(prm0t,3)))
        self.assertTrue(np.all(np.round(par4t,3)==np.round(prm4t,3)))
Esempio n. 4
0
class CatVAE(nn.Module):
    # Auto-encoder architecture
    def __init__(self,img_params,model_params,latent_params):
        super(CatVAE, self).__init__()
        image_dim  = img_params['image_dim']
        image_size = img_params['image_size']

        n_downsample = model_params['n_downsample']
        dim      = model_params['dim']
        n_res    = model_params['n_res']
        norm     = model_params['norm']
        activ    = model_params['activ']
        pad_type = model_params['pad_type']
        n_mlp    = model_params['n_mlp']
        mlp_dim  = model_params['mlp_dim']

        self.continious_dim = latent_params['continious']
        self.prior_cont = Gaussian(self.continious_dim)

        self.categorical_dim = latent_params['categorical']
        self.prior_catg = Categorical(self.categorical_dim)
        self.gumbel     = Gumbel(self.categorical_dim)

        self.encoder = CatEncoder(n_downsample,n_res,n_mlp,image_size,image_dim,dim,mlp_dim,
                                  latent_params,norm,activ,pad_type)

        conv_inp_size = image_size // (2**n_downsample)
        decoder_inp_dim = self.continious_dim + self.categorical_dim
        self.decoder = Decoder(n_downsample,n_res,n_mlp,decoder_inp_dim,mlp_dim,conv_inp_size,
                               dim,image_dim,norm,activ,pad_type)

    def forward(self, x, tempr):
        latent_distr = self.encoder(x)

        #categorical distr
        categorical_distr     = latent_distr[:,-self.categorical_dim:]
        categorical_distr_act = self.prior_catg.activate(categorical_distr)# need for KL

        catg_samples = self.gumbel.gumbel_softmax_sample(categorical_distr,tempr) # categotical sampling, reconstruction

        #continious distr
        continious_distr     = latent_distr[:,:-self.categorical_dim]
        continious_distr_act = self.prior_cont.activate(continious_distr)
        cont_samples         = self.prior_cont.sample(continious_distr_act)

        #create full latent code
        full_samples = torch.cat([cont_samples,catg_samples],1)

        recons = self.decoder(full_samples)

        return recons, full_samples, categorical_distr_act, continious_distr_act
        
    def encode_decode(self, x, tempr=0.4, hard_catg=True):
        latent_distr = self.encoder(x)

        #categorical distr stuff
        categorical_distr = latent_distr[:,-self.categorical_dim:]
        if hard_catg:
            #just make one hot vector
            catg_samples = self.prior_catg.logits_to_onehot(categorical_distr)
        else:
            #make smoothed one hot by softmax
            catg_samples = self.prior_catg.activate(categorical_distr)['prob']

        #continious distr stuff
        continious_distr     = latent_distr[:,:-self.categorical_dim]
        continious_distr_act = self.prior_cont.activate(continious_distr)
        cont_samples         = continious_distr_act['mean']

        #create full latent code
        full_samples = torch.cat([cont_samples,catg_samples],1)
        recons = self.decoder(full_samples)

        return recons, full_samples#, categorical_distr_act, continious_distr_act

    def sample_full_prior(self, batch_size, device='cuda:0'):
        cont_samples = self.prior_cont.sample_prior(batch_size, device=device)
        catg_samples = self.prior_catg.sample_prior(batch_size, device=device)
        full_samples = torch.cat([cont_samples,catg_samples],1)
        return full_samples
Esempio n. 5
0
class GBN(Layer):
    def __init__(self,
                 dim_in,
                 dim_h,
                 posterior=None,
                 conditional=None,
                 prior=None,
                 name='gbn',
                 **kwargs):

        self.dim_in = dim_in
        self.dim_h = dim_h

        self.posterior = posterior
        self.conditional = conditional
        self.prior = prior

        kwargs = init_weights(self, **kwargs)
        kwargs = init_rngs(self, **kwargs)

        super(GBN, self).__init__(name=name)

    @staticmethod
    def mlp_factory(dim_h,
                    dims,
                    distributions,
                    prototype=None,
                    recognition_net=None,
                    generation_net=None):
        mlps = {}

        if recognition_net is not None:
            t = recognition_net.get('type', None)
            if t == 'mmmlp':
                raise NotImplementedError()
            elif t == 'lfmlp':
                recognition_net['prototype'] = prototype

            input_name = recognition_net.get('input_layer')
            recognition_net['distribution'] = 'gaussian'
            recognition_net['dim_in'] = dims[input_name]
            recognition_net['dim_out'] = dim_h
            posterior = resolve_mlp(t).factory(**recognition_net)
            mlps['posterior'] = posterior

        if generation_net is not None:
            output_name = generation_net['output']
            generation_net['dim_in'] = dim_h

            t = generation_net.get('type', None)

            if t == 'mmmlp':
                for out in generation_net['graph']['outputs']:
                    generation_net['graph']['outs'][out] = dict(
                        dim=dims[out], distribution=distributions[out])
                conditional = resolve_mlp(t).factory(**generation_net)
            else:
                if t == 'lfmlp':
                    generation_net['filter_in'] = False
                    generation_net['prototype'] = prototype
                generation_net['dim_out'] = dims[output_name]
                generation_net['distribution'] = distributions[output_name]
                conditional = resolve_mlp(t).factory(**generation_net)
            mlps['conditional'] = conditional

        return mlps

    def set_params(self):
        self.params = OrderedDict()

        if self.prior is None:
            self.prior = Gaussian(self.dim_h)

        if self.posterior is None:
            self.posterior = MLP(self.dim_in,
                                 self.dim_h,
                                 dim_hs=[],
                                 rng=self.rng,
                                 trng=self.trng,
                                 h_act='T.nnet.sigmoid',
                                 distribution='gaussian')
        if self.conditional is None:
            self.conditional = MLP(self.dim_h,
                                   self.dim_in,
                                   dim_hs=[],
                                   rng=self.rng,
                                   trng=self.trng,
                                   h_act='T.nnet.sigmoid',
                                   distribution='binomial')

        self.posterior.name = self.name + '_posterior'
        self.conditional.name = self.name + '_conditional'

    def set_tparams(self, excludes=[]):
        tparams = super(GBN, self).set_tparams()
        tparams.update(**self.posterior.set_tparams())
        tparams.update(**self.conditional.set_tparams())
        tparams.update(**self.prior.set_tparams())
        tparams = OrderedDict(
            (k, v) for k, v in tparams.iteritems() if k not in excludes)

        return tparams

    def get_params(self):
        params = self.prior.get_params() + self.conditional.get_params()
        return params

    def get_prior_params(self, *params):
        params = list(params)
        return params[:self.prior.n_params]

    # Latent sampling ---------------------------------------------------------

    def sample_from_prior(self, n_samples=100):
        h, updates = self.prior.sample(n_samples)
        py = self.conditional.feed(h)
        return self.get_center(py), updates

    def visualize_latents(self):
        h0 = self.prior.mu
        y0_mu, y0_logsigma = self.conditional.distribution.split_prob(
            self.conditional.feed(h0))

        sigma = T.nlinalg.AllocDiag()(T.exp(self.prior.log_sigma))
        h = 10 * sigma.astype(floatX) + h0[None, :]
        y_mu, y_logsigma = self.conditional.distribution.split_prob(
            self.conditional.feed(h))
        py = y_mu - y0_mu[None, :]
        return py  # / py.std()#T.exp(y_logsigma)

    # Misc --------------------------------------------------------------------

    def get_center(self, p):
        return self.conditional.get_center(p)

    def p_y_given_h(self, h, *params):
        start = self.prior.n_params
        stop = start + self.conditional.n_params
        params = params[start:stop]
        return self.conditional.step_feed(h, *params)

    def kl_divergence(self, p, q):
        dim = self.dim_h
        mu_p = _slice(p, 0, dim)
        log_sigma_p = _slice(p, 1, dim)
        mu_q = _slice(q, 0, dim)
        log_sigma_q = _slice(q, 1, dim)

        kl = log_sigma_q - log_sigma_p + 0.5 * (
            (T.exp(2 * log_sigma_p) +
             (mu_p - mu_q)**2) / T.exp(2 * log_sigma_q) - 1)
        return kl.sum(axis=kl.ndim - 1)

    def l2_decay(self, rate):
        rec_l2_cost = self.posterior.get_L2_weight_cost(rate)
        gen_l2_cost = self.conditional.get_L2_weight_cost(rate)

        rval = OrderedDict(rec_l2_cost=rec_l2_cost,
                           gen_l2_cost=gen_l2_cost,
                           cost=rec_l2_cost + gen_l2_cost)

        return rval

    def init_inference_samples(self, size):
        return self.posterior.distribution.prototype_samples(size)

    def __call__(self,
                 x,
                 y,
                 qk=None,
                 n_posterior_samples=10,
                 pass_gradients=False):
        q0 = self.posterior.feed(x)

        if qk is None:
            qk = q0

        h, updates = self.posterior.sample(qk, n_samples=n_posterior_samples)
        py = self.conditional.feed(h)

        log_py_h = -self.conditional.neg_log_prob(y[None, :, :], py)
        KL_qk_p = self.prior.kl_divergence(qk)
        qk_c = qk.copy()
        if pass_gradients:
            KL_qk_q0 = T.constant(0.).astype(floatX)
        else:
            KL_qk_q0 = self.kl_divergence(qk_c, q0)

        log_ph = -self.prior.neg_log_prob(h)
        log_qh = -self.posterior.neg_log_prob(h, qk[None, :, :])

        log_p = (log_sum_exp(log_py_h + log_ph - log_qh, axis=0) -
                 T.log(n_posterior_samples))

        y_energy = -log_py_h.mean(axis=0)
        p_entropy = self.prior.entropy()
        q_entropy = self.posterior.entropy(qk)
        nll = -log_p

        lower_bound = -(y_energy + KL_qk_p).mean()
        if pass_gradients:
            cost = (y_energy + KL_qk_p).mean(0)
        else:
            cost = (y_energy + KL_qk_p + KL_qk_q0).mean(0)

        mu, log_sigma = self.posterior.distribution.split_prob(qk)

        results = OrderedDict({
            'mu': mu.mean(),
            'log_sigma': log_sigma.mean(),
            '-log p(x|h)': y_energy.mean(0),
            '-log p(x)': nll.mean(0),
            '-log p(h)': log_ph.mean(),
            '-log q(h)': log_qh.mean(),
            'KL(q_k||p)': KL_qk_p.mean(0),
            'KL(q_k||q_0)': KL_qk_q0.mean(),
            'H(p)': p_entropy,
            'H(q)': q_entropy.mean(0),
            'lower_bound': lower_bound,
            'cost': cost
        })

        samples = OrderedDict(py=py, batch_energies=y_energy)

        constants = [qk_c]
        return results, samples, constants