def __call__(cls, *args, **kwargs): kwargs.update(zip(cls._ast_fields, args)) args = cls._fill_defaults(**kwargs) args = numbers_to_tensors(*args) # If value was explicitly specified, evaluate under current interpretation. if 'value' in kwargs: return super(DistributionMeta, cls).__call__(*args) # Otherwise lazily construct a distribution instance. # This makes it cheaper to construct observations in minipyro. with interpretation(lazy): return super(DistributionMeta, cls).__call__(*args)
def test_sequential_sum_product_adjoint(impl, sum_op, prod_op, batch_inputs, state_domain, num_steps): # test mostly copied from test_sum_product.py inputs = OrderedDict(batch_inputs) inputs.update(prev=state_domain, curr=state_domain) inputs["time"] = bint(num_steps) if state_domain.dtype == "real": trans = random_gaussian(inputs) else: trans = random_tensor(inputs) time = Variable("time", bint(num_steps)) with AdjointTape() as actual_tape: actual = impl(sum_op, prod_op, trans, time, {"prev": "curr"}) expected_inputs = batch_inputs.copy() expected_inputs.update(prev=state_domain, curr=state_domain) assert dict(actual.inputs) == expected_inputs # Check against contract. operands = tuple( trans(time=t, prev="t_{}".format(t), curr="t_{}".format(t + 1)) for t in range(num_steps)) reduce_vars = frozenset("t_{}".format(t) for t in range(1, num_steps)) with AdjointTape() as expected_tape: with interpretation(reflect): expected = sum_product(sum_op, prod_op, operands, reduce_vars) expected = apply_optimizer(expected) expected = expected(**{ "t_0": "prev", "t_{}".format(num_steps): "curr" }) expected = expected.align(tuple(actual.inputs.keys())) # check forward pass (sanity check) assert_close(actual, expected, rtol=5e-4 * num_steps) # perform backward passes only after the sanity check expected_bwds = expected_tape.adjoint(sum_op, prod_op, expected, operands) actual_bwd = actual_tape.adjoint(sum_op, prod_op, actual, (trans, ))[trans] # check backward pass for t, operand in enumerate(operands): actual_bwd_t = actual_bwd(time=t, prev="t_{}".format(t), curr="t_{}".format(t + 1)) expected_bwd = expected_bwds[operand].align( tuple(actual_bwd_t.inputs.keys())) check_funsor(actual_bwd_t, expected_bwd.inputs, expected_bwd.output) assert_close(actual_bwd_t, expected_bwd, rtol=5e-4 * num_steps)
def main(args): funsor.set_backend("torch") # Define a basic model with a single Normal latent random variable `loc` # and a batch of Normally distributed observations. def model(data): loc = pyro.sample("loc", dist.Normal(0., 1.)) with pyro.plate("data", len(data), dim=-1): pyro.sample("obs", dist.Normal(loc, 1.), obs=data) # Define a guide (i.e. variational distribution) with a Normal # distribution over the latent random variable `loc`. def guide(data): guide_loc = pyro.param("guide_loc", torch.tensor(0.)) guide_scale = pyro.param("guide_scale", torch.tensor(1.), constraint=constraints.positive) pyro.sample("loc", dist.Normal(guide_loc, guide_scale)) # Generate some data. torch.manual_seed(0) data = torch.randn(100) + 3.0 # Because the API in minipyro matches that of Pyro proper, # training code works with generic Pyro implementations. with pyro_backend(args.backend), interpretation(MonteCarlo()): # Construct an SVI object so we can do variational inference on our # model/guide pair. Elbo = infer.JitTrace_ELBO if args.jit else infer.Trace_ELBO elbo = Elbo() adam = optim.Adam({"lr": args.learning_rate}) svi = infer.SVI(model, guide, adam, elbo) # Basic training loop pyro.get_param_store().clear() for step in range(args.num_steps): loss = svi.step(data) if args.verbose and step % 100 == 0: print("step {} loss = {}".format(step, loss)) # Report the final values of the variational parameters # in the guide after training. if args.verbose: for name in pyro.get_param_store(): value = pyro.param(name).data print("{} = {}".format(name, value.detach().cpu().numpy())) # For this simple (conjugate) model we know the exact posterior. In # particular we know that the variational distribution should be # centered near 3.0. So let's check this explicitly. assert (pyro.param("guide_loc") - 3.0).abs() < 0.1
def einsum(eqn, *terms, **kwargs): r""" Top-level interface for optimized tensor variable elimination. :param str equation: An einsum equation. :param funsor.terms.Funsor \*terms: One or more operands. :param set plates: Optional keyword argument denoting which funsor dimensions are plate dimensions. Among all input dimensions (from terms): dimensions in plates but not in outputs are product-reduced; dimensions in neither plates nor outputs are sum-reduced. """ with interpretation(lazy): naive_ast = naive_plated_einsum(eqn, *terms, **kwargs) return apply_optimizer(naive_ast)
def __init__(self, initial_dist, transition_matrix, transition_dist, observation_matrix, observation_dist, validate_args=None): assert isinstance(initial_dist, torch.distributions.MultivariateNormal) assert isinstance(transition_matrix, torch.Tensor) assert isinstance(transition_dist, torch.distributions.MultivariateNormal) assert isinstance(observation_matrix, torch.Tensor) assert isinstance(observation_dist, torch.distributions.MultivariateNormal) hidden_dim, obs_dim = observation_matrix.shape[-2:] assert obs_dim >= hidden_dim // 2, "obs_dim must be at least half of hidden_dim" assert initial_dist.event_shape == (hidden_dim, ) assert transition_matrix.shape[-2:] == (hidden_dim, hidden_dim) assert transition_dist.event_shape == (hidden_dim, ) assert observation_dist.event_shape == (obs_dim, ) shape = broadcast_shape(initial_dist.batch_shape + (1, ), transition_matrix.shape[:-2], transition_dist.batch_shape, observation_matrix.shape[:-2], observation_dist.batch_shape) batch_shape, time_shape = shape[:-1], shape[-1:] event_shape = time_shape + (obs_dim, ) # Convert distributions to funsors. init = dist_to_funsor(initial_dist)(value="state") trans = matrix_and_mvn_to_funsor(transition_matrix, transition_dist, ("time", ), "state", "state(time=1)") obs = matrix_and_mvn_to_funsor(observation_matrix, observation_dist, ("time", ), "state(time=1)", "value") dtype = "real" # Construct the joint funsor. with interpretation(lazy): value = Variable("value", Reals[time_shape[0], obs_dim]) result = trans + obs(value=value["time"]) result = MarkovProduct(ops.logaddexp, ops.add, result, "time", {"state": "state(time=1)"}) result = init + result.reduce(ops.logaddexp, "state(time=1)") funsor_dist = result.reduce(ops.logaddexp, "state") super(GaussianHMM, self).__init__(funsor_dist, batch_shape, event_shape, dtype, validate_args) self.hidden_dim = hidden_dim self.obs_dim = obs_dim
def test_sequential_sum_product_multi(impl, x_domain, y_domain, batch_inputs, num_steps): sum_op = ops.logaddexp prod_op = ops.add inputs = OrderedDict(batch_inputs) inputs.update(x_prev=x_domain, x_curr=x_domain, y_prev=y_domain, y_curr=y_domain) if num_steps is None: num_steps = 1 else: inputs["time"] = bint(num_steps) if any(v.dtype == "real" for v in inputs.values()): trans = random_gaussian(inputs) else: trans = random_tensor(inputs) time = Variable("time", bint(num_steps)) step = {"x_prev": "x_curr", "y_prev": "y_curr"} with interpretation(moment_matching): actual = impl(sum_op, prod_op, trans, time, step) expected_inputs = batch_inputs.copy() expected_inputs.update(x_prev=x_domain, x_curr=x_domain, y_prev=y_domain, y_curr=y_domain) assert dict(actual.inputs) == expected_inputs # Check against contract. operands = tuple( trans(time=t, x_prev="x_{}".format(t), x_curr="x_{}".format(t + 1), y_prev="y_{}".format(t), y_curr="y_{}".format(t + 1)) for t in range(num_steps)) reduce_vars = frozenset("x_{}".format(t) for t in range(1, num_steps)).union( "y_{}".format(t) for t in range(1, num_steps)) expected = sum_product(sum_op, prod_op, operands, reduce_vars) expected = expected( **{ "x_0": "x_prev", "x_{}".format(num_steps): "x_curr", "y_0": "y_prev", "y_{}".format(num_steps): "y_curr" }) expected = expected.align(tuple(actual.inputs.keys()))
def main(args): funsor.set_backend("torch") # Declare parameters. trans_noise = torch.tensor(0.1, requires_grad=True) emit_noise = torch.tensor(0.5, requires_grad=True) params = [trans_noise, emit_noise] # A Gaussian HMM model. def model(data): log_prob = funsor.to_funsor(0.) x_curr = funsor.Tensor(torch.tensor(0.)) for t, y in enumerate(data): x_prev = x_curr # A delayed sample statement. x_curr = funsor.Variable('x_{}'.format(t), funsor.Real) log_prob += dist.Normal(1 + x_prev / 2., trans_noise, value=x_curr) # Optionally marginalize out the previous state. if t > 0 and not args.lazy: log_prob = log_prob.reduce(ops.logaddexp, x_prev.name) # An observe statement. log_prob += dist.Normal(0.5 + 3 * x_curr, emit_noise, value=y) # Marginalize out all remaining delayed variables. log_prob = log_prob.reduce(ops.logaddexp) return log_prob # Train model parameters. torch.manual_seed(0) data = torch.randn(args.time_steps) optim = torch.optim.Adam(params, lr=args.learning_rate) for step in range(args.train_steps): optim.zero_grad() if args.lazy: with interpretation(lazy): log_prob = apply_optimizer(model(data)) log_prob = reinterpret(log_prob) else: log_prob = model(data) assert not log_prob.inputs, 'free variables remain' loss = -log_prob.data loss.backward() optim.step() if args.verbose and step % 10 == 0: print('step {} loss = {}'.format(step, loss.item()))
def test_mvn_sample(with_lazy, batch_shape, sample_inputs, event_shape): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) loc = Tensor(torch.randn(batch_shape + event_shape), inputs) scale_tril = Tensor(_random_scale_tril(batch_shape + event_shape * 2), inputs) with interpretation(lazy if with_lazy else eager): funsor_dist = dist.MultivariateNormal(loc, scale_tril) _check_sample(funsor_dist, sample_inputs, inputs, atol=5e-2, num_samples=200000)
def test_normal_sample(with_lazy, batch_shape, sample_inputs, reparametrized): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) loc = Tensor(torch.randn(batch_shape), inputs) scale = Tensor(torch.rand(batch_shape), inputs) with interpretation(lazy if with_lazy else eager): funsor_dist = (dist.Normal if reparametrized else dist.NonreparameterizedNormal)(loc, scale) _check_sample(funsor_dist, sample_inputs, inputs, num_samples=200000, atol=1e-2 if reparametrized else 1e-1)
def substitute(expr, subs): if isinstance(subs, (dict, OrderedDict)): subs = tuple(subs.items()) assert isinstance(subs, tuple) @interpreter.interpretation(interpreter._INTERPRETATION) # use base def subs_interpreter(cls, *args): expr = cls(*args) fresh_subs = tuple((k, v) for k, v in subs if k in expr.fresh) if fresh_subs: expr = interpreter.debug_logged(expr.eager_subs)(fresh_subs) return expr with interpreter.interpretation(subs_interpreter): return interpreter.reinterpret(expr)
def test_normalize_einsum(equation, plates, backend, einsum_impl): if get_backend() == "torch": import torch # noqa: F401 inputs, outputs, sizes, operands, funsor_operands = make_einsum_example( equation) with interpretation(reflect): expr = einsum_impl(equation, *funsor_operands, backend=backend, plates=plates) with interpretation(normalize): transformed_expr = reinterpret(expr) assert isinstance(transformed_expr, Contraction) check_funsor(transformed_expr, expr.inputs, expr.output) assert all( isinstance(v, (Number, Tensor, Contraction)) for v in transformed_expr.terms) with interpretation(normalize): transformed_expr2 = reinterpret(transformed_expr) assert transformed_expr2 is transformed_expr # check normalization with interpretation(eager): actual = reinterpret(transformed_expr) expected = reinterpret(expr) assert_close(actual, expected, rtol=1e-4) actual = eval(quote(expected)) # requires torch, bint assert_close(actual, expected)
def test_reduce_moment_matching_finite(): delta = Delta('x', random_tensor(OrderedDict([('h', bint(7))]))) discrete = random_tensor( OrderedDict([('i', bint(6)), ('j', bint(5)), ('k', bint(3))])) gaussian = random_gaussian( OrderedDict([('k', bint(3)), ('l', bint(2)), ('y', reals()), ('z', reals(2))])) discrete.data[1:, :] = -float('inf') discrete.data[:, 1:] = -float('inf') reduced_vars = frozenset(['j', 'k']) joint = delta + discrete + gaussian with interpretation(moment_matching): joint.reduce(ops.logaddexp, reduced_vars)
def test_einsum_complete_sharing(equation, plates, backend, einsum_impl, same_lazy): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) with interpretation(reflect): lazy_expr1 = einsum_impl(equation, *funsor_operands, backend=backend, plates=plates) lazy_expr2 = lazy_expr1 if same_lazy else \ einsum_impl(equation, *funsor_operands, backend=backend, plates=plates) with memoize(): expr1 = reinterpret(lazy_expr1) expr2 = reinterpret(lazy_expr2) expr3 = reinterpret(lazy_expr1) assert expr1 is expr2 assert expr1 is not expr3
def test_quote(interp): with interpretation(interp): x = Variable('x', bint(8)) check_quote(x) y = Variable('y', reals(8, 3, 3)) check_quote(y) check_quote(y[x]) z = Stack('i', (Number(0), Variable('z', reals()))) check_quote(z) check_quote(z(i=0)) check_quote(z(i=Slice('i', 0, 1, 1, 2))) check_quote(z.reduce(ops.add, 'i')) check_quote(Cat('i', (z, z, z))) check_quote(Lambda(Variable('i', bint(2)), z))
def test_cat_slice_tensor(start, stop, step): terms = tuple( random_tensor(OrderedDict(t=bint(t), a=bint(2))) for t in [2, 1, 3, 4, 1, 3]) dtype = sum(term.inputs['t'].dtype for term in terms) sub = Slice('t', start, stop, step, dtype) # eager expected = Cat('t', terms)(t=sub) # lazy - exercise Cat.eager_subs with interpretation(lazy): actual = Cat('t', terms)(t=sub) actual = reinterpret(actual) assert_close(actual, expected)
def memoize(cache=None): """ Exploit cons-hashing to do implicit common subexpression elimination """ if cache is None: cache = {} @interpreter.interpretation(interpreter._INTERPRETATION) # use base def memoize_interpretation(cls, *args): key = (cls, ) + tuple( id(arg) if not isinstance(arg, Hashable) else arg for arg in args) if key not in cache: cache[key] = cls(*args) return cache[key] with interpreter.interpretation(memoize_interpretation): yield cache
def test_nested_complete_sharing_direct(): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example("ab,bc,cd->d") ab, bc, cd = funsor_operands # avoids the complicated internal interpreter usage of the nested optimized einsum tests above with interpretation(reflect): c1 = (ab * bc).reduce(ops.add, frozenset({"a", "b"})) d1 = (c1 * cd).reduce(ops.add, frozenset({"c"})) # this does not trigger a second alpha-renaming c2 = (ab * bc).reduce(ops.add, frozenset({"a", "b"})) d2 = (c2 * cd).reduce(ops.add, frozenset({"c"})) with memoize(): assert reinterpret(c1) is reinterpret(c2) assert reinterpret(d1) is reinterpret(d2)
def test_lognormal_distribution(moment): num_samples = 100000 inputs = OrderedDict(batch=Bint[10]) loc = random_tensor(inputs) scale = random_tensor(inputs).exp() log_measure = dist.LogNormal(loc, scale)(value='x') probe = Variable('x', Real)**moment with interpretation(MonteCarlo(particle=Bint[num_samples])): with xfail_if_not_implemented(): actual = Integrate(log_measure, probe, frozenset(['x'])) _, (loc_data, scale_data) = align_tensors(loc, scale) samples = backend_dist.LogNormal(loc_data, scale_data).sample( (num_samples, )) expected = (samples**moment).mean(0) assert_close(actual.data, expected, atol=1e-2, rtol=1e-2)
def test_beta_sample(with_lazy, batch_shape, sample_inputs, reparametrized): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] inputs = OrderedDict((k, bint(v)) for k, v in zip(batch_dims, batch_shape)) concentration1 = Tensor(torch.randn(batch_shape).exp(), inputs) concentration0 = Tensor(torch.randn(batch_shape).exp(), inputs) with interpretation(lazy if with_lazy else eager): funsor_dist = (dist.Beta if reparametrized else dist.NonreparameterizedBeta)(concentration1, concentration0) _check_sample(funsor_dist, sample_inputs, inputs, atol=1e-2 if reparametrized else 1e-1, statistic="variance", num_samples=100000)
def test_reduce_moment_matching_shape(interp): delta = Delta('x', random_tensor(OrderedDict([('h', bint(7))]))) discrete = random_tensor( OrderedDict([('h', bint(7)), ('i', bint(6)), ('j', bint(5)), ('k', bint(4))])) gaussian = random_gaussian( OrderedDict([('k', bint(4)), ('l', bint(3)), ('m', bint(2)), ('y', reals()), ('z', reals(2))])) reduced_vars = frozenset(['i', 'k', 'l']) real_vars = frozenset(k for k, d in gaussian.inputs.items() if d.dtype == "real") joint = delta + discrete + gaussian with interpretation(interp): actual = joint.reduce(ops.logaddexp, reduced_vars) assert set(actual.inputs) == set(joint.inputs) - reduced_vars assert_close(actual.reduce(ops.logaddexp, real_vars), joint.reduce(ops.logaddexp, real_vars | reduced_vars))
def test_reduce_moment_matching_moments(): x = Variable('x', reals(2)) gaussian = random_gaussian( OrderedDict([('i', bint(2)), ('j', bint(3)), ('x', reals(2))])) with interpretation(moment_matching): approx = gaussian.reduce(ops.logaddexp, 'j') with monte_carlo_interpretation(s=bint(100000)): actual = Integrate(approx, Number(1.), 'x') expected = Integrate(gaussian, Number(1.), {'j', 'x'}) assert_close(actual, expected, atol=1e-3, rtol=1e-3) actual = Integrate(approx, x, 'x') expected = Integrate(gaussian, x, {'j', 'x'}) assert_close(actual, expected, atol=1e-2, rtol=1e-2) actual = Integrate(approx, x * x, 'x') expected = Integrate(gaussian, x * x, {'j', 'x'}) assert_close(actual, expected, atol=1e-2, rtol=1e-2)
def test_optimized_einsum(equation, backend, einsum_impl): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) expected = pyro_einsum(equation, *operands, backend=backend)[0] with interpretation(normalize): naive_ast = einsum_impl(equation, *funsor_operands, backend=backend) optimized_ast = apply_optimizer(naive_ast) actual = reinterpret(optimized_ast) # eager by default assert isinstance(actual, funsor.Tensor) and len(outputs) == 1 if len(outputs[0]) > 0: actual = actual.align(tuple(outputs[0])) assert expected.shape == actual.data.shape assert torch.allclose(expected, actual.data) for output in outputs: for i, output_dim in enumerate(output): assert output_dim in actual.inputs assert actual.inputs[output_dim].dtype == sizes[output_dim]
def test_distribute_reduce(lhs_vars, rhs_vars): lhs_vars, rhs_vars = frozenset(lhs_vars), frozenset(rhs_vars) lhs = random_tensor(OrderedDict([('i', bint(3)), ('j', bint(2))]), reals()) rhs = random_tensor(OrderedDict([('i', bint(3)), ('j', bint(2))]), reals()) with interpretation(reflect): actual_lhs = lhs.reduce(ops.add, lhs_vars) if lhs_vars else lhs actual_rhs = rhs.reduce(ops.add, rhs_vars) if rhs_vars else rhs actual = actual_lhs * actual_rhs lhs_subs = {v: gensym(v) for v in lhs_vars} rhs_subs = {v: gensym(v) for v in rhs_vars} expected = (lhs(**lhs_subs) * rhs(**rhs_subs)).reduce( ops.add, frozenset(lhs_subs.values()) | frozenset(rhs_subs.values())) assert_close(actual, expected)
def test_binary(symbol, data1, data2): dtype = 'real' if symbol in BOOLEAN_OPS: dtype = 2 data1 = bool(data1) data2 = bool(data2) try: expected_data = binary_eval(symbol, data1, data2) except ZeroDivisionError: return x1 = Number(data1, dtype) x2 = Number(data2, dtype) actual = binary_eval(symbol, x1, x2) check_funsor(actual, {}, Domain((), dtype), expected_data) with interpretation(normalize): actual_reflect = binary_eval(symbol, x1, x2) assert actual.output == actual_reflect.output
def test_advanced_indexing_lazy(output_shape): x = Tensor(randn((2, 3, 4) + output_shape), OrderedDict([ ('i', bint(2)), ('j', bint(3)), ('k', bint(4)), ])) u = Variable('u', bint(2)) v = Variable('v', bint(3)) with interpretation(lazy): i = Number(1, 2) - u j = Number(2, 3) - v k = u + v expected_data = empty((2, 3) + output_shape) i_data = x.materialize(i).data j_data = x.materialize(j).data k_data = x.materialize(k).data for u in range(2): for v in range(3): expected_data[u, v] = x.data[i_data[u], j_data[v], k_data[u, v]] expected = Tensor(expected_data, OrderedDict([ ('u', bint(2)), ('v', bint(3)), ])) assert_equiv(expected, x(i, j, k)) assert_equiv(expected, x(i=i, j=j, k=k)) assert_equiv(expected, x(i=i, j=j)(k=k)) assert_equiv(expected, x(j=j, k=k)(i=i)) assert_equiv(expected, x(k=k, i=i)(j=j)) assert_equiv(expected, x(i=i)(j=j, k=k)) assert_equiv(expected, x(j=j)(k=k, i=i)) assert_equiv(expected, x(k=k)(i=i, j=j)) assert_equiv(expected, x(i=i)(j=j)(k=k)) assert_equiv(expected, x(i=i)(k=k)(j=j)) assert_equiv(expected, x(j=j)(i=i)(k=k)) assert_equiv(expected, x(j=j)(k=k)(i=i)) assert_equiv(expected, x(k=k)(i=i)(j=j)) assert_equiv(expected, x(k=k)(j=j)(i=i))
def test_advanced_indexing_lazy(output_shape): x = Array(np.random.normal(size=(2, 3, 4) + output_shape), OrderedDict([ ('i', bint(2)), ('j', bint(3)), ('k', bint(4)), ])) u = Variable('u', bint(2)) v = Variable('v', bint(3)) with interpretation(lazy): i = Number(1, 2) - u j = Number(2, 3) - v k = u + v expected_data = np.empty((2, 3) + output_shape) i_data = funsor.numpy.materialize(i).data.astype(np.int64) j_data = funsor.numpy.materialize(j).data.astype(np.int64) k_data = funsor.numpy.materialize(k).data.astype(np.int64) for u in range(2): for v in range(3): expected_data[u, v] = x.data[i_data[u], j_data[v], k_data[u, v]] expected = Array(expected_data, OrderedDict([ ('u', bint(2)), ('v', bint(3)), ])) assert_equiv(expected, x(i, j, k)) assert_equiv(expected, x(i=i, j=j, k=k)) assert_equiv(expected, x(i=i, j=j)(k=k)) assert_equiv(expected, x(j=j, k=k)(i=i)) assert_equiv(expected, x(k=k, i=i)(j=j)) assert_equiv(expected, x(i=i)(j=j, k=k)) assert_equiv(expected, x(j=j)(k=k, i=i)) assert_equiv(expected, x(k=k)(i=i, j=j)) assert_equiv(expected, x(i=i)(j=j)(k=k)) assert_equiv(expected, x(i=i)(k=k)(j=j)) assert_equiv(expected, x(j=j)(i=i)(k=k)) assert_equiv(expected, x(j=j)(k=k)(i=i)) assert_equiv(expected, x(k=k)(i=i)(j=j)) assert_equiv(expected, x(k=k)(j=j)(i=i))
def test_reduce_all(op): x = Variable('x', bint(2)) y = Variable('y', bint(3)) z = Variable('z', bint(4)) f = x * y + z dtype = f.dtype check_funsor(f, {'x': bint(2), 'y': bint(3), 'z': bint(4)}, Domain((), dtype)) if op is ops.logaddexp: pytest.skip() with interpretation(sequential): actual = f.reduce(op) values = [f(x=i, y=j, z=k) for i in x.output for j in y.output for k in z.output] expected = reduce(op, values) assert actual == expected
def test_einsum_adjoint_unary_marginals(einsum_impl, equation, backend): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example( equation) equation = ",".join(inputs) + "->" targets = [Variable(k, bint(sizes[k])) for k in set(sizes)] with interpretation(reflect): fwd_expr = einsum_impl(equation, *funsor_operands, backend=backend) actuals = adjoint(fwd_expr, targets) for target in targets: actual = actuals[target] expected = opt_einsum.contract(equation + target.name, *operands, backend=backend) assert isinstance(actual, funsor.Tensor) assert expected.shape == actual.data.shape assert torch.allclose(expected, actual.data, atol=1e-7)
def __init__(self, initial_logits, transition_logits, observation_dist, validate_args=None): assert isinstance(initial_logits, torch.Tensor) assert isinstance(transition_logits, torch.Tensor) assert isinstance(observation_dist, torch.distributions.Distribution) assert initial_logits.dim() >= 1 assert transition_logits.dim() >= 2 assert len(observation_dist.batch_shape) >= 1 shape = broadcast_shape(initial_logits.shape[:-1] + (1, ), transition_logits.shape[:-2], observation_dist.batch_shape[:-1]) batch_shape, time_shape = shape[:-1], shape[-1:] event_shape = time_shape + observation_dist.event_shape self._has_rsample = observation_dist.has_rsample # Normalize. initial_logits = initial_logits - initial_logits.logsumexp(-1, True) transition_logits = transition_logits - transition_logits.logsumexp( -1, True) # Convert tensors and distributions to funsors. init = tensor_to_funsor(initial_logits, ("state", )) trans = tensor_to_funsor(transition_logits, ("time", "state", "state(time=1)")) obs = dist_to_funsor(observation_dist, ("time", "state(time=1)")) dtype = obs.inputs["value"].dtype # Construct the joint funsor. with interpretation(lazy): # TODO perform math here once sequential_sum_product has been # implemented as a first-class funsor. funsor_dist = Variable("value", obs.inputs["value"]) # a bogus value # Until funsor_dist is defined, we save factors for hand-computation in .log_prob(). self._init = init self._trans = trans self._obs = obs super(DiscreteHMM, self).__init__(funsor_dist, batch_shape, event_shape, dtype, validate_args)
def test_match_binary(): with interpretation(lazy): pattern = Variable('a', reals()) + Number(2.) * Variable('b', reals()) expr = Number(1.) + Number(2.) * (Number(3.) - Number(4.)) @match_vars(pattern) def expand_2_vars(a, b): return a + b + b @match(pattern) def expand_2_walk(x): return x.lhs + x.rhs.rhs + x.rhs.rhs eager_val = reinterpret(expr) lazy_val = expand_2_vars(expr) assert eager_val == reinterpret(lazy_val) lazy_val_2 = expand_2_walk(expr) assert eager_val == reinterpret(lazy_val_2)