def moment_matching_contract_joint(red_op, bin_op, reduced_vars, discrete, gaussian): approx_vars = frozenset( k for k in reduced_vars if k in gaussian.inputs and gaussian.inputs[k].dtype != 'real') exact_vars = reduced_vars - approx_vars if exact_vars and approx_vars: return Contraction(red_op, bin_op, exact_vars, discrete, gaussian).reduce(red_op, approx_vars) if approx_vars and not exact_vars: discrete += gaussian.log_normalizer new_discrete = discrete.reduce( ops.logaddexp, approx_vars.intersection(discrete.inputs)) new_discrete = discrete.reduce( ops.logaddexp, approx_vars.intersection(discrete.inputs)) num_elements = reduce(ops.mul, [ gaussian.inputs[k].num_elements for k in approx_vars.difference(discrete.inputs) ], 1) if num_elements != 1: new_discrete -= math.log(num_elements) int_inputs = OrderedDict( (k, d) for k, d in gaussian.inputs.items() if d.dtype != 'real') probs = (discrete - new_discrete.clamp_finite()).exp() old_loc = Tensor( gaussian.info_vec.unsqueeze(-1).cholesky_solve( gaussian._precision_chol).squeeze(-1), int_inputs) new_loc = (probs * old_loc).reduce(ops.add, approx_vars) old_cov = Tensor(cholesky_inverse(gaussian._precision_chol), int_inputs) diff = old_loc - new_loc outers = Tensor( diff.data.unsqueeze(-1) * diff.data.unsqueeze(-2), diff.inputs) new_cov = ((probs * old_cov).reduce(ops.add, approx_vars) + (probs * outers).reduce(ops.add, approx_vars)) # Numerically stabilize by adding bogus precision to empty components. total = probs.reduce(ops.add, approx_vars) mask = (total.data == 0).to( total.data.dtype).unsqueeze(-1).unsqueeze(-1) new_cov.data += mask * torch.eye(new_cov.data.size(-1)) new_precision = Tensor(cholesky_inverse(cholesky(new_cov.data)), new_cov.inputs) new_info_vec = new_precision.data.matmul( new_loc.data.unsqueeze(-1)).squeeze(-1) new_inputs = new_loc.inputs.copy() new_inputs.update( (k, d) for k, d in gaussian.inputs.items() if d.dtype == 'real') new_gaussian = Gaussian(new_info_vec, new_precision.data, new_inputs) new_discrete -= new_gaussian.log_normalizer return new_discrete + new_gaussian return None
def test_cholesky_inverse(batch_shape, size, requires_grad): x = torch.randn(batch_shape + (size, size)) x = x.transpose(-1, -2).matmul(x) u = x.cholesky() if requires_grad: u.requires_grad_() assert_close(cholesky_inverse(u), naive_cholesky_inverse(u)) if requires_grad: cholesky_inverse(u).sum().backward()
def eager_integrate(log_measure, integrand, reduced_vars): real_vars = frozenset(k for k in reduced_vars if log_measure.inputs[k].dtype == 'real') if real_vars: lhs_reals = frozenset(k for k, d in log_measure.inputs.items() if d.dtype == 'real') rhs_reals = frozenset(k for k, d in integrand.inputs.items() if d.dtype == 'real') if lhs_reals == real_vars and rhs_reals <= real_vars: inputs = OrderedDict((k, d) for t in (log_measure, integrand) for k, d in t.inputs.items()) lhs_info_vec, lhs_precision = align_gaussian(inputs, log_measure) rhs_info_vec, rhs_precision = align_gaussian(inputs, integrand) lhs = Gaussian(lhs_info_vec, lhs_precision, inputs) # Compute the expectation of a non-normalized quadratic form. # See "The Matrix Cookbook" (November 15, 2012) ss. 8.2.2 eq. 380. # http://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf norm = lhs.log_normalizer.data.exp() lhs_cov = cholesky_inverse(lhs._precision_chol) lhs_loc = lhs.info_vec.unsqueeze(-1).cholesky_solve( lhs._precision_chol).squeeze(-1) vmv_term = _vv(lhs_loc, rhs_info_vec - 0.5 * _mv(rhs_precision, lhs_loc)) data = norm * (vmv_term - 0.5 * _trace_mm(rhs_precision, lhs_cov)) inputs = OrderedDict( (k, d) for k, d in inputs.items() if k not in reduced_vars) result = Tensor(data, inputs) return result.reduce(ops.add, reduced_vars - real_vars) raise NotImplementedError('TODO implement partial integration') return None # defer to default implementation
def eager_mvn(loc, scale_tril, value): assert len(loc.shape) == 1 assert len(scale_tril.shape) == 2 assert value.output == loc.output if not is_affine(loc) or not is_affine(value): return None # lazy info_vec = scale_tril.data.new_zeros(scale_tril.data.shape[:-1]) precision = cholesky_inverse(scale_tril.data) scale_diag = Tensor(scale_tril.data.diagonal(dim1=-1, dim2=-2), scale_tril.inputs) log_prob = -0.5 * scale_diag.shape[0] * math.log( 2 * math.pi) - scale_diag.log().sum() inputs = scale_tril.inputs.copy() var = gensym('value') inputs[var] = reals(scale_diag.shape[0]) gaussian = log_prob + Gaussian(info_vec, precision, inputs) return gaussian(**{var: value - loc})