def condition(self, value): """ Condition this Gaussian on a trailing subset of its state. This should satisfy:: g.condition(y).dim() == g.dim() - y.size(-1) Note that since this is a non-normalized Gaussian, we include the density of ``y`` in the result. Thus :meth:`condition` is similar to a ``functools.partial`` binding of arguments:: left = x[..., :n] right = x[..., n:] g.log_density(x) == g.condition(right).log_density(left) """ assert isinstance(value, torch.Tensor) right = value.size(-1) dim = self.dim() assert right <= dim n = dim - right info_a = self.info_vec[..., :n] info_b = self.info_vec[..., n:] P_aa = self.precision[..., :n, :n] P_ab = self.precision[..., :n, n:] P_bb = self.precision[..., n:, n:] b = value info_vec = info_a - matvecmul(P_ab, b) precision = P_aa log_normalizer = (self.log_normalizer + -0.5 * matvecmul(P_bb, b).mul(b).sum(-1) + b.mul(info_b).sum(-1)) return Gaussian(log_normalizer, info_vec, precision)
def mvn_to_gaussian(mvn): """ Convert a MultivariateNormal distribution to a Gaussian. :param ~torch.distributions.MultivariateNormal mvn: A multivariate normal distribution. :return: An equivalent Gaussian object. :rtype: ~pyro.ops.gaussian.Gaussian """ assert (isinstance(mvn, torch.distributions.MultivariateNormal) or (isinstance(mvn, torch.distributions.Independent) and isinstance(mvn.base_dist, torch.distributions.Normal))) if isinstance(mvn, torch.distributions.Independent): mvn = mvn.base_dist precision_diag = mvn.scale.pow(-2) precision = precision_diag.diag_embed() info_vec = mvn.loc * precision_diag scale_diag = mvn.scale else: precision = mvn.precision_matrix info_vec = matvecmul(precision, mvn.loc) scale_diag = mvn.scale_tril.diagonal(dim1=-2, dim2=-1) n = mvn.loc.size(-1) log_normalizer = (-0.5 * n * math.log(2 * math.pi) + -0.5 * (info_vec * mvn.loc).sum(-1) - scale_diag.log().sum(-1)) return Gaussian(log_normalizer, info_vec, precision)
def condition(self, value): if value.size(-1) == self.loc.size(-1): prec_sqrt = self.matrix / self.scale.unsqueeze(-2) precision = matmul(prec_sqrt, prec_sqrt.transpose(-1, -2)) delta = (value - self.loc) / self.scale info_vec = matvecmul(prec_sqrt, delta) log_normalizer = (-0.5 * self.loc.size(-1) * math.log(2 * math.pi) - 0.5 * delta.pow(2).sum(-1) - self.scale.log().sum(-1)) return Gaussian(log_normalizer, info_vec, precision) else: return self.to_gaussian().condition(value)
def left_condition(self, value): """ If ``value.size(-1) == x_dim``, this returns a Normal distribution with ``event_dim=1``. After applying this method, the cost to draw a sample is ``O(y_dim)`` instead of ``O(y_dim ** 3)``. """ if value.size(-1) == self.matrix.size(-2): loc = matvecmul(self.matrix.transpose(-1, -2), value) + self.loc matrix = value.new_zeros(loc.shape[:-1] + (0, loc.size(-1))) scale = self.scale.expand(loc.shape) return AffineNormal(matrix, loc, scale) else: return self.to_gaussian().left_condition(value)
def _matrix_and_gaussian_to_gaussian(matrix, y_gaussian): P_yy = y_gaussian.precision neg_P_xy = matmul(matrix, P_yy) P_xy = -neg_P_xy P_yx = P_xy.transpose(-1, -2) P_xx = matmul(neg_P_xy, matrix.transpose(-1, -2)) precision = torch.cat([torch.cat([P_xx, P_xy], -1), torch.cat([P_yx, P_yy], -1)], -2) info_y = y_gaussian.info_vec info_x = -matvecmul(matrix, info_y) info_vec = torch.cat([info_x, info_y], -1) log_normalizer = y_gaussian.log_normalizer result = Gaussian(log_normalizer, info_vec, precision) return result
def log_density(self, value): """ Evaluate the log density of this Gaussian at a point value:: -0.5 * value.T @ precision @ value + value.T @ info_vec + log_normalizer This is mainly used for testing. """ if value.size(-1) == 0: batch_shape = broadcast_shape(value.shape[:-1], self.batch_shape) return self.log_normalizer.expand(batch_shape) result = (-0.5) * matvecmul(self.precision, value) result = result + self.info_vec result = (value * result).sum(-1) return result + self.log_normalizer