def test_reduce_moment_matching_multivariate(): int_inputs = [('i', bint(4))] real_inputs = [('x', reals(2))] inputs = OrderedDict(int_inputs + real_inputs) int_inputs = OrderedDict(int_inputs) real_inputs = OrderedDict(real_inputs) loc = numeric_array([[-10., -1.], [+10., -1.], [+10., +1.], [-10., +1.]]) precision = zeros(4, 1, 1) + ops.new_eye(loc, (2, )) discrete = Tensor(zeros(4), int_inputs) gaussian = Gaussian(loc, precision, inputs) gaussian -= gaussian.log_normalizer joint = discrete + gaussian with interpretation(moment_matching): actual = joint.reduce(ops.logaddexp, 'i') assert_close(actual.reduce(ops.logaddexp), joint.reduce(ops.logaddexp)) expected_loc = zeros(2) expected_covariance = numeric_array([[101., 0.], [0., 2.]]) expected_precision = _inverse(expected_covariance) expected_gaussian = Gaussian(expected_loc, expected_precision, real_inputs) expected_gaussian -= expected_gaussian.log_normalizer expected_discrete = Tensor(ops.log(numeric_array(4.))) expected = expected_discrete + expected_gaussian assert_close(actual, expected, atol=1e-5, rtol=None)
def eager_dirichlet_categorical(red_op, bin_op, reduced_vars, x, y): dirichlet_reduction = frozenset(x.inputs).intersection(reduced_vars) if dirichlet_reduction: backend_dist = import_module( BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) identity = Tensor( ops.new_eye(funsor.tensor.get_default_prototype(), x.concentration.shape)) return backend_dist.DirichletMultinomial(concentration=x.concentration, total_count=1, value=identity[y.value]) else: return eager(Contraction, red_op, bin_op, reduced_vars, (x, y))
def random_gaussian(inputs): """ Creates a random :class:`funsor.gaussian.Gaussian` with given inputs. """ assert isinstance(inputs, OrderedDict) batch_shape = tuple(d.dtype for d in inputs.values() if d.dtype != 'real') event_shape = (sum(d.num_elements for d in inputs.values() if d.dtype == 'real'), ) prec_sqrt = randn(batch_shape + event_shape + event_shape) precision = ops.matmul(prec_sqrt, ops.transpose(prec_sqrt, -1, -2)) precision = precision + 0.5 * ops.new_eye(precision, event_shape[:1]) loc = randn(batch_shape + event_shape) info_vec = ops.matmul(precision, ops.unsqueeze(loc, -1)).squeeze(-1) return Gaussian(info_vec, precision, inputs)
def random_mvn(batch_shape, dim, diag=False): """ Generate a random :class:`torch.distributions.MultivariateNormal` with given shape. """ backend = get_backend() rank = dim + dim loc = randn(batch_shape + (dim, )) cov = randn(batch_shape + (dim, rank)) cov = cov @ ops.transpose(cov, -1, -2) if diag: cov = cov * ops.new_eye(cov, (dim, )) if backend == "torch": import pyro return pyro.distributions.MultivariateNormal(loc, cov) elif backend == "jax": import numpyro return numpyro.distributions.MultivariateNormal(loc, cov)
def test_dirichlet_categorical_conjugate(batch_shape, size): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape)) full_shape = batch_shape + (size,) prior = Variable("prior", Reals[size]) concentration = Tensor(ops.exp(randn(full_shape)), inputs) value = random_tensor(inputs, Bint[size]) latent = dist.Dirichlet(concentration, value=prior) conditional = dist.Categorical(probs=prior) reduced = (latent + conditional).reduce(ops.logaddexp, set(["prior"])) assert isinstance(reduced, Tensor) actual = reduced(value=value) expected = dist.DirichletMultinomial(concentration=concentration, total_count=1)( value=Tensor(ops.new_eye(concentration.data, (size,)))[value]) # TODO: investigate why jax backend gives inconsistent results on Travis assert_close(actual, expected, rtol=1e-5 if get_backend() == "jax" else 1e-6) obs = random_tensor(inputs, Bint[size]) _assert_conjugate_density_ok(latent, conditional, obs)
def _eager_subs_affine(self, subs, remaining_subs): # Extract an affine representation. affine = OrderedDict() for k, v in subs: const, coeffs = extract_affine(v) if (isinstance(const, Tensor) and all( isinstance(coeff, Tensor) for coeff, _ in coeffs.values())): affine[k] = const, coeffs else: remaining_subs += (k, v), if not affine: return reflect(Subs, self, remaining_subs) # Align integer dimensions. old_int_inputs = OrderedDict( (k, v) for k, v in self.inputs.items() if v.dtype != 'real') tensors = [ Tensor(self.info_vec, old_int_inputs), Tensor(self.precision, old_int_inputs) ] for const, coeffs in affine.values(): tensors.append(const) tensors.extend(coeff for coeff, _ in coeffs.values()) new_int_inputs, tensors = align_tensors(*tensors, expand=True) tensors = (Tensor(x, new_int_inputs) for x in tensors) old_info_vec = next(tensors).data old_precision = next(tensors).data for old_k, (const, coeffs) in affine.items(): const = next(tensors) for new_k, (coeff, eqn) in coeffs.items(): coeff = next(tensors) coeffs[new_k] = coeff, eqn affine[old_k] = const, coeffs batch_shape = old_info_vec.shape[:-1] # Align real dimensions. old_real_inputs = OrderedDict( (k, v) for k, v in self.inputs.items() if v.dtype == 'real') new_real_inputs = old_real_inputs.copy() for old_k, (const, coeffs) in affine.items(): del new_real_inputs[old_k] for new_k, (coeff, eqn) in coeffs.items(): new_shape = coeff.shape[:len(eqn.split('->')[0].split(',')[1])] new_real_inputs[new_k] = Reals[new_shape] old_offsets, old_dim = _compute_offsets(old_real_inputs) new_offsets, new_dim = _compute_offsets(new_real_inputs) new_inputs = new_int_inputs.copy() new_inputs.update(new_real_inputs) # Construct a blockwise affine representation of the substitution. subs_vector = BlockVector(batch_shape + (old_dim, )) subs_matrix = BlockMatrix(batch_shape + (new_dim, old_dim)) for old_k, old_offset in old_offsets.items(): old_size = old_real_inputs[old_k].num_elements old_slice = slice(old_offset, old_offset + old_size) if old_k in new_real_inputs: new_offset = new_offsets[old_k] new_slice = slice(new_offset, new_offset + old_size) subs_matrix[..., new_slice, old_slice] = \ ops.new_eye(self.info_vec, batch_shape + (old_size,)) continue const, coeffs = affine[old_k] old_shape = old_real_inputs[old_k].shape assert const.data.shape == batch_shape + old_shape subs_vector[..., old_slice] = const.data.reshape(batch_shape + (old_size, )) for new_k, new_offset in new_offsets.items(): if new_k in coeffs: coeff, eqn = coeffs[new_k] new_size = new_real_inputs[new_k].num_elements new_slice = slice(new_offset, new_offset + new_size) assert coeff.shape == new_real_inputs[ new_k].shape + old_shape subs_matrix[..., new_slice, old_slice] = \ coeff.data.reshape(batch_shape + (new_size, old_size)) subs_vector = subs_vector.as_tensor() subs_matrix = subs_matrix.as_tensor() subs_matrix_t = ops.transpose(subs_matrix, -1, -2) # Construct the new funsor. Suppose the old Gaussian funsor g has density # g(x) = < x | i - 1/2 P x> # Now define a new funsor f by substituting x = A y + B: # f(y) = g(A y + B) # = < A y + B | i - 1/2 P (A y + B) > # = < y | At (i - P B) - 1/2 At P A y > + < B | i - 1/2 P B > # =: < y | i' - 1/2 P' y > + C # where P' = At P A and i' = At (i - P B) parametrize a new Gaussian # and C = < B | i - 1/2 P B > parametrize a new Tensor. precision = subs_matrix @ old_precision @ subs_matrix_t info_vec = _mv(subs_matrix, old_info_vec - _mv(old_precision, subs_vector)) const = _vv(subs_vector, old_info_vec - 0.5 * _mv(old_precision, subs_vector)) result = Gaussian(info_vec, precision, new_inputs) + Tensor( const, new_int_inputs) return Subs(result, remaining_subs) if remaining_subs else result
def moment_matching_contract_joint(red_op, bin_op, reduced_vars, discrete, gaussian): approx_vars = frozenset( k for k in reduced_vars if k in gaussian.inputs and gaussian.inputs[k].dtype != 'real') exact_vars = reduced_vars - approx_vars if exact_vars and approx_vars: return Contraction(red_op, bin_op, exact_vars, discrete, gaussian).reduce(red_op, approx_vars) if approx_vars and not exact_vars: discrete += gaussian.log_normalizer new_discrete = discrete.reduce( ops.logaddexp, approx_vars.intersection(discrete.inputs)) new_discrete = discrete.reduce( ops.logaddexp, approx_vars.intersection(discrete.inputs)) num_elements = reduce(ops.mul, [ gaussian.inputs[k].num_elements for k in approx_vars.difference(discrete.inputs) ], 1) if num_elements != 1: new_discrete -= math.log(num_elements) int_inputs = OrderedDict( (k, d) for k, d in gaussian.inputs.items() if d.dtype != 'real') probs = (discrete - new_discrete.clamp_finite()).exp() old_loc = Tensor( ops.cholesky_solve(ops.unsqueeze(gaussian.info_vec, -1), gaussian._precision_chol).squeeze(-1), int_inputs) new_loc = (probs * old_loc).reduce(ops.add, approx_vars) old_cov = Tensor(ops.cholesky_inverse(gaussian._precision_chol), int_inputs) diff = old_loc - new_loc outers = Tensor( ops.unsqueeze(diff.data, -1) * ops.unsqueeze(diff.data, -2), diff.inputs) new_cov = ((probs * old_cov).reduce(ops.add, approx_vars) + (probs * outers).reduce(ops.add, approx_vars)) # Numerically stabilize by adding bogus precision to empty components. total = probs.reduce(ops.add, approx_vars) mask = ops.unsqueeze(ops.unsqueeze((total.data == 0), -1), -1) new_cov.data = new_cov.data + mask * ops.new_eye( new_cov.data, new_cov.data.shape[-1:]) new_precision = Tensor( ops.cholesky_inverse(ops.cholesky(new_cov.data)), new_cov.inputs) new_info_vec = ( new_precision.data @ ops.unsqueeze(new_loc.data, -1)).squeeze(-1) new_inputs = new_loc.inputs.copy() new_inputs.update( (k, d) for k, d in gaussian.inputs.items() if d.dtype == 'real') new_gaussian = Gaussian(new_info_vec, new_precision.data, new_inputs) new_discrete -= new_gaussian.log_normalizer return new_discrete + new_gaussian return None