def test_sequential_sum_product_adjoint(impl, sum_op, prod_op, batch_inputs, state_domain, num_steps): # test mostly copied from test_sum_product.py inputs = OrderedDict(batch_inputs) inputs.update(prev=state_domain, curr=state_domain) inputs["time"] = bint(num_steps) if state_domain.dtype == "real": trans = random_gaussian(inputs) else: trans = random_tensor(inputs) time = Variable("time", bint(num_steps)) with AdjointTape() as actual_tape: actual = impl(sum_op, prod_op, trans, time, {"prev": "curr"}) expected_inputs = batch_inputs.copy() expected_inputs.update(prev=state_domain, curr=state_domain) assert dict(actual.inputs) == expected_inputs # Check against contract. operands = tuple( trans(time=t, prev="t_{}".format(t), curr="t_{}".format(t + 1)) for t in range(num_steps)) reduce_vars = frozenset("t_{}".format(t) for t in range(1, num_steps)) with AdjointTape() as expected_tape: with interpretation(reflect): expected = sum_product(sum_op, prod_op, operands, reduce_vars) expected = apply_optimizer(expected) expected = expected(**{ "t_0": "prev", "t_{}".format(num_steps): "curr" }) expected = expected.align(tuple(actual.inputs.keys())) # check forward pass (sanity check) assert_close(actual, expected, rtol=5e-4 * num_steps) # perform backward passes only after the sanity check expected_bwds = expected_tape.adjoint(sum_op, prod_op, expected, operands) actual_bwd = actual_tape.adjoint(sum_op, prod_op, actual, (trans, ))[trans] # check backward pass for t, operand in enumerate(operands): actual_bwd_t = actual_bwd(time=t, prev="t_{}".format(t), curr="t_{}".format(t + 1)) expected_bwd = expected_bwds[operand].align( tuple(actual_bwd_t.inputs.keys())) check_funsor(actual_bwd_t, expected_bwd.inputs, expected_bwd.output) assert_close(actual_bwd_t, expected_bwd, rtol=5e-4 * num_steps)
def test_plated_einsum_adjoint(einsum_impl, equation, plates, backend): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example( equation) sum_op, prod_op = BACKEND_ADJOINT_OPS[backend] with AdjointTape() as tape: # interpretation(reflect): fwd_expr = einsum_impl(equation, *funsor_operands, plates=plates, backend=backend) actuals = tape.adjoint(sum_op, prod_op, fwd_expr, funsor_operands) for operand in operands: pyro_require_backward(operand) expected_out = pyro_einsum(equation, *operands, modulo_total=False, plates=plates, backend=backend)[0] expected_out._pyro_backward() for i, (inp, tv, fv) in enumerate(zip(inputs, operands, funsor_operands)): actual = actuals[fv] expected = tv._pyro_backward_result if inp: actual = actual.align(tuple(inp)) assert isinstance(actual, funsor.Tensor) assert expected.shape == actual.data.shape assert torch.allclose(expected, actual.data, atol=1e-7)
def test_einsum_adjoint_unary_marginals(einsum_impl, equation, backend): sum_op, prod_op = BACKEND_ADJOINT_OPS[backend] with AdjointTape() as tape: # interpretation(reflect): inputs, outputs, sizes, operands, funsor_operands = make_einsum_example( equation) equation = ",".join(inputs) + "->" targets = [Variable(k, bint(sizes[k])) for k in set(sizes)] fwd_expr = einsum_impl(equation, *funsor_operands, backend=backend) actuals = tape.adjoint(sum_op, prod_op, fwd_expr, targets) for target in targets: actual = actuals[target] expected = opt_einsum.contract(equation + target.name, *operands, backend=backend) assert isinstance(actual, funsor.Tensor) assert expected.shape == actual.data.shape assert torch.allclose(expected, actual.data, atol=1e-7)
def differentiable_loss(self, model, guide, *args, **kwargs): # get batched, enumerated, to_funsor-ed traces from the guide and model with plate( size=self.num_particles ) if self.num_particles > 1 else contextlib.ExitStack(), enum( first_available_dim=(-self.max_plate_nesting - 1) if self.max_plate_nesting else None): guide_tr = trace(guide).get_trace(*args, **kwargs) model_tr = trace(replay(model, trace=guide_tr)).get_trace( *args, **kwargs) # extract from traces all metadata that we will need to compute the elbo guide_terms = terms_from_trace(guide_tr) model_terms = terms_from_trace(model_tr) # build up a lazy expression for the elbo with funsor.terms.lazy: # identify and contract out auxiliary variables in the model with partial_sum_product contracted_factors, uncontracted_factors = [], [] for f in model_terms["log_factors"]: if model_terms["measure_vars"].intersection(f.inputs): contracted_factors.append(f) else: uncontracted_factors.append(f) # incorporate the effects of subsampling and handlers.scale through a common scale factor contracted_costs = [ model_terms["scale"] * f for f in funsor.sum_product.partial_sum_product( funsor.ops.logaddexp, funsor.ops.add, model_terms["log_measures"] + contracted_factors, plates=model_terms["plate_vars"], eliminate=model_terms["measure_vars"], ) ] # accumulate costs from model (logp) and guide (-logq) costs = contracted_costs + uncontracted_factors # model costs: logp costs += [-f for f in guide_terms["log_factors"] ] # guide costs: -logq # compute expected cost # Cf. pyro.infer.util.Dice.compute_expectation() # https://github.com/pyro-ppl/pyro/blob/0.3.0/pyro/infer/util.py#L212 # TODO Replace this with funsor.Expectation plate_vars = guide_terms["plate_vars"] | model_terms["plate_vars"] # compute the marginal logq in the guide corresponding to each cost term targets = dict() for cost in costs: input_vars = frozenset(cost.inputs) if input_vars not in targets: targets[input_vars] = funsor.Tensor( funsor.ops.new_zeros( funsor.tensor.get_default_prototype(), tuple(v.size for v in cost.inputs.values()), ), cost.inputs, cost.dtype, ) with AdjointTape() as tape: logzq = funsor.sum_product.sum_product( funsor.ops.logaddexp, funsor.ops.add, guide_terms["log_measures"] + list(targets.values()), plates=plate_vars, eliminate=(plate_vars | guide_terms["measure_vars"]), ) marginals = tape.adjoint(funsor.ops.logaddexp, funsor.ops.add, logzq, tuple(targets.values())) # finally, integrate out guide variables in the elbo and all plates elbo = to_funsor(0, output=funsor.Real) for cost in costs: target = targets[frozenset(cost.inputs)] logzq_local = marginals[target].reduce( funsor.ops.logaddexp, frozenset(cost.inputs) - plate_vars) log_prob = marginals[target] - logzq_local elbo_term = funsor.Integrate( log_prob, cost, guide_terms["measure_vars"] & frozenset(log_prob.inputs), ) elbo += elbo_term.reduce(funsor.ops.add, plate_vars & frozenset(cost.inputs)) # evaluate the elbo, using memoize to share tensor computation where possible with funsor.interpretations.memoize(): return -to_data(apply_optimizer(elbo))