def main(args): # Declare parameters. trans_noise = torch.tensor(0.1, requires_grad=True) emit_noise = torch.tensor(0.5, requires_grad=True) params = [trans_noise, emit_noise] # A Gaussian HMM model. def model(data): log_prob = funsor.to_funsor(0.) x_curr = funsor.Tensor(torch.tensor(0.)) for t, y in enumerate(data): x_prev = x_curr # A delayed sample statement. x_curr = funsor.Variable('x_{}'.format(t), funsor.reals()) log_prob += dist.Normal(1 + x_prev / 2., trans_noise, value=x_curr) # Optionally marginalize out the previous state. if t > 0 and not args.lazy: log_prob = log_prob.reduce(ops.logaddexp, x_prev.name) # An observe statement. log_prob += dist.Normal(0.5 + 3 * x_curr, emit_noise, value=y) # Marginalize out all remaining delayed variables. log_prob = log_prob.reduce(ops.logaddexp) return log_prob # Train model parameters. torch.manual_seed(0) data = torch.randn(args.time_steps) optim = torch.optim.Adam(params, lr=args.learning_rate) for step in range(args.train_steps): optim.zero_grad() if args.lazy: with interpretation(lazy): log_prob = apply_optimizer(model(data)) log_prob = reinterpret(log_prob) else: log_prob = model(data) assert not log_prob.inputs, 'free variables remain' loss = -log_prob.data loss.backward() optim.step() if args.verbose and step % 10 == 0: print('step {} loss = {}'.format(step, loss.item()))
def test_optimized_einsum(equation, backend, einsum_impl): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) expected = pyro_einsum(equation, *operands, backend=backend)[0] with interpretation(normalize): naive_ast = einsum_impl(equation, *funsor_operands, backend=backend) optimized_ast = apply_optimizer(naive_ast) actual = reinterpret(optimized_ast) # eager by default assert isinstance(actual, funsor.Tensor) and len(outputs) == 1 if len(outputs[0]) > 0: actual = actual.align(tuple(outputs[0])) assert expected.shape == actual.data.shape assert torch.allclose(expected, actual.data) for output in outputs: for i, output_dim in enumerate(output): assert output_dim in actual.inputs assert actual.inputs[output_dim].dtype == sizes[output_dim]
def test_sequential_sum_product_adjoint(impl, sum_op, prod_op, batch_inputs, state_domain, num_steps): # test mostly copied from test_sum_product.py inputs = OrderedDict(batch_inputs) inputs.update(prev=state_domain, curr=state_domain) inputs["time"] = bint(num_steps) if state_domain.dtype == "real": trans = random_gaussian(inputs) else: trans = random_tensor(inputs) time = Variable("time", bint(num_steps)) with AdjointTape() as actual_tape: actual = impl(sum_op, prod_op, trans, time, {"prev": "curr"}) expected_inputs = batch_inputs.copy() expected_inputs.update(prev=state_domain, curr=state_domain) assert dict(actual.inputs) == expected_inputs # Check against contract. operands = tuple(trans(time=t, prev="t_{}".format(t), curr="t_{}".format(t+1)) for t in range(num_steps)) reduce_vars = frozenset("t_{}".format(t) for t in range(1, num_steps)) with AdjointTape() as expected_tape: with interpretation(reflect): expected = sum_product(sum_op, prod_op, operands, reduce_vars) expected = apply_optimizer(expected) expected = expected(**{"t_0": "prev", "t_{}".format(num_steps): "curr"}) expected = expected.align(tuple(actual.inputs.keys())) # check forward pass (sanity check) assert_close(actual, expected, rtol=5e-4 * num_steps) # perform backward passes only after the sanity check expected_bwds = expected_tape.adjoint(sum_op, prod_op, expected, operands) actual_bwd = actual_tape.adjoint(sum_op, prod_op, actual, (trans,))[trans] # check backward pass for t, operand in enumerate(operands): actual_bwd_t = actual_bwd(time=t, prev="t_{}".format(t), curr="t_{}".format(t+1)) expected_bwd = expected_bwds[operand].align(tuple(actual_bwd_t.inputs.keys())) check_funsor(actual_bwd_t, expected_bwd.inputs, expected_bwd.output) assert_close(actual_bwd_t, expected_bwd, rtol=5e-4 * num_steps)
def test_einsum_categorical(equation): if get_backend() == "jax": from funsor.jax.distributions import Categorical else: from funsor.torch.distributions import Categorical inputs, outputs, sizes, operands, _ = make_einsum_example(equation) operands = [ops.abs(operand) / ops.abs(operand).sum(-1)[..., None] for operand in operands] expected = opt_einsum.contract(equation, *operands, backend=BACKEND_TO_EINSUM_BACKEND[get_backend()]) with interpretation(reflect): funsor_operands = [ Categorical(probs=Tensor( operand, inputs=OrderedDict([(d, Bint[sizes[d]]) for d in inp[:-1]]) ))(value=Variable(inp[-1], Bint[sizes[inp[-1]]])).exp() for inp, operand in zip(inputs, operands) ] naive_ast = naive_einsum(equation, *funsor_operands) optimized_ast = apply_optimizer(naive_ast) print("Naive expression: {}".format(naive_ast)) print("Optimized expression: {}".format(optimized_ast)) actual_optimized = reinterpret(optimized_ast) # eager by default actual = naive_einsum(equation, *map(reinterpret, funsor_operands)) if len(outputs[0]) > 0: actual = actual.align(tuple(outputs[0])) actual_optimized = actual_optimized.align(tuple(outputs[0])) assert_close(actual, actual_optimized, atol=1e-4) assert expected.shape == actual.data.shape assert_close(expected, actual.data) for output in outputs: for i, output_dim in enumerate(output): assert output_dim in actual.inputs assert actual.inputs[output_dim].dtype == sizes[output_dim]
def parallel_loss_fn(model, guide, parallel=True): # We're doing exact inference, so we don't use the guide here. factors = model() t_term, new_factors = factors[0], factors[1:] t = to_funsor("t", t_term.inputs["t"]) if parallel: result = MarkovProduct(ops.logaddexp, ops.add, t_term, t, {"y(t=1)": "y"}) else: result = naive_sequential_sum_product(ops.logaddexp, ops.add, t_term, t, {"y(t=1)": "y"}) new_factors = [result] + new_factors plates = frozenset(['g', 'i']) eliminate = frozenset().union(*(f.inputs for f in new_factors)) with interpretation(lazy): loss = sum_product(ops.logaddexp, ops.add, new_factors, eliminate, plates) loss = apply_optimizer(loss) assert not loss.inputs return -loss.data
def test_plated_einsum(equation, plates, backend): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) expected = pyro_einsum(equation, *operands, plates=plates, backend=backend, modulo_total=False)[0] with interpretation(reflect): naive_ast = naive_plated_einsum(equation, *funsor_operands, plates=plates, backend=backend) optimized_ast = apply_optimizer(naive_ast) actual_optimized = reinterpret(optimized_ast) # eager by default actual = naive_plated_einsum(equation, *funsor_operands, plates=plates, backend=backend) if len(outputs[0]) > 0: actual = actual.align(tuple(outputs[0])) actual_optimized = actual_optimized.align(tuple(outputs[0])) assert_close(actual, actual_optimized, atol=1e-3 if backend == 'torch' else 1e-4) assert expected.shape == actual.data.shape assert torch.allclose(expected, actual.data) for output in outputs: for i, output_dim in enumerate(output): assert output_dim in actual.inputs assert actual.inputs[output_dim].dtype == sizes[output_dim]
def test_einsum_categorical(equation): inputs, outputs, sizes, operands, _ = make_einsum_example(equation) operands = [operand.abs() / operand.abs().sum(-1, keepdim=True) for operand in operands] expected = opt_einsum.contract(equation, *operands, backend='torch') with interpretation(reflect): funsor_operands = [ Categorical(probs=Tensor( operand, inputs=OrderedDict([(d, bint(sizes[d])) for d in inp[:-1]]) ))(value=Variable(inp[-1], bint(sizes[inp[-1]]))).exp() for inp, operand in zip(inputs, operands) ] naive_ast = naive_einsum(equation, *funsor_operands) optimized_ast = apply_optimizer(naive_ast) print("Naive expression: {}".format(naive_ast)) print("Optimized expression: {}".format(optimized_ast)) actual_optimized = reinterpret(optimized_ast) # eager by default actual = naive_einsum(equation, *map(reinterpret, funsor_operands)) if len(outputs[0]) > 0: actual = actual.align(tuple(outputs[0])) actual_optimized = actual_optimized.align(tuple(outputs[0])) assert_close(actual, actual_optimized, atol=1e-4) assert expected.shape == actual.data.shape assert torch.allclose(expected, actual.data) for output in outputs: for i, output_dim in enumerate(output): assert output_dim in actual.inputs assert actual.inputs[output_dim].dtype == sizes[output_dim]
def test_einsum(equation, backend): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example(equation) expected = opt_einsum.contract(equation, *operands, backend=backend) with interpretation(reflect): naive_ast = naive_einsum(equation, *funsor_operands, backend=backend) optimized_ast = apply_optimizer(naive_ast) print("Naive expression: {}".format(naive_ast)) print("Optimized expression: {}".format(optimized_ast)) actual_optimized = reinterpret(optimized_ast) # eager by default actual = naive_einsum(equation, *funsor_operands, backend=backend) assert isinstance(actual, funsor.Tensor) and len(outputs) == 1 if len(outputs[0]) > 0: actual = actual.align(tuple(outputs[0])) actual_optimized = actual_optimized.align(tuple(outputs[0])) assert_close(actual, actual_optimized, atol=1e-4) assert expected.shape == actual.data.shape assert torch.allclose(expected, actual.data) for output in outputs: for i, output_dim in enumerate(output): assert output_dim in actual.inputs assert actual.inputs[output_dim].dtype == sizes[output_dim]
def einsum(eqn, *terms, **kwargs): with interpretation(reflect): naive_ast = naive_plated_einsum(eqn, *terms, **kwargs) optimized_ast = apply_optimizer(naive_ast) return reinterpret(optimized_ast) # eager by default