Ejemplo n.º 1
0
    def _eager_subs_affine(self, subs, remaining_subs):
        # Extract an affine representation.
        affine = OrderedDict()
        for k, v in subs:
            const, coeffs = extract_affine(v)
            if (isinstance(const, Tensor) and all(
                    isinstance(coeff, Tensor)
                    for coeff, _ in coeffs.values())):
                affine[k] = const, coeffs
            else:
                remaining_subs += (k, v),
        if not affine:
            return reflect(Subs, self, remaining_subs)

        # Align integer dimensions.
        old_int_inputs = OrderedDict(
            (k, v) for k, v in self.inputs.items() if v.dtype != 'real')
        tensors = [
            Tensor(self.info_vec, old_int_inputs),
            Tensor(self.precision, old_int_inputs)
        ]
        for const, coeffs in affine.values():
            tensors.append(const)
            tensors.extend(coeff for coeff, _ in coeffs.values())
        new_int_inputs, tensors = align_tensors(*tensors, expand=True)
        tensors = (Tensor(x, new_int_inputs) for x in tensors)
        old_info_vec = next(tensors).data
        old_precision = next(tensors).data
        for old_k, (const, coeffs) in affine.items():
            const = next(tensors)
            for new_k, (coeff, eqn) in coeffs.items():
                coeff = next(tensors)
                coeffs[new_k] = coeff, eqn
            affine[old_k] = const, coeffs
        batch_shape = old_info_vec.shape[:-1]

        # Align real dimensions.
        old_real_inputs = OrderedDict(
            (k, v) for k, v in self.inputs.items() if v.dtype == 'real')
        new_real_inputs = old_real_inputs.copy()
        for old_k, (const, coeffs) in affine.items():
            del new_real_inputs[old_k]
            for new_k, (coeff, eqn) in coeffs.items():
                new_shape = coeff.shape[:len(eqn.split('->')[0].split(',')[1])]
                new_real_inputs[new_k] = reals(*new_shape)
        old_offsets, old_dim = _compute_offsets(old_real_inputs)
        new_offsets, new_dim = _compute_offsets(new_real_inputs)
        new_inputs = new_int_inputs.copy()
        new_inputs.update(new_real_inputs)

        # Construct a blockwise affine representation of the substitution.
        subs_vector = BlockVector(batch_shape + (old_dim, ))
        subs_matrix = BlockMatrix(batch_shape + (new_dim, old_dim))
        for old_k, old_offset in old_offsets.items():
            old_size = old_real_inputs[old_k].num_elements
            old_slice = slice(old_offset, old_offset + old_size)
            if old_k in new_real_inputs:
                new_offset = new_offsets[old_k]
                new_slice = slice(new_offset, new_offset + old_size)
                subs_matrix[..., new_slice, old_slice] = \
                    ops.new_eye(self.info_vec, batch_shape + (old_size,))
                continue
            const, coeffs = affine[old_k]
            old_shape = old_real_inputs[old_k].shape
            assert const.data.shape == batch_shape + old_shape
            subs_vector[..., old_slice] = const.data.reshape(batch_shape +
                                                             (old_size, ))
            for new_k, new_offset in new_offsets.items():
                if new_k in coeffs:
                    coeff, eqn = coeffs[new_k]
                    new_size = new_real_inputs[new_k].num_elements
                    new_slice = slice(new_offset, new_offset + new_size)
                    assert coeff.shape == new_real_inputs[
                        new_k].shape + old_shape
                    subs_matrix[..., new_slice, old_slice] = \
                        coeff.data.reshape(batch_shape + (new_size, old_size))
        subs_vector = subs_vector.as_tensor()
        subs_matrix = subs_matrix.as_tensor()
        subs_matrix_t = ops.transpose(subs_matrix, -1, -2)

        # Construct the new funsor. Suppose the old Gaussian funsor g has density
        #   g(x) = < x | i - 1/2 P x>
        # Now define a new funsor f by substituting x = A y + B:
        #   f(y) = g(A y + B)
        #        = < A y + B | i - 1/2 P (A y + B) >
        #        = < y | At (i - P B) - 1/2 At P A y > + < B | i - 1/2 P B >
        #        =: < y | i' - 1/2 P' y > + C
        # where  P' = At P A  and  i' = At (i - P B)  parametrize a new Gaussian
        # and  C = < B | i - 1/2 P B >  parametrize a new Tensor.
        precision = subs_matrix @ old_precision @ subs_matrix_t
        info_vec = _mv(subs_matrix,
                       old_info_vec - _mv(old_precision, subs_vector))
        const = _vv(subs_vector,
                    old_info_vec - 0.5 * _mv(old_precision, subs_vector))
        result = Gaussian(info_vec, precision, new_inputs) + Tensor(
            const, new_int_inputs)
        return Subs(result, remaining_subs) if remaining_subs else result
Ejemplo n.º 2
0
def test_cons_hash():
    x = randn((3, 3))
    assert Tensor(x) is Tensor(x)
Ejemplo n.º 3
0
 def eager_log_prob(cls, *params):
     inputs, tensors = align_tensors(*params)
     params = dict(zip(cls._ast_fields, tensors))
     value = params.pop('value')
     data = cls.dist_class(**params).log_prob(value)
     return Tensor(data, inputs)
Ejemplo n.º 4
0
def test_arange_2(start, stop, step):
    t = randn((10, 2))
    f = Tensor(t)["i"]
    actual = f(i=f.new_arange("j", start, stop, step, dtype=10))
    expected = Tensor(t[start:stop:step])["j"]
    assert_close(actual, expected)
Ejemplo n.º 5
0
def test_to_data():
    data = zeros((3, 3))
    x = Tensor(data)
    assert funsor.to_data(x) is data
Ejemplo n.º 6
0
    def initialize_params(self):

        # return a dict of per-site params as funsor.tensor.Tensors
        params = {
            "e_g": {},
            "theta_g": {},
            "eps_g": {},
            "e_i": {},
            "theta_i": {},
            "eps_i": {},
            "zi_step": {},
            "step": {},
            "angle": {},
            "zi_omega": {},
            "omega": {},
        }

        # size parameters
        N_v = self.config["sizes"]["random"]
        N_state = self.config["sizes"]["state"]

        # initialize group-level random effect parameters
        if self.config["group"]["random"] == "discrete":

            params["e_g"]["probs"] = Tensor(
                pyro.param("probs_e_g",
                           lambda: torch.randn((N_v,)).abs(),
                           constraint=constraints.simplex),
                OrderedDict(),
            )

            params["eps_g"]["theta"] = Tensor(
                pyro.param("theta_g",
                           lambda: torch.randn((N_v, N_state, N_state))),
                OrderedDict([("e_g", bint(N_v)), ("y_prev", bint(N_state))]),
            )

        elif self.config["group"]["random"] == "continuous":

            # note these are prior values, trainable versions live in guide
            params["eps_g"]["loc"] = Tensor(
                torch.zeros((N_state, N_state)),
                OrderedDict([("y_prev", bint(N_state))]),
            )

            params["eps_g"]["scale"] = Tensor(
                torch.ones((N_state, N_state)),
                OrderedDict([("y_prev", bint(N_state))]),
            )

        # initialize individual-level random effect parameters
        N_c = self.config["sizes"]["group"]
        if self.config["individual"]["random"] == "discrete":

            params["e_i"]["probs"] = Tensor(
                pyro.param("probs_e_i",
                           lambda: torch.randn((N_c, N_v,)).abs(),
                           constraint=constraints.simplex),
                OrderedDict([("g", bint(N_c))]),  # different value per group
            )

            params["eps_i"]["theta"] = Tensor(
                pyro.param("theta_i",
                           lambda: torch.randn((N_c, N_v, N_state, N_state))),
                OrderedDict([("g", bint(N_c)), ("e_i", bint(N_v)), ("y_prev", bint(N_state))]),
            )

        elif self.config["individual"]["random"] == "continuous":

            params["eps_i"]["loc"] = Tensor(
                torch.zeros((N_c, N_state, N_state)),
                OrderedDict([("g", bint(N_c)), ("y_prev", bint(N_state))]),
            )

            params["eps_i"]["scale"] = Tensor(
                torch.ones((N_c, N_state, N_state)),
                OrderedDict([("g", bint(N_c)), ("y_prev", bint(N_state))]),
            )

        # initialize likelihood parameters
        # observation 1: step size (step ~ Gamma)
        params["zi_step"]["zi_param"] = Tensor(
            pyro.param("step_zi_param",
                       lambda: torch.ones((N_state, 2)),
                       constraint=constraints.simplex),
            OrderedDict([("y_curr", bint(N_state))]),
        )

        params["step"]["concentration"] = Tensor(
            pyro.param("step_param_concentration",
                       lambda: torch.randn((N_state,)).abs(),
                       constraint=constraints.positive),
            OrderedDict([("y_curr", bint(N_state))]),
        )

        params["step"]["rate"] = Tensor(
            pyro.param("step_param_rate",
                       lambda: torch.randn((N_state,)).abs(),
                       constraint=constraints.positive),
            OrderedDict([("y_curr", bint(N_state))]),
        )

        # observation 2: step angle (angle ~ VonMises)
        params["angle"]["concentration"] = Tensor(
            pyro.param("angle_param_concentration",
                       lambda: torch.randn((N_state,)).abs(),
                       constraint=constraints.positive),
            OrderedDict([("y_curr", bint(N_state))]),
        )

        params["angle"]["loc"] = Tensor(
            pyro.param("angle_param_loc",
                       lambda: torch.randn((N_state,)).abs()),
            OrderedDict([("y_curr", bint(N_state))]),
        )

        # observation 3: dive activity (omega ~ Beta)
        params["zi_omega"]["zi_param"] = Tensor(
            pyro.param("omega_zi_param",
                       lambda: torch.ones((N_state, 2)),
                       constraint=constraints.simplex),
            OrderedDict([("y_curr", bint(N_state))]),
        )

        params["omega"]["concentration0"] = Tensor(
            pyro.param("omega_param_concentration0",
                       lambda: torch.randn((N_state,)).abs(),
                       constraint=constraints.positive),
            OrderedDict([("y_curr", bint(N_state))]),
        )

        params["omega"]["concentration1"] = Tensor(
            pyro.param("omega_param_concentration1",
                       lambda: torch.randn((N_state,)).abs(),
                       constraint=constraints.positive),
            OrderedDict([("y_curr", bint(N_state))]),
        )

        self.params = params
        return self.params
Ejemplo n.º 7
0
def test_affine_subs():
    # This was recorded from test_pyro_convert.
    x = Subs(
        Gaussian(
            torch.tensor([
                1.3027106523513794, 1.4167094230651855, -0.9750942587852478,
                0.5321089029312134, -0.9039931297302246
            ],
                         dtype=torch.float32),  # noqa
            torch.tensor([[
                1.0199567079544067, 0.9840421676635742, -0.473368763923645,
                0.34206756949424744, -0.7562517523765564
            ],
                          [
                              0.9840421676635742, 1.511502742767334,
                              -1.7593903541564941, 0.6647964119911194,
                              -0.5119513273239136
                          ],
                          [
                              -0.4733688533306122, -1.7593903541564941,
                              3.2386727333068848, -0.9345928430557251,
                              -0.1534711718559265
                          ],
                          [
                              0.34206756949424744, 0.6647964119911194,
                              -0.9345928430557251, 0.3141004145145416,
                              -0.12399007380008698
                          ],
                          [
                              -0.7562517523765564, -0.5119513273239136,
                              -0.1534711718559265, -0.12399007380008698,
                              0.6450173854827881
                          ]],
                         dtype=torch.float32),  # noqa
            (
                (
                    'state_1_b6',
                    reals(3, ),
                ),
                (
                    'obs_b2',
                    reals(2, ),
                ),
            )),
        (
            (
                'obs_b2',
                Contraction(
                    ops.nullop,
                    ops.add,
                    frozenset(),
                    (
                        Variable('bias_b5', reals(2, )),
                        Tensor(
                            torch.tensor(
                                [-2.1787893772125244, 0.5684312582015991],
                                dtype=torch.float32),  # noqa
                            (),
                            'real'),
                    )),
            ), ))
    assert isinstance(x, (Gaussian, Contraction)), x.pretty()
Ejemplo n.º 8
0
def _check_sample(funsor_dist,
                  sample_inputs,
                  inputs,
                  atol=1e-2,
                  rtol=None,
                  num_samples=100000,
                  statistic="mean",
                  skip_grad=False):
    """utility that compares a Monte Carlo estimate of a distribution mean with the true mean"""
    samples_per_dim = int(num_samples**(1. / max(1, len(sample_inputs))))
    sample_inputs = OrderedDict(
        (k, bint(samples_per_dim)) for k in sample_inputs)

    for tensor in list(funsor_dist.params.values())[:-1]:
        tensor.data.requires_grad_()

    sample_value = funsor_dist.sample(frozenset(['value']), sample_inputs)
    expected_inputs = OrderedDict(
        tuple(sample_inputs.items()) + tuple(inputs.items()) +
        (('value', funsor_dist.inputs['value']), ))
    check_funsor(sample_value, expected_inputs, reals())

    if sample_inputs:

        actual_mean = Integrate(sample_value,
                                Variable('value', funsor_dist.inputs['value']),
                                frozenset(['value'
                                           ])).reduce(ops.add,
                                                      frozenset(sample_inputs))

        inputs, tensors = align_tensors(
            *list(funsor_dist.params.values())[:-1])
        raw_dist = funsor_dist.dist_class(
            **dict(zip(funsor_dist._ast_fields[:-1], tensors)))
        expected_mean = Tensor(raw_dist.mean, inputs)

        check_funsor(actual_mean, expected_mean.inputs, expected_mean.output)
        assert_close(actual_mean, expected_mean, atol=atol, rtol=rtol)

    if sample_inputs and not skip_grad:
        if statistic == "mean":
            actual_stat, expected_stat = actual_mean, expected_mean
        elif statistic == "variance":
            actual_stat = Integrate(
                sample_value, (Variable('value', funsor_dist.inputs['value']) -
                               actual_mean)**2,
                frozenset(['value'])).reduce(ops.add, frozenset(sample_inputs))
            expected_stat = Tensor(raw_dist.variance, inputs)
        elif statistic == "entropy":
            actual_stat = -Integrate(sample_value, funsor_dist,
                                     frozenset(['value'])).reduce(
                                         ops.add, frozenset(sample_inputs))
            expected_stat = Tensor(raw_dist.entropy(), inputs)
        else:
            raise ValueError("invalid test statistic")

        grad_targets = [v.data for v in list(funsor_dist.params.values())[:-1]]
        actual_grads = torch.autograd.grad(actual_stat.reduce(
            ops.add).sum().data,
                                           grad_targets,
                                           allow_unused=True)
        expected_grads = torch.autograd.grad(expected_stat.reduce(
            ops.add).sum().data,
                                             grad_targets,
                                             allow_unused=True)

        assert_close(actual_stat, expected_stat, atol=atol, rtol=rtol)

        for actual_grad, expected_grad in zip(actual_grads, expected_grads):
            if expected_grad is not None:
                assert_close(actual_grad, expected_grad, atol=atol, rtol=rtol)
            else:
                assert_close(actual_grad,
                             torch.zeros_like(actual_grad),
                             atol=atol,
                             rtol=rtol)
Ejemplo n.º 9
0
def test_transform_exp(shape):
    point = Tensor(ops.abs(randn(shape)))
    x = Variable('x', reals(*shape))
    actual = Delta('y', point)(y=ops.exp(x))
    expected = Delta('x', point.log(), point.log().sum())
    assert_close(actual, expected)
Ejemplo n.º 10
0
    def forward(self, observations, add_bias=True):
        obs_dim = 2 * self.num_sensors
        bias_scale = self.log_bias_scale.exp()
        obs_noise = self.log_obs_noise.exp()
        trans_noise = self.log_trans_noise.exp()

        # bias distribution
        bias = Variable('bias', reals(obs_dim))
        assert not torch.isnan(bias_scale), "bias scales was nan"
        bias_dist = dist_to_funsor(
            dist.MultivariateNormal(
                torch.zeros(obs_dim),
                scale_tril=bias_scale *
                torch.eye(2 * self.num_sensors)))(value=bias)

        init_dist = torch.distributions.MultivariateNormal(torch.zeros(4),
                                                           scale_tril=100. *
                                                           torch.eye(4))
        self.init = dist_to_funsor(init_dist)(value="state")

        # hidden states
        prev = Variable("prev", reals(4))
        curr = Variable("curr", reals(4))
        self.trans_dist = f_dist.MultivariateNormal(
            loc=prev @ NCV_TRANSITION_MATRIX,
            scale_tril=trans_noise * NCV_PROCESS_NOISE.cholesky(),
            value=curr)

        state = Variable('state', reals(4))
        obs = Variable("obs", reals(obs_dim))
        observation_matrix = Tensor(
            torch.eye(4,
                      2).unsqueeze(-1).expand(-1, -1,
                                              self.num_sensors).reshape(4, -1))
        assert observation_matrix.output.shape == (
            4, obs_dim), observation_matrix.output.shape
        obs_loc = state @ observation_matrix
        if add_bias:
            obs_loc += bias
        self.observation_dist = f_dist.MultivariateNormal(
            loc=obs_loc, scale_tril=obs_noise * torch.eye(obs_dim), value=obs)

        logp = bias_dist
        curr = "state_init"
        logp += self.init(state=curr)
        for t, x in enumerate(observations):
            prev, curr = curr, f"state_{t}"
            logp += self.trans_dist(prev=prev, curr=curr)
            logp += self.observation_dist(state=curr, obs=x)
            # marginalize out previous state
            logp = logp.reduce(ops.logaddexp, prev)
        # marginalize out bias variable
        logp = logp.reduce(ops.logaddexp, "bias")

        # save posterior over the final state
        assert set(logp.inputs) == {f'state_{len(observations) - 1}'}
        posterior = funsor_to_mvn(logp, ndims=0)

        # marginalize out remaining variables
        logp = logp.reduce(ops.logaddexp)
        assert isinstance(logp, Tensor) and logp.shape == (), logp.pretty()
        return logp.data, posterior
Ejemplo n.º 11
0
def test_delta_delta():
    v = Variable('v', reals(2))
    point = Tensor(torch.randn(2))
    log_density = Tensor(torch.tensor(0.5))
    d = dist.Delta(point, log_density, v)
    assert d is Delta('v', point, log_density)
Ejemplo n.º 12
0
def test_delta_delta():
    v = Variable('v', Reals[2])
    point = Tensor(randn(2))
    log_density = Tensor(numeric_array(0.5))
    d = dist.Delta(point, log_density, v)
    assert d is Delta('v', point, log_density)
Ejemplo n.º 13
0
def adjoint_subs_gaussianmixture_gaussianmixture(adj_redop, adj_binop, out_adj,
                                                 arg, subs):

    if any(v.dtype == 'real' and not isinstance(v, Variable) for k, v in subs):
        raise NotImplementedError(
            "TODO implement adjoint for substitution into Gaussian real variable"
        )

    # invert renaming
    renames = tuple((v.name, k) for k, v in subs if isinstance(v, Variable))
    out_adj = Subs(out_adj, renames)

    # inverting advanced indexing
    slices = tuple((k, v) for k, v in subs if not isinstance(v, Variable))

    assert len(slices + renames) == len(subs)

    in_adj_discrete = adjoint_ops(Subs, adj_redop, adj_binop, out_adj.terms[0],
                                  arg.terms[0], subs)[arg.terms[0]]

    arg_int_inputs = OrderedDict(
        (k, v) for k, v in arg.inputs.items() if v.dtype != 'real')
    out_adj_int_inputs = OrderedDict(
        (k, v) for k, v in out_adj.inputs.items() if v.dtype != 'real')

    arg_real_inputs = OrderedDict(
        (k, v) for k, v in arg.inputs.items() if v.dtype == 'real')

    align_inputs = OrderedDict((k, v)
                               for k, v in out_adj.terms[1].inputs.items()
                               if v.dtype != 'real')
    align_inputs.update(arg_real_inputs)
    out_adj_info_vec, out_adj_precision = align_gaussian(
        align_inputs, out_adj.terms[1])

    in_adj_info_vec = list(
        adjoint_ops(
            Subs,
            adj_redop,
            adj_binop,  # ops.add, ops.mul,
            Tensor(out_adj_info_vec, out_adj_int_inputs),
            Tensor(arg.terms[1].info_vec, arg_int_inputs),
            slices).values())[0]

    in_adj_precision = list(
        adjoint_ops(
            Subs,
            adj_redop,
            adj_binop,  # ops.add, ops.mul,
            Tensor(out_adj_precision, out_adj_int_inputs),
            Tensor(arg.terms[1].precision, arg_int_inputs),
            slices).values())[0]

    assert isinstance(in_adj_info_vec, Tensor)
    assert isinstance(in_adj_precision, Tensor)

    in_adj_gaussian = Gaussian(in_adj_info_vec.data, in_adj_precision.data,
                               arg.inputs.copy())

    in_adj = in_adj_gaussian + in_adj_discrete
    return {arg: in_adj}
Ejemplo n.º 14
0
def _scatter(src, res, subs):
    # inverse of advanced indexing
    # TODO check types of subs, in case some logic from eager_subs was accidentally left out?

    # use advanced indexing logic copied from Tensor.eager_subs:

    # materialize after checking for renaming case
    subs = OrderedDict((k, res.materialize(v)) for k, v in subs)

    # Compute result shapes.
    inputs = OrderedDict()
    for k, domain in res.inputs.items():
        inputs[k] = domain

    # Construct a dict with each input's positional dim,
    # counting from the right so as to support broadcasting.
    total_size = len(inputs) + len(
        res.output.shape)  # Assumes only scalar indices.
    new_dims = {}
    for k, domain in inputs.items():
        assert not domain.shape
        new_dims[k] = len(new_dims) - total_size

    # Use advanced indexing to construct a simultaneous substitution.
    index = []
    for k, domain in res.inputs.items():
        if k in subs:
            v = subs.get(k)
            if isinstance(v, Number):
                index.append(int(v.data))
            else:
                # Permute and expand v.data to end up at new_dims.
                assert isinstance(v, Tensor)
                v = v.align(tuple(k2 for k2 in inputs if k2 in v.inputs))
                assert isinstance(v, Tensor)
                v_shape = [1] * total_size
                for k2, size in zip(v.inputs, v.data.shape):
                    v_shape[new_dims[k2]] = size
                index.append(v.data.reshape(tuple(v_shape)))
        else:
            # Construct a [:] slice for this preserved input.
            offset_from_right = -1 - new_dims[k]
            index.append(
                ops.new_arange(
                    res.data,
                    domain.dtype).reshape((-1, ) + (1, ) * offset_from_right))

    # Construct a [:] slice for the output.
    for i, size in enumerate(res.output.shape):
        offset_from_right = len(res.output.shape) - i - 1
        index.append(
            ops.new_arange(res.data,
                           size).reshape((-1, ) + (1, ) * offset_from_right))

    # the only difference from Tensor.eager_subs is here:
    # instead of indexing the rhs (lhs = rhs[index]), we index the lhs (lhs[index] = rhs)

    # unsqueeze to make broadcasting work
    src_inputs, src_data = src.inputs.copy(), src.data
    for k, v in res.inputs.items():
        if k not in src.inputs and isinstance(subs[k], Number):
            src_inputs[k] = bint(1)
            src_data = src_data.unsqueeze(-1 - len(src.output.shape))
    src = Tensor(src_data, src_inputs,
                 src.output.dtype).align(tuple(res.inputs.keys()))

    data = res.data
    data[tuple(index)] = src.data
    return Tensor(data, inputs, res.dtype)
Ejemplo n.º 15
0
    def eager_reduce(self, op, reduced_vars):
        if op is ops.logaddexp:
            # Marginalize out real variables, but keep mixtures lazy.
            assert all(v in self.inputs for v in reduced_vars)
            real_vars = frozenset(k for k, d in self.inputs.items()
                                  if d.dtype == "real")
            reduced_reals = reduced_vars & real_vars
            reduced_ints = reduced_vars - real_vars
            if not reduced_reals:
                return None  # defer to default implementation

            inputs = OrderedDict((k, d) for k, d in self.inputs.items()
                                 if k not in reduced_reals)
            if reduced_reals == real_vars:
                result = self.log_normalizer
            else:
                int_inputs = OrderedDict(
                    (k, v) for k, v in inputs.items() if v.dtype != 'real')
                offsets, _ = _compute_offsets(self.inputs)
                a = []
                b = []
                for key, domain in self.inputs.items():
                    if domain.dtype == 'real':
                        block = ops.new_arange(
                            self.info_vec, offsets[key],
                            offsets[key] + domain.num_elements, 1)
                        (b if key in reduced_vars else a).append(block)
                a = ops.cat(-1, *a)
                b = ops.cat(-1, *b)
                prec_aa = self.precision[..., a[..., None], a]
                prec_ba = self.precision[..., b[..., None], a]
                prec_bb = self.precision[..., b[..., None], b]
                prec_b = ops.cholesky(prec_bb)
                prec_a = ops.triangular_solve(prec_ba, prec_b)
                prec_at = ops.transpose(prec_a, -1, -2)
                precision = prec_aa - ops.matmul(prec_at, prec_a)

                info_a = self.info_vec[..., a]
                info_b = self.info_vec[..., b]
                b_tmp = ops.triangular_solve(info_b[..., None], prec_b)
                info_vec = info_a - ops.matmul(prec_at, b_tmp)[..., 0]

                log_prob = Tensor(
                    0.5 * len(b) * math.log(2 * math.pi) -
                    _log_det_tri(prec_b) + 0.5 * (b_tmp[..., 0]**2).sum(-1),
                    int_inputs)
                result = log_prob + Gaussian(info_vec, precision, inputs)

            return result.reduce(ops.logaddexp, reduced_ints)

        elif op is ops.add:
            for v in reduced_vars:
                if self.inputs[v].dtype == 'real':
                    raise ValueError(
                        "Cannot sum along a real dimension: {}".format(
                            repr(v)))

            # Fuse Gaussians along a plate. Compare to eager_add_gaussian_gaussian().
            old_ints = OrderedDict(
                (k, v) for k, v in self.inputs.items() if v.dtype != 'real')
            new_ints = OrderedDict(
                (k, v) for k, v in old_ints.items() if k not in reduced_vars)
            inputs = OrderedDict((k, v) for k, v in self.inputs.items()
                                 if k not in reduced_vars)

            info_vec = Tensor(self.info_vec,
                              old_ints).reduce(ops.add, reduced_vars)
            precision = Tensor(self.precision,
                               old_ints).reduce(ops.add, reduced_vars)
            assert info_vec.inputs == new_ints
            assert precision.inputs == new_ints
            return Gaussian(info_vec.data, precision.data, inputs)

        return None  # defer to default implementation
Ejemplo n.º 16
0
def test_eager_subs_variable():
    v = Variable('v', Reals[3])
    point = Tensor(randn(3))
    d = Delta('foo', v)
    assert d(v=point) is Delta('foo', point)
Ejemplo n.º 17
0
def test_bart(analytic_kl):
    global call_count
    call_count = 0

    with interpretation(reflect):
        q = Independent(
            Independent(
                Contraction(
                    ops.nullop,
                    ops.add,
                    frozenset(),
                    (
                        Tensor(
                            torch.tensor(
                                [[
                                    -0.6077086925506592, -1.1546266078948975,
                                    -0.7021151781082153, -0.5303535461425781,
                                    -0.6365622282028198, -1.2423288822174072,
                                    -0.9941254258155823, -0.6287292242050171
                                ],
                                 [
                                     -0.6987162828445435, -1.0875964164733887,
                                     -0.7337473630905151, -0.4713417589664459,
                                     -0.6674002408981323, -1.2478348016738892,
                                     -0.8939017057418823, -0.5238542556762695
                                 ]],
                                dtype=torch.float32),  # noqa
                            (
                                (
                                    'time_b4',
                                    bint(2),
                                ),
                                (
                                    '_event_1_b2',
                                    bint(8),
                                ),
                            ),
                            'real'),
                        Gaussian(
                            torch.tensor([
                                [[-0.3536059558391571], [-0.21779225766658783],
                                 [0.2840439975261688], [0.4531521499156952],
                                 [-0.1220812276005745], [-0.05519985035061836],
                                 [0.10932210087776184], [0.6656699776649475]],
                                [[-0.39107921719551086], [
                                    -0.20241987705230713
                                ], [0.2170514464378357], [0.4500560462474823],
                                 [0.27945515513420105], [-0.0490039587020874],
                                 [-0.06399798393249512], [0.846565842628479]]
                            ],
                                         dtype=torch.float32),  # noqa
                            torch.tensor([
                                [[[1.984686255455017]], [[0.6699360013008118]],
                                 [[1.6215802431106567]], [[2.372016668319702]],
                                 [[1.77385413646698]], [[0.526767373085022]],
                                 [[0.8722561597824097]], [[2.1879124641418457]]
                                 ],
                                [[[1.6996612548828125]], [[
                                    0.7535632252693176
                                ]], [[1.4946647882461548]],
                                 [[2.642792224884033]], [[1.7301604747772217]],
                                 [[0.5203893780708313]], [[1.055436372756958]],
                                 [[2.8370864391326904]]]
                            ],
                                         dtype=torch.float32),  # noqa
                            (
                                (
                                    'time_b4',
                                    bint(2),
                                ),
                                (
                                    '_event_1_b2',
                                    bint(8),
                                ),
                                (
                                    'value_b1',
                                    reals(),
                                ),
                            )),
                    )),
                'gate_rate_b3',
                '_event_1_b2',
                'value_b1'),
            'gate_rate_t',
            'time_b4',
            'gate_rate_b3')
        p_prior = Contraction(
            ops.logaddexp,
            ops.add,
            frozenset({'state(time=1)_b11', 'state_b10'}),
            (
                MarkovProduct(
                    ops.logaddexp,
                    ops.add,
                    Contraction(
                        ops.nullop,
                        ops.add,
                        frozenset(),
                        (
                            Tensor(
                                torch.tensor(2.7672932147979736,
                                             dtype=torch.float32), (), 'real'),
                            Gaussian(
                                torch.tensor([-0.0, -0.0, 0.0, 0.0],
                                             dtype=torch.float32),
                                torch.tensor([[
                                    98.01002502441406, 0.0, -99.0000228881836,
                                    -0.0
                                ],
                                              [
                                                  0.0, 98.01002502441406, -0.0,
                                                  -99.0000228881836
                                              ],
                                              [
                                                  -99.0000228881836, -0.0,
                                                  100.0000228881836, 0.0
                                              ],
                                              [
                                                  -0.0, -99.0000228881836, 0.0,
                                                  100.0000228881836
                                              ]],
                                             dtype=torch.float32),  # noqa
                                (
                                    (
                                        'state_b7',
                                        reals(2, ),
                                    ),
                                    (
                                        'state(time=1)_b8',
                                        reals(2, ),
                                    ),
                                )),
                            Subs(
                                AffineNormal(
                                    Tensor(
                                        torch.tensor(
                                            [[
                                                0.03488487750291824,
                                                0.07356668263673782,
                                                0.19946961104869843,
                                                0.5386509299278259,
                                                -0.708323061466217,
                                                0.24411526322364807,
                                                -0.20855577290058136,
                                                -0.2421337217092514
                                            ],
                                             [
                                                 0.41762110590934753,
                                                 0.5272183418273926,
                                                 -0.49835553765296936,
                                                 -0.0363837406039238,
                                                 -0.0005282597267068923,
                                                 0.2704298794269562,
                                                 -0.155222088098526,
                                                 -0.44802337884902954
                                             ]],
                                            dtype=torch.float32),  # noqa
                                        (),
                                        'real'),
                                    Tensor(
                                        torch.tensor(
                                            [[
                                                -0.003566693514585495,
                                                -0.2848514914512634,
                                                0.037103548645973206,
                                                0.12648648023605347,
                                                -0.18501518666744232,
                                                -0.20899859070777893,
                                                0.04121830314397812,
                                                0.0054807960987091064
                                            ],
                                             [
                                                 0.0021788496524095535,
                                                 -0.18700894713401794,
                                                 0.08187370002269745,
                                                 0.13554862141609192,
                                                 -0.10477752983570099,
                                                 -0.20848378539085388,
                                                 -0.01393645629286766,
                                                 0.011670656502246857
                                             ]],
                                            dtype=torch.float32),  # noqa
                                        ((
                                            'time_b9',
                                            bint(2),
                                        ), ),
                                        'real'),
                                    Tensor(
                                        torch.tensor(
                                            [[
                                                0.5974780917167664,
                                                0.864071786403656,
                                                1.0236268043518066,
                                                0.7147538065910339,
                                                0.7423890233039856,
                                                0.9462157487869263,
                                                1.2132389545440674,
                                                1.0596832036972046
                                            ],
                                             [
                                                 0.5787821412086487,
                                                 0.9178534150123596,
                                                 0.9074794054031372,
                                                 0.6600189208984375,
                                                 0.8473222255706787,
                                                 0.8426999449729919,
                                                 1.194266438484192,
                                                 1.0471148490905762
                                             ]],
                                            dtype=torch.float32),  # noqa
                                        ((
                                            'time_b9',
                                            bint(2),
                                        ), ),
                                        'real'),
                                    Variable('state(time=1)_b8', reals(2, )),
                                    Variable('gate_rate_b6', reals(8, ))),
                                ((
                                    'gate_rate_b6',
                                    Binary(
                                        ops.GetitemOp(0),
                                        Variable('gate_rate_t', reals(2, 8)),
                                        Variable('time_b9', bint(2))),
                                ), )),
                        )),
                    Variable('time_b9', bint(2)),
                    frozenset({('state_b7', 'state(time=1)_b8')}),
                    frozenset({('state(time=1)_b8', 'state(time=1)_b11'),
                               ('state_b7', 'state_b10')})),  # noqa
                Subs(
                    dist.MultivariateNormal(
                        Tensor(torch.tensor([0.0, 0.0], dtype=torch.float32),
                               (), 'real'),
                        Tensor(
                            torch.tensor([[10.0, 0.0], [0.0, 10.0]],
                                         dtype=torch.float32),
                            (), 'real'), Variable('value_b5', reals(2, ))), ((
                                'value_b5',
                                Variable('state_b10', reals(2, )),
                            ), )),
            ))
        p_likelihood = Contraction(
            ops.add,
            ops.nullop,
            frozenset({'time_b17', 'destin_b16', 'origin_b15'}),
            (
                Contraction(
                    ops.logaddexp,
                    ops.add,
                    frozenset({'gated_b14'}),
                    (
                        dist.Categorical(
                            Binary(
                                ops.GetitemOp(0),
                                Binary(
                                    ops.GetitemOp(0),
                                    Subs(
                                        Function(
                                            unpack_gate_rate_0, reals(2, 2, 2),
                                            (Variable('gate_rate_b12',
                                                      reals(8, )), )),
                                        ((
                                            'gate_rate_b12',
                                            Binary(
                                                ops.GetitemOp(0),
                                                Variable(
                                                    'gate_rate_t', reals(2,
                                                                         8)),
                                                Variable('time_b17', bint(2))),
                                        ), )), Variable('origin_b15',
                                                        bint(2))),
                                Variable('destin_b16', bint(2))),
                            Variable('gated_b14', bint(2))),
                        Stack(
                            'gated_b14',
                            (
                                dist.Poisson(
                                    Binary(
                                        ops.GetitemOp(0),
                                        Binary(
                                            ops.GetitemOp(0),
                                            Subs(
                                                Function(
                                                    unpack_gate_rate_1,
                                                    reals(2, 2), (Variable(
                                                        'gate_rate_b13',
                                                        reals(8, )), )),
                                                ((
                                                    'gate_rate_b13',
                                                    Binary(
                                                        ops.GetitemOp(0),
                                                        Variable(
                                                            'gate_rate_t',
                                                            reals(2, 8)),
                                                        Variable(
                                                            'time_b17',
                                                            bint(2))),
                                                ), )),
                                            Variable('origin_b15', bint(2))),
                                        Variable('destin_b16', bint(2))),
                                    Tensor(
                                        torch.tensor(
                                            [[[1.0, 1.0], [5.0, 0.0]],
                                             [[0.0, 6.0], [19.0, 3.0]]],
                                            dtype=torch.float32),  # noqa
                                        (
                                            (
                                                'time_b17',
                                                bint(2),
                                            ),
                                            (
                                                'origin_b15',
                                                bint(2),
                                            ),
                                            (
                                                'destin_b16',
                                                bint(2),
                                            ),
                                        ),
                                        'real')),
                                dist.Delta(
                                    Tensor(
                                        torch.tensor(0.0, dtype=torch.float32),
                                        (), 'real'),
                                    Tensor(
                                        torch.tensor(0.0, dtype=torch.float32),
                                        (), 'real'),
                                    Tensor(
                                        torch.tensor(
                                            [[[1.0, 1.0], [5.0, 0.0]],
                                             [[0.0, 6.0], [19.0, 3.0]]],
                                            dtype=torch.float32),  # noqa
                                        (
                                            (
                                                'time_b17',
                                                bint(2),
                                            ),
                                            (
                                                'origin_b15',
                                                bint(2),
                                            ),
                                            (
                                                'destin_b16',
                                                bint(2),
                                            ),
                                        ),
                                        'real')),
                            )),
                    )), ))

    if analytic_kl:
        exact_part = funsor.Integrate(q, p_prior - q, "gate_rate_t")
        with interpretation(monte_carlo):
            approx_part = funsor.Integrate(q, p_likelihood, "gate_rate_t")
        elbo = exact_part + approx_part
    else:
        p = p_prior + p_likelihood
        with interpretation(monte_carlo):
            elbo = Integrate(q, p - q, "gate_rate_t")

    assert isinstance(elbo, Tensor), elbo.pretty()
    assert call_count == 1
Ejemplo n.º 18
0
def test_eager_subs_ground(log_density):
    point1 = Tensor(randn(3))
    point2 = Tensor(randn(3))
    d = Delta('foo', point1, log_density)
    check_funsor(d(foo=point1), {}, Real, numeric_array(float(log_density)))
    check_funsor(d(foo=point2), {}, Real, numeric_array(float('-inf')))
Ejemplo n.º 19
0
    def __call__(self):

        # calls pyro.param so that params are exposed and constraints applied
        # should not create any new torch.Tensors after __init__
        self.initialize_params()

        N_state = self.config["sizes"]["state"]

        # initialize gamma to uniform
        gamma = Tensor(
            torch.zeros((N_state, N_state)),
            OrderedDict([("y_prev", bint(N_state))]),
        )

        N_v = self.config["sizes"]["random"]
        N_c = self.config["sizes"]["group"]
        log_prob = []

        plate_g = Tensor(torch.zeros(N_c), OrderedDict([("g", bint(N_c))]))

        # group-level random effects
        if self.config["group"]["random"] == "discrete":
            # group-level discrete effect
            e_g = Variable("e_g", bint(N_v))
            e_g_dist = plate_g + dist.Categorical(**self.params["e_g"])(value=e_g)

            log_prob.append(e_g_dist)

            eps_g = (plate_g + self.params["eps_g"]["theta"])(e_g=e_g)

        elif self.config["group"]["random"] == "continuous":
            eps_g = Variable("eps_g", reals(N_state))
            eps_g_dist = plate_g + dist.Normal(**self.params["eps_g"])(value=eps_g)

            log_prob.append(eps_g_dist)
        else:
            eps_g = to_funsor(0.)

        N_s = self.config["sizes"]["individual"]

        plate_i = Tensor(torch.zeros(N_s), OrderedDict([("i", bint(N_s))]))
        # individual-level random effects
        if self.config["individual"]["random"] == "discrete":
            # individual-level discrete effect
            e_i = Variable("e_i", bint(N_v))
            e_i_dist = plate_g + plate_i + dist.Categorical(
                **self.params["e_i"]
            )(value=e_i) * self.raggedness_masks["individual"](t=0)

            log_prob.append(e_i_dist)

            eps_i = (plate_i + plate_g + self.params["eps_i"]["theta"](e_i=e_i))

        elif self.config["individual"]["random"] == "continuous":
            eps_i = Variable("eps_i", reals(N_state))
            eps_i_dist = plate_g + plate_i + dist.Normal(**self.params["eps_i"])(value=eps_i)

            log_prob.append(eps_i_dist)
        else:
            eps_i = to_funsor(0.)

        # add group-level and individual-level random effects to gamma
        gamma = gamma + eps_g + eps_i

        N_state = self.config["sizes"]["state"]

        # we've accounted for all effects, now actually compute gamma_y
        gamma_y = gamma(y_prev="y(t=1)")

        y = Variable("y", bint(N_state))
        y_dist = plate_g + plate_i + dist.Categorical(
            probs=gamma_y.exp() / gamma_y.exp().sum()
        )(value=y)

        # observation 1: step size
        step_dist = plate_g + plate_i + dist.Gamma(
            **{k: v(y_curr=y) for k, v in self.params["step"].items()}
        )(value=self.observations["step"])

        # step size zero-inflation
        if self.config["zeroinflation"]:
            step_zi = dist.Categorical(probs=self.params["zi_step"]["zi_param"](y_curr=y))(
                value="zi_step")
            step_zi_dist = plate_g + plate_i + dist.Delta(self.config["MISSING"], 0.)(
                value=self.observations["step"])
            step_dist = (step_zi + Stack("zi_step", (step_dist, step_zi_dist))).reduce(ops.logaddexp, "zi_step")

        # observation 2: step angle
        angle_dist = plate_g + plate_i + dist.VonMises(
            **{k: v(y_curr=y) for k, v in self.params["angle"].items()}
        )(value=self.observations["angle"])

        # observation 3: dive activity
        omega_dist = plate_g + plate_i + dist.Beta(
            **{k: v(y_curr=y) for k, v in self.params["omega"].items()}
        )(value=self.observations["omega"])

        # dive activity zero-inflation
        if self.config["zeroinflation"]:
            omega_zi = dist.Categorical(probs=self.params["zi_omega"]["zi_param"](y_curr=y))(
                value="zi_omega")
            omega_zi_dist = plate_g + plate_i + dist.Delta(self.config["MISSING"], 0.)(
                value=self.observations["omega"])
            omega_dist = (omega_zi + Stack("zi_omega", (omega_dist, omega_zi_dist))).reduce(ops.logaddexp, "zi_omega")

        # finally, construct the term for parallel scan reduction
        hmm_factor = step_dist + angle_dist + omega_dist
        hmm_factor = hmm_factor * self.raggedness_masks["individual"]
        hmm_factor = hmm_factor * self.raggedness_masks["timestep"]
        # copy masking behavior of pyro.infer.TraceEnum_ELBO._compute_model_factors
        hmm_factor = hmm_factor + y_dist
        log_prob.insert(0, hmm_factor)

        return log_prob
Ejemplo n.º 20
0
def test_reduce():
    point = Tensor(randn(3))
    d = Delta('foo', point)
    assert d.reduce(ops.logaddexp, frozenset(['foo'])) is Number(0)
Ejemplo n.º 21
0
def test_slice_2(start, stop, step):
    t = randn((10, 2))
    actual = Tensor(t)["i"](i=Slice("j", start, stop, step, dtype=10))
    expected = Tensor(t[start:stop:step])["j"]
    assert_close(actual, expected)
Ejemplo n.º 22
0
def test_reduce_density(log_density):
    point = Tensor(randn(3))
    d = Delta('foo', point, log_density)
    # Note that log_density affects ground substitution but does not affect reduction.
    assert d.reduce(ops.logaddexp, frozenset(['foo'])) is Number(0)
Ejemplo n.º 23
0
def test_getitem_string():
    data = randn((5, 4, 3, 2))
    x = Tensor(data)
    assert x['i'] is Tensor(data, OrderedDict([('i', bint(5))]))
    assert x['i', 'j'] is Tensor(data,
                                 OrderedDict([('i', bint(5)), ('j', bint(4))]))
Ejemplo n.º 24
0
def test_transform_log(shape):
    point = Tensor(randn(shape))
    x = Variable('x', Reals[shape])
    actual = Delta('y', point)(y=ops.log(x))
    expected = Delta('x', point.exp(), -point.sum())
    assert_close(actual, expected)
Ejemplo n.º 25
0
def test_to_data_error():
    data = zeros((3, 3))
    x = Tensor(data, OrderedDict(i=bint(3)))
    with pytest.raises(ValueError):
        funsor.to_data(x)
Ejemplo n.º 26
0
    def _eager_subs_real(self, subs, remaining_subs):
        # Broadcast all component tensors.
        subs = OrderedDict(subs)
        int_inputs = OrderedDict(
            (k, d) for k, d in self.inputs.items() if d.dtype != 'real')
        tensors = [
            Tensor(self.info_vec, int_inputs),
            Tensor(self.precision, int_inputs)
        ]
        tensors.extend(subs.values())
        int_inputs, tensors = align_tensors(*tensors)
        batch_dim = len(tensors[0].shape) - 1
        batch_shape = broadcast_shape(*(x.shape[:batch_dim] for x in tensors))
        (info_vec, precision), values = tensors[:2], tensors[2:]
        offsets, event_size = _compute_offsets(self.inputs)
        slices = [(k, slice(offset, offset + self.inputs[k].num_elements))
                  for k, offset in offsets.items()]

        # Expand all substituted values.
        values = OrderedDict(zip(subs, values))
        for k, value in values.items():
            value = value.reshape(value.shape[:batch_dim] + (-1, ))
            if not get_tracing_state():
                assert value.shape[-1] == self.inputs[k].num_elements
            values[k] = ops.expand(value, batch_shape + value.shape[-1:])

        # Try to perform a complete substitution of all real variables, resulting in a Tensor.
        if all(k in subs for k, d in self.inputs.items() if d.dtype == 'real'):
            # Form the concatenated value.
            value = BlockVector(batch_shape + (event_size, ))
            for k, i in slices:
                if k in values:
                    value[..., i] = values[k]
            value = value.as_tensor()

            # Evaluate the non-normalized log density.
            result = _vv(value, info_vec - 0.5 * _mv(precision, value))

            result = Tensor(result, int_inputs)
            assert result.output == reals()
            return Subs(result, remaining_subs) if remaining_subs else result

        # Perform a partial substution of a subset of real variables, resulting in a Joint.
        # We split real inputs into two sets: a for the preserved and b for the substituted.
        b = frozenset(k for k, v in subs.items())
        a = frozenset(k for k, d in self.inputs.items()
                      if d.dtype == 'real' and k not in b)
        prec_aa = ops.cat(
            -2, *[
                ops.cat(
                    -1,
                    *[precision[..., i1, i2] for k2, i2 in slices if k2 in a])
                for k1, i1 in slices if k1 in a
            ])
        prec_ab = ops.cat(
            -2, *[
                ops.cat(
                    -1,
                    *[precision[..., i1, i2] for k2, i2 in slices if k2 in b])
                for k1, i1 in slices if k1 in a
            ])
        prec_bb = ops.cat(
            -2, *[
                ops.cat(
                    -1,
                    *[precision[..., i1, i2] for k2, i2 in slices if k2 in b])
                for k1, i1 in slices if k1 in b
            ])
        info_a = ops.cat(-1, *[info_vec[..., i] for k, i in slices if k in a])
        info_b = ops.cat(-1, *[info_vec[..., i] for k, i in slices if k in b])
        value_b = ops.cat(-1, *[values[k] for k, i in slices if k in b])
        info_vec = info_a - _mv(prec_ab, value_b)
        log_scale = _vv(value_b, info_b - 0.5 * _mv(prec_bb, value_b))
        precision = ops.expand(prec_aa, info_vec.shape + info_vec.shape[-1:])
        inputs = int_inputs.copy()
        for k, d in self.inputs.items():
            if k not in subs:
                inputs[k] = d
        result = Gaussian(info_vec, precision, inputs) + Tensor(
            log_scale, int_inputs)
        return Subs(result, remaining_subs) if remaining_subs else result
Ejemplo n.º 27
0
def test_tensor_stack(n, shape, dim):
    tensors = [randn(shape) for _ in range(n)]
    actual = stack(tuple(Tensor(t) for t in tensors), dim=dim)
    expected = Tensor(ops.stack(dim, *tensors))
    assert_close(actual, expected)
Ejemplo n.º 28
0
def moment_matching_contract_joint(red_op, bin_op, reduced_vars, discrete,
                                   gaussian):

    approx_vars = frozenset(
        k for k in reduced_vars
        if k in gaussian.inputs and gaussian.inputs[k].dtype != 'real')
    exact_vars = reduced_vars - approx_vars

    if exact_vars and approx_vars:
        return Contraction(red_op, bin_op, exact_vars, discrete,
                           gaussian).reduce(red_op, approx_vars)

    if approx_vars and not exact_vars:
        discrete += gaussian.log_normalizer
        new_discrete = discrete.reduce(
            ops.logaddexp, approx_vars.intersection(discrete.inputs))
        new_discrete = discrete.reduce(
            ops.logaddexp, approx_vars.intersection(discrete.inputs))
        num_elements = reduce(ops.mul, [
            gaussian.inputs[k].num_elements
            for k in approx_vars.difference(discrete.inputs)
        ], 1)
        if num_elements != 1:
            new_discrete -= math.log(num_elements)

        int_inputs = OrderedDict(
            (k, d) for k, d in gaussian.inputs.items() if d.dtype != 'real')
        probs = (discrete - new_discrete.clamp_finite()).exp()

        old_loc = Tensor(
            ops.cholesky_solve(ops.unsqueeze(gaussian.info_vec, -1),
                               gaussian._precision_chol).squeeze(-1),
            int_inputs)
        new_loc = (probs * old_loc).reduce(ops.add, approx_vars)
        old_cov = Tensor(ops.cholesky_inverse(gaussian._precision_chol),
                         int_inputs)
        diff = old_loc - new_loc
        outers = Tensor(
            ops.unsqueeze(diff.data, -1) * ops.unsqueeze(diff.data, -2),
            diff.inputs)
        new_cov = ((probs * old_cov).reduce(ops.add, approx_vars) +
                   (probs * outers).reduce(ops.add, approx_vars))

        # Numerically stabilize by adding bogus precision to empty components.
        total = probs.reduce(ops.add, approx_vars)
        mask = ops.unsqueeze(ops.unsqueeze((total.data == 0), -1), -1)
        new_cov.data = new_cov.data + mask * ops.new_eye(
            new_cov.data, new_cov.data.shape[-1:])

        new_precision = Tensor(
            ops.cholesky_inverse(ops.cholesky(new_cov.data)), new_cov.inputs)
        new_info_vec = (
            new_precision.data @ ops.unsqueeze(new_loc.data, -1)).squeeze(-1)
        new_inputs = new_loc.inputs.copy()
        new_inputs.update(
            (k, d) for k, d in gaussian.inputs.items() if d.dtype == 'real')
        new_gaussian = Gaussian(new_info_vec, new_precision.data, new_inputs)
        new_discrete -= new_gaussian.log_normalizer

        return new_discrete + new_gaussian

    return None