Example #1
def test_nested_einsum(eqn1, eqn2, optimize1, optimize2, backend1, backend2, einsum_impl):
    inputs1, outputs1, sizes1, operands1, _ = make_einsum_example(eqn1, sizes=(3,))
    inputs2, outputs2, sizes2, operands2, funsor_operands2 = make_einsum_example(eqn2, sizes=(3,))

    # normalize the probs for ground-truth comparison
    operands1 = [operand.abs() / operand.abs().sum(-1, keepdim=True)
                 for operand in operands1]

    expected1 = pyro_einsum(eqn1, *operands1, backend=backend1, modulo_total=True)[0]
    expected2 = pyro_einsum(outputs1[0] + "," + eqn2, *([expected1] + operands2),
                            backend=backend2, modulo_total=True)[0]

    with interpretation(normalize):
        funsor_operands1 = [
                inputs=OrderedDict([(d, bint(sizes1[d])) for d in inp[:-1]])
            ))(value=Variable(inp[-1], bint(sizes1[inp[-1]]))).exp()
            for inp, operand in zip(inputs1, operands1)

        output1_naive = einsum_impl(eqn1, *funsor_operands1, backend=backend1)
        with interpretation(reflect):
            output1 = apply_optimizer(output1_naive) if optimize1 else output1_naive
        output2_naive = einsum_impl(outputs1[0] + "," + eqn2, *([output1] + funsor_operands2), backend=backend2)
        with interpretation(reflect):
            output2 = apply_optimizer(output2_naive) if optimize2 else output2_naive

    actual1 = reinterpret(output1)
    actual2 = reinterpret(output2)

    assert torch.allclose(expected1, actual1.data)
    assert torch.allclose(expected2, actual2.data)
Example #2
def test_normalize_einsum(equation, plates, backend, einsum_impl):
    inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation)

    with interpretation(reflect):
        expr = einsum_impl(equation, *funsor_operands, backend=backend, plates=plates)

    with interpretation(normalize):
        transformed_expr = reinterpret(expr)

    assert isinstance(transformed_expr, Contraction)
    check_funsor(transformed_expr, expr.inputs, expr.output)

    assert all(isinstance(v, (Number, Tensor, Contraction)) for v in transformed_expr.terms)

    with interpretation(normalize):
        transformed_expr2 = reinterpret(transformed_expr)

    assert transformed_expr2 is transformed_expr  # check normalization

    with interpretation(eager):
        actual = reinterpret(transformed_expr)
        expected = reinterpret(expr)

    assert_close(actual, expected, rtol=1e-4)

    actual = eval(quote(expected))  # requires torch, bint
    assert_close(actual, expected)
Example #3
def apply_optimizer(x):
    def nested_optimize_interpreter(cls, *args):
        result = optimize.dispatch(cls, *args)(*args)
        if result is None:
            result = cls(*args)
        return result

    with interpreter.interpretation(unfold):
        expr = interpreter.reinterpret(x)

    with interpreter.interpretation(nested_optimize_interpreter):
        return interpreter.reinterpret(expr)
Example #4
def test_einsum_complete_sharing(equation, plates, backend, einsum_impl, same_lazy):
    inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation)

    with interpretation(reflect):
        lazy_expr1 = einsum_impl(equation, *funsor_operands, backend=backend, plates=plates)
        lazy_expr2 = lazy_expr1 if same_lazy else \
            einsum_impl(equation, *funsor_operands, backend=backend, plates=plates)

    with memoize():
        expr1 = reinterpret(lazy_expr1)
        expr2 = reinterpret(lazy_expr2)
    expr3 = reinterpret(lazy_expr1)

    assert expr1 is expr2
    assert expr1 is not expr3
Example #5
def apply_optimizer(x):

    with interpretation(associate):
        x = reinterpret(x)

    with interpretation(distribute):
        x = reinterpret(x)

    with interpretation(optimize):
        x = reinterpret(x)

    with interpretation(desugar):
        x = reinterpret(x)

    return reinterpret(x)  # use previous interpretation
Example #6
def naive_contract_einsum(eqn, *terms, **kwargs):
    Use for testing Contract against einsum
    assert "plates" not in kwargs

    backend = kwargs.pop('backend', 'torch')
    if backend == 'torch':
        sum_op, prod_op = ops.add, ops.mul
    elif backend in ('pyro.ops.einsum.torch_log',
        sum_op, prod_op = ops.logaddexp, ops.add
        raise ValueError("{} backend not implemented".format(backend))

    assert isinstance(eqn, str)
    assert all(isinstance(term, Funsor) for term in terms)
    inputs, output = eqn.split('->')
    inputs = inputs.split(',')
    assert len(inputs) == len(terms)
    assert len(output.split(',')) == 1
    input_dims = frozenset(d for inp in inputs for d in inp)
    output_dims = frozenset(d for d in output)
    reduced_vars = input_dims - output_dims

    with interpretation(optimize):
        rhs = Finitary(prod_op, tuple(terms))
        lhs = _make_base_lhs(prod_op, rhs, reduced_vars, normalized=False)
        assert frozenset(lhs.inputs) == reduced_vars
        result = Contract(sum_op, prod_op, lhs, rhs, reduced_vars)

    return reinterpret(result)
Example #7
def test_einsum(equation, backend):
    inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(
    expected = opt_einsum.contract(equation, *operands, backend=backend)

    with interpretation(reflect):
        naive_ast = naive_einsum(equation, *funsor_operands, backend=backend)
        optimized_ast = apply_optimizer(naive_ast)
    print("Naive expression: {}".format(naive_ast))
    print("Optimized expression: {}".format(optimized_ast))
    actual_optimized = reinterpret(optimized_ast)  # eager by default
    actual = naive_einsum(equation, *funsor_operands, backend=backend)

    assert isinstance(actual, funsor.Tensor) and len(outputs) == 1
    if len(outputs[0]) > 0:
        actual = actual.align(tuple(outputs[0]))
        actual_optimized = actual_optimized.align(tuple(outputs[0]))

    assert_close(actual, actual_optimized, atol=1e-4)
    assert expected.shape == actual.data.shape
    assert_close(expected, actual.data, rtol=1e-5, atol=1e-8)
    for output in outputs:
        for i, output_dim in enumerate(output):
            assert output_dim in actual.inputs
            assert actual.inputs[output_dim].dtype == sizes[output_dim]
Example #8
def test_plated_einsum(equation, plates, backend):
    inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(
    expected = pyro_einsum(equation,
    with interpretation(reflect):
        naive_ast = naive_plated_einsum(equation,
        optimized_ast = apply_optimizer(naive_ast)
    actual_optimized = reinterpret(optimized_ast)  # eager by default
    actual = naive_plated_einsum(equation,

    if len(outputs[0]) > 0:
        actual = actual.align(tuple(outputs[0]))
        actual_optimized = actual_optimized.align(tuple(outputs[0]))

                 atol=1e-3 if backend == 'torch' else 1e-4)

    assert expected.shape == actual.data.shape
    assert torch.allclose(expected, actual.data)
    for output in outputs:
        for i, output_dim in enumerate(output):
            assert output_dim in actual.inputs
            assert actual.inputs[output_dim].dtype == sizes[output_dim]
def _check_mvn_affine(d1, data):
    assert isinstance(d1, dist.MultivariateNormal)
    d2 = reinterpret(d1)
    assert issubclass(type(d2), GaussianMixture)
    actual = d2(**data)
    expected = d1(**data)
    assert_close(actual, expected)
Example #10
def _check_mvn_affine(d1, data):
    backend_module = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()])
    assert isinstance(d1, backend_module.MultivariateNormal)
    d2 = reinterpret(d1)
    assert issubclass(type(d2), GaussianMixture)
    actual = d2(**data)
    expected = d1(**data)
    assert_close(actual, expected)
Example #11
def test_nested_complete_sharing_direct():

    inputs, outputs, sizes, operands, funsor_operands = make_einsum_example("ab,bc,cd->d")
    ab, bc, cd = funsor_operands

    # avoids the complicated internal interpreter usage of the nested optimized einsum tests above
    with interpretation(reflect):
        c1 = (ab * bc).reduce(ops.add, frozenset({"a", "b"}))
        d1 = (c1 * cd).reduce(ops.add, frozenset({"c"}))

        # this does not trigger a second alpha-renaming
        c2 = (ab * bc).reduce(ops.add, frozenset({"a", "b"}))
        d2 = (c2 * cd).reduce(ops.add, frozenset({"c"}))

    with memoize():
        assert reinterpret(c1) is reinterpret(c2)
        assert reinterpret(d1) is reinterpret(d2)
Example #12
def test_match_binary():
    with interpretation(lazy):
        pattern = Variable('a', reals()) + Number(2.) * Variable('b', reals())
        expr = Number(1.) + Number(2.) * (Number(3.) - Number(4.))

    def expand_2_vars(a, b):
        return a + b + b

    def expand_2_walk(x):
        return x.lhs + x.rhs.rhs + x.rhs.rhs

    eager_val = reinterpret(expr)
    lazy_val = expand_2_vars(expr)
    assert eager_val == reinterpret(lazy_val)

    lazy_val_2 = expand_2_walk(expr)
    assert eager_val == reinterpret(lazy_val_2)
Example #13
def main(args):

    # Declare parameters.
    trans_probs = torch.tensor([[0.2, 0.8], [0.7, 0.3]], requires_grad=True)
    emit_probs = torch.tensor([[0.4, 0.6], [0.1, 0.9]], requires_grad=True)
    params = [trans_probs, emit_probs]

    # A discrete HMM model.
    def model(data):
        log_prob = funsor.to_funsor(0.)

        trans = dist.Categorical(probs=funsor.Tensor(
            inputs=OrderedDict([('prev', funsor.Bint[args.hidden_dim])]),

        emit = dist.Categorical(probs=funsor.Tensor(
            inputs=OrderedDict([('latent', funsor.Bint[args.hidden_dim])]),

        x_curr = funsor.Number(0, args.hidden_dim)
        for t, y in enumerate(data):
            x_prev = x_curr

            # A delayed sample statement.
            x_curr = funsor.Variable('x_{}'.format(t),
            log_prob += trans(prev=x_prev, value=x_curr)

            if not args.lazy and isinstance(x_prev, funsor.Variable):
                log_prob = log_prob.reduce(ops.logaddexp, x_prev.name)

            log_prob += emit(latent=x_curr, value=funsor.Tensor(y, dtype=2))

        log_prob = log_prob.reduce(ops.logaddexp)
        return log_prob

    # Train model parameters.
    data = torch.ones(args.time_steps, dtype=torch.long)
    optim = torch.optim.Adam(params, lr=args.learning_rate)
    for step in range(args.train_steps):
        if args.lazy:
            with interpretation(lazy):
                log_prob = apply_optimizer(model(data))
            log_prob = reinterpret(log_prob)
            log_prob = model(data)
        assert not log_prob.inputs, 'free variables remain'
        loss = -log_prob.data
Example #14
def main(args):

    # Declare parameters.
    trans_noise = torch.tensor(0.1, requires_grad=True)
    emit_noise = torch.tensor(0.5, requires_grad=True)
    params = [trans_noise, emit_noise]

    # A Gaussian HMM model.
    def model(data):
        log_prob = funsor.to_funsor(0.)

        x_curr = funsor.Tensor(torch.tensor(0.))
        for t, y in enumerate(data):
            x_prev = x_curr

            # A delayed sample statement.
            x_curr = funsor.Variable('x_{}'.format(t), funsor.Real)
            log_prob += dist.Normal(1 + x_prev / 2., trans_noise, value=x_curr)

            # Optionally marginalize out the previous state.
            if t > 0 and not args.lazy:
                log_prob = log_prob.reduce(ops.logaddexp, x_prev.name)

            # An observe statement.
            log_prob += dist.Normal(0.5 + 3 * x_curr, emit_noise, value=y)

        # Marginalize out all remaining delayed variables.
        log_prob = log_prob.reduce(ops.logaddexp)
        return log_prob

    # Train model parameters.
    data = torch.randn(args.time_steps)
    optim = torch.optim.Adam(params, lr=args.learning_rate)
    for step in range(args.train_steps):
        if args.lazy:
            with interpretation(lazy):
                log_prob = apply_optimizer(model(data))
            log_prob = reinterpret(log_prob)
            log_prob = model(data)
        assert not log_prob.inputs, 'free variables remain'
        loss = -log_prob.data
        if args.verbose and step % 10 == 0:
            print('step {} loss = {}'.format(step, loss.item()))
Example #15
def substitute(expr, subs):
    if isinstance(subs, (dict, OrderedDict)):
        subs = tuple(subs.items())
    assert isinstance(subs, tuple)

    @interpreter.interpretation(interpreter._INTERPRETATION)  # use base
    def subs_interpreter(cls, *args):
        expr = cls(*args)
        fresh_subs = tuple((k, v) for k, v in subs if k in expr.fresh)
        if fresh_subs:
            expr = interpreter.debug_logged(expr.eager_subs)(fresh_subs)
        return expr

    with interpreter.interpretation(subs_interpreter):
        return interpreter.reinterpret(expr)
Example #16
def test_cat_slice_tensor(start, stop, step):

    terms = tuple(
        random_tensor(OrderedDict(t=bint(t), a=bint(2)))
        for t in [2, 1, 3, 4, 1, 3])
    dtype = sum(term.inputs['t'].dtype for term in terms)
    sub = Slice('t', start, stop, step, dtype)

    # eager
    expected = Cat('t', terms)(t=sub)

    # lazy - exercise Cat.eager_subs
    with interpretation(lazy):
        actual = Cat('t', terms)(t=sub)
    actual = reinterpret(actual)

    assert_close(actual, expected)
Example #17
def test_optimized_einsum(equation, backend, einsum_impl):
    inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation)
    expected = pyro_einsum(equation, *operands, backend=backend)[0]
    with interpretation(normalize):
        naive_ast = einsum_impl(equation, *funsor_operands, backend=backend)
    optimized_ast = apply_optimizer(naive_ast)
    actual = reinterpret(optimized_ast)  # eager by default

    assert isinstance(actual, funsor.Tensor) and len(outputs) == 1
    if len(outputs[0]) > 0:
        actual = actual.align(tuple(outputs[0]))

    assert expected.shape == actual.data.shape
    assert torch.allclose(expected, actual.data)
    for output in outputs:
        for i, output_dim in enumerate(output):
            assert output_dim in actual.inputs
            assert actual.inputs[output_dim].dtype == sizes[output_dim]
Example #18
def test_distribute_reduce(lhs_vars, rhs_vars):

    lhs_vars, rhs_vars = frozenset(lhs_vars), frozenset(rhs_vars)
    lhs = random_tensor(OrderedDict([('i', bint(3)), ('j', bint(2))]), reals())
    rhs = random_tensor(OrderedDict([('i', bint(3)), ('j', bint(2))]), reals())

    with interpretation(reflect):
        actual_lhs = lhs.reduce(ops.add, lhs_vars) if lhs_vars else lhs
        actual_rhs = rhs.reduce(ops.add, rhs_vars) if rhs_vars else rhs

    actual = reinterpret(actual_lhs * actual_rhs)

    lhs_subs = {v: gensym(v) for v in lhs_vars}
    rhs_subs = {v: gensym(v) for v in rhs_vars}
    expected = (lhs(**lhs_subs) * rhs(**rhs_subs)).reduce(
        frozenset(lhs_subs.values()) | frozenset(rhs_subs.values()))

    assert_close(actual, expected)
Example #19
def test_einsum_categorical(equation):
    if get_backend() == "jax":
        from funsor.jax.distributions import Categorical
        from funsor.torch.distributions import Categorical

    inputs, outputs, sizes, operands, _ = make_einsum_example(equation)
    operands = [ops.abs(operand) / ops.abs(operand).sum(-1)[..., None]
                for operand in operands]

    expected = opt_einsum.contract(equation, *operands,

    with interpretation(reflect):
        funsor_operands = [
                inputs=OrderedDict([(d, Bint[sizes[d]]) for d in inp[:-1]])
            ))(value=Variable(inp[-1], Bint[sizes[inp[-1]]])).exp()
            for inp, operand in zip(inputs, operands)

        naive_ast = naive_einsum(equation, *funsor_operands)
        optimized_ast = apply_optimizer(naive_ast)

    print("Naive expression: {}".format(naive_ast))
    print("Optimized expression: {}".format(optimized_ast))
    actual_optimized = reinterpret(optimized_ast)  # eager by default
    actual = naive_einsum(equation, *map(reinterpret, funsor_operands))

    if len(outputs[0]) > 0:
        actual = actual.align(tuple(outputs[0]))
        actual_optimized = actual_optimized.align(tuple(outputs[0]))

    assert_close(actual, actual_optimized, atol=1e-4)

    assert expected.shape == actual.data.shape
    assert_close(expected, actual.data)
    for output in outputs:
        for i, output_dim in enumerate(output):
            assert output_dim in actual.inputs
            assert actual.inputs[output_dim].dtype == sizes[output_dim]
Example #20
def adjoint(expr, targets, start=Number(0.)):

    adjoint_values = defaultdict(lambda: Number(0.))  # 1 in logspace
    multiplicities = defaultdict(lambda: 0)

    tape_recorder = AdjointTape()
    with interpretation(tape_recorder):
        adjoint_values[reinterpret(expr)] = start

    while tape_recorder.tape:
        output, fn, inputs = tape_recorder.tape.pop()
        in_adjs = adjoint_ops(fn, adjoint_values[output], output, *inputs)
        for v, adjv in in_adjs.items():
            multiplicities[v] += 1
            adjoint_values[v] = adjoint_values[v] + adjv  # product in logspace

    target_adjs = {}
    for v in targets:
        target_adjs[v] = adjoint_values[v] / multiplicities[v]
        if not isinstance(v, Variable):
            target_adjs[v] = target_adjs[v] + v
    return target_adjs
Example #21
def test_einsum_categorical(equation):
    inputs, outputs, sizes, operands, _ = make_einsum_example(equation)
    operands = [operand.abs() / operand.abs().sum(-1, keepdim=True)
                for operand in operands]

    expected = opt_einsum.contract(equation, *operands, backend='torch')

    with interpretation(reflect):
        funsor_operands = [
                inputs=OrderedDict([(d, bint(sizes[d])) for d in inp[:-1]])
            ))(value=Variable(inp[-1], bint(sizes[inp[-1]]))).exp()
            for inp, operand in zip(inputs, operands)

        naive_ast = naive_einsum(equation, *funsor_operands)
        optimized_ast = apply_optimizer(naive_ast)

    print("Naive expression: {}".format(naive_ast))
    print("Optimized expression: {}".format(optimized_ast))
    actual_optimized = reinterpret(optimized_ast)  # eager by default
    actual = naive_einsum(equation, *map(reinterpret, funsor_operands))

    if len(outputs[0]) > 0:
        actual = actual.align(tuple(outputs[0]))
        actual_optimized = actual_optimized.align(tuple(outputs[0]))

    assert_close(actual, actual_optimized, atol=1e-4)

    assert expected.shape == actual.data.shape
    assert torch.allclose(expected, actual.data)
    for output in outputs:
        for i, output_dim in enumerate(output):
            assert output_dim in actual.inputs
            assert actual.inputs[output_dim].dtype == sizes[output_dim]
Example #22
def einsum(eqn, *terms, **kwargs):
    with interpretation(reflect):
        naive_ast = naive_plated_einsum(eqn, *terms, **kwargs)
        optimized_ast = apply_optimizer(naive_ast)
    return reinterpret(optimized_ast)  # eager by default