Exemplo n.º 1
0
def test_sequential_sum_product_multi(impl, x_domain, y_domain, batch_inputs,
                                      num_steps):
    sum_op = ops.logaddexp
    prod_op = ops.add
    inputs = OrderedDict(batch_inputs)
    inputs.update(x_prev=x_domain,
                  x_curr=x_domain,
                  y_prev=y_domain,
                  y_curr=y_domain)
    if num_steps is None:
        num_steps = 1
    else:
        inputs["time"] = bint(num_steps)
    if any(v.dtype == "real" for v in inputs.values()):
        trans = random_gaussian(inputs)
    else:
        trans = random_tensor(inputs)
    time = Variable("time", bint(num_steps))
    step = {"x_prev": "x_curr", "y_prev": "y_curr"}

    with interpretation(moment_matching):
        actual = impl(sum_op, prod_op, trans, time, step)
        expected_inputs = batch_inputs.copy()
        expected_inputs.update(x_prev=x_domain,
                               x_curr=x_domain,
                               y_prev=y_domain,
                               y_curr=y_domain)
        assert dict(actual.inputs) == expected_inputs

        # Check against contract.
        operands = tuple(
            trans(time=t,
                  x_prev="x_{}".format(t),
                  x_curr="x_{}".format(t + 1),
                  y_prev="y_{}".format(t),
                  y_curr="y_{}".format(t + 1)) for t in range(num_steps))
        reduce_vars = frozenset("x_{}".format(t)
                                for t in range(1, num_steps)).union(
                                    "y_{}".format(t)
                                    for t in range(1, num_steps))
        expected = sum_product(sum_op, prod_op, operands, reduce_vars)
        expected = expected(
            **{
                "x_0": "x_prev",
                "x_{}".format(num_steps): "x_curr",
                "y_0": "y_prev",
                "y_{}".format(num_steps): "y_curr"
            })
        expected = expected.align(tuple(actual.inputs.keys()))
Exemplo n.º 2
0
    def __init__(self,
                 initial_dist,
                 transition_matrix,
                 transition_dist,
                 observation_matrix,
                 observation_dist,
                 validate_args=None):
        assert isinstance(initial_dist, torch.distributions.MultivariateNormal)
        assert isinstance(transition_matrix, torch.Tensor)
        assert isinstance(transition_dist,
                          torch.distributions.MultivariateNormal)
        assert isinstance(observation_matrix, torch.Tensor)
        assert isinstance(observation_dist,
                          torch.distributions.MultivariateNormal)
        hidden_dim, obs_dim = observation_matrix.shape[-2:]
        assert obs_dim >= hidden_dim // 2, "obs_dim must be at least half of hidden_dim"
        assert initial_dist.event_shape == (hidden_dim, )
        assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim)
        assert transition_dist.event_shape == (hidden_dim, )
        assert observation_dist.event_shape == (obs_dim, )
        shape = broadcast_shape(initial_dist.batch_shape + (1, ),
                                transition_matrix.shape[:-2],
                                transition_dist.batch_shape,
                                observation_matrix.shape[:-2],
                                observation_dist.batch_shape)
        batch_shape, time_shape = shape[:-1], shape[-1:]
        event_shape = time_shape + (obs_dim, )

        # Convert distributions to funsors.
        init = dist_to_funsor(initial_dist)(value="state")
        trans = matrix_and_mvn_to_funsor(transition_matrix, transition_dist,
                                         ("time", ), "state", "state(time=1)")
        obs = matrix_and_mvn_to_funsor(observation_matrix, observation_dist,
                                       ("time", ), "state(time=1)", "value")
        dtype = "real"

        # Construct the joint funsor.
        with interpretation(lazy):
            value = Variable("value", reals(time_shape[0], obs_dim))
            result = trans + obs(value=value["time"])
            result = MarkovProduct(ops.logaddexp, ops.add, result, "time",
                                   {"state": "state(time=1)"})
            result = init + result.reduce(ops.logaddexp, "state(time=1)")
            funsor_dist = result.reduce(ops.logaddexp, "state")

        super(GaussianHMM, self).__init__(funsor_dist, batch_shape,
                                          event_shape, dtype, validate_args)
        self.hidden_dim = hidden_dim
        self.obs_dim = obs_dim
Exemplo n.º 3
0
def test_lognormal_distribution(moment):
    num_samples = 100000
    inputs = OrderedDict(batch=bint(10))
    loc = random_tensor(inputs)
    scale = random_tensor(inputs).exp()

    log_measure = dist.LogNormal(loc, scale)(value='x')
    probe = Variable('x', reals())**moment
    with monte_carlo_interpretation(particle=bint(num_samples)):
        with xfail_if_not_implemented():
            actual = Integrate(log_measure, probe, frozenset(['x']))

    samples = backend_dist.LogNormal(loc, scale).sample((num_samples, ))
    expected = (samples**moment).mean(0)
    assert_close(actual.data, expected, atol=1e-2, rtol=1e-2)
Exemplo n.º 4
0
def test_function_lazy_matmul():
    @funsor.torch.function(reals(3, 4), reals(4, 5), reals(3, 5))
    def matmul(x, y):
        return torch.matmul(x, y)

    x_lazy = Variable('x', reals(3, 4))
    y = Tensor(torch.randn(4, 5))
    actual_lazy = matmul(x_lazy, y)
    check_funsor(actual_lazy, {'x': reals(3, 4)}, reals(3, 5))
    assert isinstance(actual_lazy, funsor.torch.Function)

    x = Tensor(torch.randn(3, 4))
    actual = actual_lazy(x=x)
    expected_data = torch.matmul(x.data, y.data)
    check_funsor(actual, {}, reals(3, 5), expected_data)
Exemplo n.º 5
0
def test_function_lazy_matmul():
    @funsor.function(reals(3, 4), reals(4, 5), reals(3, 5))
    def matmul(x, y):
        return x @ y

    x_lazy = Variable('x', reals(3, 4))
    y = Tensor(randn((4, 5)))
    actual_lazy = matmul(x_lazy, y)
    check_funsor(actual_lazy, {'x': reals(3, 4)}, reals(3, 5))
    assert isinstance(actual_lazy, funsor.tensor.Function)

    x = Tensor(randn((3, 4)))
    actual = actual_lazy(x=x)
    expected_data = x.data @ y.data
    check_funsor(actual, {}, reals(3, 5), expected_data)
Exemplo n.º 6
0
def test_affine_subs(expr, expected_type, expected_inputs):

    expected_output = reals()

    t = Tensor(torch.randn(2, 3), OrderedDict([('i', bint(2)),
                                               ('j', bint(3))]))
    assert isinstance(t, Tensor)

    n = Number(2.)
    assert isinstance(n, Number)

    x = Variable('x', reals())
    assert isinstance(x, Variable)

    y = Variable('y', reals())
    assert isinstance(y, Variable)

    z = Variable('z', reals())
    assert isinstance(z, Variable)

    result = eval(expr)
    assert isinstance(result, expected_type)
    check_funsor(result, expected_inputs, expected_output)
    assert is_affine(result)
Exemplo n.º 7
0
def test_gamma_poisson_conjugate(batch_shape):
    batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
    inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape))
    full_shape = batch_shape
    prior = Variable("prior", Real)
    concentration = Tensor(ops.exp(randn(full_shape)), inputs)
    rate = Tensor(ops.exp(randn(full_shape)), inputs)
    latent = dist.Gamma(concentration, rate, value=prior)
    conditional = dist.Poisson(rate=prior)
    reduced = (latent + conditional).reduce(ops.logaddexp, set(["prior"]))
    assert isinstance(reduced, dist.GammaPoisson)
    assert_close(reduced.concentration, concentration)
    assert_close(reduced.rate, rate)

    obs = Tensor(ops.astype(ops.astype(ops.exp(randn(batch_shape)), 'int32'), 'float32'), inputs)
    _assert_conjugate_density_ok(latent, conditional, obs)
Exemplo n.º 8
0
def test_nested_einsum(eqn1, eqn2, optimize1, optimize2, backend1, backend2,
                       einsum_impl):
    inputs1, outputs1, sizes1, operands1, _ = make_einsum_example(eqn1,
                                                                  sizes=(3, ))
    inputs2, outputs2, sizes2, operands2, funsor_operands2 = make_einsum_example(
        eqn2, sizes=(3, ))

    # normalize the probs for ground-truth comparison
    operands1 = [
        operand.abs() / operand.abs().sum(-1, keepdim=True)
        for operand in operands1
    ]

    expected1 = pyro_einsum(eqn1,
                            *operands1,
                            backend=backend1,
                            modulo_total=True)[0]
    expected2 = pyro_einsum(outputs1[0] + "," + eqn2,
                            *([expected1] + operands2),
                            backend=backend2,
                            modulo_total=True)[0]

    with interpretation(normalize):
        funsor_operands1 = [
            Categorical(probs=Tensor(operand,
                                     inputs=OrderedDict([(d, Bint[sizes1[d]])
                                                         for d in inp[:-1]])))
            (value=Variable(inp[-1], Bint[sizes1[inp[-1]]])).exp()
            for inp, operand in zip(inputs1, operands1)
        ]

        output1_naive = einsum_impl(eqn1, *funsor_operands1, backend=backend1)
        with interpretation(reflect):
            output1 = apply_optimizer(
                output1_naive) if optimize1 else output1_naive
        output2_naive = einsum_impl(outputs1[0] + "," + eqn2,
                                    *([output1] + funsor_operands2),
                                    backend=backend2)
        with interpretation(reflect):
            output2 = apply_optimizer(
                output2_naive) if optimize2 else output2_naive

    actual1 = reinterpret(output1)
    actual2 = reinterpret(output2)

    assert torch.allclose(expected1, actual1.data)
    assert torch.allclose(expected2, actual2.data)
Exemplo n.º 9
0
def extract_affine(fn):
    """
    Extracts an affine representation of a funsor, satisfying::

        x = ...
        const, coeffs = extract_affine(x)
        y = sum(Einsum(eqn, (coeff, Variable(var, coeff.output)))
                for var, (coeff, eqn) in coeffs.items())
        assert_close(y, x)
        assert frozenset(coeffs) == affine_inputs(x)

    The ``coeffs`` will have one key per input wrt which ``fn`` is known to be
    affine (via :func:`affine_inputs` ), and ``const`` and ``coeffs.values``
    will all be constant wrt these inputs.

    The affine approximation is computed by ev evaluating ``fn`` at
    zero and each basis vector. To improve performance, users may want to run
    under the :func:`~funsor.memoize.memoize` interpretation.

    :param Funsor fn: A funsor that is affine wrt the (add,mul) semiring in
        some subset of its inputs.
    :return: A pair ``(const, coeffs)`` where const is a funsor with no real
        inputs and ``coeffs`` is an OrderedDict mapping input name to a
        ``(coefficient, eqn)`` pair in einsum form.
    :rtype: tuple
    """
    # Determine constant part by evaluating fn at zero.
    inputs = affine_inputs(fn)
    inputs = OrderedDict((k, v) for k, v in fn.inputs.items() if k in inputs)
    zeros = {k: Tensor(torch.zeros(v.shape)) for k, v in inputs.items()}
    const = fn(**zeros)

    # Determine linear coefficients by evaluating fn on basis vectors.
    name = gensym('probe')
    coeffs = OrderedDict()
    for k, v in inputs.items():
        dim = v.num_elements
        var = Variable(name, bint(dim))
        subs = zeros.copy()
        subs[k] = Tensor(torch.eye(dim).reshape((dim,) + v.shape))[var]
        coeff = Lambda(var, fn(**subs) - const).reshape(v.shape + const.shape)
        inputs1 = ''.join(map(opt_einsum.get_symbol, range(len(coeff.shape))))
        inputs2 = inputs1[:len(v.shape)]
        output = inputs1[len(v.shape):]
        eqn = f'{inputs1},{inputs2}->{output}'
        coeffs[k] = coeff, eqn
    return const, coeffs
Exemplo n.º 10
0
def test_cons_hash():
    assert Variable('x', bint(3)) is Variable('x', bint(3))
    assert Variable('x', reals()) is Variable('x', reals())
    assert Variable('x', reals()) is not Variable('x', bint(3))
    assert Number(0, 3) is Number(0, 3)
    assert Number(0.) is Number(0.)
    assert Number(0.) is not Number(0, 3)
    assert Slice('x', 10) is Slice('x', 10)
    assert Slice('x', 10) is Slice('x', 0, 10)
    assert Slice('x', 10, 10) is not Slice('x', 0, 10)
    assert Slice('x', 2, 10, 1) is Slice('x', 2, 10)
Exemplo n.º 11
0
def test_cons_hash():
    assert Variable('x', Bint[3]) is Variable('x', Bint[3])
    assert Variable('x', Real) is Variable('x', Real)
    assert Variable('x', Real) is not Variable('x', Bint[3])
    assert Number(0, 3) is Number(0, 3)
    assert Number(0.) is Number(0.)
    assert Number(0.) is not Number(0, 3)
    assert Slice('x', 10) is Slice('x', 10)
    assert Slice('x', 10) is Slice('x', 0, 10)
    assert Slice('x', 10, 10) is not Slice('x', 0, 10)
    assert Slice('x', 2, 10, 1) is Slice('x', 2, 10)
Exemplo n.º 12
0
def test_lognormal_distribution(moment):
    num_samples = 100000
    inputs = OrderedDict(batch=Bint[10])
    loc = random_tensor(inputs)
    scale = random_tensor(inputs).exp()

    log_measure = dist.LogNormal(loc, scale)(value='x')
    probe = Variable('x', Real)**moment
    with interpretation(MonteCarlo(particle=Bint[num_samples])):
        with xfail_if_not_implemented():
            actual = Integrate(log_measure, probe, frozenset(['x']))

    _, (loc_data, scale_data) = align_tensors(loc, scale)
    samples = backend_dist.LogNormal(loc_data, scale_data).sample(
        (num_samples, ))
    expected = (samples**moment).mean(0)
    assert_close(actual.data, expected, atol=1e-2, rtol=1e-2)
Exemplo n.º 13
0
def test_integrate_variable(int_inputs, real_inputs):
    int_inputs = OrderedDict(sorted(int_inputs.items()))
    real_inputs = OrderedDict(sorted(real_inputs.items()))
    inputs = int_inputs.copy()
    inputs.update(real_inputs)

    log_measure = random_gaussian(inputs)
    integrand = reduce(ops.add, [Variable(k, d) for k, d in real_inputs.items()])
    reduced_vars = frozenset(real_inputs)

    sampled_log_measure = log_measure.sample(reduced_vars, OrderedDict(particle=bint(100000)))
    approx = Integrate(sampled_log_measure, integrand, reduced_vars | {'particle'})
    assert isinstance(approx, Tensor)

    exact = Integrate(log_measure, integrand, reduced_vars)
    assert isinstance(exact, Tensor)
    assert_close(approx, exact, atol=0.1, rtol=0.1)
Exemplo n.º 14
0
def test_normal_gaussian_2(batch_shape):
    batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
    inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape))

    loc = Tensor(randn(batch_shape), inputs)
    scale = Tensor(ops.exp(randn(batch_shape)), inputs)
    value = Tensor(randn(batch_shape), inputs)

    expected = dist.Normal(loc, scale, value)
    assert isinstance(expected, Tensor)
    check_funsor(expected, inputs, Real)

    g = dist.Normal(Variable('value', Real), scale, loc)
    assert isinstance(g, Contraction)
    actual = g(value=value)
    check_funsor(actual, inputs, Real)

    assert_close(actual, expected, atol=1e-4)
Exemplo n.º 15
0
def test_normal_gaussian_3(batch_shape):
    batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
    inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape))

    loc = Tensor(torch.randn(batch_shape), inputs)
    scale = Tensor(torch.randn(batch_shape).exp(), inputs)
    value = Tensor(torch.randn(batch_shape), inputs)

    expected = dist.Normal(loc, scale, value)
    assert isinstance(expected, Tensor)
    check_funsor(expected, inputs, reals())

    g = dist.Normal(Variable('loc', reals()), scale, 'value')
    assert isinstance(g, Contraction)
    actual = g(loc=loc, value=value)
    check_funsor(actual, inputs, reals())

    assert_close(actual, expected, atol=1e-4)
Exemplo n.º 16
0
def test_reduce_moment_matching_moments():
    x = Variable('x', reals(2))
    gaussian = random_gaussian(
        OrderedDict([('i', bint(2)), ('j', bint(3)), ('x', reals(2))]))
    with interpretation(moment_matching):
        approx = gaussian.reduce(ops.logaddexp, 'j')
    with monte_carlo_interpretation(s=bint(100000)):
        actual = Integrate(approx, Number(1.), 'x')
        expected = Integrate(gaussian, Number(1.), {'j', 'x'})
        assert_close(actual, expected, atol=1e-3, rtol=1e-3)

        actual = Integrate(approx, x, 'x')
        expected = Integrate(gaussian, x, {'j', 'x'})
        assert_close(actual, expected, atol=1e-2, rtol=1e-2)

        actual = Integrate(approx, x * x, 'x')
        expected = Integrate(gaussian, x * x, {'j', 'x'})
        assert_close(actual, expected, atol=1e-2, rtol=1e-2)
Exemplo n.º 17
0
def test_reduce_moment_matching_moments():
    x = Variable('x', Reals[2])
    gaussian = random_gaussian(
        OrderedDict([('i', Bint[2]), ('j', Bint[3]), ('x', Reals[2])]))
    with interpretation(moment_matching):
        approx = gaussian.reduce(ops.logaddexp, 'j')
    with interpretation(MonteCarlo(s=Bint[100000])):
        actual = Integrate(approx, Number(1.), 'x')
        expected = Integrate(gaussian, Number(1.), {'j', 'x'})
        assert_close(actual, expected, atol=1e-3, rtol=1e-3)

        actual = Integrate(approx, x, 'x')
        expected = Integrate(gaussian, x, {'j', 'x'})
        assert_close(actual, expected, atol=1e-2, rtol=1e-2)

        actual = Integrate(approx, x * x, 'x')
        expected = Integrate(gaussian, x * x, {'j', 'x'})
        assert_close(actual, expected, atol=1e-2, rtol=1e-2)
Exemplo n.º 18
0
def test_integrate_variable(int_inputs, real_inputs):
    int_inputs = OrderedDict(sorted(int_inputs.items()))
    real_inputs = OrderedDict(sorted(real_inputs.items()))
    inputs = int_inputs.copy()
    inputs.update(real_inputs)

    log_measure = random_gaussian(inputs)
    integrand = reduce(ops.add, [Variable(k, d) for k, d in real_inputs.items()])
    reduced_vars = frozenset(real_inputs)

    rng_key = None if get_backend() != 'jax' else np.array([0, 0], dtype=np.uint32)
    sampled_log_measure = log_measure.sample(reduced_vars, OrderedDict(particle=Bint[100000]), rng_key=rng_key)
    approx = Integrate(sampled_log_measure, integrand, reduced_vars | {'particle'})
    assert isinstance(approx, Tensor)

    exact = Integrate(log_measure, integrand, reduced_vars)
    assert isinstance(exact, Tensor)
    assert_close(approx, exact, atol=0.1, rtol=0.1)
Exemplo n.º 19
0
def test_function_nested_lazy():
    @funsor.function(Reals[8], (Real, Bint[8]))
    def max_and_argmax(x):
        return tuple(_numeric_max_and_argmax(x))

    x_lazy = Variable('x', Reals[8])
    lazy_max, lazy_argmax = max_and_argmax(x_lazy)
    assert isinstance(lazy_max, funsor.tensor.Function)
    assert isinstance(lazy_argmax, funsor.tensor.Function)
    check_funsor(lazy_max, {'x': Reals[8]}, Real)
    check_funsor(lazy_argmax, {'x': Reals[8]}, Bint[8])

    inputs = OrderedDict([('i', Bint[2]), ('j', Bint[3])])
    y = Tensor(randn((2, 3, 8)), inputs)
    actual_max = lazy_max(x=y)
    actual_argmax = lazy_argmax(x=y)
    expected_max, expected_argmax = max_and_argmax(y)
    assert_close(actual_max, expected_max)
    assert_close(actual_argmax, expected_argmax)
Exemplo n.º 20
0
    def log_prob(self, value):
        if self._validate_args:
            self._validate_sample(value)
        ndims = max(len(self.batch_shape), value.dim() - self.event_dim)
        time = Variable("time", Bint[self.event_shape[0]])
        value = tensor_to_funsor(value, ("time", ),
                                 event_output=self.event_dim - 1,
                                 dtype=self.dtype)

        # Compare with pyro.distributions.hmm.DiscreteHMM.log_prob().
        obs = self._obs(value=value)
        result = self._trans + obs
        result = sequential_sum_product(ops.logaddexp, ops.add, result, time,
                                        {"state": "state(time=1)"})
        result = self._init + result.reduce(ops.logaddexp, "state(time=1)")
        result = result.reduce(ops.logaddexp, "state")

        result = funsor_to_tensor(result, ndims=ndims)
        return result
Exemplo n.º 21
0
def test_affine_subs():
    # This was recorded from test_pyro_convert.
    x = Subs(
     Gaussian(
      torch.tensor([1.3027106523513794, 1.4167094230651855, -0.9750942587852478, 0.5321089029312134, -0.9039931297302246], dtype=torch.float32),  # noqa
      torch.tensor([[1.0199567079544067, 0.9840421676635742, -0.473368763923645, 0.34206756949424744, -0.7562517523765564], [0.9840421676635742, 1.511502742767334, -1.7593903541564941, 0.6647964119911194, -0.5119513273239136], [-0.4733688533306122, -1.7593903541564941, 3.2386727333068848, -0.9345928430557251, -0.1534711718559265], [0.34206756949424744, 0.6647964119911194, -0.9345928430557251, 0.3141004145145416, -0.12399007380008698], [-0.7562517523765564, -0.5119513273239136, -0.1534711718559265, -0.12399007380008698, 0.6450173854827881]], dtype=torch.float32),  # noqa
      (('state_1_b6',
        reals(3,),),
       ('obs_b2',
        reals(2,),),)),
     (('obs_b2',
       Contraction(ops.nullop, ops.add,
        frozenset(),
        (Variable('bias_b5', reals(2,)),
         Tensor(
          torch.tensor([-2.1787893772125244, 0.5684312582015991], dtype=torch.float32),  # noqa
          (),
          'real'),)),),))
    assert isinstance(x, (Gaussian, Contraction)), x.pretty()
Exemplo n.º 22
0
def test_function_nested_lazy():
    @funsor.function(reals(8), (reals(), bint(8)))
    def max_and_argmax(x):
        return tuple(_numeric_max_and_argmax(x))

    x_lazy = Variable('x', reals(8))
    lazy_max, lazy_argmax = max_and_argmax(x_lazy)
    assert isinstance(lazy_max, funsor.tensor.Function)
    assert isinstance(lazy_argmax, funsor.tensor.Function)
    check_funsor(lazy_max, {'x': reals(8)}, reals())
    check_funsor(lazy_argmax, {'x': reals(8)}, bint(8))

    inputs = OrderedDict([('i', bint(2)), ('j', bint(3))])
    y = Tensor(randn((2, 3, 8)), inputs)
    actual_max = lazy_max(x=y)
    actual_argmax = lazy_argmax(x=y)
    expected_max, expected_argmax = max_and_argmax(y)
    assert_close(actual_max, expected_max)
    assert_close(actual_argmax, expected_argmax)
Exemplo n.º 23
0
def test_einsum_adjoint_unary_marginals(einsum_impl, equation, backend):
    inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(
        equation)
    equation = ",".join(inputs) + "->"

    targets = [Variable(k, bint(sizes[k])) for k in set(sizes)]
    with interpretation(reflect):
        fwd_expr = einsum_impl(equation, *funsor_operands, backend=backend)
    actuals = adjoint(fwd_expr, targets)

    for target in targets:
        actual = actuals[target]

        expected = opt_einsum.contract(equation + target.name,
                                       *operands,
                                       backend=backend)
        assert isinstance(actual, funsor.Tensor)
        assert expected.shape == actual.data.shape
        assert torch.allclose(expected, actual.data, atol=1e-7)
Exemplo n.º 24
0
    def __init__(self,
                 initial_logits,
                 transition_logits,
                 observation_dist,
                 validate_args=None):
        assert isinstance(initial_logits, torch.Tensor)
        assert isinstance(transition_logits, torch.Tensor)
        assert isinstance(observation_dist, torch.distributions.Distribution)
        assert initial_logits.dim() >= 1
        assert transition_logits.dim() >= 2
        assert len(observation_dist.batch_shape) >= 1
        shape = broadcast_shape(initial_logits.shape[:-1] + (1, ),
                                transition_logits.shape[:-2],
                                observation_dist.batch_shape[:-1])
        batch_shape, time_shape = shape[:-1], shape[-1:]
        event_shape = time_shape + observation_dist.event_shape
        self._has_rsample = observation_dist.has_rsample

        # Normalize.
        initial_logits = initial_logits - initial_logits.logsumexp(-1, True)
        transition_logits = transition_logits - transition_logits.logsumexp(
            -1, True)

        # Convert tensors and distributions to funsors.
        init = tensor_to_funsor(initial_logits, ("state", ))
        trans = tensor_to_funsor(transition_logits,
                                 ("time", "state", "state(time=1)"))
        obs = dist_to_funsor(observation_dist, ("time", "state(time=1)"))
        dtype = obs.inputs["value"].dtype

        # Construct the joint funsor.
        with interpretation(lazy):
            # TODO perform math here once sequential_sum_product has been
            #   implemented as a first-class funsor.
            funsor_dist = Variable("value",
                                   obs.inputs["value"])  # a bogus value
            # Until funsor_dist is defined, we save factors for hand-computation in .log_prob().
            self._init = init
            self._trans = trans
            self._obs = obs

        super(DiscreteHMM, self).__init__(funsor_dist, batch_shape,
                                          event_shape, dtype, validate_args)
Exemplo n.º 25
0
def test_sequential_sum_product_adjoint(impl, sum_op, prod_op, batch_inputs, state_domain, num_steps):
    # test mostly copied from test_sum_product.py
    inputs = OrderedDict(batch_inputs)
    inputs.update(prev=state_domain, curr=state_domain)
    inputs["time"] = bint(num_steps)
    if state_domain.dtype == "real":
        trans = random_gaussian(inputs)
    else:
        trans = random_tensor(inputs)
    time = Variable("time", bint(num_steps))

    with AdjointTape() as actual_tape:
        actual = impl(sum_op, prod_op, trans, time, {"prev": "curr"})

    expected_inputs = batch_inputs.copy()
    expected_inputs.update(prev=state_domain, curr=state_domain)
    assert dict(actual.inputs) == expected_inputs

    # Check against contract.
    operands = tuple(trans(time=t, prev="t_{}".format(t), curr="t_{}".format(t+1))
                     for t in range(num_steps))
    reduce_vars = frozenset("t_{}".format(t) for t in range(1, num_steps))
    with AdjointTape() as expected_tape:
        with interpretation(reflect):
            expected = sum_product(sum_op, prod_op, operands, reduce_vars)
        expected = apply_optimizer(expected)
        expected = expected(**{"t_0": "prev", "t_{}".format(num_steps): "curr"})
        expected = expected.align(tuple(actual.inputs.keys()))

    # check forward pass (sanity check)
    assert_close(actual, expected, rtol=5e-4 * num_steps)

    # perform backward passes only after the sanity check
    expected_bwds = expected_tape.adjoint(sum_op, prod_op, expected, operands)
    actual_bwd = actual_tape.adjoint(sum_op, prod_op, actual, (trans,))[trans]

    # check backward pass
    for t, operand in enumerate(operands):
        actual_bwd_t = actual_bwd(time=t, prev="t_{}".format(t), curr="t_{}".format(t+1))
        expected_bwd = expected_bwds[operand].align(tuple(actual_bwd_t.inputs.keys()))
        check_funsor(actual_bwd_t, expected_bwd.inputs, expected_bwd.output)
        assert_close(actual_bwd_t, expected_bwd, rtol=5e-4 * num_steps)
Exemplo n.º 26
0
def test_einsum_categorical(equation):
    if get_backend() == "jax":
        from funsor.jax.distributions import Categorical
    else:
        from funsor.torch.distributions import Categorical

    inputs, outputs, sizes, operands, _ = make_einsum_example(equation)
    operands = [ops.abs(operand) / ops.abs(operand).sum(-1)[..., None]
                for operand in operands]

    expected = opt_einsum.contract(equation, *operands,
                                   backend=BACKEND_TO_EINSUM_BACKEND[get_backend()])

    with interpretation(reflect):
        funsor_operands = [
            Categorical(probs=Tensor(
                operand,
                inputs=OrderedDict([(d, Bint[sizes[d]]) for d in inp[:-1]])
            ))(value=Variable(inp[-1], Bint[sizes[inp[-1]]])).exp()
            for inp, operand in zip(inputs, operands)
        ]

        naive_ast = naive_einsum(equation, *funsor_operands)
        optimized_ast = apply_optimizer(naive_ast)

    print("Naive expression: {}".format(naive_ast))
    print("Optimized expression: {}".format(optimized_ast))
    actual_optimized = reinterpret(optimized_ast)  # eager by default
    actual = naive_einsum(equation, *map(reinterpret, funsor_operands))

    if len(outputs[0]) > 0:
        actual = actual.align(tuple(outputs[0]))
        actual_optimized = actual_optimized.align(tuple(outputs[0]))

    assert_close(actual, actual_optimized, atol=1e-4)

    assert expected.shape == actual.data.shape
    assert_close(expected, actual.data)
    for output in outputs:
        for i, output_dim in enumerate(output):
            assert output_dim in actual.inputs
            assert actual.inputs[output_dim].dtype == sizes[output_dim]
Exemplo n.º 27
0
def test_dirichlet_categorical_conjugate(batch_shape, size):
    batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
    inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape))

    full_shape = batch_shape + (size,)
    prior = Variable("prior", Reals[size])
    concentration = Tensor(ops.exp(randn(full_shape)), inputs)
    value = random_tensor(inputs, Bint[size])
    latent = dist.Dirichlet(concentration, value=prior)
    conditional = dist.Categorical(probs=prior)
    reduced = (latent + conditional).reduce(ops.logaddexp, set(["prior"]))
    assert isinstance(reduced, Tensor)
    actual = reduced(value=value)
    expected = dist.DirichletMultinomial(concentration=concentration, total_count=1)(
        value=Tensor(ops.new_eye(concentration.data, (size,)))[value])
    # TODO: investigate why jax backend gives inconsistent results on Travis
    assert_close(actual, expected, rtol=1e-5 if get_backend() == "jax" else 1e-6)

    obs = random_tensor(inputs, Bint[size])
    _assert_conjugate_density_ok(latent, conditional, obs)
Exemplo n.º 28
0
def test_dirichlet_multinomial_conjugate_plate(batch_shape, size):
    max_count = 10
    batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
    inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape))
    full_shape = batch_shape + (size,)
    prior = Variable("prior", Reals[size])
    concentration = Tensor(ops.exp(randn(full_shape)), inputs)
    value_data = ops.astype(randint(0, max_count, size=batch_shape + (7, size)), 'float32')
    obs_inputs = inputs.copy()
    obs_inputs['plate'] = Bint[7]
    obs = Tensor(value_data, obs_inputs)
    total_count_data = value_data.sum(-1)
    total_count = Tensor(total_count_data, obs_inputs)
    latent = dist.Dirichlet(concentration, value=prior)
    conditional = dist.Multinomial(probs=prior, total_count=total_count, value=obs)
    p = latent + conditional.reduce(ops.add, 'plate')
    reduced = p.reduce(ops.logaddexp, 'prior')
    assert isinstance(reduced, Tensor)

    _assert_conjugate_density_ok(latent, conditional, obs)
Exemplo n.º 29
0
def test_beta_bernoulli_conjugate(batch_shape):
    batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
    inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape))
    full_shape = batch_shape
    prior = Variable("prior", Real)
    concentration1 = Tensor(ops.exp(randn(full_shape)), inputs)
    concentration0 = Tensor(ops.exp(randn(full_shape)), inputs)
    latent = dist.Beta(concentration1, concentration0, value=prior)
    conditional = dist.Bernoulli(probs=prior)
    reduced = (latent + conditional).reduce(ops.logaddexp, set(["prior"]))
    assert isinstance(reduced, dist.DirichletMultinomial)
    concentration = stack((concentration0, concentration1), dim=-1)
    assert_close(reduced.concentration, concentration)
    assert_close(reduced.total_count, Tensor(numeric_array(1.)))

    # we need lazy expression for Beta to draw samples from it
    with interpretation(funsor.terms.lazy):
        lazy_latent = dist.Beta(concentration1, concentration0, value=prior)
    obs = Tensor(rand(batch_shape).round(), inputs)
    _assert_conjugate_density_ok(latent, conditional, obs, lazy_latent=lazy_latent)
Exemplo n.º 30
0
    def log_prob(self, value):
        ndims = max(len(self.batch_shape), value.dim() - 2)
        time = Variable("time", Bint[self.event_shape[0]])
        value = tensor_to_funsor(value, ("time", ), 1)

        seq_sum_prod = naive_sequential_sum_product if self.exact else sequential_sum_product
        with interpretation(eager if self.exact else moment_matching):
            result = self._trans + self._obs(value=value)
            result = seq_sum_prod(ops.logaddexp, ops.add, result, time, {
                "class": "class(time=1)",
                "state": "state(time=1)"
            })
            result += self._init
            result = result.reduce(
                ops.logaddexp,
                frozenset(["class", "state", "class(time=1)",
                           "state(time=1)"]))

            result = funsor_to_tensor(result, ndims=ndims)
            return result