def update_mvns(pi, mu_opt, tril_entries): mvns = [ MVN(loc=mu_opt[i], scale_tril=update_tril(tril_entries[i], D), validate_args=True) for i in range(K) ] diffs = [] for i in range(K): for j in range(K): m = MVN(mu_opt[i] - mu_opt[j], covariance_matrix=mvns[i].covariance_matrix + mvns[j].covariance_matrix, validate_args=True) diffs.append((m, pi[i], pi[j])) return mvns, diffs
def A_new(self, X, k_new, Z, A): ''' p(A_new|X,Z_new,Z_old,A_old) propto p(X|Z_new,Z_old,A_old,A_new)p(A_new) ~ N(mu,cov) let ones = knew x knew matrix of ones let sig_n2 = sigma_n^2 let sig_A2 = sigma_A^2 mu = (ones + sig_n2/sig_a2 I)^{-1} Z_new_T (X - Z_old A_old) cov = sig_n2 (ones + sig_n2/sig_A2 I)^{-1} ''' N, D = X.size() K = Z.size()[1] assert K == A.size()[0] + k_new ones = torch.ones(k_new, k_new) I = torch.eye(k_new) sig_n = self.sigma_n sig_a = self.sigma_a Z_new = Z[:, -k_new:] Z_old = Z[:, :-k_new] Z_new_T = Z_new.transpose(0, 1) # mu is k_new x D mu = (ones + (sig_n/sig_a).pow(2)*I).inverse() @ \ Z_new_T @ (X - Z_old@A) # cov is k_new x k_new cov = sig_n.pow(2) * (ones + (sig_n / sig_a).pow(2) * I).inverse() A_new = torch.zeros(k_new, D) for d in range(D): p_A = MVN(mu[:, d], cov) A_new[:, d] = p_A.sample() return A_new
def forward(self, x): assert 1 == len(x.size()) N = x.size(0) sw = self.p.Pw.log_scale.exp() sz = self.p.PzGw.log_scale.exp() sx = self.p.PxGz.log_scale.exp() dist = MVN( t.zeros(N, device=x.device), sw**2 * t.ones(N, N, device=x.device) + (sz**2 + sx**2) * t.eye(N, device=x.device)) return dist.log_prob(x)
def init_A(self, K, D): ''' Sample from prior p(A_k) A_k ~ N(0,sigma_A^2 I) ''' Ak_mean = torch.zeros(D) Ak_cov = self.sigma_a.pow(2) * torch.eye(D) p_Ak = MVN(Ak_mean, Ak_cov) A = torch.zeros(K, D) for k in range(K): A[k] = p_Ak.sample() return A
def resample_A(self, X, Z): ''' mu = (Z^T Z + (sigma_n^2 / sigma_A^2) I )^{-1} Z^T X Cov = sigma_n^2 (Z^T Z + (sigma_n^2/sigma_A^2) I)^{-1} p(A|X,Z) = N(mu,cov) ''' N, D = X.size() K = Z.size()[1] ZT = Z.transpose(0, 1) ZTZ = ZT @ Z I = torch.eye(K) sig_n = self.sigma_n sig_a = self.sigma_a mu = (ZTZ + (sig_n / sig_a).pow(2) * I).inverse() @ ZT @ X cov = sig_n.pow(2) * (ZTZ + (sig_n / sig_a).pow(2) * I).inverse() A = torch.zeros(K, D) for d in range(D): p_A = MVN(mu[:, d], cov) A[:, d] = p_A.sample() return A
# torch initialize to fit distribution mu_opt = torch.randn(D, requires_grad=True) tril_entries = torch.randn(D * (D + 1) // 2, requires_grad=True) def update_tril(entries, D): tril = torch.zeros(D, D) tril[range(D), range(D)] = softplus(entries[0:D]) off_idx = torch.tril_indices(D, D)[0] != torch.tril_indices(D, D)[1] a, b = torch.tril_indices(D, D)[:, off_idx] tril[a, b] = entries[D:] return tril mvn = MVN(loc=mu_opt, scale_tril=update_tril(tril_entries, D)) opt = torch.optim.Adam([mu_opt, tril_entries], lr=0.01) print(opt.param_groups) for i in range(10): opt.zero_grad() loss = (-mvn.log_prob(torch.Tensor(data))).sum() loss.backward(retain_graph=True) opt.step() if i % 100 == 0: print(f"{i}: {loss.item()}") print(mvn.scale_tril @ mvn.scale_tril.T) print(tril_entries) mvn = MVN(mu_opt, scale_tril=update_tril(tril_entries, D)) pi = torch.randn(K, requires_grad=True)
from lvm import * from torch.distributions import MultivariateNormal as MVN t.manual_seed(1) N = 100 sw = 1. sz = 1. sx = 0.1 p, q, x = pqx(N, sw=sw, sz=sz, sx=sx) print( MVN(t.zeros(N), sw**2 * t.ones(N, N) + (sz**2 + sx**2) * t.eye(N)).log_prob(x.cpu())) t.manual_seed(1) vae = VAE(p, q, 10000).cuda() for i in range(10): print(vae(x)) iters = 100 tmc = TMC(p, q, 501, 502).cuda() tmcs = [] #print() for i in range(iters): t.manual_seed(i) res = tmc(x) tmcs.append(res.detach().cpu().numpy()) #print(res) smc = SMC(p, q, 500).cuda()
signeddiagstdev = torch.randn((numcurves, numdiagvars), requires_grad=False).double().cuda() diagsdtev = scalediag * torch.abs(signeddiagstdev) signedoffdiagstdev = torch.randn((numcurves, numcovars), requires_grad=False).double().cuda() offdiagstdev = scaleoff * signedoffdiagstdev print("diagsdtev.shape: %s" % (str(diagsdtev.shape), )) scale_tril = torch.diag_embed(diagsdtev) tril_indices = torch.tril_indices(row=numdims, col=numdims, offset=-1) scale_tril[:, tril_indices[0], tril_indices[1]] = offdiagstdev #print("scale_tril: %s" % (str(scale_tril),)) print("scale_tril.shape: %s" % (str(scale_tril.shape), )) print("scale_tril.requires_grad: %s" % (str(scale_tril.requires_grad), )) bcdist = MVN(bcflat, scale_tril=scale_tril, validate_args=True) bcsamples_ = bcdist.sample((numsamples, )) print("bcsamples_.shape: %s" % (str(bcsamples_.shape), )) logprob = bcdist.log_prob(bcsamples_) pdfvals = torch.exp(logprob).t() #print("pdfvals: %s" % (str(pdfvals),)) pdfmaxes, pdfmaxes_idx = torch.max(pdfvals, dim=1, keepdim=True) alphas = pdfvals / pdfmaxes print("first_points.shape: %s" % (str(first_points.shape), )) #print("bcsamples_.shape: %s" % (str(bcsamples_.shape),)) bcsamples = torch.cat((first_points, bcsamples_.transpose(0, 1).reshape( numcurves, numsamples, kbezier, 2)), dim=2) print("bcsamples.shape: %s" % (str(bcsamples.shape), )) #bcs sample_curves = torch.matmul(M.unsqueeze(1), bcsamples)