def actual_normal_log_normalizer(natparam, factors, sizes): def make_dense_precision_matrix(natparam, sizes): dims = [sum(sizes[:n]) for n in range(len(sizes) + 1)] prec = np.zeros((dims[-1], dims[-1])) for factor, minusJ in natparam.xi_xjtrs.iteritems(): node1, node2 = factor prec[dims[node1]:dims[node1 + 1], dims[node2]:dims[node2 + 1]] += -minusJ prec[dims[node2]:dims[node2 + 1], dims[node1]:dims[node1 + 1]] += -minusJ.T for factor, minustau in natparam.xi_times_xjs.iteritems(): node1, node2 = factor prec[dims[node1]:dims[node1+1], dims[node2]:dims[node2+1]] += \ -pgm.diag(minustau) prec[dims[node2]:dims[node2+1], dims[node1]:dims[node1+1]] += \ -pgm.diag(minustau).T for factor, halfminusJ in natparam.xi_xitrs.iteritems(): node, = factor prec[dims[node]:dims[node + 1], dims[node]:dims[node + 1]] += -2 * halfminusJ for factor, halfminustau in natparam.xi_squareds.iteritems(): node, = factor prec[dims[node]:dims[node+1], dims[node]:dims[node+1]] += \ -2*pgm.diag(halfminustau) return prec prec = make_dense_precision_matrix(natparam, sizes) inv_prec = np.linalg.inv(prec) h = np.concatenate([ natparam.xis.get((n, ), (np.zeros(sizes[n]), )) for n in range(len(sizes)) ]) log_normalizer = 0.5 * np.dot(h, np.dot(inv_prec, h)) log_normalizer -= 0.5 * logdet(prec) + 0.5 * sum(sizes) * np.log(2 * np.pi) return log_normalizer
def fun(x, a, b): return tracers.logdet( tracers.add_n(np.einsum(',ab->ab', x, a), np.einsum(',ab->ab', x, b)))
def log_joint(z, x): log_prior = -1./2 * np.dot(z, z) centered = x - np.dot(A, z) log_like = -1./2 * np.dot(centered, np.dot(np.linalg.inv(Sigma), centered)) \ - 1./2 * logdet(Sigma) return log_prior + log_like