def test_reduce_moment_matching_univariate(): int_inputs = [('i', bint(2))] real_inputs = [('x', reals())] inputs = OrderedDict(int_inputs + real_inputs) int_inputs = OrderedDict(int_inputs) real_inputs = OrderedDict(real_inputs) p = 0.8 t = 1.234 s1, s2, s3 = 2.0, 3.0, 4.0 loc = torch.tensor([[-s1], [s1]]) precision = torch.tensor([[[s2**-2]], [[s3**-2]]]) info_vec = precision.matmul(loc.unsqueeze(-1)).squeeze(-1) discrete = Tensor(torch.tensor([1 - p, p]).log() + t, int_inputs) gaussian = Gaussian(info_vec, precision, inputs) gaussian -= gaussian.log_normalizer joint = discrete + gaussian with interpretation(moment_matching): actual = joint.reduce(ops.logaddexp, 'i') assert_close(actual.reduce(ops.logaddexp), joint.reduce(ops.logaddexp)) expected_loc = torch.tensor([(2 * p - 1) * s1]) expected_variance = (4 * p * (1 - p) * s1**2 + (1 - p) * s2**2 + p * s3**2) expected_precision = torch.tensor([[1 / expected_variance]]) expected_info_vec = expected_precision.matmul( expected_loc.unsqueeze(-1)).squeeze(-1) expected_gaussian = Gaussian(expected_info_vec, expected_precision, real_inputs) expected_gaussian -= expected_gaussian.log_normalizer expected_discrete = Tensor(torch.tensor(t)) expected = expected_discrete + expected_gaussian assert_close(actual, expected, atol=1e-5, rtol=None)
def test_reduce_moment_matching_multivariate(): int_inputs = [('i', Bint[4])] real_inputs = [('x', Reals[2])] inputs = OrderedDict(int_inputs + real_inputs) int_inputs = OrderedDict(int_inputs) real_inputs = OrderedDict(real_inputs) loc = numeric_array([[-10., -1.], [+10., -1.], [+10., +1.], [-10., +1.]]) precision = zeros(4, 1, 1) + ops.new_eye(loc, (2, )) discrete = Tensor(zeros(4), int_inputs) gaussian = Gaussian(loc, precision, inputs) gaussian -= gaussian.log_normalizer joint = discrete + gaussian with interpretation(moment_matching): actual = joint.reduce(ops.logaddexp, 'i') assert_close(actual.reduce(ops.logaddexp), joint.reduce(ops.logaddexp)) expected_loc = zeros(2) expected_covariance = numeric_array([[101., 0.], [0., 2.]]) expected_precision = _inverse(expected_covariance) expected_gaussian = Gaussian(expected_loc, expected_precision, real_inputs) expected_gaussian -= expected_gaussian.log_normalizer expected_discrete = Tensor(ops.log(numeric_array(4.))) expected = expected_discrete + expected_gaussian assert_close(actual, expected, atol=1e-5, rtol=None)
def test_smoke(expr, expected_type): g1 = Gaussian(info_vec=numeric_array([[0.0, 0.1, 0.2], [2.0, 3.0, 4.0]]), precision=numeric_array([[[1.0, 0.1, 0.2], [0.1, 1.0, 0.3], [0.2, 0.3, 1.0]], [[1.0, 0.1, 0.2], [0.1, 1.0, 0.3], [0.2, 0.3, 1.0]]]), inputs=OrderedDict([('i', bint(2)), ('x', reals(3))])) assert isinstance(g1, Gaussian) g2 = Gaussian(info_vec=numeric_array([[0.0, 0.1], [2.0, 3.0]]), precision=numeric_array([[[1.0, 0.2], [0.2, 1.0]], [[1.0, 0.2], [0.2, 1.0]]]), inputs=OrderedDict([('i', bint(2)), ('y', reals(2))])) assert isinstance(g2, Gaussian) shift = Tensor(numeric_array([-1., 1.]), OrderedDict([('i', bint(2))])) assert isinstance(shift, Tensor) i0 = Number(1, 2) assert isinstance(i0, Number) x0 = Tensor(numeric_array([0.5, 0.6, 0.7])) assert isinstance(x0, Tensor) y0 = Tensor(numeric_array([[0.2, 0.3], [0.8, 0.9]]), inputs=OrderedDict([('i', bint(2))])) assert isinstance(y0, Tensor) result = eval(expr) assert isinstance(result, expected_type)
def test_reduce_moment_matching_multivariate(): int_inputs = [('i', bint(4))] real_inputs = [('x', reals(2))] inputs = OrderedDict(int_inputs + real_inputs) int_inputs = OrderedDict(int_inputs) real_inputs = OrderedDict(real_inputs) loc = torch.tensor([[-10., -1.], [+10., -1.], [+10., +1.], [-10., +1.]]) precision = torch.zeros(4, 1, 1) + torch.eye(2, 2) discrete = Tensor(torch.zeros(4), int_inputs) gaussian = Gaussian(loc, precision, inputs) gaussian -= gaussian.log_normalizer joint = discrete + gaussian with interpretation(moment_matching): actual = joint.reduce(ops.logaddexp, 'i') assert_close(actual.reduce(ops.logaddexp), joint.reduce(ops.logaddexp)) expected_loc = torch.zeros(2) expected_covariance = torch.tensor([[101., 0.], [0., 2.]]) expected_precision = expected_covariance.inverse() expected_gaussian = Gaussian(expected_loc, expected_precision, real_inputs) expected_gaussian -= expected_gaussian.log_normalizer expected_discrete = Tensor(torch.tensor(4.).log()) expected = expected_discrete + expected_gaussian assert_close(actual, expected, atol=1e-5, rtol=None)
def test_mc_plate_gaussian(): log_measure = Gaussian(torch.tensor([0.]), torch.tensor([[1.]]), (('loc', reals()),)) + torch.tensor(-0.9189) integrand = Gaussian(torch.randn((100, 1)) + 3., torch.ones((100, 1, 1)), (('data', bint(100)), ('loc', reals()))) res = Integrate(log_measure.sample(frozenset({'loc'})), integrand, frozenset({'loc'})) res = res.reduce(ops.mul, frozenset({'data'})) assert not torch.isinf(res).any()
def test_mc_plate_gaussian(): log_measure = Gaussian(numeric_array([0.]), numeric_array([[1.]]), (('loc', Real),)) + numeric_array(-0.9189) integrand = Gaussian(randn((100, 1)) + 3., ones((100, 1, 1)), (('data', Bint[100]), ('loc', Real))) rng_key = None if get_backend() != 'jax' else np.array([0, 0], dtype=np.uint32) res = Integrate(log_measure.sample('loc', rng_key=rng_key), integrand, 'loc') res = res.reduce(ops.mul, 'data') assert not ((res == float('inf')) | (res == float('-inf'))).any()
def mvn_to_funsor(pyro_dist, event_inputs=(), real_inputs=OrderedDict()): """ Convert a joint :class:`torch.distributions.MultivariateNormal` distribution into a :class:`~funsor.terms.Funsor` with multiple real inputs. This should satisfy:: sum(d.num_elements for d in real_inputs.values()) == pyro_dist.event_shape[0] :param torch.distributions.MultivariateNormal pyro_dist: A multivariate normal distribution over one or more variables of real or vector or tensor type. :param tuple event_inputs: A tuple of names for rightmost dimensions. These will be assigned to ``result.inputs`` of type ``Bint``. :param OrderedDict real_inputs: A dict mapping real variable name to appropriately sized ``Real``. The sum of all ``.numel()`` of all real inputs should be equal to the ``pyro_dist`` dimension. :return: A funsor with given ``real_inputs`` and possibly additional Bint inputs. :rtype: funsor.terms.Funsor """ assert isinstance(pyro_dist, torch.distributions.MultivariateNormal) assert isinstance(event_inputs, tuple) assert isinstance(real_inputs, OrderedDict) dim_to_name = default_dim_to_name(pyro_dist.batch_shape, event_inputs) funsor_dist = to_funsor(pyro_dist, Real, dim_to_name) if len(real_inputs) == 0: return funsor_dist discrete, gaussian = funsor_dist(value="value").terms inputs = OrderedDict( (k, v) for k, v in gaussian.inputs.items() if v.dtype != 'real') inputs.update(real_inputs) return discrete + Gaussian(gaussian.info_vec, gaussian.precision, inputs)
def moment_matching_contract_joint(red_op, bin_op, reduced_vars, discrete, gaussian): approx_vars = frozenset( k for k in reduced_vars if k in gaussian.inputs and gaussian.inputs[k].dtype != 'real') exact_vars = reduced_vars - approx_vars if exact_vars and approx_vars: return Contraction(red_op, bin_op, exact_vars, discrete, gaussian).reduce(red_op, approx_vars) if approx_vars and not exact_vars: discrete += gaussian.log_normalizer new_discrete = discrete.reduce( ops.logaddexp, approx_vars.intersection(discrete.inputs)) new_discrete = discrete.reduce( ops.logaddexp, approx_vars.intersection(discrete.inputs)) num_elements = reduce(ops.mul, [ gaussian.inputs[k].num_elements for k in approx_vars.difference(discrete.inputs) ], 1) if num_elements != 1: new_discrete -= math.log(num_elements) int_inputs = OrderedDict( (k, d) for k, d in gaussian.inputs.items() if d.dtype != 'real') probs = (discrete - new_discrete.clamp_finite()).exp() old_loc = Tensor( gaussian.info_vec.unsqueeze(-1).cholesky_solve( gaussian._precision_chol).squeeze(-1), int_inputs) new_loc = (probs * old_loc).reduce(ops.add, approx_vars) old_cov = Tensor(cholesky_inverse(gaussian._precision_chol), int_inputs) diff = old_loc - new_loc outers = Tensor( diff.data.unsqueeze(-1) * diff.data.unsqueeze(-2), diff.inputs) new_cov = ((probs * old_cov).reduce(ops.add, approx_vars) + (probs * outers).reduce(ops.add, approx_vars)) # Numerically stabilize by adding bogus precision to empty components. total = probs.reduce(ops.add, approx_vars) mask = (total.data == 0).to( total.data.dtype).unsqueeze(-1).unsqueeze(-1) new_cov.data += mask * torch.eye(new_cov.data.size(-1)) new_precision = Tensor(cholesky_inverse(cholesky(new_cov.data)), new_cov.inputs) new_info_vec = new_precision.data.matmul( new_loc.data.unsqueeze(-1)).squeeze(-1) new_inputs = new_loc.inputs.copy() new_inputs.update( (k, d) for k, d in gaussian.inputs.items() if d.dtype == 'real') new_gaussian = Gaussian(new_info_vec, new_precision.data, new_inputs) new_discrete -= new_gaussian.log_normalizer return new_discrete + new_gaussian return None
def test_smoke(expr, expected_type): dx = Delta('x', Tensor(torch.randn(2, 3), OrderedDict([('i', bint(2))]))) assert isinstance(dx, Delta) dy = Delta('y', Tensor(torch.randn(3, 4), OrderedDict([('j', bint(3))]))) assert isinstance(dy, Delta) t = Tensor(torch.randn(2, 3), OrderedDict([('i', bint(2)), ('j', bint(3))])) assert isinstance(t, Tensor) g = Gaussian(info_vec=torch.tensor([[0.0, 0.1, 0.2], [2.0, 3.0, 4.0]]), precision=torch.tensor([[[1.0, 0.1, 0.2], [0.1, 1.0, 0.3], [0.2, 0.3, 1.0]], [[1.0, 0.1, 0.2], [0.1, 1.0, 0.3], [0.2, 0.3, 1.0]]]), inputs=OrderedDict([('i', bint(2)), ('x', reals(3))])) assert isinstance(g, Gaussian) i0 = Number(1, 2) assert isinstance(i0, Number) x0 = Tensor(torch.tensor([0.5, 0.6, 0.7])) assert isinstance(x0, Tensor) result = eval(expr) assert isinstance(result, expected_type)
def eager_integrate(log_measure, integrand, reduced_vars): real_vars = frozenset(k for k in reduced_vars if log_measure.inputs[k].dtype == 'real') if real_vars: lhs_reals = frozenset(k for k, d in log_measure.inputs.items() if d.dtype == 'real') rhs_reals = frozenset(k for k, d in integrand.inputs.items() if d.dtype == 'real') if lhs_reals == real_vars and rhs_reals <= real_vars: inputs = OrderedDict((k, d) for t in (log_measure, integrand) for k, d in t.inputs.items()) lhs_info_vec, lhs_precision = align_gaussian(inputs, log_measure) rhs_info_vec, rhs_precision = align_gaussian(inputs, integrand) lhs = Gaussian(lhs_info_vec, lhs_precision, inputs) # Compute the expectation of a non-normalized quadratic form. # See "The Matrix Cookbook" (November 15, 2012) ss. 8.2.2 eq. 380. # http://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf norm = lhs.log_normalizer.data.exp() lhs_cov = cholesky_inverse(lhs._precision_chol) lhs_loc = lhs.info_vec.unsqueeze(-1).cholesky_solve( lhs._precision_chol).squeeze(-1) vmv_term = _vv(lhs_loc, rhs_info_vec - 0.5 * _mv(rhs_precision, lhs_loc)) data = norm * (vmv_term - 0.5 * _trace_mm(rhs_precision, lhs_cov)) inputs = OrderedDict( (k, d) for k, d in inputs.items() if k not in reduced_vars) result = Tensor(data, inputs) return result.reduce(ops.add, reduced_vars - real_vars) raise NotImplementedError('TODO implement partial integration') return None # defer to default implementation
def torchmvn_to_funsor(pyro_dist, output=None, dim_to_name=None, real_inputs=OrderedDict()): funsor_dist = torchdistribution_to_funsor(pyro_dist, output=output, dim_to_name=dim_to_name) if len(real_inputs) == 0: return funsor_dist discrete, gaussian = funsor_dist(value="value").terms inputs = OrderedDict((k, v) for k, v in gaussian.inputs.items() if v.dtype != 'real') inputs.update(real_inputs) return discrete + Gaussian(gaussian.info_vec, gaussian.precision, inputs)
def adjoint_subs_gaussianmixture_discrete(adj_redop, adj_binop, out_adj, arg, subs): if any(v.dtype == 'real' and not isinstance(v, Variable) for k, v in subs): raise NotImplementedError( "TODO implement adjoint for substitution into Gaussian real variable" ) # invert renaming renames = tuple((v.name, k) for k, v in subs if isinstance(v, Variable)) out_adj = Subs(out_adj, renames) # inverting advanced indexing slices = tuple((k, v) for k, v in subs if not isinstance(v, Variable)) arg_int_inputs = OrderedDict( (k, v) for k, v in arg.inputs.items() if v.dtype != 'real') zeros_like_out = Subs( Tensor( arg.terms[1].info_vec.new_full(arg.terms[1].info_vec.shape[:-1], ops.UNITS[adj_binop]), arg_int_inputs), slices) out_adj = adj_binop(out_adj, zeros_like_out) in_adj_discrete = adjoint_ops(Subs, adj_redop, adj_binop, out_adj, arg.terms[0], subs)[arg.terms[0]] # invert the slicing for the Gaussian term even though the message does not affect the values in_adj_info_vec = list( adjoint_ops( Subs, adj_redop, adj_binop, # ops.add, ops.mul, zeros_like_out, Tensor(arg.terms[1].info_vec, arg_int_inputs), slices).values())[0] in_adj_precision = list( adjoint_ops( Subs, adj_redop, adj_binop, # ops.add, ops.mul, zeros_like_out, Tensor(arg.terms[1].precision, arg_int_inputs), slices).values())[0] assert isinstance(in_adj_info_vec, Tensor) assert isinstance(in_adj_precision, Tensor) in_adj_gaussian = Gaussian(in_adj_info_vec.data, in_adj_precision.data, arg.inputs.copy()) in_adj = in_adj_gaussian + in_adj_discrete return {arg: in_adj}
def moment_matching_reduce(self, op, reduced_vars): if not reduced_vars: return self if op is ops.logaddexp: if not all(reduced_vars.isdisjoint(d.inputs) for d in self.deltas): raise NotImplementedError( 'TODO handle moment_matching with Deltas') lazy_vars = frozenset().union( *(d.inputs for d in self.deltas)).intersection(reduced_vars) approx_vars = frozenset(k for k in reduced_vars - lazy_vars if self.inputs[k].dtype != 'real' if k in self.gaussian.inputs) exact_vars = reduced_vars - lazy_vars - approx_vars if exact_vars: return self.eager_reduce(op, exact_vars).reduce( op, approx_vars | lazy_vars) # Moment-matching approximation. assert approx_vars and not exact_vars discrete = self.discrete new_discrete = discrete.reduce( ops.logaddexp, approx_vars.intersection(discrete.inputs)) num_elements = reduce(ops.mul, [ self.inputs[k].num_elements for k in approx_vars.difference(discrete.inputs) ], 1) if num_elements != 1: new_discrete -= math.log(num_elements) gaussian = self.gaussian int_inputs = OrderedDict((k, d) for k, d in gaussian.inputs.items() if d.dtype != 'real') probs = (discrete - new_discrete).exp() old_loc = Tensor(gaussian.loc, int_inputs) new_loc = (probs * old_loc).reduce(ops.add, approx_vars) old_cov = Tensor(sym_inverse(gaussian.precision), int_inputs) diff = old_loc - new_loc outers = Tensor( diff.data.unsqueeze(-1) * diff.data.unsqueeze(-2), diff.inputs) new_cov = ((probs * old_cov).reduce(ops.add, approx_vars) + (probs * outers).reduce(ops.add, approx_vars)) new_precision = Tensor(sym_inverse(new_cov.data), new_cov.inputs) new_inputs = new_loc.inputs.copy() new_inputs.update((k, d) for k, d in self.gaussian.inputs.items() if d.dtype == 'real') new_gaussian = Gaussian(new_loc.data, new_precision.data, new_inputs) result = Joint(self.deltas, new_discrete, new_gaussian) return result.reduce(ops.logaddexp, lazy_vars) return None # defer to default implementation
def eager_cat_homogeneous(name, part_name, *parts): assert parts output = parts[0].output inputs = OrderedDict([(part_name, None)]) for part in parts: assert part.output == output assert part_name in part.inputs inputs.update(part.inputs) int_inputs = OrderedDict( (k, v) for k, v in inputs.items() if v.dtype != "real") real_inputs = OrderedDict( (k, v) for k, v in inputs.items() if v.dtype == "real") inputs = int_inputs.copy() inputs.update(real_inputs) discretes = [] info_vecs = [] precisions = [] for part in parts: inputs[part_name] = part.inputs[part_name] int_inputs[part_name] = inputs[part_name] shape = tuple(d.size for d in int_inputs.values()) if isinstance(part, Gaussian): discrete = None gaussian = part elif issubclass(type(part), GaussianMixture ): # TODO figure out why isinstance isn't working discrete, gaussian = part.terms[0], part.terms[1] discrete = ops.expand(align_tensor(int_inputs, discrete), shape) else: raise NotImplementedError("TODO") discretes.append(discrete) info_vec, precision = align_gaussian(inputs, gaussian) info_vecs.append(ops.expand(info_vec, shape + (-1, ))) precisions.append(ops.expand(precision, shape + (-1, -1))) if part_name != name: del inputs[part_name] del int_inputs[part_name] dim = 0 info_vec = ops.cat(dim, *info_vecs) precision = ops.cat(dim, *precisions) inputs[name] = Bint[info_vec.shape[dim]] int_inputs[name] = inputs[name] result = Gaussian(info_vec, precision, inputs) if any(d is not None for d in discretes): for i, d in enumerate(discretes): if d is None: discretes[i] = ops.new_zeros(info_vecs[i], info_vecs[i].shape[:-1]) discrete = ops.cat(dim, *discretes) result = result + Tensor(discrete, int_inputs) return result
def random_gaussian(inputs): """ Creates a random :class:`funsor.gaussian.Gaussian` with given inputs. """ assert isinstance(inputs, OrderedDict) batch_shape = tuple(d.dtype for d in inputs.values() if d.dtype != 'real') event_shape = (sum(d.num_elements for d in inputs.values() if d.dtype == 'real'),) loc = torch.randn(batch_shape + event_shape) prec_sqrt = torch.randn(batch_shape + event_shape + event_shape) precision = torch.matmul(prec_sqrt, prec_sqrt.transpose(-1, -2)) precision = precision + 0.05 * torch.eye(event_shape[0]) return Gaussian(loc, precision, inputs)
def eager_normal(loc, scale, value): if isinstance(loc, Variable): loc, value = value, loc inputs, (loc, scale) = align_tensors(loc, scale) loc, scale = torch.broadcast_tensors(loc, scale) inputs.update(value.inputs) int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real') log_prob = -0.5 * math.log(2 * math.pi) - scale.log() loc = loc.unsqueeze(-1) precision = scale.pow(-2).unsqueeze(-1).unsqueeze(-1) return Tensor(log_prob, int_inputs) + Gaussian(loc, precision, inputs)
def eager_mvn(loc, scale_tril, value): if isinstance(loc, Variable): loc, value = value, loc dim, = loc.output.shape inputs, (loc, scale_tril) = align_tensors(loc, scale_tril) inputs.update(value.inputs) int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real') log_prob = -0.5 * dim * math.log(2 * math.pi) - scale_tril.diagonal(dim1=-1, dim2=-2).log().sum(-1) inv_scale_tril = torch.inverse(scale_tril) precision = torch.matmul(inv_scale_tril.transpose(-1, -2), inv_scale_tril) return Tensor(log_prob, int_inputs) + Gaussian(loc, precision, inputs)
def random_gaussian(inputs): """ Creates a random :class:`funsor.gaussian.Gaussian` with given inputs. """ assert isinstance(inputs, OrderedDict) batch_shape = tuple(d.dtype for d in inputs.values() if d.dtype != 'real') event_shape = (sum(d.num_elements for d in inputs.values() if d.dtype == 'real'), ) prec_sqrt = randn(batch_shape + event_shape + event_shape) precision = ops.matmul(prec_sqrt, ops.transpose(prec_sqrt, -1, -2)) precision = precision + 0.5 * ops.new_eye(precision, event_shape[:1]) loc = randn(batch_shape + event_shape) info_vec = ops.matmul(precision, ops.unsqueeze(loc, -1)).squeeze(-1) return Gaussian(info_vec, precision, inputs)
def eager_normal(loc, scale, value): assert loc.output == Real assert scale.output == Real assert value.output == Real if not is_affine(loc) or not is_affine(value): return None # lazy info_vec = ops.new_zeros(scale.data, scale.data.shape + (1, )) precision = ops.pow(scale.data, -2).reshape(scale.data.shape + (1, 1)) log_prob = -0.5 * math.log(2 * math.pi) - ops.log(scale).sum() inputs = scale.inputs.copy() var = gensym('value') inputs[var] = Real gaussian = log_prob + Gaussian(info_vec, precision, inputs) return gaussian(**{var: value - loc})
def eager_normal(loc, scale, value): if isinstance(loc, Variable): loc, value = value, loc inputs, (loc, scale) = align_tensors(loc, scale, expand=True) inputs.update(value.inputs) int_inputs = OrderedDict( (k, v) for k, v in inputs.items() if v.dtype != 'real') precision = scale.pow(-2) info_vec = (precision * loc).unsqueeze(-1) precision = precision.unsqueeze(-1).unsqueeze(-1) log_prob = -0.5 * math.log( 2 * math.pi) - scale.log() - 0.5 * (loc * info_vec).squeeze(-1) return Tensor(log_prob, int_inputs) + Gaussian(info_vec, precision, inputs)
def eager_mvn(loc, scale_tril, value): assert len(loc.shape) == 1 assert len(scale_tril.shape) == 2 assert value.output == loc.output if not is_affine(loc) or not is_affine(value): return None # lazy info_vec = scale_tril.data.new_zeros(scale_tril.data.shape[:-1]) precision = ops.cholesky_inverse(scale_tril.data) scale_diag = Tensor(scale_tril.data.diagonal(dim1=-1, dim2=-2), scale_tril.inputs) log_prob = -0.5 * scale_diag.shape[0] * math.log(2 * math.pi) - scale_diag.log().sum() inputs = scale_tril.inputs.copy() var = gensym('value') inputs[var] = reals(scale_diag.shape[0]) gaussian = log_prob + Gaussian(info_vec, precision, inputs) return gaussian(**{var: value - loc})
def test_affine_subs(): # This was recorded from test_pyro_convert. x = Subs( Gaussian( torch.tensor([1.3027106523513794, 1.4167094230651855, -0.9750942587852478, 0.5321089029312134, -0.9039931297302246], dtype=torch.float32), # noqa torch.tensor([[1.0199567079544067, 0.9840421676635742, -0.473368763923645, 0.34206756949424744, -0.7562517523765564], [0.9840421676635742, 1.511502742767334, -1.7593903541564941, 0.6647964119911194, -0.5119513273239136], [-0.4733688533306122, -1.7593903541564941, 3.2386727333068848, -0.9345928430557251, -0.1534711718559265], [0.34206756949424744, 0.6647964119911194, -0.9345928430557251, 0.3141004145145416, -0.12399007380008698], [-0.7562517523765564, -0.5119513273239136, -0.1534711718559265, -0.12399007380008698, 0.6450173854827881]], dtype=torch.float32), # noqa (('state_1_b6', reals(3,),), ('obs_b2', reals(2,),),)), (('obs_b2', Contraction(ops.nullop, ops.add, frozenset(), (Variable('bias_b5', reals(2,)), Tensor( torch.tensor([-2.1787893772125244, 0.5684312582015991], dtype=torch.float32), # noqa (), 'real'),)),),)) assert isinstance(x, (Gaussian, Contraction)), x.pretty()
def adjoint_subs_gaussianmixture_gaussianmixture(adj_redop, adj_binop, out_adj, arg, subs): if any(v.dtype == 'real' and not isinstance(v, Variable) for k, v in subs): raise NotImplementedError("TODO implement adjoint for substitution into Gaussian real variable") # invert renaming renames = tuple((v.name, k) for k, v in subs if isinstance(v, Variable)) out_adj = Subs(out_adj, renames) # inverting advanced indexing slices = tuple((k, v) for k, v in subs if not isinstance(v, Variable)) assert len(slices + renames) == len(subs) in_adj_discrete = adjoint_ops(Subs, adj_redop, adj_binop, out_adj.terms[0], arg.terms[0], subs)[arg.terms[0]] arg_int_inputs = OrderedDict((k, v) for k, v in arg.inputs.items() if v.dtype != 'real') out_adj_int_inputs = OrderedDict((k, v) for k, v in out_adj.inputs.items() if v.dtype != 'real') arg_real_inputs = OrderedDict((k, v) for k, v in arg.inputs.items() if v.dtype == 'real') align_inputs = OrderedDict((k, v) for k, v in out_adj.terms[1].inputs.items() if v.dtype != 'real') align_inputs.update(arg_real_inputs) out_adj_info_vec, out_adj_precision = align_gaussian(align_inputs, out_adj.terms[1]) in_adj_info_vec = list(adjoint_ops(Subs, adj_redop, adj_binop, # ops.add, ops.mul, Tensor(out_adj_info_vec, out_adj_int_inputs), Tensor(arg.terms[1].info_vec, arg_int_inputs), slices).values())[0] in_adj_precision = list(adjoint_ops(Subs, adj_redop, adj_binop, # ops.add, ops.mul, Tensor(out_adj_precision, out_adj_int_inputs), Tensor(arg.terms[1].precision, arg_int_inputs), slices).values())[0] assert isinstance(in_adj_info_vec, Tensor) assert isinstance(in_adj_precision, Tensor) in_adj_gaussian = Gaussian(in_adj_info_vec.data, in_adj_precision.data, arg.inputs.copy()) in_adj = in_adj_gaussian + in_adj_discrete return {arg: in_adj}
def eager_affine_normal(matrix, loc, scale, value_x, value_y): assert len(matrix.output.shape) == 2 assert value_x.output == reals(matrix.output.shape[0]) assert value_y.output == reals(matrix.output.shape[1]) tensors = (matrix, loc, scale, value_y) int_inputs, tensors = align_tensors(*tensors) matrix, loc, scale, value_y = tensors assert value_y.size(-1) == loc.size(-1) prec_sqrt = matrix / scale.unsqueeze(-2) precision = prec_sqrt.matmul(prec_sqrt.transpose(-1, -2)) delta = (value_y - loc) / scale info_vec = prec_sqrt.matmul(delta.unsqueeze(-1)).squeeze(-1) log_normalizer = (-0.5 * loc.size(-1) * math.log(2 * math.pi) - 0.5 * delta.pow(2).sum(-1) - scale.log().sum(-1)) precision = precision.expand(info_vec.shape + (-1,)) log_normalizer = log_normalizer.expand(info_vec.shape[:-1]) inputs = int_inputs.copy() x_name = gensym("x") inputs[x_name] = value_x.output x_dist = Tensor(log_normalizer, int_inputs) + Gaussian(info_vec, precision, inputs) return x_dist(**{x_name: value_x})
def mvn_to_funsor(pyro_dist, event_dims=(), real_inputs=OrderedDict()): """ Convert a joint :class:`torch.distributions.MultivariateNormal` distribution into a :class:`~funsor.terms.Funsor` with multiple real inputs. This should satisfy:: sum(d.num_elements for d in real_inputs.values()) == pyro_dist.event_shape[0] :param torch.distributions.MultivariateNormal pyro_dist: A multivariate normal distribution over one or more variables of real or vector or tensor type. :param tuple event_dims: A tuple of names for rightmost dimensions. These will be assigned to ``result.inputs`` of type ``bint``. :param OrderedDict real_inputs: A dict mapping real variable name to appropriately sized ``reals()``. The sum of all ``.numel()`` of all real inputs should be equal to the ``pyro_dist`` dimension. :return: A funsor with given ``real_inputs`` and possibly additional bint inputs. :rtype: funsor.terms.Funsor """ assert isinstance(pyro_dist, torch.distributions.MultivariateNormal) assert isinstance(event_dims, tuple) assert isinstance(real_inputs, OrderedDict) loc = tensor_to_funsor(pyro_dist.loc, event_dims, 1) scale_tril = tensor_to_funsor(pyro_dist.scale_tril, event_dims, 2) precision = tensor_to_funsor(pyro_dist.precision_matrix, event_dims, 2) assert loc.inputs == scale_tril.inputs assert loc.inputs == precision.inputs info_vec = precision.data.matmul(loc.data.unsqueeze(-1)).squeeze(-1) log_prob = (-0.5 * loc.output.shape[0] * math.log(2 * math.pi) - scale_tril.data.diagonal(dim1=-1, dim2=-2).log().sum(-1) - 0.5 * (info_vec * loc.data).sum(-1)) inputs = loc.inputs.copy() inputs.update(real_inputs) return Tensor(log_prob, loc.inputs) + Gaussian(info_vec, precision.data, inputs)
def eager_normal(loc, scale, value): affine = (loc - value) / scale assert isinstance(affine, Affine) real_inputs = OrderedDict((k, v) for k, v in affine.inputs.items() if v.dtype == 'real') assert not any(v.shape for v in real_inputs.values()) tensors = [affine.const] + [c for v, c in affine.coeffs.items()] inputs, tensors = align_tensors(*tensors) tensors = torch.broadcast_tensors(*tensors) const, coeffs = tensors[0], tensors[1:] dim = sum(d.num_elements for d in real_inputs.values()) loc = BlockVector(const.shape + (dim,)) loc[..., 0] = -const / coeffs[0] precision = BlockMatrix(const.shape + (dim, dim)) for i, (v1, c1) in enumerate(zip(real_inputs, coeffs)): for j, (v2, c2) in enumerate(zip(real_inputs, coeffs)): precision[..., i, j] = c1 * c2 loc = loc.as_tensor() precision = precision.as_tensor() log_prob = -0.5 * math.log(2 * math.pi) - scale.log() return log_prob + Gaussian(loc, precision, affine.inputs)
def eager_normal(loc, scale, value): affine = (loc - value) / scale if not affine.is_affine: return None real_inputs = OrderedDict( (k, v) for k, v in affine.inputs.items() if v.dtype == 'real') int_inputs = OrderedDict( (k, v) for k, v in affine.inputs.items() if v.dtype != 'real') assert not any(v.shape for v in real_inputs.values()) const = affine(**{k: 0. for k, v in real_inputs.items()}) coeffs = OrderedDict() for c in real_inputs.keys(): coeffs[c] = affine( **{k: 1. if c == k else 0. for k in real_inputs.keys()}) - const tensors = [const] + list(coeffs.values()) inputs, tensors = align_tensors(*tensors, expand=True) const, coeffs = tensors[0], tensors[1:] dim = sum(d.num_elements for d in real_inputs.values()) loc = BlockVector(const.shape + (dim, )) loc[..., 0] = -const / coeffs[0] precision = BlockMatrix(const.shape + (dim, dim)) for i, (v1, c1) in enumerate(zip(real_inputs, coeffs)): for j, (v2, c2) in enumerate(zip(real_inputs, coeffs)): precision[..., i, j] = c1 * c2 loc = loc.as_tensor() precision = precision.as_tensor() info_vec = precision.matmul(loc.unsqueeze(-1)).squeeze(-1) log_prob = -0.5 * math.log( 2 * math.pi) - scale.data.log() - 0.5 * (loc * info_vec).sum(-1) return Tensor(log_prob, int_inputs) + Gaussian(info_vec, precision, affine.inputs)
def test_bart(analytic_kl): global call_count call_count = 0 with interpretation(reflect): q = Independent( Independent( Contraction( ops.nullop, ops.add, frozenset(), ( Tensor( torch.tensor( [[ -0.6077086925506592, -1.1546266078948975, -0.7021151781082153, -0.5303535461425781, -0.6365622282028198, -1.2423288822174072, -0.9941254258155823, -0.6287292242050171 ], [ -0.6987162828445435, -1.0875964164733887, -0.7337473630905151, -0.4713417589664459, -0.6674002408981323, -1.2478348016738892, -0.8939017057418823, -0.5238542556762695 ]], dtype=torch.float32), # noqa ( ( 'time_b4', bint(2), ), ( '_event_1_b2', bint(8), ), ), 'real'), Gaussian( torch.tensor([ [[-0.3536059558391571], [-0.21779225766658783], [0.2840439975261688], [0.4531521499156952], [-0.1220812276005745], [-0.05519985035061836], [0.10932210087776184], [0.6656699776649475]], [[-0.39107921719551086], [ -0.20241987705230713 ], [0.2170514464378357], [0.4500560462474823], [0.27945515513420105], [-0.0490039587020874], [-0.06399798393249512], [0.846565842628479]] ], dtype=torch.float32), # noqa torch.tensor([ [[[1.984686255455017]], [[0.6699360013008118]], [[1.6215802431106567]], [[2.372016668319702]], [[1.77385413646698]], [[0.526767373085022]], [[0.8722561597824097]], [[2.1879124641418457]] ], [[[1.6996612548828125]], [[ 0.7535632252693176 ]], [[1.4946647882461548]], [[2.642792224884033]], [[1.7301604747772217]], [[0.5203893780708313]], [[1.055436372756958]], [[2.8370864391326904]]] ], dtype=torch.float32), # noqa ( ( 'time_b4', bint(2), ), ( '_event_1_b2', bint(8), ), ( 'value_b1', reals(), ), )), )), 'gate_rate_b3', '_event_1_b2', 'value_b1'), 'gate_rate_t', 'time_b4', 'gate_rate_b3') p_prior = Contraction( ops.logaddexp, ops.add, frozenset({'state(time=1)_b11', 'state_b10'}), ( MarkovProduct( ops.logaddexp, ops.add, Contraction( ops.nullop, ops.add, frozenset(), ( Tensor( torch.tensor(2.7672932147979736, dtype=torch.float32), (), 'real'), Gaussian( torch.tensor([-0.0, -0.0, 0.0, 0.0], dtype=torch.float32), torch.tensor([[ 98.01002502441406, 0.0, -99.0000228881836, -0.0 ], [ 0.0, 98.01002502441406, -0.0, -99.0000228881836 ], [ -99.0000228881836, -0.0, 100.0000228881836, 0.0 ], [ -0.0, -99.0000228881836, 0.0, 100.0000228881836 ]], dtype=torch.float32), # noqa ( ( 'state_b7', reals(2, ), ), ( 'state(time=1)_b8', reals(2, ), ), )), Subs( AffineNormal( Tensor( torch.tensor( [[ 0.03488487750291824, 0.07356668263673782, 0.19946961104869843, 0.5386509299278259, -0.708323061466217, 0.24411526322364807, -0.20855577290058136, -0.2421337217092514 ], [ 0.41762110590934753, 0.5272183418273926, -0.49835553765296936, -0.0363837406039238, -0.0005282597267068923, 0.2704298794269562, -0.155222088098526, -0.44802337884902954 ]], dtype=torch.float32), # noqa (), 'real'), Tensor( torch.tensor( [[ -0.003566693514585495, -0.2848514914512634, 0.037103548645973206, 0.12648648023605347, -0.18501518666744232, -0.20899859070777893, 0.04121830314397812, 0.0054807960987091064 ], [ 0.0021788496524095535, -0.18700894713401794, 0.08187370002269745, 0.13554862141609192, -0.10477752983570099, -0.20848378539085388, -0.01393645629286766, 0.011670656502246857 ]], dtype=torch.float32), # noqa (( 'time_b9', bint(2), ), ), 'real'), Tensor( torch.tensor( [[ 0.5974780917167664, 0.864071786403656, 1.0236268043518066, 0.7147538065910339, 0.7423890233039856, 0.9462157487869263, 1.2132389545440674, 1.0596832036972046 ], [ 0.5787821412086487, 0.9178534150123596, 0.9074794054031372, 0.6600189208984375, 0.8473222255706787, 0.8426999449729919, 1.194266438484192, 1.0471148490905762 ]], dtype=torch.float32), # noqa (( 'time_b9', bint(2), ), ), 'real'), Variable('state(time=1)_b8', reals(2, )), Variable('gate_rate_b6', reals(8, ))), (( 'gate_rate_b6', Binary( ops.GetitemOp(0), Variable('gate_rate_t', reals(2, 8)), Variable('time_b9', bint(2))), ), )), )), Variable('time_b9', bint(2)), frozenset({('state_b7', 'state(time=1)_b8')}), frozenset({('state(time=1)_b8', 'state(time=1)_b11'), ('state_b7', 'state_b10')})), # noqa Subs( dist.MultivariateNormal( Tensor(torch.tensor([0.0, 0.0], dtype=torch.float32), (), 'real'), Tensor( torch.tensor([[10.0, 0.0], [0.0, 10.0]], dtype=torch.float32), (), 'real'), Variable('value_b5', reals(2, ))), (( 'value_b5', Variable('state_b10', reals(2, )), ), )), )) p_likelihood = Contraction( ops.add, ops.nullop, frozenset({'time_b17', 'destin_b16', 'origin_b15'}), ( Contraction( ops.logaddexp, ops.add, frozenset({'gated_b14'}), ( dist.Categorical( Binary( ops.GetitemOp(0), Binary( ops.GetitemOp(0), Subs( Function( unpack_gate_rate_0, reals(2, 2, 2), (Variable('gate_rate_b12', reals(8, )), )), (( 'gate_rate_b12', Binary( ops.GetitemOp(0), Variable( 'gate_rate_t', reals(2, 8)), Variable('time_b17', bint(2))), ), )), Variable('origin_b15', bint(2))), Variable('destin_b16', bint(2))), Variable('gated_b14', bint(2))), Stack( 'gated_b14', ( dist.Poisson( Binary( ops.GetitemOp(0), Binary( ops.GetitemOp(0), Subs( Function( unpack_gate_rate_1, reals(2, 2), (Variable( 'gate_rate_b13', reals(8, )), )), (( 'gate_rate_b13', Binary( ops.GetitemOp(0), Variable( 'gate_rate_t', reals(2, 8)), Variable( 'time_b17', bint(2))), ), )), Variable('origin_b15', bint(2))), Variable('destin_b16', bint(2))), Tensor( torch.tensor( [[[1.0, 1.0], [5.0, 0.0]], [[0.0, 6.0], [19.0, 3.0]]], dtype=torch.float32), # noqa ( ( 'time_b17', bint(2), ), ( 'origin_b15', bint(2), ), ( 'destin_b16', bint(2), ), ), 'real')), dist.Delta( Tensor( torch.tensor(0.0, dtype=torch.float32), (), 'real'), Tensor( torch.tensor(0.0, dtype=torch.float32), (), 'real'), Tensor( torch.tensor( [[[1.0, 1.0], [5.0, 0.0]], [[0.0, 6.0], [19.0, 3.0]]], dtype=torch.float32), # noqa ( ( 'time_b17', bint(2), ), ( 'origin_b15', bint(2), ), ( 'destin_b16', bint(2), ), ), 'real')), )), )), )) if analytic_kl: exact_part = funsor.Integrate(q, p_prior - q, "gate_rate_t") with interpretation(monte_carlo): approx_part = funsor.Integrate(q, p_likelihood, "gate_rate_t") elbo = exact_part + approx_part else: p = p_prior + p_likelihood with interpretation(monte_carlo): elbo = Integrate(q, p - q, "gate_rate_t") assert isinstance(elbo, Tensor), elbo.pretty() assert call_count == 1
def matrix_and_mvn_to_funsor(matrix, mvn, event_dims=(), x_name="value_x", y_name="value_y"): """ Convert a noisy affine function to a Gaussian. The noisy affine function is defined as:: y = x @ matrix + mvn.sample() The result is a non-normalized Gaussian funsor with two real inputs, ``x_name`` and ``y_name``, corresponding to a conditional distribution of real vector ``y` given real vector ``x``. :param torch.Tensor matrix: A matrix with rightmost shape ``(x_size, y_size)``. :param mvn: A multivariate normal distribution with ``event_shape == (y_size,)``. :type mvn: torch.distributions.MultivariateNormal or torch.distributions.Independent of torch.distributions.Normal :param tuple event_dims: A tuple of names for rightmost dimensions. These will be assigned to ``result.inputs`` of type ``bint``. :param str x_name: The name of the ``x`` random variable. :param str y_name: The name of the ``y`` random variable. :return: A funsor with given ``real_inputs`` and possibly additional bint inputs. :rtype: funsor.terms.Funsor """ assert (isinstance(mvn, torch.distributions.MultivariateNormal) or (isinstance(mvn, torch.distributions.Independent) and isinstance(mvn.base_dist, torch.distributions.Normal))) assert isinstance(matrix, torch.Tensor) x_size, y_size = matrix.shape[-2:] assert mvn.event_shape == (y_size,) # Handle diagonal normal distributions as an efficient special case. if isinstance(mvn, torch.distributions.Independent): return AffineNormal(tensor_to_funsor(matrix, event_dims, 2), tensor_to_funsor(mvn.base_dist.loc, event_dims, 1), tensor_to_funsor(mvn.base_dist.scale, event_dims, 1), Variable(x_name, reals(x_size)), Variable(y_name, reals(y_size))) info_vec = mvn.loc.unsqueeze(-1).cholesky_solve(mvn.scale_tril).squeeze(-1) log_prob = (-0.5 * y_size * math.log(2 * math.pi) - mvn.scale_tril.diagonal(dim1=-1, dim2=-2).log().sum(-1) - 0.5 * (info_vec * mvn.loc).sum(-1)) batch_shape = broadcast_shape(matrix.shape[:-2], mvn.batch_shape) P_yy = mvn.precision_matrix.expand(batch_shape + (y_size, y_size)) neg_P_xy = matrix.matmul(P_yy) P_xy = -neg_P_xy P_yx = P_xy.transpose(-1, -2) P_xx = neg_P_xy.matmul(matrix.transpose(-1, -2)) precision = torch.cat([torch.cat([P_xx, P_xy], -1), torch.cat([P_yx, P_yy], -1)], -2) info_y = info_vec.expand(batch_shape + (y_size,)) info_x = -matrix.matmul(info_y.unsqueeze(-1)).squeeze(-1) info_vec = torch.cat([info_x, info_y], -1) info_vec = tensor_to_funsor(info_vec, event_dims, 1) precision = tensor_to_funsor(precision, event_dims, 2) inputs = info_vec.inputs.copy() inputs[x_name] = reals(x_size) inputs[y_name] = reals(y_size) return tensor_to_funsor(log_prob, event_dims) + Gaussian(info_vec.data, precision.data, inputs)
def eager_mvn(loc, scale_tril, value): assert len(loc.shape) == 1 assert len(scale_tril.shape) == 2 assert value.output == loc.output if not is_affine(loc) or not is_affine(value): return None # lazy # Extract an affine representation. eye = torch.eye(scale_tril.data.size(-1)).expand(scale_tril.data.shape) prec_sqrt = Tensor( eye.triangular_solve(scale_tril.data, upper=False).solution, scale_tril.inputs) affine = prec_sqrt @ (loc - value) const, coeffs = extract_affine(affine) if not isinstance(const, Tensor): return None # lazy if not all(isinstance(coeff, Tensor) for coeff, _ in coeffs.values()): return None # lazy # Compute log_prob using funsors. scale_diag = Tensor(scale_tril.data.diagonal(dim1=-1, dim2=-2), scale_tril.inputs) log_prob = (-0.5 * scale_diag.shape[0] * math.log(2 * math.pi) - scale_diag.log().sum() - 0.5 * (const**2).sum()) # Dovetail to avoid variable name collision in einsum. equations1 = [ ''.join(c if c in ',->' else chr(ord(c) * 2 - ord('a')) for c in eqn) for _, eqn in coeffs.values() ] equations2 = [ ''.join(c if c in ',->' else chr(ord(c) * 2 - ord('a') + 1) for c in eqn) for _, eqn in coeffs.values() ] real_inputs = OrderedDict( (k, v) for k, v in affine.inputs.items() if v.dtype == 'real') assert tuple(real_inputs) == tuple(coeffs) # Align and broadcast tensors. neg_const = -const tensors = [neg_const] + [coeff for coeff, _ in coeffs.values()] inputs, tensors = align_tensors(*tensors, expand=True) neg_const, coeffs = tensors[0], tensors[1:] dim = sum(d.num_elements for d in real_inputs.values()) batch_shape = neg_const.shape[:-1] info_vec = BlockVector(batch_shape + (dim, )) precision = BlockMatrix(batch_shape + (dim, dim)) offset1 = 0 for i1, (v1, c1) in enumerate(zip(real_inputs, coeffs)): size1 = real_inputs[v1].num_elements slice1 = slice(offset1, offset1 + size1) inputs1, output1 = equations1[i1].split('->') input11, input12 = inputs1.split(',') assert input11 == input12 + output1 info_vec[..., slice1] = torch.einsum( f'...{input11},...{output1}->...{input12}', c1, neg_const) \ .reshape(batch_shape + (size1,)) offset2 = 0 for i2, (v2, c2) in enumerate(zip(real_inputs, coeffs)): size2 = real_inputs[v2].num_elements slice2 = slice(offset2, offset2 + size2) inputs2, output2 = equations2[i2].split('->') input21, input22 = inputs2.split(',') assert input21 == input22 + output2 precision[..., slice1, slice2] = torch.einsum( f'...{input11},...{input22}{output1}->...{input12}{input22}', c1, c2) \ .reshape(batch_shape + (size1, size2)) offset2 += size2 offset1 += size1 info_vec = info_vec.as_tensor() precision = precision.as_tensor() inputs.update(real_inputs) return log_prob + Gaussian(info_vec, precision, inputs)