def test_nested_einsum(eqn1, eqn2, optimize1, optimize2, backend1, backend2, einsum_impl): inputs1, outputs1, sizes1, operands1, _ = make_einsum_example(eqn1, sizes=(3,)) inputs2, outputs2, sizes2, operands2, funsor_operands2 = make_einsum_example(eqn2, sizes=(3,)) # normalize the probs for ground-truth comparison operands1 = [operand.abs() / operand.abs().sum(-1, keepdim=True) for operand in operands1] expected1 = pyro_einsum(eqn1, *operands1, backend=backend1, modulo_total=True)[0] expected2 = pyro_einsum(outputs1[0] + "," + eqn2, *([expected1] + operands2), backend=backend2, modulo_total=True)[0] with interpretation(normalize): funsor_operands1 = [ Categorical(probs=Tensor( operand, inputs=OrderedDict([(d, bint(sizes1[d])) for d in inp[:-1]]) ))(value=Variable(inp[-1], bint(sizes1[inp[-1]]))).exp() for inp, operand in zip(inputs1, operands1) ] output1_naive = einsum_impl(eqn1, *funsor_operands1, backend=backend1) with interpretation(reflect): output1 = apply_optimizer(output1_naive) if optimize1 else output1_naive output2_naive = einsum_impl(outputs1[0] + "," + eqn2, *([output1] + funsor_operands2), backend=backend2) with interpretation(reflect): output2 = apply_optimizer(output2_naive) if optimize2 else output2_naive actual1 = reinterpret(output1) actual2 = reinterpret(output2) assert torch.allclose(expected1, actual1.data) assert torch.allclose(expected2, actual2.data)
def test_normalize_einsum(equation, plates, backend, einsum_impl): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) with interpretation(reflect): expr = einsum_impl(equation, *funsor_operands, backend=backend, plates=plates) with interpretation(normalize): transformed_expr = reinterpret(expr) assert isinstance(transformed_expr, Contraction) check_funsor(transformed_expr, expr.inputs, expr.output) assert all(isinstance(v, (Number, Tensor, Contraction)) for v in transformed_expr.terms) with interpretation(normalize): transformed_expr2 = reinterpret(transformed_expr) assert transformed_expr2 is transformed_expr # check normalization with interpretation(eager): actual = reinterpret(transformed_expr) expected = reinterpret(expr) assert_close(actual, expected, rtol=1e-4) actual = eval(quote(expected)) # requires torch, bint assert_close(actual, expected)
def apply_optimizer(x): @interpreter.interpretation(interpreter._INTERPRETATION) def nested_optimize_interpreter(cls, *args): result = optimize.dispatch(cls, *args)(*args) if result is None: result = cls(*args) return result with interpreter.interpretation(unfold): expr = interpreter.reinterpret(x) with interpreter.interpretation(nested_optimize_interpreter): return interpreter.reinterpret(expr)
def test_einsum_complete_sharing(equation, plates, backend, einsum_impl, same_lazy): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) with interpretation(reflect): lazy_expr1 = einsum_impl(equation, *funsor_operands, backend=backend, plates=plates) lazy_expr2 = lazy_expr1 if same_lazy else \ einsum_impl(equation, *funsor_operands, backend=backend, plates=plates) with memoize(): expr1 = reinterpret(lazy_expr1) expr2 = reinterpret(lazy_expr2) expr3 = reinterpret(lazy_expr1) assert expr1 is expr2 assert expr1 is not expr3
def apply_optimizer(x): with interpretation(associate): x = reinterpret(x) with interpretation(distribute): x = reinterpret(x) with interpretation(optimize): x = reinterpret(x) with interpretation(desugar): x = reinterpret(x) return reinterpret(x) # use previous interpretation
def naive_contract_einsum(eqn, *terms, **kwargs): """ Use for testing Contract against einsum """ assert "plates" not in kwargs backend = kwargs.pop('backend', 'torch') if backend == 'torch': sum_op, prod_op = ops.add, ops.mul elif backend in ('pyro.ops.einsum.torch_log', 'pyro.ops.einsum.torch_marginal'): sum_op, prod_op = ops.logaddexp, ops.add else: raise ValueError("{} backend not implemented".format(backend)) assert isinstance(eqn, str) assert all(isinstance(term, Funsor) for term in terms) inputs, output = eqn.split('->') inputs = inputs.split(',') assert len(inputs) == len(terms) assert len(output.split(',')) == 1 input_dims = frozenset(d for inp in inputs for d in inp) output_dims = frozenset(d for d in output) reduced_vars = input_dims - output_dims with interpretation(optimize): rhs = Finitary(prod_op, tuple(terms)) lhs = _make_base_lhs(prod_op, rhs, reduced_vars, normalized=False) assert frozenset(lhs.inputs) == reduced_vars result = Contract(sum_op, prod_op, lhs, rhs, reduced_vars) return reinterpret(result)
def test_einsum(equation, backend): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example( equation) expected = opt_einsum.contract(equation, *operands, backend=backend) with interpretation(reflect): naive_ast = naive_einsum(equation, *funsor_operands, backend=backend) optimized_ast = apply_optimizer(naive_ast) print("Naive expression: {}".format(naive_ast)) print("Optimized expression: {}".format(optimized_ast)) actual_optimized = reinterpret(optimized_ast) # eager by default actual = naive_einsum(equation, *funsor_operands, backend=backend) assert isinstance(actual, funsor.Tensor) and len(outputs) == 1 if len(outputs[0]) > 0: actual = actual.align(tuple(outputs[0])) actual_optimized = actual_optimized.align(tuple(outputs[0])) assert_close(actual, actual_optimized, atol=1e-4) assert expected.shape == actual.data.shape assert_close(expected, actual.data, rtol=1e-5, atol=1e-8) for output in outputs: for i, output_dim in enumerate(output): assert output_dim in actual.inputs assert actual.inputs[output_dim].dtype == sizes[output_dim]
def test_plated_einsum(equation, plates, backend): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example( equation) expected = pyro_einsum(equation, *operands, plates=plates, backend=backend, modulo_total=False)[0] with interpretation(reflect): naive_ast = naive_plated_einsum(equation, *funsor_operands, plates=plates, backend=backend) optimized_ast = apply_optimizer(naive_ast) actual_optimized = reinterpret(optimized_ast) # eager by default actual = naive_plated_einsum(equation, *funsor_operands, plates=plates, backend=backend) if len(outputs[0]) > 0: actual = actual.align(tuple(outputs[0])) actual_optimized = actual_optimized.align(tuple(outputs[0])) assert_close(actual, actual_optimized, atol=1e-3 if backend == 'torch' else 1e-4) assert expected.shape == actual.data.shape assert torch.allclose(expected, actual.data) for output in outputs: for i, output_dim in enumerate(output): assert output_dim in actual.inputs assert actual.inputs[output_dim].dtype == sizes[output_dim]
def _check_mvn_affine(d1, data): assert isinstance(d1, dist.MultivariateNormal) d2 = reinterpret(d1) assert issubclass(type(d2), GaussianMixture) actual = d2(**data) expected = d1(**data) assert_close(actual, expected)
def _check_mvn_affine(d1, data): backend_module = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) assert isinstance(d1, backend_module.MultivariateNormal) d2 = reinterpret(d1) assert issubclass(type(d2), GaussianMixture) actual = d2(**data) expected = d1(**data) assert_close(actual, expected)
def test_nested_complete_sharing_direct(): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example("ab,bc,cd->d") ab, bc, cd = funsor_operands # avoids the complicated internal interpreter usage of the nested optimized einsum tests above with interpretation(reflect): c1 = (ab * bc).reduce(ops.add, frozenset({"a", "b"})) d1 = (c1 * cd).reduce(ops.add, frozenset({"c"})) # this does not trigger a second alpha-renaming c2 = (ab * bc).reduce(ops.add, frozenset({"a", "b"})) d2 = (c2 * cd).reduce(ops.add, frozenset({"c"})) with memoize(): assert reinterpret(c1) is reinterpret(c2) assert reinterpret(d1) is reinterpret(d2)
def test_match_binary(): with interpretation(lazy): pattern = Variable('a', reals()) + Number(2.) * Variable('b', reals()) expr = Number(1.) + Number(2.) * (Number(3.) - Number(4.)) @match_vars(pattern) def expand_2_vars(a, b): return a + b + b @match(pattern) def expand_2_walk(x): return x.lhs + x.rhs.rhs + x.rhs.rhs eager_val = reinterpret(expr) lazy_val = expand_2_vars(expr) assert eager_val == reinterpret(lazy_val) lazy_val_2 = expand_2_walk(expr) assert eager_val == reinterpret(lazy_val_2)
def main(args): funsor.set_backend("torch") # Declare parameters. trans_probs = torch.tensor([[0.2, 0.8], [0.7, 0.3]], requires_grad=True) emit_probs = torch.tensor([[0.4, 0.6], [0.1, 0.9]], requires_grad=True) params = [trans_probs, emit_probs] # A discrete HMM model. def model(data): log_prob = funsor.to_funsor(0.) trans = dist.Categorical(probs=funsor.Tensor( trans_probs, inputs=OrderedDict([('prev', funsor.Bint[args.hidden_dim])]), )) emit = dist.Categorical(probs=funsor.Tensor( emit_probs, inputs=OrderedDict([('latent', funsor.Bint[args.hidden_dim])]), )) x_curr = funsor.Number(0, args.hidden_dim) for t, y in enumerate(data): x_prev = x_curr # A delayed sample statement. x_curr = funsor.Variable('x_{}'.format(t), funsor.Bint[args.hidden_dim]) log_prob += trans(prev=x_prev, value=x_curr) if not args.lazy and isinstance(x_prev, funsor.Variable): log_prob = log_prob.reduce(ops.logaddexp, x_prev.name) log_prob += emit(latent=x_curr, value=funsor.Tensor(y, dtype=2)) log_prob = log_prob.reduce(ops.logaddexp) return log_prob # Train model parameters. data = torch.ones(args.time_steps, dtype=torch.long) optim = torch.optim.Adam(params, lr=args.learning_rate) for step in range(args.train_steps): optim.zero_grad() if args.lazy: with interpretation(lazy): log_prob = apply_optimizer(model(data)) log_prob = reinterpret(log_prob) else: log_prob = model(data) assert not log_prob.inputs, 'free variables remain' loss = -log_prob.data loss.backward() optim.step()
def main(args): funsor.set_backend("torch") # Declare parameters. trans_noise = torch.tensor(0.1, requires_grad=True) emit_noise = torch.tensor(0.5, requires_grad=True) params = [trans_noise, emit_noise] # A Gaussian HMM model. def model(data): log_prob = funsor.to_funsor(0.) x_curr = funsor.Tensor(torch.tensor(0.)) for t, y in enumerate(data): x_prev = x_curr # A delayed sample statement. x_curr = funsor.Variable('x_{}'.format(t), funsor.Real) log_prob += dist.Normal(1 + x_prev / 2., trans_noise, value=x_curr) # Optionally marginalize out the previous state. if t > 0 and not args.lazy: log_prob = log_prob.reduce(ops.logaddexp, x_prev.name) # An observe statement. log_prob += dist.Normal(0.5 + 3 * x_curr, emit_noise, value=y) # Marginalize out all remaining delayed variables. log_prob = log_prob.reduce(ops.logaddexp) return log_prob # Train model parameters. torch.manual_seed(0) data = torch.randn(args.time_steps) optim = torch.optim.Adam(params, lr=args.learning_rate) for step in range(args.train_steps): optim.zero_grad() if args.lazy: with interpretation(lazy): log_prob = apply_optimizer(model(data)) log_prob = reinterpret(log_prob) else: log_prob = model(data) assert not log_prob.inputs, 'free variables remain' loss = -log_prob.data loss.backward() optim.step() if args.verbose and step % 10 == 0: print('step {} loss = {}'.format(step, loss.item()))
def substitute(expr, subs): if isinstance(subs, (dict, OrderedDict)): subs = tuple(subs.items()) assert isinstance(subs, tuple) @interpreter.interpretation(interpreter._INTERPRETATION) # use base def subs_interpreter(cls, *args): expr = cls(*args) fresh_subs = tuple((k, v) for k, v in subs if k in expr.fresh) if fresh_subs: expr = interpreter.debug_logged(expr.eager_subs)(fresh_subs) return expr with interpreter.interpretation(subs_interpreter): return interpreter.reinterpret(expr)
def test_cat_slice_tensor(start, stop, step): terms = tuple( random_tensor(OrderedDict(t=bint(t), a=bint(2))) for t in [2, 1, 3, 4, 1, 3]) dtype = sum(term.inputs['t'].dtype for term in terms) sub = Slice('t', start, stop, step, dtype) # eager expected = Cat('t', terms)(t=sub) # lazy - exercise Cat.eager_subs with interpretation(lazy): actual = Cat('t', terms)(t=sub) actual = reinterpret(actual) assert_close(actual, expected)
def test_optimized_einsum(equation, backend, einsum_impl): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) expected = pyro_einsum(equation, *operands, backend=backend)[0] with interpretation(normalize): naive_ast = einsum_impl(equation, *funsor_operands, backend=backend) optimized_ast = apply_optimizer(naive_ast) actual = reinterpret(optimized_ast) # eager by default assert isinstance(actual, funsor.Tensor) and len(outputs) == 1 if len(outputs[0]) > 0: actual = actual.align(tuple(outputs[0])) assert expected.shape == actual.data.shape assert torch.allclose(expected, actual.data) for output in outputs: for i, output_dim in enumerate(output): assert output_dim in actual.inputs assert actual.inputs[output_dim].dtype == sizes[output_dim]
def test_distribute_reduce(lhs_vars, rhs_vars): lhs_vars, rhs_vars = frozenset(lhs_vars), frozenset(rhs_vars) lhs = random_tensor(OrderedDict([('i', bint(3)), ('j', bint(2))]), reals()) rhs = random_tensor(OrderedDict([('i', bint(3)), ('j', bint(2))]), reals()) with interpretation(reflect): actual_lhs = lhs.reduce(ops.add, lhs_vars) if lhs_vars else lhs actual_rhs = rhs.reduce(ops.add, rhs_vars) if rhs_vars else rhs actual = reinterpret(actual_lhs * actual_rhs) lhs_subs = {v: gensym(v) for v in lhs_vars} rhs_subs = {v: gensym(v) for v in rhs_vars} expected = (lhs(**lhs_subs) * rhs(**rhs_subs)).reduce( ops.add, frozenset(lhs_subs.values()) | frozenset(rhs_subs.values())) assert_close(actual, expected)
def test_einsum_categorical(equation): if get_backend() == "jax": from funsor.jax.distributions import Categorical else: from funsor.torch.distributions import Categorical inputs, outputs, sizes, operands, _ = make_einsum_example(equation) operands = [ops.abs(operand) / ops.abs(operand).sum(-1)[..., None] for operand in operands] expected = opt_einsum.contract(equation, *operands, backend=BACKEND_TO_EINSUM_BACKEND[get_backend()]) with interpretation(reflect): funsor_operands = [ Categorical(probs=Tensor( operand, inputs=OrderedDict([(d, Bint[sizes[d]]) for d in inp[:-1]]) ))(value=Variable(inp[-1], Bint[sizes[inp[-1]]])).exp() for inp, operand in zip(inputs, operands) ] naive_ast = naive_einsum(equation, *funsor_operands) optimized_ast = apply_optimizer(naive_ast) print("Naive expression: {}".format(naive_ast)) print("Optimized expression: {}".format(optimized_ast)) actual_optimized = reinterpret(optimized_ast) # eager by default actual = naive_einsum(equation, *map(reinterpret, funsor_operands)) if len(outputs[0]) > 0: actual = actual.align(tuple(outputs[0])) actual_optimized = actual_optimized.align(tuple(outputs[0])) assert_close(actual, actual_optimized, atol=1e-4) assert expected.shape == actual.data.shape assert_close(expected, actual.data) for output in outputs: for i, output_dim in enumerate(output): assert output_dim in actual.inputs assert actual.inputs[output_dim].dtype == sizes[output_dim]
def adjoint(expr, targets, start=Number(0.)): adjoint_values = defaultdict(lambda: Number(0.)) # 1 in logspace multiplicities = defaultdict(lambda: 0) tape_recorder = AdjointTape() with interpretation(tape_recorder): adjoint_values[reinterpret(expr)] = start while tape_recorder.tape: output, fn, inputs = tape_recorder.tape.pop() in_adjs = adjoint_ops(fn, adjoint_values[output], output, *inputs) for v, adjv in in_adjs.items(): multiplicities[v] += 1 adjoint_values[v] = adjoint_values[v] + adjv # product in logspace target_adjs = {} for v in targets: target_adjs[v] = adjoint_values[v] / multiplicities[v] if not isinstance(v, Variable): target_adjs[v] = target_adjs[v] + v return target_adjs
def test_einsum_categorical(equation): inputs, outputs, sizes, operands, _ = make_einsum_example(equation) operands = [operand.abs() / operand.abs().sum(-1, keepdim=True) for operand in operands] expected = opt_einsum.contract(equation, *operands, backend='torch') with interpretation(reflect): funsor_operands = [ Categorical(probs=Tensor( operand, inputs=OrderedDict([(d, bint(sizes[d])) for d in inp[:-1]]) ))(value=Variable(inp[-1], bint(sizes[inp[-1]]))).exp() for inp, operand in zip(inputs, operands) ] naive_ast = naive_einsum(equation, *funsor_operands) optimized_ast = apply_optimizer(naive_ast) print("Naive expression: {}".format(naive_ast)) print("Optimized expression: {}".format(optimized_ast)) actual_optimized = reinterpret(optimized_ast) # eager by default actual = naive_einsum(equation, *map(reinterpret, funsor_operands)) if len(outputs[0]) > 0: actual = actual.align(tuple(outputs[0])) actual_optimized = actual_optimized.align(tuple(outputs[0])) assert_close(actual, actual_optimized, atol=1e-4) assert expected.shape == actual.data.shape assert torch.allclose(expected, actual.data) for output in outputs: for i, output_dim in enumerate(output): assert output_dim in actual.inputs assert actual.inputs[output_dim].dtype == sizes[output_dim]
def einsum(eqn, *terms, **kwargs): with interpretation(reflect): naive_ast = naive_plated_einsum(eqn, *terms, **kwargs) optimized_ast = apply_optimizer(naive_ast) return reinterpret(optimized_ast) # eager by default