Ejemplo n.º 1
0
def update_mvns(pi, mu_opt, tril_entries):
    mvns = [
        MVN(loc=mu_opt[i],
            scale_tril=update_tril(tril_entries[i], D),
            validate_args=True) for i in range(K)
    ]

    diffs = []
    for i in range(K):
        for j in range(K):
            m = MVN(mu_opt[i] - mu_opt[j],
                    covariance_matrix=mvns[i].covariance_matrix +
                    mvns[j].covariance_matrix,
                    validate_args=True)
            diffs.append((m, pi[i], pi[j]))
    return mvns, diffs
Ejemplo n.º 2
0
 def A_new(self, X, k_new, Z, A):
     '''
     p(A_new|X,Z_new,Z_old,A_old) propto
         p(X|Z_new,Z_old,A_old,A_new)p(A_new)
     ~ N(mu,cov)
         let ones = knew x knew matrix of ones
         let sig_n2 = sigma_n^2
         let sig_A2 = sigma_A^2
         mu =  (ones + sig_n2/sig_a2 I)^{-1} Z_new_T (X - Z_old A_old)
         cov = sig_n2 (ones + sig_n2/sig_A2 I)^{-1}
     '''
     N, D = X.size()
     K = Z.size()[1]
     assert K == A.size()[0] + k_new
     ones = torch.ones(k_new, k_new)
     I = torch.eye(k_new)
     sig_n = self.sigma_n
     sig_a = self.sigma_a
     Z_new = Z[:, -k_new:]
     Z_old = Z[:, :-k_new]
     Z_new_T = Z_new.transpose(0, 1)
     # mu is k_new x D
     mu = (ones + (sig_n/sig_a).pow(2)*I).inverse() @ \
         Z_new_T @ (X - Z_old@A)
     # cov is k_new x k_new
     cov = sig_n.pow(2) * (ones + (sig_n / sig_a).pow(2) * I).inverse()
     A_new = torch.zeros(k_new, D)
     for d in range(D):
         p_A = MVN(mu[:, d], cov)
         A_new[:, d] = p_A.sample()
     return A_new
Ejemplo n.º 3
0
 def forward(self, x):
     assert 1 == len(x.size())
     N = x.size(0)
     sw = self.p.Pw.log_scale.exp()
     sz = self.p.PzGw.log_scale.exp()
     sx = self.p.PxGz.log_scale.exp()
     dist = MVN(
         t.zeros(N,
                 device=x.device), sw**2 * t.ones(N, N, device=x.device) +
         (sz**2 + sx**2) * t.eye(N, device=x.device))
     return dist.log_prob(x)
Ejemplo n.º 4
0
 def init_A(self, K, D):
     '''
     Sample from prior p(A_k)
     A_k ~ N(0,sigma_A^2 I)
     '''
     Ak_mean = torch.zeros(D)
     Ak_cov = self.sigma_a.pow(2) * torch.eye(D)
     p_Ak = MVN(Ak_mean, Ak_cov)
     A = torch.zeros(K, D)
     for k in range(K):
         A[k] = p_Ak.sample()
     return A
Ejemplo n.º 5
0
 def resample_A(self, X, Z):
     '''
     mu = (Z^T Z + (sigma_n^2 / sigma_A^2) I )^{-1} Z^T  X
     Cov = sigma_n^2 (Z^T Z + (sigma_n^2/sigma_A^2) I)^{-1}
     p(A|X,Z) = N(mu,cov)
     '''
     N, D = X.size()
     K = Z.size()[1]
     ZT = Z.transpose(0, 1)
     ZTZ = ZT @ Z
     I = torch.eye(K)
     sig_n = self.sigma_n
     sig_a = self.sigma_a
     mu = (ZTZ + (sig_n / sig_a).pow(2) * I).inverse() @ ZT @ X
     cov = sig_n.pow(2) * (ZTZ + (sig_n / sig_a).pow(2) * I).inverse()
     A = torch.zeros(K, D)
     for d in range(D):
         p_A = MVN(mu[:, d], cov)
         A[:, d] = p_A.sample()
     return A
Ejemplo n.º 6
0
# torch initialize to fit distribution
mu_opt = torch.randn(D, requires_grad=True)
tril_entries = torch.randn(D * (D + 1) // 2, requires_grad=True)


def update_tril(entries, D):
    tril = torch.zeros(D, D)
    tril[range(D), range(D)] = softplus(entries[0:D])
    off_idx = torch.tril_indices(D, D)[0] != torch.tril_indices(D, D)[1]
    a, b = torch.tril_indices(D, D)[:, off_idx]
    tril[a, b] = entries[D:]
    return tril


mvn = MVN(loc=mu_opt, scale_tril=update_tril(tril_entries, D))

opt = torch.optim.Adam([mu_opt, tril_entries], lr=0.01)
print(opt.param_groups)
for i in range(10):
    opt.zero_grad()
    loss = (-mvn.log_prob(torch.Tensor(data))).sum()
    loss.backward(retain_graph=True)
    opt.step()
    if i % 100 == 0:
        print(f"{i}: {loss.item()}")
        print(mvn.scale_tril @ mvn.scale_tril.T)
        print(tril_entries)
    mvn = MVN(mu_opt, scale_tril=update_tril(tril_entries, D))

pi = torch.randn(K, requires_grad=True)
Ejemplo n.º 7
0
from lvm import *
from torch.distributions import MultivariateNormal as MVN

t.manual_seed(1)
N = 100
sw = 1.
sz = 1.
sx = 0.1
p, q, x = pqx(N, sw=sw, sz=sz, sx=sx)

print(
    MVN(t.zeros(N),
        sw**2 * t.ones(N, N) + (sz**2 + sx**2) * t.eye(N)).log_prob(x.cpu()))

t.manual_seed(1)
vae = VAE(p, q, 10000).cuda()
for i in range(10):
    print(vae(x))

iters = 100

tmc = TMC(p, q, 501, 502).cuda()
tmcs = []
#print()
for i in range(iters):
    t.manual_seed(i)
    res = tmc(x)
    tmcs.append(res.detach().cpu().numpy())
    #print(res)

smc = SMC(p, q, 500).cuda()
Ejemplo n.º 8
0
signeddiagstdev = torch.randn((numcurves, numdiagvars),
                              requires_grad=False).double().cuda()
diagsdtev = scalediag * torch.abs(signeddiagstdev)

signedoffdiagstdev = torch.randn((numcurves, numcovars),
                                 requires_grad=False).double().cuda()
offdiagstdev = scaleoff * signedoffdiagstdev
print("diagsdtev.shape: %s" % (str(diagsdtev.shape), ))
scale_tril = torch.diag_embed(diagsdtev)
tril_indices = torch.tril_indices(row=numdims, col=numdims, offset=-1)
scale_tril[:, tril_indices[0], tril_indices[1]] = offdiagstdev
#print("scale_tril: %s" % (str(scale_tril),))
print("scale_tril.shape: %s" % (str(scale_tril.shape), ))
print("scale_tril.requires_grad: %s" % (str(scale_tril.requires_grad), ))

bcdist = MVN(bcflat, scale_tril=scale_tril, validate_args=True)
bcsamples_ = bcdist.sample((numsamples, ))
print("bcsamples_.shape: %s" % (str(bcsamples_.shape), ))
logprob = bcdist.log_prob(bcsamples_)
pdfvals = torch.exp(logprob).t()
#print("pdfvals: %s" % (str(pdfvals),))
pdfmaxes, pdfmaxes_idx = torch.max(pdfvals, dim=1, keepdim=True)
alphas = pdfvals / pdfmaxes
print("first_points.shape: %s" % (str(first_points.shape), ))
#print("bcsamples_.shape: %s" % (str(bcsamples_.shape),))
bcsamples = torch.cat((first_points, bcsamples_.transpose(0, 1).reshape(
    numcurves, numsamples, kbezier, 2)),
                      dim=2)
print("bcsamples.shape: %s" % (str(bcsamples.shape), ))
#bcs
sample_curves = torch.matmul(M.unsqueeze(1), bcsamples)