class BLinear(nn.Module): """Bayesian Linear layer, default prior is Gaussian for weights and bias""" def __init__(self, in_features, out_features): super().__init__() self.in_features = in_features self.out_features = out_features # Weight parameters self.weight_mu = nn.Parameter( torch.Tensor(out_features, in_features).uniform_(-0.2, 0.2)) self.weight_rho = nn.Parameter( torch.Tensor(out_features, in_features).uniform_(1e-1, 2)) # variational posterior for the weights self.weight = Gaussian(self.weight_mu, self.weight_rho) # Bias parameters self.bias_mu = nn.Parameter( torch.Tensor(out_features).uniform_(-0.2, 0.2)) self.bias_rho = nn.Parameter( torch.Tensor(out_features).uniform_(1e-1, 2)) # variational posterior for the bias self.bias = Gaussian(self.bias_mu, self.bias_rho) # Prior distributions self.weight_prior = Gaussian(torch.Tensor([0.]), torch.Tensor([1.])) self.bias_prior = Gaussian(torch.Tensor([0.]), torch.Tensor([1.])) # initialize log_prior and log_posterior as 0 self.log_prior = 0 self.log_variational_posterior = 0 def forward(self, input, sample=False, calculate_log_probs=False): # 1. Sample weights and bias from variational posterior if self.training or sample: weight = self.weight.sample() bias = self.bias.sample() else: weight = self.weight.mu bias = self.bias.mu # 2. Update log_prior and log_posterior according to current approximation if self.training or calculate_log_probs: self.log_prior = self.weight_prior.log_prob( weight) + self.bias_prior.log_prob(bias) self.log_variational_posterior = self.weight.log_prob( weight) + self.bias.log_prob(bias) else: self.log_prior, self.log_variational_posterior = 0, 0 # 3. Do a forward pass through the layer return F.linear(input, weight, bias)
class VAE(nn.Module): def __init__(self,img_params,model_params,latent_params): super(VAE, self).__init__() image_dim = img_params['image_dim'] image_size = img_params['image_size'] n_downsample = model_params['n_downsample'] dim = model_params['dim'] n_res = model_params['n_res'] norm = model_params['norm'] activ = model_params['activ'] pad_type = model_params['pad_type'] n_mlp = model_params['n_mlp'] mlp_dim = model_params['mlp_dim'] self.latent_dim = latent_params['latent_dim'] self.prior = Gaussian(self.latent_dim) self.encoder = Encoder(n_downsample,n_res,n_mlp,image_size,image_dim,dim,mlp_dim, self.latent_dim,norm,activ,pad_type) conv_inp_size = image_size // (2**n_downsample) self.decoder = Decoder(n_downsample,n_res,n_mlp,self.latent_dim,mlp_dim,conv_inp_size, dim,image_dim,norm,activ,pad_type) def forward(self,x): latent_distr = self.encoder(x) latent_distr = self.prior.activate(latent_distr) samples = self.prior.sample(latent_distr) return self.decoder(samples),latent_distr,samples
class TestGaussian(unittest.TestCase): def setUp(self): self.g1f = Gaussian(1, False) self.g2f = Gaussian(2, False) self.g4t = Gaussian(4, True) self.param1f = np.array([0,2]) self.param2f = np.array([np.arange(6), 2*np.arange(6)]) self.param4t = np.array([np.arange(8), -np.arange(8)]) def test_reparam(self): t1f = {"g_mu":np.array([0]), "g_Sig":np.array([4]), "g_Siginv":np.array([0.25])} theta1f = self.g1f.reparam(self.param1f) self.assertDictEqual(t1f, theta1f) t2f = {"g_mu":np.array([[0,1],[0,2]]), "g_Sig":np.array([[[13,23],[23,41]],[[52,92],[92,164]]]), "g_Siginv":np.array([[[10.25,-5.75],[-5.75,3.25]],[[2.5625,-1.4375],[-1.4375,0.8125]]])} theta2f = self.g2f.reparam(self.param2f) for key in t2f.keys(): with self.subTest(key=key): self.assertTrue(np.all(np.round(t2f[key],2)==np.round(theta2f[key],2))) t4t = {"g_mu":np.array([[0,1,2,3],[0,-1,-2,-3]]), "g_Sig":np.array([np.exp([4,5,6,7]),np.exp([-4,-5,-6,-7])])} theta4t = self.g4t.reparam(self.param4t) for key in t4t.keys(): with self.subTest(key=key): self.assertTrue(np.all(t4t[key]==theta4t[key])) def test_logpdf(self): X1f = np.random.randn(5) p1f = norm.logpdf(X1f, 0, 2) logp1f = self.g1f.logpdf(self.param1f, X1f) self.assertTrue(np.all(np.round(p1f[:,np.newaxis],5)==np.round(logp1f,5))) X2f = np.random.randn(2) p2f1 = multivariate_normal.logpdf(X2f, np.array([0,1]), np.array([[13,23],[23,41]])) p2f2 = multivariate_normal.logpdf(X2f, np.array([0,2]), np.array([[52,92],[92,164]])) p2f = np.array([p2f1, p2f2]) logp2f = self.g2f.logpdf(self.param2f, X2f) self.assertTrue(np.all(np.round(p2f[np.newaxis,:],5)==np.round(logp2f,5))) X4t = np.random.randn(8).reshape((2,4)) p4t1 = multivariate_normal.logpdf(X4t, np.array([0,1,2,3]),np.diag(np.exp(np.array([4,5,6,7])))) p4t2 = multivariate_normal.logpdf(X4t, np.array([0,-1,-2,-3]),np.diag(np.exp(np.array([-4,-5,-6,-7])))) p4t = np.hstack((p4t1[:,np.newaxis], p4t2[:,np.newaxis])) logp4t = self.g4t.logpdf(self.param4t, X4t) self.assertTrue(np.all(np.around(p4t,5)==np.around(logp4t,5))) def test_sample(self): np.random.seed(1) x1 = 0 + 2* np.random.randn(5) x2 = np.array([0,1]) + np.dot(np.random.randn(1,2), np.array([[2,3],[4,5]])) x4 = np.array([0,1,2,3]) + np.dot(np.random.randn(10,4), np.diag(np.exp(np.array([4,5,6,7])/2))) np.random.seed(1) samp1 = self.g1f.sample(self.param1f, 5) samp2 = self.g2f.sample(self.param2f[0,:], 1) samp4 = self.g4t.sample(self.param4t[0,:], 10) self.assertTrue(np.all(np.round(x1[:,np.newaxis],3)==np.round(samp1,3))) self.assertTrue(np.all(np.round(x2,3)==np.round(samp2,3))) self.assertTrue(np.all(np.round(x4,3)==np.round(samp4,3))) def test_cross_sample(self): var1f = 2 / np.array([1/4+1/25]) mu1f = np.array([0]) cov2f = 2 * np.linalg.inv(np.array([[10.25,-5.75],[-5.75,3.25]])+np.array([[2.5625,-1.4375],[-1.4375,0.8125]])) mu2f = 0.5*np.dot(cov2f, np.dot(np.array([[10.25,-5.75],[-5.75,3.25]]),np.array([0,1])) + np.dot(np.array([[2.5625,-1.4375],[-1.4375,0.8125]]),np.array([0,2]))) cov4t = 2 /(np.exp(np.arange(-4,-8,-1)) + np.exp(np.arange(4,8))) mu4t = 0.5*(cov4t*(np.arange(4)*np.exp(np.arange(-4,-8,-1))+np.arange(0,-4,-1)*np.exp(np.arange(4,8)))) np.random.seed(1) x1 = np.random.multivariate_normal(mu1f, np.atleast_2d(var1f), 5) np.random.seed(2) x2 = np.random.multivariate_normal(mu2f, cov2f, 1) np.random.seed(0) x4 = mu4t + np.sqrt(cov4t)*np.random.randn(10, 4) np.random.seed(1) samp1 = self.g1f.cross_sample(self.param1f, -2.5*self.param1f, 5) np.random.seed(2) samp2 = self.g2f.cross_sample(self.param2f[0,:], self.param2f[1,:], 1) np.random.seed(0) samp4 = self.g4t.cross_sample(self.param4t[0,:], self.param4t[1,:], 10) self.assertTrue(np.all(np.round(x1,3)==np.round(samp1,3))) self.assertTrue(np.all(np.round(x2,3)==np.round(samp2,3))) self.assertTrue(np.all(np.round(x4,3)==np.round(samp4,3))) def test_log_sqrt_pair_integral(self): l1 = -0.5*np.log(10/(2*4)) difmu2 = np.array([0,1]) s = np.array([[13,23],[23,41]]) S = np.array([[52,92],[92,164]]) S2 = np.array([[32.5,57.5],[57.5,102.5]]) l21 = -0.125*(difmu2*np.linalg.solve(S2,difmu2)).sum()- 0.5*np.linalg.slogdet(S2)[1] + 0.25*np.linalg.slogdet(s)[1] + 0.25*np.linalg.slogdet(S)[1] l2 = np.array([0,l21]) difmu4 = np.array([[0,-0.5,-1,-1.5],[0,2.5,5,7.5]]) lsig = self.param4t[0, 4:] Lsig = self.param4t[:,4:]*1.5 lSig2 = np.log(0.5)+np.logaddexp(lsig, Lsig) l4 = -0.125*np.sum(np.exp(-lSig2)*difmu4**2, axis=1) - 0.5*np.sum(lSig2, axis=1) + 0.25*np.sum(lsig) + 0.25*np.sum(Lsig, axis=1) la1f = self.g1f.log_sqrt_pair_integral(self.param1f, -2*self.param1f) la2f = self.g2f.log_sqrt_pair_integral(self.param2f[0,:], self.param2f) la4t = self.g4t.log_sqrt_pair_integral(self.param4t[0,:], 1.5*self.param4t) self.assertTrue(np.all(np.round(l1,5)==np.round(la1f,5))) self.assertTrue(np.all(np.round(l2,5)==np.round(la2f,5))) self.assertTrue(np.all(np.round(l4,5)==np.round(la4t,5))) def test_params_init(self): np.random.seed(1) m0 = np.random.multivariate_normal(np.zeros(2), 4*np.eye(2)) prm0f = np.concatenate((m0, np.array([1,0,0,1]))) np.random.seed(2) m0 = np.random.multivariate_normal(np.zeros(2), 4*np.eye(2)) prm0t = np.concatenate((m0, np.zeros(2))) np.random.seed(4) mu = np.array([[0,1,2,3],[0,-1,-2,-3]]) k = np.random.choice(2, p=np.array([0.5,0.5])) lsig = np.array([[4,5,6,7],[-4,-5,-6,-7]]) mu0 = mu[k]+np.random.randn(4)*np.sqrt(10)*np.exp(lsig[k,:]) LSig = np.random.randn(4)+lsig[k] prm4t = np.hstack((mu0, LSig)) g2t = Gaussian(2, True) np.random.seed(1) par0f = self.g2f.params_init(np.empty((0,6)), np.empty((0,1)), 4) np.random.seed(2) par0t = g2t.params_init(np.empty((0,6)), np.empty((0,1)), 4) np.random.seed(4) par4t = self.g4t.params_init(self.param4t, np.array([0.1, 0.9]), 10) self.assertTrue(np.all(np.round(par0f,3)==np.round(prm0f,3))) self.assertTrue(np.all(np.round(par0t,3)==np.round(prm0t,3))) self.assertTrue(np.all(np.round(par4t,3)==np.round(prm4t,3)))
class CatVAE(nn.Module): # Auto-encoder architecture def __init__(self,img_params,model_params,latent_params): super(CatVAE, self).__init__() image_dim = img_params['image_dim'] image_size = img_params['image_size'] n_downsample = model_params['n_downsample'] dim = model_params['dim'] n_res = model_params['n_res'] norm = model_params['norm'] activ = model_params['activ'] pad_type = model_params['pad_type'] n_mlp = model_params['n_mlp'] mlp_dim = model_params['mlp_dim'] self.continious_dim = latent_params['continious'] self.prior_cont = Gaussian(self.continious_dim) self.categorical_dim = latent_params['categorical'] self.prior_catg = Categorical(self.categorical_dim) self.gumbel = Gumbel(self.categorical_dim) self.encoder = CatEncoder(n_downsample,n_res,n_mlp,image_size,image_dim,dim,mlp_dim, latent_params,norm,activ,pad_type) conv_inp_size = image_size // (2**n_downsample) decoder_inp_dim = self.continious_dim + self.categorical_dim self.decoder = Decoder(n_downsample,n_res,n_mlp,decoder_inp_dim,mlp_dim,conv_inp_size, dim,image_dim,norm,activ,pad_type) def forward(self, x, tempr): latent_distr = self.encoder(x) #categorical distr categorical_distr = latent_distr[:,-self.categorical_dim:] categorical_distr_act = self.prior_catg.activate(categorical_distr)# need for KL catg_samples = self.gumbel.gumbel_softmax_sample(categorical_distr,tempr) # categotical sampling, reconstruction #continious distr continious_distr = latent_distr[:,:-self.categorical_dim] continious_distr_act = self.prior_cont.activate(continious_distr) cont_samples = self.prior_cont.sample(continious_distr_act) #create full latent code full_samples = torch.cat([cont_samples,catg_samples],1) recons = self.decoder(full_samples) return recons, full_samples, categorical_distr_act, continious_distr_act def encode_decode(self, x, tempr=0.4, hard_catg=True): latent_distr = self.encoder(x) #categorical distr stuff categorical_distr = latent_distr[:,-self.categorical_dim:] if hard_catg: #just make one hot vector catg_samples = self.prior_catg.logits_to_onehot(categorical_distr) else: #make smoothed one hot by softmax catg_samples = self.prior_catg.activate(categorical_distr)['prob'] #continious distr stuff continious_distr = latent_distr[:,:-self.categorical_dim] continious_distr_act = self.prior_cont.activate(continious_distr) cont_samples = continious_distr_act['mean'] #create full latent code full_samples = torch.cat([cont_samples,catg_samples],1) recons = self.decoder(full_samples) return recons, full_samples#, categorical_distr_act, continious_distr_act def sample_full_prior(self, batch_size, device='cuda:0'): cont_samples = self.prior_cont.sample_prior(batch_size, device=device) catg_samples = self.prior_catg.sample_prior(batch_size, device=device) full_samples = torch.cat([cont_samples,catg_samples],1) return full_samples
class GBN(Layer): def __init__(self, dim_in, dim_h, posterior=None, conditional=None, prior=None, name='gbn', **kwargs): self.dim_in = dim_in self.dim_h = dim_h self.posterior = posterior self.conditional = conditional self.prior = prior kwargs = init_weights(self, **kwargs) kwargs = init_rngs(self, **kwargs) super(GBN, self).__init__(name=name) @staticmethod def mlp_factory(dim_h, dims, distributions, prototype=None, recognition_net=None, generation_net=None): mlps = {} if recognition_net is not None: t = recognition_net.get('type', None) if t == 'mmmlp': raise NotImplementedError() elif t == 'lfmlp': recognition_net['prototype'] = prototype input_name = recognition_net.get('input_layer') recognition_net['distribution'] = 'gaussian' recognition_net['dim_in'] = dims[input_name] recognition_net['dim_out'] = dim_h posterior = resolve_mlp(t).factory(**recognition_net) mlps['posterior'] = posterior if generation_net is not None: output_name = generation_net['output'] generation_net['dim_in'] = dim_h t = generation_net.get('type', None) if t == 'mmmlp': for out in generation_net['graph']['outputs']: generation_net['graph']['outs'][out] = dict( dim=dims[out], distribution=distributions[out]) conditional = resolve_mlp(t).factory(**generation_net) else: if t == 'lfmlp': generation_net['filter_in'] = False generation_net['prototype'] = prototype generation_net['dim_out'] = dims[output_name] generation_net['distribution'] = distributions[output_name] conditional = resolve_mlp(t).factory(**generation_net) mlps['conditional'] = conditional return mlps def set_params(self): self.params = OrderedDict() if self.prior is None: self.prior = Gaussian(self.dim_h) if self.posterior is None: self.posterior = MLP(self.dim_in, self.dim_h, dim_hs=[], rng=self.rng, trng=self.trng, h_act='T.nnet.sigmoid', distribution='gaussian') if self.conditional is None: self.conditional = MLP(self.dim_h, self.dim_in, dim_hs=[], rng=self.rng, trng=self.trng, h_act='T.nnet.sigmoid', distribution='binomial') self.posterior.name = self.name + '_posterior' self.conditional.name = self.name + '_conditional' def set_tparams(self, excludes=[]): tparams = super(GBN, self).set_tparams() tparams.update(**self.posterior.set_tparams()) tparams.update(**self.conditional.set_tparams()) tparams.update(**self.prior.set_tparams()) tparams = OrderedDict( (k, v) for k, v in tparams.iteritems() if k not in excludes) return tparams def get_params(self): params = self.prior.get_params() + self.conditional.get_params() return params def get_prior_params(self, *params): params = list(params) return params[:self.prior.n_params] # Latent sampling --------------------------------------------------------- def sample_from_prior(self, n_samples=100): h, updates = self.prior.sample(n_samples) py = self.conditional.feed(h) return self.get_center(py), updates def visualize_latents(self): h0 = self.prior.mu y0_mu, y0_logsigma = self.conditional.distribution.split_prob( self.conditional.feed(h0)) sigma = T.nlinalg.AllocDiag()(T.exp(self.prior.log_sigma)) h = 10 * sigma.astype(floatX) + h0[None, :] y_mu, y_logsigma = self.conditional.distribution.split_prob( self.conditional.feed(h)) py = y_mu - y0_mu[None, :] return py # / py.std()#T.exp(y_logsigma) # Misc -------------------------------------------------------------------- def get_center(self, p): return self.conditional.get_center(p) def p_y_given_h(self, h, *params): start = self.prior.n_params stop = start + self.conditional.n_params params = params[start:stop] return self.conditional.step_feed(h, *params) def kl_divergence(self, p, q): dim = self.dim_h mu_p = _slice(p, 0, dim) log_sigma_p = _slice(p, 1, dim) mu_q = _slice(q, 0, dim) log_sigma_q = _slice(q, 1, dim) kl = log_sigma_q - log_sigma_p + 0.5 * ( (T.exp(2 * log_sigma_p) + (mu_p - mu_q)**2) / T.exp(2 * log_sigma_q) - 1) return kl.sum(axis=kl.ndim - 1) def l2_decay(self, rate): rec_l2_cost = self.posterior.get_L2_weight_cost(rate) gen_l2_cost = self.conditional.get_L2_weight_cost(rate) rval = OrderedDict(rec_l2_cost=rec_l2_cost, gen_l2_cost=gen_l2_cost, cost=rec_l2_cost + gen_l2_cost) return rval def init_inference_samples(self, size): return self.posterior.distribution.prototype_samples(size) def __call__(self, x, y, qk=None, n_posterior_samples=10, pass_gradients=False): q0 = self.posterior.feed(x) if qk is None: qk = q0 h, updates = self.posterior.sample(qk, n_samples=n_posterior_samples) py = self.conditional.feed(h) log_py_h = -self.conditional.neg_log_prob(y[None, :, :], py) KL_qk_p = self.prior.kl_divergence(qk) qk_c = qk.copy() if pass_gradients: KL_qk_q0 = T.constant(0.).astype(floatX) else: KL_qk_q0 = self.kl_divergence(qk_c, q0) log_ph = -self.prior.neg_log_prob(h) log_qh = -self.posterior.neg_log_prob(h, qk[None, :, :]) log_p = (log_sum_exp(log_py_h + log_ph - log_qh, axis=0) - T.log(n_posterior_samples)) y_energy = -log_py_h.mean(axis=0) p_entropy = self.prior.entropy() q_entropy = self.posterior.entropy(qk) nll = -log_p lower_bound = -(y_energy + KL_qk_p).mean() if pass_gradients: cost = (y_energy + KL_qk_p).mean(0) else: cost = (y_energy + KL_qk_p + KL_qk_q0).mean(0) mu, log_sigma = self.posterior.distribution.split_prob(qk) results = OrderedDict({ 'mu': mu.mean(), 'log_sigma': log_sigma.mean(), '-log p(x|h)': y_energy.mean(0), '-log p(x)': nll.mean(0), '-log p(h)': log_ph.mean(), '-log q(h)': log_qh.mean(), 'KL(q_k||p)': KL_qk_p.mean(0), 'KL(q_k||q_0)': KL_qk_q0.mean(), 'H(p)': p_entropy, 'H(q)': q_entropy.mean(0), 'lower_bound': lower_bound, 'cost': cost }) samples = OrderedDict(py=py, batch_energies=y_energy) constants = [qk_c] return results, samples, constants