예제 #1
0
def test_reduce_moment_matching_univariate():
    int_inputs = [('i', bint(2))]
    real_inputs = [('x', reals())]
    inputs = OrderedDict(int_inputs + real_inputs)
    int_inputs = OrderedDict(int_inputs)
    real_inputs = OrderedDict(real_inputs)

    p = 0.8
    t = 1.234
    s1, s2, s3 = 2.0, 3.0, 4.0
    loc = torch.tensor([[-s1], [s1]])
    precision = torch.tensor([[[s2**-2]], [[s3**-2]]])
    info_vec = precision.matmul(loc.unsqueeze(-1)).squeeze(-1)
    discrete = Tensor(torch.tensor([1 - p, p]).log() + t, int_inputs)
    gaussian = Gaussian(info_vec, precision, inputs)
    gaussian -= gaussian.log_normalizer
    joint = discrete + gaussian
    with interpretation(moment_matching):
        actual = joint.reduce(ops.logaddexp, 'i')
    assert_close(actual.reduce(ops.logaddexp), joint.reduce(ops.logaddexp))

    expected_loc = torch.tensor([(2 * p - 1) * s1])
    expected_variance = (4 * p * (1 - p) * s1**2 + (1 - p) * s2**2 + p * s3**2)
    expected_precision = torch.tensor([[1 / expected_variance]])
    expected_info_vec = expected_precision.matmul(
        expected_loc.unsqueeze(-1)).squeeze(-1)
    expected_gaussian = Gaussian(expected_info_vec, expected_precision,
                                 real_inputs)
    expected_gaussian -= expected_gaussian.log_normalizer
    expected_discrete = Tensor(torch.tensor(t))
    expected = expected_discrete + expected_gaussian
    assert_close(actual, expected, atol=1e-5, rtol=None)
예제 #2
0
def test_reduce_moment_matching_multivariate():
    int_inputs = [('i', Bint[4])]
    real_inputs = [('x', Reals[2])]
    inputs = OrderedDict(int_inputs + real_inputs)
    int_inputs = OrderedDict(int_inputs)
    real_inputs = OrderedDict(real_inputs)

    loc = numeric_array([[-10., -1.], [+10., -1.], [+10., +1.], [-10., +1.]])
    precision = zeros(4, 1, 1) + ops.new_eye(loc, (2, ))
    discrete = Tensor(zeros(4), int_inputs)
    gaussian = Gaussian(loc, precision, inputs)
    gaussian -= gaussian.log_normalizer
    joint = discrete + gaussian
    with interpretation(moment_matching):
        actual = joint.reduce(ops.logaddexp, 'i')
    assert_close(actual.reduce(ops.logaddexp), joint.reduce(ops.logaddexp))

    expected_loc = zeros(2)
    expected_covariance = numeric_array([[101., 0.], [0., 2.]])
    expected_precision = _inverse(expected_covariance)
    expected_gaussian = Gaussian(expected_loc, expected_precision, real_inputs)
    expected_gaussian -= expected_gaussian.log_normalizer
    expected_discrete = Tensor(ops.log(numeric_array(4.)))
    expected = expected_discrete + expected_gaussian
    assert_close(actual, expected, atol=1e-5, rtol=None)
예제 #3
0
def test_smoke(expr, expected_type):
    g1 = Gaussian(info_vec=numeric_array([[0.0, 0.1, 0.2], [2.0, 3.0, 4.0]]),
                  precision=numeric_array([[[1.0, 0.1, 0.2], [0.1, 1.0, 0.3],
                                            [0.2, 0.3, 1.0]],
                                           [[1.0, 0.1, 0.2], [0.1, 1.0, 0.3],
                                            [0.2, 0.3, 1.0]]]),
                  inputs=OrderedDict([('i', bint(2)), ('x', reals(3))]))
    assert isinstance(g1, Gaussian)

    g2 = Gaussian(info_vec=numeric_array([[0.0, 0.1], [2.0, 3.0]]),
                  precision=numeric_array([[[1.0, 0.2], [0.2, 1.0]],
                                           [[1.0, 0.2], [0.2, 1.0]]]),
                  inputs=OrderedDict([('i', bint(2)), ('y', reals(2))]))
    assert isinstance(g2, Gaussian)

    shift = Tensor(numeric_array([-1., 1.]), OrderedDict([('i', bint(2))]))
    assert isinstance(shift, Tensor)

    i0 = Number(1, 2)
    assert isinstance(i0, Number)

    x0 = Tensor(numeric_array([0.5, 0.6, 0.7]))
    assert isinstance(x0, Tensor)

    y0 = Tensor(numeric_array([[0.2, 0.3], [0.8, 0.9]]),
                inputs=OrderedDict([('i', bint(2))]))
    assert isinstance(y0, Tensor)

    result = eval(expr)
    assert isinstance(result, expected_type)
예제 #4
0
def test_reduce_moment_matching_multivariate():
    int_inputs = [('i', bint(4))]
    real_inputs = [('x', reals(2))]
    inputs = OrderedDict(int_inputs + real_inputs)
    int_inputs = OrderedDict(int_inputs)
    real_inputs = OrderedDict(real_inputs)

    loc = torch.tensor([[-10., -1.], [+10., -1.], [+10., +1.], [-10., +1.]])
    precision = torch.zeros(4, 1, 1) + torch.eye(2, 2)
    discrete = Tensor(torch.zeros(4), int_inputs)
    gaussian = Gaussian(loc, precision, inputs)
    gaussian -= gaussian.log_normalizer
    joint = discrete + gaussian
    with interpretation(moment_matching):
        actual = joint.reduce(ops.logaddexp, 'i')
    assert_close(actual.reduce(ops.logaddexp), joint.reduce(ops.logaddexp))

    expected_loc = torch.zeros(2)
    expected_covariance = torch.tensor([[101., 0.], [0., 2.]])
    expected_precision = expected_covariance.inverse()
    expected_gaussian = Gaussian(expected_loc, expected_precision, real_inputs)
    expected_gaussian -= expected_gaussian.log_normalizer
    expected_discrete = Tensor(torch.tensor(4.).log())
    expected = expected_discrete + expected_gaussian
    assert_close(actual, expected, atol=1e-5, rtol=None)
예제 #5
0
def test_mc_plate_gaussian():
    log_measure = Gaussian(torch.tensor([0.]), torch.tensor([[1.]]),
                           (('loc', reals()),)) + torch.tensor(-0.9189)
    integrand = Gaussian(torch.randn((100, 1)) + 3., torch.ones((100, 1, 1)),
                         (('data', bint(100)), ('loc', reals())))

    res = Integrate(log_measure.sample(frozenset({'loc'})), integrand, frozenset({'loc'}))
    res = res.reduce(ops.mul, frozenset({'data'}))
    assert not torch.isinf(res).any()
예제 #6
0
def test_mc_plate_gaussian():
    log_measure = Gaussian(numeric_array([0.]), numeric_array([[1.]]),
                           (('loc', Real),)) + numeric_array(-0.9189)
    integrand = Gaussian(randn((100, 1)) + 3., ones((100, 1, 1)),
                         (('data', Bint[100]), ('loc', Real)))

    rng_key = None if get_backend() != 'jax' else np.array([0, 0], dtype=np.uint32)
    res = Integrate(log_measure.sample('loc', rng_key=rng_key), integrand, 'loc')
    res = res.reduce(ops.mul, 'data')
    assert not ((res == float('inf')) | (res == float('-inf'))).any()
예제 #7
0
def mvn_to_funsor(pyro_dist, event_inputs=(), real_inputs=OrderedDict()):
    """
    Convert a joint :class:`torch.distributions.MultivariateNormal`
    distribution into a :class:`~funsor.terms.Funsor` with multiple real
    inputs.

    This should satisfy::

        sum(d.num_elements for d in real_inputs.values())
          == pyro_dist.event_shape[0]

    :param torch.distributions.MultivariateNormal pyro_dist: A
        multivariate normal distribution over one or more variables
        of real or vector or tensor type.
    :param tuple event_inputs: A tuple of names for rightmost dimensions.
        These will be assigned to ``result.inputs`` of type ``Bint``.
    :param OrderedDict real_inputs: A dict mapping real variable name
        to appropriately sized ``Real``. The sum of all ``.numel()``
        of all real inputs should be equal to the ``pyro_dist`` dimension.
    :return: A funsor with given ``real_inputs`` and possibly additional
        Bint inputs.
    :rtype: funsor.terms.Funsor
    """
    assert isinstance(pyro_dist, torch.distributions.MultivariateNormal)
    assert isinstance(event_inputs, tuple)
    assert isinstance(real_inputs, OrderedDict)
    dim_to_name = default_dim_to_name(pyro_dist.batch_shape, event_inputs)
    funsor_dist = to_funsor(pyro_dist, Real, dim_to_name)
    if len(real_inputs) == 0:
        return funsor_dist
    discrete, gaussian = funsor_dist(value="value").terms
    inputs = OrderedDict(
        (k, v) for k, v in gaussian.inputs.items() if v.dtype != 'real')
    inputs.update(real_inputs)
    return discrete + Gaussian(gaussian.info_vec, gaussian.precision, inputs)
예제 #8
0
def moment_matching_contract_joint(red_op, bin_op, reduced_vars, discrete,
                                   gaussian):

    approx_vars = frozenset(
        k for k in reduced_vars
        if k in gaussian.inputs and gaussian.inputs[k].dtype != 'real')
    exact_vars = reduced_vars - approx_vars

    if exact_vars and approx_vars:
        return Contraction(red_op, bin_op, exact_vars, discrete,
                           gaussian).reduce(red_op, approx_vars)

    if approx_vars and not exact_vars:
        discrete += gaussian.log_normalizer
        new_discrete = discrete.reduce(
            ops.logaddexp, approx_vars.intersection(discrete.inputs))
        new_discrete = discrete.reduce(
            ops.logaddexp, approx_vars.intersection(discrete.inputs))
        num_elements = reduce(ops.mul, [
            gaussian.inputs[k].num_elements
            for k in approx_vars.difference(discrete.inputs)
        ], 1)
        if num_elements != 1:
            new_discrete -= math.log(num_elements)

        int_inputs = OrderedDict(
            (k, d) for k, d in gaussian.inputs.items() if d.dtype != 'real')
        probs = (discrete - new_discrete.clamp_finite()).exp()

        old_loc = Tensor(
            gaussian.info_vec.unsqueeze(-1).cholesky_solve(
                gaussian._precision_chol).squeeze(-1), int_inputs)
        new_loc = (probs * old_loc).reduce(ops.add, approx_vars)
        old_cov = Tensor(cholesky_inverse(gaussian._precision_chol),
                         int_inputs)
        diff = old_loc - new_loc
        outers = Tensor(
            diff.data.unsqueeze(-1) * diff.data.unsqueeze(-2), diff.inputs)
        new_cov = ((probs * old_cov).reduce(ops.add, approx_vars) +
                   (probs * outers).reduce(ops.add, approx_vars))

        # Numerically stabilize by adding bogus precision to empty components.
        total = probs.reduce(ops.add, approx_vars)
        mask = (total.data == 0).to(
            total.data.dtype).unsqueeze(-1).unsqueeze(-1)
        new_cov.data += mask * torch.eye(new_cov.data.size(-1))

        new_precision = Tensor(cholesky_inverse(cholesky(new_cov.data)),
                               new_cov.inputs)
        new_info_vec = new_precision.data.matmul(
            new_loc.data.unsqueeze(-1)).squeeze(-1)
        new_inputs = new_loc.inputs.copy()
        new_inputs.update(
            (k, d) for k, d in gaussian.inputs.items() if d.dtype == 'real')
        new_gaussian = Gaussian(new_info_vec, new_precision.data, new_inputs)
        new_discrete -= new_gaussian.log_normalizer

        return new_discrete + new_gaussian

    return None
예제 #9
0
def test_smoke(expr, expected_type):
    dx = Delta('x', Tensor(torch.randn(2, 3), OrderedDict([('i', bint(2))])))
    assert isinstance(dx, Delta)

    dy = Delta('y', Tensor(torch.randn(3, 4), OrderedDict([('j', bint(3))])))
    assert isinstance(dy, Delta)

    t = Tensor(torch.randn(2, 3), OrderedDict([('i', bint(2)),
                                               ('j', bint(3))]))
    assert isinstance(t, Tensor)

    g = Gaussian(info_vec=torch.tensor([[0.0, 0.1, 0.2], [2.0, 3.0, 4.0]]),
                 precision=torch.tensor([[[1.0, 0.1, 0.2], [0.1, 1.0, 0.3],
                                          [0.2, 0.3, 1.0]],
                                         [[1.0, 0.1, 0.2], [0.1, 1.0, 0.3],
                                          [0.2, 0.3, 1.0]]]),
                 inputs=OrderedDict([('i', bint(2)), ('x', reals(3))]))
    assert isinstance(g, Gaussian)

    i0 = Number(1, 2)
    assert isinstance(i0, Number)

    x0 = Tensor(torch.tensor([0.5, 0.6, 0.7]))
    assert isinstance(x0, Tensor)

    result = eval(expr)
    assert isinstance(result, expected_type)
예제 #10
0
def eager_integrate(log_measure, integrand, reduced_vars):
    real_vars = frozenset(k for k in reduced_vars
                          if log_measure.inputs[k].dtype == 'real')
    if real_vars:

        lhs_reals = frozenset(k for k, d in log_measure.inputs.items()
                              if d.dtype == 'real')
        rhs_reals = frozenset(k for k, d in integrand.inputs.items()
                              if d.dtype == 'real')
        if lhs_reals == real_vars and rhs_reals <= real_vars:
            inputs = OrderedDict((k, d) for t in (log_measure, integrand)
                                 for k, d in t.inputs.items())
            lhs_info_vec, lhs_precision = align_gaussian(inputs, log_measure)
            rhs_info_vec, rhs_precision = align_gaussian(inputs, integrand)
            lhs = Gaussian(lhs_info_vec, lhs_precision, inputs)

            # Compute the expectation of a non-normalized quadratic form.
            # See "The Matrix Cookbook" (November 15, 2012) ss. 8.2.2 eq. 380.
            # http://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf
            norm = lhs.log_normalizer.data.exp()
            lhs_cov = cholesky_inverse(lhs._precision_chol)
            lhs_loc = lhs.info_vec.unsqueeze(-1).cholesky_solve(
                lhs._precision_chol).squeeze(-1)
            vmv_term = _vv(lhs_loc,
                           rhs_info_vec - 0.5 * _mv(rhs_precision, lhs_loc))
            data = norm * (vmv_term - 0.5 * _trace_mm(rhs_precision, lhs_cov))
            inputs = OrderedDict(
                (k, d) for k, d in inputs.items() if k not in reduced_vars)
            result = Tensor(data, inputs)
            return result.reduce(ops.add, reduced_vars - real_vars)

        raise NotImplementedError('TODO implement partial integration')

    return None  # defer to default implementation
예제 #11
0
def torchmvn_to_funsor(pyro_dist, output=None, dim_to_name=None, real_inputs=OrderedDict()):
    funsor_dist = torchdistribution_to_funsor(pyro_dist, output=output, dim_to_name=dim_to_name)
    if len(real_inputs) == 0:
        return funsor_dist
    discrete, gaussian = funsor_dist(value="value").terms
    inputs = OrderedDict((k, v) for k, v in gaussian.inputs.items() if v.dtype != 'real')
    inputs.update(real_inputs)
    return discrete + Gaussian(gaussian.info_vec, gaussian.precision, inputs)
예제 #12
0
def adjoint_subs_gaussianmixture_discrete(adj_redop, adj_binop, out_adj, arg,
                                          subs):

    if any(v.dtype == 'real' and not isinstance(v, Variable) for k, v in subs):
        raise NotImplementedError(
            "TODO implement adjoint for substitution into Gaussian real variable"
        )

    # invert renaming
    renames = tuple((v.name, k) for k, v in subs if isinstance(v, Variable))
    out_adj = Subs(out_adj, renames)

    # inverting advanced indexing
    slices = tuple((k, v) for k, v in subs if not isinstance(v, Variable))

    arg_int_inputs = OrderedDict(
        (k, v) for k, v in arg.inputs.items() if v.dtype != 'real')

    zeros_like_out = Subs(
        Tensor(
            arg.terms[1].info_vec.new_full(arg.terms[1].info_vec.shape[:-1],
                                           ops.UNITS[adj_binop]),
            arg_int_inputs), slices)
    out_adj = adj_binop(out_adj, zeros_like_out)

    in_adj_discrete = adjoint_ops(Subs, adj_redop, adj_binop, out_adj,
                                  arg.terms[0], subs)[arg.terms[0]]

    # invert the slicing for the Gaussian term even though the message does not affect the values
    in_adj_info_vec = list(
        adjoint_ops(
            Subs,
            adj_redop,
            adj_binop,  # ops.add, ops.mul,
            zeros_like_out,
            Tensor(arg.terms[1].info_vec, arg_int_inputs),
            slices).values())[0]

    in_adj_precision = list(
        adjoint_ops(
            Subs,
            adj_redop,
            adj_binop,  # ops.add, ops.mul,
            zeros_like_out,
            Tensor(arg.terms[1].precision, arg_int_inputs),
            slices).values())[0]

    assert isinstance(in_adj_info_vec, Tensor)
    assert isinstance(in_adj_precision, Tensor)

    in_adj_gaussian = Gaussian(in_adj_info_vec.data, in_adj_precision.data,
                               arg.inputs.copy())

    in_adj = in_adj_gaussian + in_adj_discrete
    return {arg: in_adj}
예제 #13
0
파일: joint.py 프로젝트: fehiepsi/funsor
    def moment_matching_reduce(self, op, reduced_vars):
        if not reduced_vars:
            return self
        if op is ops.logaddexp:
            if not all(reduced_vars.isdisjoint(d.inputs) for d in self.deltas):
                raise NotImplementedError(
                    'TODO handle moment_matching with Deltas')
            lazy_vars = frozenset().union(
                *(d.inputs for d in self.deltas)).intersection(reduced_vars)
            approx_vars = frozenset(k for k in reduced_vars - lazy_vars
                                    if self.inputs[k].dtype != 'real'
                                    if k in self.gaussian.inputs)
            exact_vars = reduced_vars - lazy_vars - approx_vars
            if exact_vars:
                return self.eager_reduce(op, exact_vars).reduce(
                    op, approx_vars | lazy_vars)

            # Moment-matching approximation.
            assert approx_vars and not exact_vars
            discrete = self.discrete

            new_discrete = discrete.reduce(
                ops.logaddexp, approx_vars.intersection(discrete.inputs))
            num_elements = reduce(ops.mul, [
                self.inputs[k].num_elements
                for k in approx_vars.difference(discrete.inputs)
            ], 1)
            if num_elements != 1:
                new_discrete -= math.log(num_elements)

            gaussian = self.gaussian
            int_inputs = OrderedDict((k, d)
                                     for k, d in gaussian.inputs.items()
                                     if d.dtype != 'real')
            probs = (discrete - new_discrete).exp()
            old_loc = Tensor(gaussian.loc, int_inputs)
            new_loc = (probs * old_loc).reduce(ops.add, approx_vars)
            old_cov = Tensor(sym_inverse(gaussian.precision), int_inputs)
            diff = old_loc - new_loc
            outers = Tensor(
                diff.data.unsqueeze(-1) * diff.data.unsqueeze(-2), diff.inputs)
            new_cov = ((probs * old_cov).reduce(ops.add, approx_vars) +
                       (probs * outers).reduce(ops.add, approx_vars))
            new_precision = Tensor(sym_inverse(new_cov.data), new_cov.inputs)
            new_inputs = new_loc.inputs.copy()
            new_inputs.update((k, d) for k, d in self.gaussian.inputs.items()
                              if d.dtype == 'real')
            new_gaussian = Gaussian(new_loc.data, new_precision.data,
                                    new_inputs)
            result = Joint(self.deltas, new_discrete, new_gaussian)
            return result.reduce(ops.logaddexp, lazy_vars)

        return None  # defer to default implementation
예제 #14
0
파일: joint.py 프로젝트: ordabayevy/funsor
def eager_cat_homogeneous(name, part_name, *parts):
    assert parts
    output = parts[0].output
    inputs = OrderedDict([(part_name, None)])
    for part in parts:
        assert part.output == output
        assert part_name in part.inputs
        inputs.update(part.inputs)

    int_inputs = OrderedDict(
        (k, v) for k, v in inputs.items() if v.dtype != "real")
    real_inputs = OrderedDict(
        (k, v) for k, v in inputs.items() if v.dtype == "real")
    inputs = int_inputs.copy()
    inputs.update(real_inputs)
    discretes = []
    info_vecs = []
    precisions = []
    for part in parts:
        inputs[part_name] = part.inputs[part_name]
        int_inputs[part_name] = inputs[part_name]
        shape = tuple(d.size for d in int_inputs.values())
        if isinstance(part, Gaussian):
            discrete = None
            gaussian = part
        elif issubclass(type(part), GaussianMixture
                        ):  # TODO figure out why isinstance isn't working
            discrete, gaussian = part.terms[0], part.terms[1]
            discrete = ops.expand(align_tensor(int_inputs, discrete), shape)
        else:
            raise NotImplementedError("TODO")
        discretes.append(discrete)
        info_vec, precision = align_gaussian(inputs, gaussian)
        info_vecs.append(ops.expand(info_vec, shape + (-1, )))
        precisions.append(ops.expand(precision, shape + (-1, -1)))
    if part_name != name:
        del inputs[part_name]
        del int_inputs[part_name]

    dim = 0
    info_vec = ops.cat(dim, *info_vecs)
    precision = ops.cat(dim, *precisions)
    inputs[name] = Bint[info_vec.shape[dim]]
    int_inputs[name] = inputs[name]
    result = Gaussian(info_vec, precision, inputs)
    if any(d is not None for d in discretes):
        for i, d in enumerate(discretes):
            if d is None:
                discretes[i] = ops.new_zeros(info_vecs[i],
                                             info_vecs[i].shape[:-1])
        discrete = ops.cat(dim, *discretes)
        result = result + Tensor(discrete, int_inputs)
    return result
예제 #15
0
def random_gaussian(inputs):
    """
    Creates a random :class:`funsor.gaussian.Gaussian` with given inputs.
    """
    assert isinstance(inputs, OrderedDict)
    batch_shape = tuple(d.dtype for d in inputs.values() if d.dtype != 'real')
    event_shape = (sum(d.num_elements for d in inputs.values() if d.dtype == 'real'),)
    loc = torch.randn(batch_shape + event_shape)
    prec_sqrt = torch.randn(batch_shape + event_shape + event_shape)
    precision = torch.matmul(prec_sqrt, prec_sqrt.transpose(-1, -2))
    precision = precision + 0.05 * torch.eye(event_shape[0])
    return Gaussian(loc, precision, inputs)
예제 #16
0
def eager_normal(loc, scale, value):
    if isinstance(loc, Variable):
        loc, value = value, loc

    inputs, (loc, scale) = align_tensors(loc, scale)
    loc, scale = torch.broadcast_tensors(loc, scale)
    inputs.update(value.inputs)
    int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real')

    log_prob = -0.5 * math.log(2 * math.pi) - scale.log()
    loc = loc.unsqueeze(-1)
    precision = scale.pow(-2).unsqueeze(-1).unsqueeze(-1)
    return Tensor(log_prob, int_inputs) + Gaussian(loc, precision, inputs)
예제 #17
0
def eager_mvn(loc, scale_tril, value):
    if isinstance(loc, Variable):
        loc, value = value, loc

    dim, = loc.output.shape
    inputs, (loc, scale_tril) = align_tensors(loc, scale_tril)
    inputs.update(value.inputs)
    int_inputs = OrderedDict((k, v) for k, v in inputs.items() if v.dtype != 'real')

    log_prob = -0.5 * dim * math.log(2 * math.pi) - scale_tril.diagonal(dim1=-1, dim2=-2).log().sum(-1)
    inv_scale_tril = torch.inverse(scale_tril)
    precision = torch.matmul(inv_scale_tril.transpose(-1, -2), inv_scale_tril)
    return Tensor(log_prob, int_inputs) + Gaussian(loc, precision, inputs)
예제 #18
0
def random_gaussian(inputs):
    """
    Creates a random :class:`funsor.gaussian.Gaussian` with given inputs.
    """
    assert isinstance(inputs, OrderedDict)
    batch_shape = tuple(d.dtype for d in inputs.values() if d.dtype != 'real')
    event_shape = (sum(d.num_elements for d in inputs.values()
                       if d.dtype == 'real'), )
    prec_sqrt = randn(batch_shape + event_shape + event_shape)
    precision = ops.matmul(prec_sqrt, ops.transpose(prec_sqrt, -1, -2))
    precision = precision + 0.5 * ops.new_eye(precision, event_shape[:1])
    loc = randn(batch_shape + event_shape)
    info_vec = ops.matmul(precision, ops.unsqueeze(loc, -1)).squeeze(-1)
    return Gaussian(info_vec, precision, inputs)
예제 #19
0
def eager_normal(loc, scale, value):
    assert loc.output == Real
    assert scale.output == Real
    assert value.output == Real
    if not is_affine(loc) or not is_affine(value):
        return None  # lazy

    info_vec = ops.new_zeros(scale.data, scale.data.shape + (1, ))
    precision = ops.pow(scale.data, -2).reshape(scale.data.shape + (1, 1))
    log_prob = -0.5 * math.log(2 * math.pi) - ops.log(scale).sum()
    inputs = scale.inputs.copy()
    var = gensym('value')
    inputs[var] = Real
    gaussian = log_prob + Gaussian(info_vec, precision, inputs)
    return gaussian(**{var: value - loc})
예제 #20
0
def eager_normal(loc, scale, value):
    if isinstance(loc, Variable):
        loc, value = value, loc

    inputs, (loc, scale) = align_tensors(loc, scale, expand=True)
    inputs.update(value.inputs)
    int_inputs = OrderedDict(
        (k, v) for k, v in inputs.items() if v.dtype != 'real')

    precision = scale.pow(-2)
    info_vec = (precision * loc).unsqueeze(-1)
    precision = precision.unsqueeze(-1).unsqueeze(-1)
    log_prob = -0.5 * math.log(
        2 * math.pi) - scale.log() - 0.5 * (loc * info_vec).squeeze(-1)
    return Tensor(log_prob, int_inputs) + Gaussian(info_vec, precision, inputs)
예제 #21
0
def eager_mvn(loc, scale_tril, value):
    assert len(loc.shape) == 1
    assert len(scale_tril.shape) == 2
    assert value.output == loc.output
    if not is_affine(loc) or not is_affine(value):
        return None  # lazy

    info_vec = scale_tril.data.new_zeros(scale_tril.data.shape[:-1])
    precision = ops.cholesky_inverse(scale_tril.data)
    scale_diag = Tensor(scale_tril.data.diagonal(dim1=-1, dim2=-2), scale_tril.inputs)
    log_prob = -0.5 * scale_diag.shape[0] * math.log(2 * math.pi) - scale_diag.log().sum()
    inputs = scale_tril.inputs.copy()
    var = gensym('value')
    inputs[var] = reals(scale_diag.shape[0])
    gaussian = log_prob + Gaussian(info_vec, precision, inputs)
    return gaussian(**{var: value - loc})
예제 #22
0
def test_affine_subs():
    # This was recorded from test_pyro_convert.
    x = Subs(
     Gaussian(
      torch.tensor([1.3027106523513794, 1.4167094230651855, -0.9750942587852478, 0.5321089029312134, -0.9039931297302246], dtype=torch.float32),  # noqa
      torch.tensor([[1.0199567079544067, 0.9840421676635742, -0.473368763923645, 0.34206756949424744, -0.7562517523765564], [0.9840421676635742, 1.511502742767334, -1.7593903541564941, 0.6647964119911194, -0.5119513273239136], [-0.4733688533306122, -1.7593903541564941, 3.2386727333068848, -0.9345928430557251, -0.1534711718559265], [0.34206756949424744, 0.6647964119911194, -0.9345928430557251, 0.3141004145145416, -0.12399007380008698], [-0.7562517523765564, -0.5119513273239136, -0.1534711718559265, -0.12399007380008698, 0.6450173854827881]], dtype=torch.float32),  # noqa
      (('state_1_b6',
        reals(3,),),
       ('obs_b2',
        reals(2,),),)),
     (('obs_b2',
       Contraction(ops.nullop, ops.add,
        frozenset(),
        (Variable('bias_b5', reals(2,)),
         Tensor(
          torch.tensor([-2.1787893772125244, 0.5684312582015991], dtype=torch.float32),  # noqa
          (),
          'real'),)),),))
    assert isinstance(x, (Gaussian, Contraction)), x.pretty()
예제 #23
0
def adjoint_subs_gaussianmixture_gaussianmixture(adj_redop, adj_binop, out_adj, arg, subs):

    if any(v.dtype == 'real' and not isinstance(v, Variable) for k, v in subs):
        raise NotImplementedError("TODO implement adjoint for substitution into Gaussian real variable")

    # invert renaming
    renames = tuple((v.name, k) for k, v in subs if isinstance(v, Variable))
    out_adj = Subs(out_adj, renames)

    # inverting advanced indexing
    slices = tuple((k, v) for k, v in subs if not isinstance(v, Variable))

    assert len(slices + renames) == len(subs)

    in_adj_discrete = adjoint_ops(Subs, adj_redop, adj_binop, out_adj.terms[0], arg.terms[0], subs)[arg.terms[0]]

    arg_int_inputs = OrderedDict((k, v) for k, v in arg.inputs.items() if v.dtype != 'real')
    out_adj_int_inputs = OrderedDict((k, v) for k, v in out_adj.inputs.items() if v.dtype != 'real')

    arg_real_inputs = OrderedDict((k, v) for k, v in arg.inputs.items() if v.dtype == 'real')

    align_inputs = OrderedDict((k, v) for k, v in out_adj.terms[1].inputs.items() if v.dtype != 'real')
    align_inputs.update(arg_real_inputs)
    out_adj_info_vec, out_adj_precision = align_gaussian(align_inputs, out_adj.terms[1])

    in_adj_info_vec = list(adjoint_ops(Subs, adj_redop, adj_binop,  # ops.add, ops.mul,
                                       Tensor(out_adj_info_vec, out_adj_int_inputs),
                                       Tensor(arg.terms[1].info_vec, arg_int_inputs),
                                       slices).values())[0]

    in_adj_precision = list(adjoint_ops(Subs, adj_redop, adj_binop,  # ops.add, ops.mul,
                                        Tensor(out_adj_precision, out_adj_int_inputs),
                                        Tensor(arg.terms[1].precision, arg_int_inputs),
                                        slices).values())[0]

    assert isinstance(in_adj_info_vec, Tensor)
    assert isinstance(in_adj_precision, Tensor)

    in_adj_gaussian = Gaussian(in_adj_info_vec.data, in_adj_precision.data, arg.inputs.copy())

    in_adj = in_adj_gaussian + in_adj_discrete
    return {arg: in_adj}
예제 #24
0
def eager_affine_normal(matrix, loc, scale, value_x, value_y):
    assert len(matrix.output.shape) == 2
    assert value_x.output == reals(matrix.output.shape[0])
    assert value_y.output == reals(matrix.output.shape[1])
    tensors = (matrix, loc, scale, value_y)
    int_inputs, tensors = align_tensors(*tensors)
    matrix, loc, scale, value_y = tensors

    assert value_y.size(-1) == loc.size(-1)
    prec_sqrt = matrix / scale.unsqueeze(-2)
    precision = prec_sqrt.matmul(prec_sqrt.transpose(-1, -2))
    delta = (value_y - loc) / scale
    info_vec = prec_sqrt.matmul(delta.unsqueeze(-1)).squeeze(-1)
    log_normalizer = (-0.5 * loc.size(-1) * math.log(2 * math.pi)
                      - 0.5 * delta.pow(2).sum(-1) - scale.log().sum(-1))
    precision = precision.expand(info_vec.shape + (-1,))
    log_normalizer = log_normalizer.expand(info_vec.shape[:-1])
    inputs = int_inputs.copy()
    x_name = gensym("x")
    inputs[x_name] = value_x.output
    x_dist = Tensor(log_normalizer, int_inputs) + Gaussian(info_vec, precision, inputs)
    return x_dist(**{x_name: value_x})
예제 #25
0
def mvn_to_funsor(pyro_dist, event_dims=(), real_inputs=OrderedDict()):
    """
    Convert a joint :class:`torch.distributions.MultivariateNormal`
    distribution into a :class:`~funsor.terms.Funsor` with multiple real
    inputs.

    This should satisfy::

        sum(d.num_elements for d in real_inputs.values())
          == pyro_dist.event_shape[0]

    :param torch.distributions.MultivariateNormal pyro_dist: A
        multivariate normal distribution over one or more variables
        of real or vector or tensor type.
    :param tuple event_dims: A tuple of names for rightmost dimensions.
        These will be assigned to ``result.inputs`` of type ``bint``.
    :param OrderedDict real_inputs: A dict mapping real variable name
        to appropriately sized ``reals()``. The sum of all ``.numel()``
        of all real inputs should be equal to the ``pyro_dist`` dimension.
    :return: A funsor with given ``real_inputs`` and possibly additional
        bint inputs.
    :rtype: funsor.terms.Funsor
    """
    assert isinstance(pyro_dist, torch.distributions.MultivariateNormal)
    assert isinstance(event_dims, tuple)
    assert isinstance(real_inputs, OrderedDict)
    loc = tensor_to_funsor(pyro_dist.loc, event_dims, 1)
    scale_tril = tensor_to_funsor(pyro_dist.scale_tril, event_dims, 2)
    precision = tensor_to_funsor(pyro_dist.precision_matrix, event_dims, 2)
    assert loc.inputs == scale_tril.inputs
    assert loc.inputs == precision.inputs
    info_vec = precision.data.matmul(loc.data.unsqueeze(-1)).squeeze(-1)
    log_prob = (-0.5 * loc.output.shape[0] * math.log(2 * math.pi) -
                scale_tril.data.diagonal(dim1=-1, dim2=-2).log().sum(-1) -
                0.5 * (info_vec * loc.data).sum(-1))
    inputs = loc.inputs.copy()
    inputs.update(real_inputs)
    return Tensor(log_prob, loc.inputs) + Gaussian(info_vec, precision.data,
                                                   inputs)
예제 #26
0
def eager_normal(loc, scale, value):
    affine = (loc - value) / scale
    assert isinstance(affine, Affine)
    real_inputs = OrderedDict((k, v) for k, v in affine.inputs.items() if v.dtype == 'real')
    assert not any(v.shape for v in real_inputs.values())

    tensors = [affine.const] + [c for v, c in affine.coeffs.items()]
    inputs, tensors = align_tensors(*tensors)
    tensors = torch.broadcast_tensors(*tensors)
    const, coeffs = tensors[0], tensors[1:]

    dim = sum(d.num_elements for d in real_inputs.values())
    loc = BlockVector(const.shape + (dim,))
    loc[..., 0] = -const / coeffs[0]
    precision = BlockMatrix(const.shape + (dim, dim))
    for i, (v1, c1) in enumerate(zip(real_inputs, coeffs)):
        for j, (v2, c2) in enumerate(zip(real_inputs, coeffs)):
            precision[..., i, j] = c1 * c2
    loc = loc.as_tensor()
    precision = precision.as_tensor()

    log_prob = -0.5 * math.log(2 * math.pi) - scale.log()
    return log_prob + Gaussian(loc, precision, affine.inputs)
예제 #27
0
def eager_normal(loc, scale, value):
    affine = (loc - value) / scale
    if not affine.is_affine:
        return None

    real_inputs = OrderedDict(
        (k, v) for k, v in affine.inputs.items() if v.dtype == 'real')
    int_inputs = OrderedDict(
        (k, v) for k, v in affine.inputs.items() if v.dtype != 'real')
    assert not any(v.shape for v in real_inputs.values())

    const = affine(**{k: 0. for k, v in real_inputs.items()})
    coeffs = OrderedDict()
    for c in real_inputs.keys():
        coeffs[c] = affine(
            **{k: 1. if c == k else 0.
               for k in real_inputs.keys()}) - const

    tensors = [const] + list(coeffs.values())
    inputs, tensors = align_tensors(*tensors, expand=True)
    const, coeffs = tensors[0], tensors[1:]

    dim = sum(d.num_elements for d in real_inputs.values())
    loc = BlockVector(const.shape + (dim, ))
    loc[..., 0] = -const / coeffs[0]
    precision = BlockMatrix(const.shape + (dim, dim))
    for i, (v1, c1) in enumerate(zip(real_inputs, coeffs)):
        for j, (v2, c2) in enumerate(zip(real_inputs, coeffs)):
            precision[..., i, j] = c1 * c2
    loc = loc.as_tensor()
    precision = precision.as_tensor()
    info_vec = precision.matmul(loc.unsqueeze(-1)).squeeze(-1)

    log_prob = -0.5 * math.log(
        2 * math.pi) - scale.data.log() - 0.5 * (loc * info_vec).sum(-1)
    return Tensor(log_prob, int_inputs) + Gaussian(info_vec, precision,
                                                   affine.inputs)
예제 #28
0
def test_bart(analytic_kl):
    global call_count
    call_count = 0

    with interpretation(reflect):
        q = Independent(
            Independent(
                Contraction(
                    ops.nullop,
                    ops.add,
                    frozenset(),
                    (
                        Tensor(
                            torch.tensor(
                                [[
                                    -0.6077086925506592, -1.1546266078948975,
                                    -0.7021151781082153, -0.5303535461425781,
                                    -0.6365622282028198, -1.2423288822174072,
                                    -0.9941254258155823, -0.6287292242050171
                                ],
                                 [
                                     -0.6987162828445435, -1.0875964164733887,
                                     -0.7337473630905151, -0.4713417589664459,
                                     -0.6674002408981323, -1.2478348016738892,
                                     -0.8939017057418823, -0.5238542556762695
                                 ]],
                                dtype=torch.float32),  # noqa
                            (
                                (
                                    'time_b4',
                                    bint(2),
                                ),
                                (
                                    '_event_1_b2',
                                    bint(8),
                                ),
                            ),
                            'real'),
                        Gaussian(
                            torch.tensor([
                                [[-0.3536059558391571], [-0.21779225766658783],
                                 [0.2840439975261688], [0.4531521499156952],
                                 [-0.1220812276005745], [-0.05519985035061836],
                                 [0.10932210087776184], [0.6656699776649475]],
                                [[-0.39107921719551086], [
                                    -0.20241987705230713
                                ], [0.2170514464378357], [0.4500560462474823],
                                 [0.27945515513420105], [-0.0490039587020874],
                                 [-0.06399798393249512], [0.846565842628479]]
                            ],
                                         dtype=torch.float32),  # noqa
                            torch.tensor([
                                [[[1.984686255455017]], [[0.6699360013008118]],
                                 [[1.6215802431106567]], [[2.372016668319702]],
                                 [[1.77385413646698]], [[0.526767373085022]],
                                 [[0.8722561597824097]], [[2.1879124641418457]]
                                 ],
                                [[[1.6996612548828125]], [[
                                    0.7535632252693176
                                ]], [[1.4946647882461548]],
                                 [[2.642792224884033]], [[1.7301604747772217]],
                                 [[0.5203893780708313]], [[1.055436372756958]],
                                 [[2.8370864391326904]]]
                            ],
                                         dtype=torch.float32),  # noqa
                            (
                                (
                                    'time_b4',
                                    bint(2),
                                ),
                                (
                                    '_event_1_b2',
                                    bint(8),
                                ),
                                (
                                    'value_b1',
                                    reals(),
                                ),
                            )),
                    )),
                'gate_rate_b3',
                '_event_1_b2',
                'value_b1'),
            'gate_rate_t',
            'time_b4',
            'gate_rate_b3')
        p_prior = Contraction(
            ops.logaddexp,
            ops.add,
            frozenset({'state(time=1)_b11', 'state_b10'}),
            (
                MarkovProduct(
                    ops.logaddexp,
                    ops.add,
                    Contraction(
                        ops.nullop,
                        ops.add,
                        frozenset(),
                        (
                            Tensor(
                                torch.tensor(2.7672932147979736,
                                             dtype=torch.float32), (), 'real'),
                            Gaussian(
                                torch.tensor([-0.0, -0.0, 0.0, 0.0],
                                             dtype=torch.float32),
                                torch.tensor([[
                                    98.01002502441406, 0.0, -99.0000228881836,
                                    -0.0
                                ],
                                              [
                                                  0.0, 98.01002502441406, -0.0,
                                                  -99.0000228881836
                                              ],
                                              [
                                                  -99.0000228881836, -0.0,
                                                  100.0000228881836, 0.0
                                              ],
                                              [
                                                  -0.0, -99.0000228881836, 0.0,
                                                  100.0000228881836
                                              ]],
                                             dtype=torch.float32),  # noqa
                                (
                                    (
                                        'state_b7',
                                        reals(2, ),
                                    ),
                                    (
                                        'state(time=1)_b8',
                                        reals(2, ),
                                    ),
                                )),
                            Subs(
                                AffineNormal(
                                    Tensor(
                                        torch.tensor(
                                            [[
                                                0.03488487750291824,
                                                0.07356668263673782,
                                                0.19946961104869843,
                                                0.5386509299278259,
                                                -0.708323061466217,
                                                0.24411526322364807,
                                                -0.20855577290058136,
                                                -0.2421337217092514
                                            ],
                                             [
                                                 0.41762110590934753,
                                                 0.5272183418273926,
                                                 -0.49835553765296936,
                                                 -0.0363837406039238,
                                                 -0.0005282597267068923,
                                                 0.2704298794269562,
                                                 -0.155222088098526,
                                                 -0.44802337884902954
                                             ]],
                                            dtype=torch.float32),  # noqa
                                        (),
                                        'real'),
                                    Tensor(
                                        torch.tensor(
                                            [[
                                                -0.003566693514585495,
                                                -0.2848514914512634,
                                                0.037103548645973206,
                                                0.12648648023605347,
                                                -0.18501518666744232,
                                                -0.20899859070777893,
                                                0.04121830314397812,
                                                0.0054807960987091064
                                            ],
                                             [
                                                 0.0021788496524095535,
                                                 -0.18700894713401794,
                                                 0.08187370002269745,
                                                 0.13554862141609192,
                                                 -0.10477752983570099,
                                                 -0.20848378539085388,
                                                 -0.01393645629286766,
                                                 0.011670656502246857
                                             ]],
                                            dtype=torch.float32),  # noqa
                                        ((
                                            'time_b9',
                                            bint(2),
                                        ), ),
                                        'real'),
                                    Tensor(
                                        torch.tensor(
                                            [[
                                                0.5974780917167664,
                                                0.864071786403656,
                                                1.0236268043518066,
                                                0.7147538065910339,
                                                0.7423890233039856,
                                                0.9462157487869263,
                                                1.2132389545440674,
                                                1.0596832036972046
                                            ],
                                             [
                                                 0.5787821412086487,
                                                 0.9178534150123596,
                                                 0.9074794054031372,
                                                 0.6600189208984375,
                                                 0.8473222255706787,
                                                 0.8426999449729919,
                                                 1.194266438484192,
                                                 1.0471148490905762
                                             ]],
                                            dtype=torch.float32),  # noqa
                                        ((
                                            'time_b9',
                                            bint(2),
                                        ), ),
                                        'real'),
                                    Variable('state(time=1)_b8', reals(2, )),
                                    Variable('gate_rate_b6', reals(8, ))),
                                ((
                                    'gate_rate_b6',
                                    Binary(
                                        ops.GetitemOp(0),
                                        Variable('gate_rate_t', reals(2, 8)),
                                        Variable('time_b9', bint(2))),
                                ), )),
                        )),
                    Variable('time_b9', bint(2)),
                    frozenset({('state_b7', 'state(time=1)_b8')}),
                    frozenset({('state(time=1)_b8', 'state(time=1)_b11'),
                               ('state_b7', 'state_b10')})),  # noqa
                Subs(
                    dist.MultivariateNormal(
                        Tensor(torch.tensor([0.0, 0.0], dtype=torch.float32),
                               (), 'real'),
                        Tensor(
                            torch.tensor([[10.0, 0.0], [0.0, 10.0]],
                                         dtype=torch.float32),
                            (), 'real'), Variable('value_b5', reals(2, ))), ((
                                'value_b5',
                                Variable('state_b10', reals(2, )),
                            ), )),
            ))
        p_likelihood = Contraction(
            ops.add,
            ops.nullop,
            frozenset({'time_b17', 'destin_b16', 'origin_b15'}),
            (
                Contraction(
                    ops.logaddexp,
                    ops.add,
                    frozenset({'gated_b14'}),
                    (
                        dist.Categorical(
                            Binary(
                                ops.GetitemOp(0),
                                Binary(
                                    ops.GetitemOp(0),
                                    Subs(
                                        Function(
                                            unpack_gate_rate_0, reals(2, 2, 2),
                                            (Variable('gate_rate_b12',
                                                      reals(8, )), )),
                                        ((
                                            'gate_rate_b12',
                                            Binary(
                                                ops.GetitemOp(0),
                                                Variable(
                                                    'gate_rate_t', reals(2,
                                                                         8)),
                                                Variable('time_b17', bint(2))),
                                        ), )), Variable('origin_b15',
                                                        bint(2))),
                                Variable('destin_b16', bint(2))),
                            Variable('gated_b14', bint(2))),
                        Stack(
                            'gated_b14',
                            (
                                dist.Poisson(
                                    Binary(
                                        ops.GetitemOp(0),
                                        Binary(
                                            ops.GetitemOp(0),
                                            Subs(
                                                Function(
                                                    unpack_gate_rate_1,
                                                    reals(2, 2), (Variable(
                                                        'gate_rate_b13',
                                                        reals(8, )), )),
                                                ((
                                                    'gate_rate_b13',
                                                    Binary(
                                                        ops.GetitemOp(0),
                                                        Variable(
                                                            'gate_rate_t',
                                                            reals(2, 8)),
                                                        Variable(
                                                            'time_b17',
                                                            bint(2))),
                                                ), )),
                                            Variable('origin_b15', bint(2))),
                                        Variable('destin_b16', bint(2))),
                                    Tensor(
                                        torch.tensor(
                                            [[[1.0, 1.0], [5.0, 0.0]],
                                             [[0.0, 6.0], [19.0, 3.0]]],
                                            dtype=torch.float32),  # noqa
                                        (
                                            (
                                                'time_b17',
                                                bint(2),
                                            ),
                                            (
                                                'origin_b15',
                                                bint(2),
                                            ),
                                            (
                                                'destin_b16',
                                                bint(2),
                                            ),
                                        ),
                                        'real')),
                                dist.Delta(
                                    Tensor(
                                        torch.tensor(0.0, dtype=torch.float32),
                                        (), 'real'),
                                    Tensor(
                                        torch.tensor(0.0, dtype=torch.float32),
                                        (), 'real'),
                                    Tensor(
                                        torch.tensor(
                                            [[[1.0, 1.0], [5.0, 0.0]],
                                             [[0.0, 6.0], [19.0, 3.0]]],
                                            dtype=torch.float32),  # noqa
                                        (
                                            (
                                                'time_b17',
                                                bint(2),
                                            ),
                                            (
                                                'origin_b15',
                                                bint(2),
                                            ),
                                            (
                                                'destin_b16',
                                                bint(2),
                                            ),
                                        ),
                                        'real')),
                            )),
                    )), ))

    if analytic_kl:
        exact_part = funsor.Integrate(q, p_prior - q, "gate_rate_t")
        with interpretation(monte_carlo):
            approx_part = funsor.Integrate(q, p_likelihood, "gate_rate_t")
        elbo = exact_part + approx_part
    else:
        p = p_prior + p_likelihood
        with interpretation(monte_carlo):
            elbo = Integrate(q, p - q, "gate_rate_t")

    assert isinstance(elbo, Tensor), elbo.pretty()
    assert call_count == 1
예제 #29
0
def matrix_and_mvn_to_funsor(matrix, mvn, event_dims=(), x_name="value_x", y_name="value_y"):
    """
    Convert a noisy affine function to a Gaussian. The noisy affine function is
    defined as::

        y = x @ matrix + mvn.sample()

    The result is a non-normalized Gaussian funsor with two real inputs,
    ``x_name`` and ``y_name``, corresponding to a conditional distribution of
    real vector ``y` given real vector ``x``.

    :param torch.Tensor matrix: A matrix with rightmost shape ``(x_size, y_size)``.
    :param mvn: A multivariate normal distribution with
        ``event_shape == (y_size,)``.
    :type mvn: torch.distributions.MultivariateNormal or
        torch.distributions.Independent of torch.distributions.Normal
    :param tuple event_dims: A tuple of names for rightmost dimensions.
        These will be assigned to ``result.inputs`` of type ``bint``.
    :param str x_name: The name of the ``x`` random variable.
    :param str y_name: The name of the ``y`` random variable.
    :return: A funsor with given ``real_inputs`` and possibly additional
        bint inputs.
    :rtype: funsor.terms.Funsor
    """
    assert (isinstance(mvn, torch.distributions.MultivariateNormal) or
            (isinstance(mvn, torch.distributions.Independent) and
             isinstance(mvn.base_dist, torch.distributions.Normal)))
    assert isinstance(matrix, torch.Tensor)
    x_size, y_size = matrix.shape[-2:]
    assert mvn.event_shape == (y_size,)

    # Handle diagonal normal distributions as an efficient special case.
    if isinstance(mvn, torch.distributions.Independent):
        return AffineNormal(tensor_to_funsor(matrix, event_dims, 2),
                            tensor_to_funsor(mvn.base_dist.loc, event_dims, 1),
                            tensor_to_funsor(mvn.base_dist.scale, event_dims, 1),
                            Variable(x_name, reals(x_size)),
                            Variable(y_name, reals(y_size)))

    info_vec = mvn.loc.unsqueeze(-1).cholesky_solve(mvn.scale_tril).squeeze(-1)
    log_prob = (-0.5 * y_size * math.log(2 * math.pi)
                - mvn.scale_tril.diagonal(dim1=-1, dim2=-2).log().sum(-1)
                - 0.5 * (info_vec * mvn.loc).sum(-1))

    batch_shape = broadcast_shape(matrix.shape[:-2], mvn.batch_shape)
    P_yy = mvn.precision_matrix.expand(batch_shape + (y_size, y_size))
    neg_P_xy = matrix.matmul(P_yy)
    P_xy = -neg_P_xy
    P_yx = P_xy.transpose(-1, -2)
    P_xx = neg_P_xy.matmul(matrix.transpose(-1, -2))
    precision = torch.cat([torch.cat([P_xx, P_xy], -1),
                           torch.cat([P_yx, P_yy], -1)], -2)
    info_y = info_vec.expand(batch_shape + (y_size,))
    info_x = -matrix.matmul(info_y.unsqueeze(-1)).squeeze(-1)
    info_vec = torch.cat([info_x, info_y], -1)

    info_vec = tensor_to_funsor(info_vec, event_dims, 1)
    precision = tensor_to_funsor(precision, event_dims, 2)
    inputs = info_vec.inputs.copy()
    inputs[x_name] = reals(x_size)
    inputs[y_name] = reals(y_size)
    return tensor_to_funsor(log_prob, event_dims) + Gaussian(info_vec.data, precision.data, inputs)
예제 #30
0
def eager_mvn(loc, scale_tril, value):
    assert len(loc.shape) == 1
    assert len(scale_tril.shape) == 2
    assert value.output == loc.output
    if not is_affine(loc) or not is_affine(value):
        return None  # lazy

    # Extract an affine representation.
    eye = torch.eye(scale_tril.data.size(-1)).expand(scale_tril.data.shape)
    prec_sqrt = Tensor(
        eye.triangular_solve(scale_tril.data, upper=False).solution,
        scale_tril.inputs)
    affine = prec_sqrt @ (loc - value)
    const, coeffs = extract_affine(affine)
    if not isinstance(const, Tensor):
        return None  # lazy
    if not all(isinstance(coeff, Tensor) for coeff, _ in coeffs.values()):
        return None  # lazy

    # Compute log_prob using funsors.
    scale_diag = Tensor(scale_tril.data.diagonal(dim1=-1, dim2=-2),
                        scale_tril.inputs)
    log_prob = (-0.5 * scale_diag.shape[0] * math.log(2 * math.pi) -
                scale_diag.log().sum() - 0.5 * (const**2).sum())

    # Dovetail to avoid variable name collision in einsum.
    equations1 = [
        ''.join(c if c in ',->' else chr(ord(c) * 2 - ord('a')) for c in eqn)
        for _, eqn in coeffs.values()
    ]
    equations2 = [
        ''.join(c if c in ',->' else chr(ord(c) * 2 - ord('a') + 1)
                for c in eqn) for _, eqn in coeffs.values()
    ]

    real_inputs = OrderedDict(
        (k, v) for k, v in affine.inputs.items() if v.dtype == 'real')
    assert tuple(real_inputs) == tuple(coeffs)

    # Align and broadcast tensors.
    neg_const = -const
    tensors = [neg_const] + [coeff for coeff, _ in coeffs.values()]
    inputs, tensors = align_tensors(*tensors, expand=True)
    neg_const, coeffs = tensors[0], tensors[1:]
    dim = sum(d.num_elements for d in real_inputs.values())
    batch_shape = neg_const.shape[:-1]

    info_vec = BlockVector(batch_shape + (dim, ))
    precision = BlockMatrix(batch_shape + (dim, dim))
    offset1 = 0
    for i1, (v1, c1) in enumerate(zip(real_inputs, coeffs)):
        size1 = real_inputs[v1].num_elements
        slice1 = slice(offset1, offset1 + size1)
        inputs1, output1 = equations1[i1].split('->')
        input11, input12 = inputs1.split(',')
        assert input11 == input12 + output1
        info_vec[..., slice1] = torch.einsum(
            f'...{input11},...{output1}->...{input12}', c1, neg_const) \
            .reshape(batch_shape + (size1,))
        offset2 = 0
        for i2, (v2, c2) in enumerate(zip(real_inputs, coeffs)):
            size2 = real_inputs[v2].num_elements
            slice2 = slice(offset2, offset2 + size2)
            inputs2, output2 = equations2[i2].split('->')
            input21, input22 = inputs2.split(',')
            assert input21 == input22 + output2
            precision[..., slice1, slice2] = torch.einsum(
                f'...{input11},...{input22}{output1}->...{input12}{input22}', c1, c2) \
                .reshape(batch_shape + (size1, size2))
            offset2 += size2
        offset1 += size1

    info_vec = info_vec.as_tensor()
    precision = precision.as_tensor()
    inputs.update(real_inputs)
    return log_prob + Gaussian(info_vec, precision, inputs)