예제 #1
0
    def __call__(cls, *args, **kwargs):
        kwargs.update(zip(cls._ast_fields, args))
        args = cls._fill_defaults(**kwargs)
        args = numbers_to_tensors(*args)

        # If value was explicitly specified, evaluate under current interpretation.
        if 'value' in kwargs:
            return super(DistributionMeta, cls).__call__(*args)

        # Otherwise lazily construct a distribution instance.
        # This makes it cheaper to construct observations in minipyro.
        with interpretation(lazy):
            return super(DistributionMeta, cls).__call__(*args)
예제 #2
0
def test_sequential_sum_product_adjoint(impl, sum_op, prod_op, batch_inputs,
                                        state_domain, num_steps):
    # test mostly copied from test_sum_product.py
    inputs = OrderedDict(batch_inputs)
    inputs.update(prev=state_domain, curr=state_domain)
    inputs["time"] = bint(num_steps)
    if state_domain.dtype == "real":
        trans = random_gaussian(inputs)
    else:
        trans = random_tensor(inputs)
    time = Variable("time", bint(num_steps))

    with AdjointTape() as actual_tape:
        actual = impl(sum_op, prod_op, trans, time, {"prev": "curr"})

    expected_inputs = batch_inputs.copy()
    expected_inputs.update(prev=state_domain, curr=state_domain)
    assert dict(actual.inputs) == expected_inputs

    # Check against contract.
    operands = tuple(
        trans(time=t, prev="t_{}".format(t), curr="t_{}".format(t + 1))
        for t in range(num_steps))
    reduce_vars = frozenset("t_{}".format(t) for t in range(1, num_steps))
    with AdjointTape() as expected_tape:
        with interpretation(reflect):
            expected = sum_product(sum_op, prod_op, operands, reduce_vars)
        expected = apply_optimizer(expected)
        expected = expected(**{
            "t_0": "prev",
            "t_{}".format(num_steps): "curr"
        })
        expected = expected.align(tuple(actual.inputs.keys()))

    # check forward pass (sanity check)
    assert_close(actual, expected, rtol=5e-4 * num_steps)

    # perform backward passes only after the sanity check
    expected_bwds = expected_tape.adjoint(sum_op, prod_op, expected, operands)
    actual_bwd = actual_tape.adjoint(sum_op, prod_op, actual, (trans, ))[trans]

    # check backward pass
    for t, operand in enumerate(operands):
        actual_bwd_t = actual_bwd(time=t,
                                  prev="t_{}".format(t),
                                  curr="t_{}".format(t + 1))
        expected_bwd = expected_bwds[operand].align(
            tuple(actual_bwd_t.inputs.keys()))
        check_funsor(actual_bwd_t, expected_bwd.inputs, expected_bwd.output)
        assert_close(actual_bwd_t, expected_bwd, rtol=5e-4 * num_steps)
예제 #3
0
def main(args):
    funsor.set_backend("torch")

    # Define a basic model with a single Normal latent random variable `loc`
    # and a batch of Normally distributed observations.
    def model(data):
        loc = pyro.sample("loc", dist.Normal(0., 1.))
        with pyro.plate("data", len(data), dim=-1):
            pyro.sample("obs", dist.Normal(loc, 1.), obs=data)

    # Define a guide (i.e. variational distribution) with a Normal
    # distribution over the latent random variable `loc`.
    def guide(data):
        guide_loc = pyro.param("guide_loc", torch.tensor(0.))
        guide_scale = pyro.param("guide_scale", torch.tensor(1.),
                                 constraint=constraints.positive)
        pyro.sample("loc", dist.Normal(guide_loc, guide_scale))

    # Generate some data.
    torch.manual_seed(0)
    data = torch.randn(100) + 3.0

    # Because the API in minipyro matches that of Pyro proper,
    # training code works with generic Pyro implementations.
    with pyro_backend(args.backend), interpretation(MonteCarlo()):
        # Construct an SVI object so we can do variational inference on our
        # model/guide pair.
        Elbo = infer.JitTrace_ELBO if args.jit else infer.Trace_ELBO
        elbo = Elbo()
        adam = optim.Adam({"lr": args.learning_rate})
        svi = infer.SVI(model, guide, adam, elbo)

        # Basic training loop
        pyro.get_param_store().clear()
        for step in range(args.num_steps):
            loss = svi.step(data)
            if args.verbose and step % 100 == 0:
                print("step {} loss = {}".format(step, loss))

        # Report the final values of the variational parameters
        # in the guide after training.
        if args.verbose:
            for name in pyro.get_param_store():
                value = pyro.param(name).data
                print("{} = {}".format(name, value.detach().cpu().numpy()))

        # For this simple (conjugate) model we know the exact posterior. In
        # particular we know that the variational distribution should be
        # centered near 3.0. So let's check this explicitly.
        assert (pyro.param("guide_loc") - 3.0).abs() < 0.1
예제 #4
0
def einsum(eqn, *terms, **kwargs):
    r"""
    Top-level interface for optimized tensor variable elimination.

    :param str equation: An einsum equation.
    :param funsor.terms.Funsor \*terms: One or more operands.
    :param set plates: Optional keyword argument denoting which funsor
        dimensions are plate dimensions. Among all input dimensions (from
        terms): dimensions in plates but not in outputs are product-reduced;
        dimensions in neither plates nor outputs are sum-reduced.
    """
    with interpretation(lazy):
        naive_ast = naive_plated_einsum(eqn, *terms, **kwargs)
    return apply_optimizer(naive_ast)
예제 #5
0
파일: hmm.py 프로젝트: ordabayevy/funsor
    def __init__(self,
                 initial_dist,
                 transition_matrix,
                 transition_dist,
                 observation_matrix,
                 observation_dist,
                 validate_args=None):
        assert isinstance(initial_dist, torch.distributions.MultivariateNormal)
        assert isinstance(transition_matrix, torch.Tensor)
        assert isinstance(transition_dist,
                          torch.distributions.MultivariateNormal)
        assert isinstance(observation_matrix, torch.Tensor)
        assert isinstance(observation_dist,
                          torch.distributions.MultivariateNormal)
        hidden_dim, obs_dim = observation_matrix.shape[-2:]
        assert obs_dim >= hidden_dim // 2, "obs_dim must be at least half of hidden_dim"
        assert initial_dist.event_shape == (hidden_dim, )
        assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim)
        assert transition_dist.event_shape == (hidden_dim, )
        assert observation_dist.event_shape == (obs_dim, )
        shape = broadcast_shape(initial_dist.batch_shape + (1, ),
                                transition_matrix.shape[:-2],
                                transition_dist.batch_shape,
                                observation_matrix.shape[:-2],
                                observation_dist.batch_shape)
        batch_shape, time_shape = shape[:-1], shape[-1:]
        event_shape = time_shape + (obs_dim, )

        # Convert distributions to funsors.
        init = dist_to_funsor(initial_dist)(value="state")
        trans = matrix_and_mvn_to_funsor(transition_matrix, transition_dist,
                                         ("time", ), "state", "state(time=1)")
        obs = matrix_and_mvn_to_funsor(observation_matrix, observation_dist,
                                       ("time", ), "state(time=1)", "value")
        dtype = "real"

        # Construct the joint funsor.
        with interpretation(lazy):
            value = Variable("value", Reals[time_shape[0], obs_dim])
            result = trans + obs(value=value["time"])
            result = MarkovProduct(ops.logaddexp, ops.add, result, "time",
                                   {"state": "state(time=1)"})
            result = init + result.reduce(ops.logaddexp, "state(time=1)")
            funsor_dist = result.reduce(ops.logaddexp, "state")

        super(GaussianHMM, self).__init__(funsor_dist, batch_shape,
                                          event_shape, dtype, validate_args)
        self.hidden_dim = hidden_dim
        self.obs_dim = obs_dim
예제 #6
0
def test_sequential_sum_product_multi(impl, x_domain, y_domain, batch_inputs,
                                      num_steps):
    sum_op = ops.logaddexp
    prod_op = ops.add
    inputs = OrderedDict(batch_inputs)
    inputs.update(x_prev=x_domain,
                  x_curr=x_domain,
                  y_prev=y_domain,
                  y_curr=y_domain)
    if num_steps is None:
        num_steps = 1
    else:
        inputs["time"] = bint(num_steps)
    if any(v.dtype == "real" for v in inputs.values()):
        trans = random_gaussian(inputs)
    else:
        trans = random_tensor(inputs)
    time = Variable("time", bint(num_steps))
    step = {"x_prev": "x_curr", "y_prev": "y_curr"}

    with interpretation(moment_matching):
        actual = impl(sum_op, prod_op, trans, time, step)
        expected_inputs = batch_inputs.copy()
        expected_inputs.update(x_prev=x_domain,
                               x_curr=x_domain,
                               y_prev=y_domain,
                               y_curr=y_domain)
        assert dict(actual.inputs) == expected_inputs

        # Check against contract.
        operands = tuple(
            trans(time=t,
                  x_prev="x_{}".format(t),
                  x_curr="x_{}".format(t + 1),
                  y_prev="y_{}".format(t),
                  y_curr="y_{}".format(t + 1)) for t in range(num_steps))
        reduce_vars = frozenset("x_{}".format(t)
                                for t in range(1, num_steps)).union(
                                    "y_{}".format(t)
                                    for t in range(1, num_steps))
        expected = sum_product(sum_op, prod_op, operands, reduce_vars)
        expected = expected(
            **{
                "x_0": "x_prev",
                "x_{}".format(num_steps): "x_curr",
                "y_0": "y_prev",
                "y_{}".format(num_steps): "y_curr"
            })
        expected = expected.align(tuple(actual.inputs.keys()))
예제 #7
0
def main(args):
    funsor.set_backend("torch")

    # Declare parameters.
    trans_noise = torch.tensor(0.1, requires_grad=True)
    emit_noise = torch.tensor(0.5, requires_grad=True)
    params = [trans_noise, emit_noise]

    # A Gaussian HMM model.
    def model(data):
        log_prob = funsor.to_funsor(0.)

        x_curr = funsor.Tensor(torch.tensor(0.))
        for t, y in enumerate(data):
            x_prev = x_curr

            # A delayed sample statement.
            x_curr = funsor.Variable('x_{}'.format(t), funsor.Real)
            log_prob += dist.Normal(1 + x_prev / 2., trans_noise, value=x_curr)

            # Optionally marginalize out the previous state.
            if t > 0 and not args.lazy:
                log_prob = log_prob.reduce(ops.logaddexp, x_prev.name)

            # An observe statement.
            log_prob += dist.Normal(0.5 + 3 * x_curr, emit_noise, value=y)

        # Marginalize out all remaining delayed variables.
        log_prob = log_prob.reduce(ops.logaddexp)
        return log_prob

    # Train model parameters.
    torch.manual_seed(0)
    data = torch.randn(args.time_steps)
    optim = torch.optim.Adam(params, lr=args.learning_rate)
    for step in range(args.train_steps):
        optim.zero_grad()
        if args.lazy:
            with interpretation(lazy):
                log_prob = apply_optimizer(model(data))
            log_prob = reinterpret(log_prob)
        else:
            log_prob = model(data)
        assert not log_prob.inputs, 'free variables remain'
        loss = -log_prob.data
        loss.backward()
        optim.step()
        if args.verbose and step % 10 == 0:
            print('step {} loss = {}'.format(step, loss.item()))
예제 #8
0
def test_mvn_sample(with_lazy, batch_shape, sample_inputs, event_shape):
    batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
    inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape))

    loc = Tensor(torch.randn(batch_shape + event_shape), inputs)
    scale_tril = Tensor(_random_scale_tril(batch_shape + event_shape * 2),
                        inputs)
    with interpretation(lazy if with_lazy else eager):
        funsor_dist = dist.MultivariateNormal(loc, scale_tril)

    _check_sample(funsor_dist,
                  sample_inputs,
                  inputs,
                  atol=5e-2,
                  num_samples=200000)
예제 #9
0
def test_normal_sample(with_lazy, batch_shape, sample_inputs, reparametrized):
    batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
    inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape))

    loc = Tensor(torch.randn(batch_shape), inputs)
    scale = Tensor(torch.rand(batch_shape), inputs)
    with interpretation(lazy if with_lazy else eager):
        funsor_dist = (dist.Normal if reparametrized else
                       dist.NonreparameterizedNormal)(loc, scale)

    _check_sample(funsor_dist,
                  sample_inputs,
                  inputs,
                  num_samples=200000,
                  atol=1e-2 if reparametrized else 1e-1)
예제 #10
0
def substitute(expr, subs):
    if isinstance(subs, (dict, OrderedDict)):
        subs = tuple(subs.items())
    assert isinstance(subs, tuple)

    @interpreter.interpretation(interpreter._INTERPRETATION)  # use base
    def subs_interpreter(cls, *args):
        expr = cls(*args)
        fresh_subs = tuple((k, v) for k, v in subs if k in expr.fresh)
        if fresh_subs:
            expr = interpreter.debug_logged(expr.eager_subs)(fresh_subs)
        return expr

    with interpreter.interpretation(subs_interpreter):
        return interpreter.reinterpret(expr)
예제 #11
0
def test_normalize_einsum(equation, plates, backend, einsum_impl):
    if get_backend() == "torch":
        import torch  # noqa: F401

    inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(
        equation)

    with interpretation(reflect):
        expr = einsum_impl(equation,
                           *funsor_operands,
                           backend=backend,
                           plates=plates)

    with interpretation(normalize):
        transformed_expr = reinterpret(expr)

    assert isinstance(transformed_expr, Contraction)
    check_funsor(transformed_expr, expr.inputs, expr.output)

    assert all(
        isinstance(v, (Number, Tensor, Contraction))
        for v in transformed_expr.terms)

    with interpretation(normalize):
        transformed_expr2 = reinterpret(transformed_expr)

    assert transformed_expr2 is transformed_expr  # check normalization

    with interpretation(eager):
        actual = reinterpret(transformed_expr)
        expected = reinterpret(expr)

    assert_close(actual, expected, rtol=1e-4)

    actual = eval(quote(expected))  # requires torch, bint
    assert_close(actual, expected)
예제 #12
0
def test_reduce_moment_matching_finite():
    delta = Delta('x', random_tensor(OrderedDict([('h', bint(7))])))
    discrete = random_tensor(
        OrderedDict([('i', bint(6)), ('j', bint(5)), ('k', bint(3))]))
    gaussian = random_gaussian(
        OrderedDict([('k', bint(3)), ('l', bint(2)), ('y', reals()),
                     ('z', reals(2))]))

    discrete.data[1:, :] = -float('inf')
    discrete.data[:, 1:] = -float('inf')

    reduced_vars = frozenset(['j', 'k'])
    joint = delta + discrete + gaussian
    with interpretation(moment_matching):
        joint.reduce(ops.logaddexp, reduced_vars)
예제 #13
0
def test_einsum_complete_sharing(equation, plates, backend, einsum_impl, same_lazy):
    inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation)

    with interpretation(reflect):
        lazy_expr1 = einsum_impl(equation, *funsor_operands, backend=backend, plates=plates)
        lazy_expr2 = lazy_expr1 if same_lazy else \
            einsum_impl(equation, *funsor_operands, backend=backend, plates=plates)

    with memoize():
        expr1 = reinterpret(lazy_expr1)
        expr2 = reinterpret(lazy_expr2)
    expr3 = reinterpret(lazy_expr1)

    assert expr1 is expr2
    assert expr1 is not expr3
예제 #14
0
def test_quote(interp):
    with interpretation(interp):
        x = Variable('x', bint(8))
        check_quote(x)

        y = Variable('y', reals(8, 3, 3))
        check_quote(y)
        check_quote(y[x])

        z = Stack('i', (Number(0), Variable('z', reals())))
        check_quote(z)
        check_quote(z(i=0))
        check_quote(z(i=Slice('i', 0, 1, 1, 2)))
        check_quote(z.reduce(ops.add, 'i'))
        check_quote(Cat('i', (z, z, z)))
        check_quote(Lambda(Variable('i', bint(2)), z))
예제 #15
0
def test_cat_slice_tensor(start, stop, step):

    terms = tuple(
        random_tensor(OrderedDict(t=bint(t), a=bint(2)))
        for t in [2, 1, 3, 4, 1, 3])
    dtype = sum(term.inputs['t'].dtype for term in terms)
    sub = Slice('t', start, stop, step, dtype)

    # eager
    expected = Cat('t', terms)(t=sub)

    # lazy - exercise Cat.eager_subs
    with interpretation(lazy):
        actual = Cat('t', terms)(t=sub)
    actual = reinterpret(actual)

    assert_close(actual, expected)
예제 #16
0
def memoize(cache=None):
    """
    Exploit cons-hashing to do implicit common subexpression elimination
    """
    if cache is None:
        cache = {}

    @interpreter.interpretation(interpreter._INTERPRETATION)  # use base
    def memoize_interpretation(cls, *args):
        key = (cls, ) + tuple(
            id(arg) if not isinstance(arg, Hashable) else arg for arg in args)
        if key not in cache:
            cache[key] = cls(*args)
        return cache[key]

    with interpreter.interpretation(memoize_interpretation):
        yield cache
예제 #17
0
def test_nested_complete_sharing_direct():

    inputs, outputs, sizes, operands, funsor_operands = make_einsum_example("ab,bc,cd->d")
    ab, bc, cd = funsor_operands

    # avoids the complicated internal interpreter usage of the nested optimized einsum tests above
    with interpretation(reflect):
        c1 = (ab * bc).reduce(ops.add, frozenset({"a", "b"}))
        d1 = (c1 * cd).reduce(ops.add, frozenset({"c"}))

        # this does not trigger a second alpha-renaming
        c2 = (ab * bc).reduce(ops.add, frozenset({"a", "b"}))
        d2 = (c2 * cd).reduce(ops.add, frozenset({"c"}))

    with memoize():
        assert reinterpret(c1) is reinterpret(c2)
        assert reinterpret(d1) is reinterpret(d2)
예제 #18
0
def test_lognormal_distribution(moment):
    num_samples = 100000
    inputs = OrderedDict(batch=Bint[10])
    loc = random_tensor(inputs)
    scale = random_tensor(inputs).exp()

    log_measure = dist.LogNormal(loc, scale)(value='x')
    probe = Variable('x', Real)**moment
    with interpretation(MonteCarlo(particle=Bint[num_samples])):
        with xfail_if_not_implemented():
            actual = Integrate(log_measure, probe, frozenset(['x']))

    _, (loc_data, scale_data) = align_tensors(loc, scale)
    samples = backend_dist.LogNormal(loc_data, scale_data).sample(
        (num_samples, ))
    expected = (samples**moment).mean(0)
    assert_close(actual.data, expected, atol=1e-2, rtol=1e-2)
예제 #19
0
def test_beta_sample(with_lazy, batch_shape, sample_inputs, reparametrized):
    batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
    inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape))

    concentration1 = Tensor(torch.randn(batch_shape).exp(), inputs)
    concentration0 = Tensor(torch.randn(batch_shape).exp(), inputs)
    with interpretation(lazy if with_lazy else eager):
        funsor_dist = (dist.Beta if reparametrized else
                       dist.NonreparameterizedBeta)(concentration1,
                                                    concentration0)

    _check_sample(funsor_dist,
                  sample_inputs,
                  inputs,
                  atol=1e-2 if reparametrized else 1e-1,
                  statistic="variance",
                  num_samples=100000)
예제 #20
0
def test_reduce_moment_matching_shape(interp):
    delta = Delta('x', random_tensor(OrderedDict([('h', bint(7))])))
    discrete = random_tensor(
        OrderedDict([('h', bint(7)), ('i', bint(6)), ('j', bint(5)),
                     ('k', bint(4))]))
    gaussian = random_gaussian(
        OrderedDict([('k', bint(4)), ('l', bint(3)), ('m', bint(2)),
                     ('y', reals()), ('z', reals(2))]))
    reduced_vars = frozenset(['i', 'k', 'l'])
    real_vars = frozenset(k for k, d in gaussian.inputs.items()
                          if d.dtype == "real")
    joint = delta + discrete + gaussian
    with interpretation(interp):
        actual = joint.reduce(ops.logaddexp, reduced_vars)
    assert set(actual.inputs) == set(joint.inputs) - reduced_vars
    assert_close(actual.reduce(ops.logaddexp, real_vars),
                 joint.reduce(ops.logaddexp, real_vars | reduced_vars))
예제 #21
0
def test_reduce_moment_matching_moments():
    x = Variable('x', reals(2))
    gaussian = random_gaussian(
        OrderedDict([('i', bint(2)), ('j', bint(3)), ('x', reals(2))]))
    with interpretation(moment_matching):
        approx = gaussian.reduce(ops.logaddexp, 'j')
    with monte_carlo_interpretation(s=bint(100000)):
        actual = Integrate(approx, Number(1.), 'x')
        expected = Integrate(gaussian, Number(1.), {'j', 'x'})
        assert_close(actual, expected, atol=1e-3, rtol=1e-3)

        actual = Integrate(approx, x, 'x')
        expected = Integrate(gaussian, x, {'j', 'x'})
        assert_close(actual, expected, atol=1e-2, rtol=1e-2)

        actual = Integrate(approx, x * x, 'x')
        expected = Integrate(gaussian, x * x, {'j', 'x'})
        assert_close(actual, expected, atol=1e-2, rtol=1e-2)
예제 #22
0
def test_optimized_einsum(equation, backend, einsum_impl):
    inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation)
    expected = pyro_einsum(equation, *operands, backend=backend)[0]
    with interpretation(normalize):
        naive_ast = einsum_impl(equation, *funsor_operands, backend=backend)
    optimized_ast = apply_optimizer(naive_ast)
    actual = reinterpret(optimized_ast)  # eager by default

    assert isinstance(actual, funsor.Tensor) and len(outputs) == 1
    if len(outputs[0]) > 0:
        actual = actual.align(tuple(outputs[0]))

    assert expected.shape == actual.data.shape
    assert torch.allclose(expected, actual.data)
    for output in outputs:
        for i, output_dim in enumerate(output):
            assert output_dim in actual.inputs
            assert actual.inputs[output_dim].dtype == sizes[output_dim]
예제 #23
0
def test_distribute_reduce(lhs_vars, rhs_vars):

    lhs_vars, rhs_vars = frozenset(lhs_vars), frozenset(rhs_vars)
    lhs = random_tensor(OrderedDict([('i', bint(3)), ('j', bint(2))]), reals())
    rhs = random_tensor(OrderedDict([('i', bint(3)), ('j', bint(2))]), reals())

    with interpretation(reflect):
        actual_lhs = lhs.reduce(ops.add, lhs_vars) if lhs_vars else lhs
        actual_rhs = rhs.reduce(ops.add, rhs_vars) if rhs_vars else rhs

    actual = actual_lhs * actual_rhs

    lhs_subs = {v: gensym(v) for v in lhs_vars}
    rhs_subs = {v: gensym(v) for v in rhs_vars}
    expected = (lhs(**lhs_subs) * rhs(**rhs_subs)).reduce(
        ops.add, frozenset(lhs_subs.values()) | frozenset(rhs_subs.values()))

    assert_close(actual, expected)
예제 #24
0
def test_binary(symbol, data1, data2):
    dtype = 'real'
    if symbol in BOOLEAN_OPS:
        dtype = 2
        data1 = bool(data1)
        data2 = bool(data2)
    try:
        expected_data = binary_eval(symbol, data1, data2)
    except ZeroDivisionError:
        return

    x1 = Number(data1, dtype)
    x2 = Number(data2, dtype)
    actual = binary_eval(symbol, x1, x2)
    check_funsor(actual, {}, Domain((), dtype), expected_data)
    with interpretation(normalize):
        actual_reflect = binary_eval(symbol, x1, x2)
    assert actual.output == actual_reflect.output
예제 #25
0
def test_advanced_indexing_lazy(output_shape):
    x = Tensor(randn((2, 3, 4) + output_shape),
               OrderedDict([
                   ('i', bint(2)),
                   ('j', bint(3)),
                   ('k', bint(4)),
               ]))
    u = Variable('u', bint(2))
    v = Variable('v', bint(3))
    with interpretation(lazy):
        i = Number(1, 2) - u
        j = Number(2, 3) - v
        k = u + v

    expected_data = empty((2, 3) + output_shape)
    i_data = x.materialize(i).data
    j_data = x.materialize(j).data
    k_data = x.materialize(k).data
    for u in range(2):
        for v in range(3):
            expected_data[u, v] = x.data[i_data[u], j_data[v], k_data[u, v]]
    expected = Tensor(expected_data,
                      OrderedDict([
                          ('u', bint(2)),
                          ('v', bint(3)),
                      ]))

    assert_equiv(expected, x(i, j, k))
    assert_equiv(expected, x(i=i, j=j, k=k))

    assert_equiv(expected, x(i=i, j=j)(k=k))
    assert_equiv(expected, x(j=j, k=k)(i=i))
    assert_equiv(expected, x(k=k, i=i)(j=j))

    assert_equiv(expected, x(i=i)(j=j, k=k))
    assert_equiv(expected, x(j=j)(k=k, i=i))
    assert_equiv(expected, x(k=k)(i=i, j=j))

    assert_equiv(expected, x(i=i)(j=j)(k=k))
    assert_equiv(expected, x(i=i)(k=k)(j=j))
    assert_equiv(expected, x(j=j)(i=i)(k=k))
    assert_equiv(expected, x(j=j)(k=k)(i=i))
    assert_equiv(expected, x(k=k)(i=i)(j=j))
    assert_equiv(expected, x(k=k)(j=j)(i=i))
예제 #26
0
def test_advanced_indexing_lazy(output_shape):
    x = Array(np.random.normal(size=(2, 3, 4) + output_shape),
              OrderedDict([
                  ('i', bint(2)),
                  ('j', bint(3)),
                  ('k', bint(4)),
              ]))
    u = Variable('u', bint(2))
    v = Variable('v', bint(3))
    with interpretation(lazy):
        i = Number(1, 2) - u
        j = Number(2, 3) - v
        k = u + v

    expected_data = np.empty((2, 3) + output_shape)
    i_data = funsor.numpy.materialize(i).data.astype(np.int64)
    j_data = funsor.numpy.materialize(j).data.astype(np.int64)
    k_data = funsor.numpy.materialize(k).data.astype(np.int64)
    for u in range(2):
        for v in range(3):
            expected_data[u, v] = x.data[i_data[u], j_data[v], k_data[u, v]]
    expected = Array(expected_data,
                     OrderedDict([
                         ('u', bint(2)),
                         ('v', bint(3)),
                     ]))

    assert_equiv(expected, x(i, j, k))
    assert_equiv(expected, x(i=i, j=j, k=k))

    assert_equiv(expected, x(i=i, j=j)(k=k))
    assert_equiv(expected, x(j=j, k=k)(i=i))
    assert_equiv(expected, x(k=k, i=i)(j=j))

    assert_equiv(expected, x(i=i)(j=j, k=k))
    assert_equiv(expected, x(j=j)(k=k, i=i))
    assert_equiv(expected, x(k=k)(i=i, j=j))

    assert_equiv(expected, x(i=i)(j=j)(k=k))
    assert_equiv(expected, x(i=i)(k=k)(j=j))
    assert_equiv(expected, x(j=j)(i=i)(k=k))
    assert_equiv(expected, x(j=j)(k=k)(i=i))
    assert_equiv(expected, x(k=k)(i=i)(j=j))
    assert_equiv(expected, x(k=k)(j=j)(i=i))
예제 #27
0
def test_reduce_all(op):
    x = Variable('x', bint(2))
    y = Variable('y', bint(3))
    z = Variable('z', bint(4))
    f = x * y + z
    dtype = f.dtype
    check_funsor(f, {'x': bint(2), 'y': bint(3), 'z': bint(4)}, Domain((), dtype))
    if op is ops.logaddexp:
        pytest.skip()

    with interpretation(sequential):
        actual = f.reduce(op)

    values = [f(x=i, y=j, z=k)
              for i in x.output
              for j in y.output
              for k in z.output]
    expected = reduce(op, values)
    assert actual == expected
예제 #28
0
def test_einsum_adjoint_unary_marginals(einsum_impl, equation, backend):
    inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(
        equation)
    equation = ",".join(inputs) + "->"

    targets = [Variable(k, bint(sizes[k])) for k in set(sizes)]
    with interpretation(reflect):
        fwd_expr = einsum_impl(equation, *funsor_operands, backend=backend)
    actuals = adjoint(fwd_expr, targets)

    for target in targets:
        actual = actuals[target]

        expected = opt_einsum.contract(equation + target.name,
                                       *operands,
                                       backend=backend)
        assert isinstance(actual, funsor.Tensor)
        assert expected.shape == actual.data.shape
        assert torch.allclose(expected, actual.data, atol=1e-7)
예제 #29
0
파일: hmm.py 프로젝트: ordabayevy/funsor
    def __init__(self,
                 initial_logits,
                 transition_logits,
                 observation_dist,
                 validate_args=None):
        assert isinstance(initial_logits, torch.Tensor)
        assert isinstance(transition_logits, torch.Tensor)
        assert isinstance(observation_dist, torch.distributions.Distribution)
        assert initial_logits.dim() >= 1
        assert transition_logits.dim() >= 2
        assert len(observation_dist.batch_shape) >= 1
        shape = broadcast_shape(initial_logits.shape[:-1] + (1, ),
                                transition_logits.shape[:-2],
                                observation_dist.batch_shape[:-1])
        batch_shape, time_shape = shape[:-1], shape[-1:]
        event_shape = time_shape + observation_dist.event_shape
        self._has_rsample = observation_dist.has_rsample

        # Normalize.
        initial_logits = initial_logits - initial_logits.logsumexp(-1, True)
        transition_logits = transition_logits - transition_logits.logsumexp(
            -1, True)

        # Convert tensors and distributions to funsors.
        init = tensor_to_funsor(initial_logits, ("state", ))
        trans = tensor_to_funsor(transition_logits,
                                 ("time", "state", "state(time=1)"))
        obs = dist_to_funsor(observation_dist, ("time", "state(time=1)"))
        dtype = obs.inputs["value"].dtype

        # Construct the joint funsor.
        with interpretation(lazy):
            # TODO perform math here once sequential_sum_product has been
            #   implemented as a first-class funsor.
            funsor_dist = Variable("value",
                                   obs.inputs["value"])  # a bogus value
            # Until funsor_dist is defined, we save factors for hand-computation in .log_prob().
            self._init = init
            self._trans = trans
            self._obs = obs

        super(DiscreteHMM, self).__init__(funsor_dist, batch_shape,
                                          event_shape, dtype, validate_args)
예제 #30
0
def test_match_binary():
    with interpretation(lazy):
        pattern = Variable('a', reals()) + Number(2.) * Variable('b', reals())
        expr = Number(1.) + Number(2.) * (Number(3.) - Number(4.))

    @match_vars(pattern)
    def expand_2_vars(a, b):
        return a + b + b

    @match(pattern)
    def expand_2_walk(x):
        return x.lhs + x.rhs.rhs + x.rhs.rhs

    eager_val = reinterpret(expr)
    lazy_val = expand_2_vars(expr)
    assert eager_val == reinterpret(lazy_val)

    lazy_val_2 = expand_2_walk(expr)
    assert eager_val == reinterpret(lazy_val_2)