def differentiable_loss(self, model, guide, *args, **kwargs): with plate(size=self.num_particles) if self.num_particles > 1 else contextlib.ExitStack(), \ enum(first_available_dim=(-self.max_plate_nesting-1) if self.max_plate_nesting else None): guide_tr = trace(guide).get_trace(*args, **kwargs) model_tr = trace(replay(model, trace=guide_tr)).get_trace( *args, **kwargs) model_terms = terms_from_trace(model_tr) guide_terms = terms_from_trace(guide_tr) log_measures = guide_terms["log_measures"] + model_terms["log_measures"] log_factors = model_terms["log_factors"] + [ -f for f in guide_terms["log_factors"] ] plate_vars = model_terms["plate_vars"] | guide_terms["plate_vars"] measure_vars = model_terms["measure_vars"] | guide_terms["measure_vars"] with funsor.interpreter.interpretation(funsor.terms.lazy): elbo = funsor.sum_product.sum_product(funsor.ops.logaddexp, funsor.ops.add, log_measures + log_factors, eliminate=measure_vars | plate_vars, plates=plate_vars) return -to_data(funsor.optimizer.apply_optimizer(elbo))
def differentiable_loss(self, model, guide, *args, **kwargs): with enum(), plate( size=self.num_particles ) if self.num_particles > 1 else contextlib.ExitStack(): guide_tr = trace(config_enumerate( default="flat")(guide)).get_trace(*args, **kwargs) model_tr = trace(replay(model, trace=guide_tr)).get_trace( *args, **kwargs) model_terms = terms_from_trace(model_tr) guide_terms = terms_from_trace(guide_tr) log_measures = guide_terms["log_measures"] + model_terms["log_measures"] log_factors = model_terms["log_factors"] + [ -f for f in guide_terms["log_factors"] ] plate_vars = model_terms["plate_vars"] | guide_terms["plate_vars"] measure_vars = model_terms["measure_vars"] | guide_terms["measure_vars"] elbo = funsor.Integrate( sum(log_measures, to_funsor(0.0)), sum(log_factors, to_funsor(0.0)), measure_vars, ) elbo = elbo.reduce(funsor.ops.add, plate_vars) return -to_data(elbo)
def differentiable_loss(self, model, guide, *args, **kwargs): # get batched, enumerated, to_funsor-ed traces from the guide and model with plate(size=self.num_particles) if self.num_particles > 1 else contextlib.ExitStack(), \ enum(first_available_dim=(-self.max_plate_nesting-1) if self.max_plate_nesting else None): guide_tr = trace(guide).get_trace(*args, **kwargs) model_tr = trace(replay(model, trace=guide_tr)).get_trace( *args, **kwargs) # extract from traces all metadata that we will need to compute the elbo guide_terms = terms_from_trace(guide_tr) model_terms = terms_from_trace(model_tr) # build up a lazy expression for the elbo with funsor.interpreter.interpretation(funsor.terms.lazy): # identify and contract out auxiliary variables in the model with partial_sum_product contracted_factors, uncontracted_factors = [], [] for f in model_terms["log_factors"]: if model_terms["measure_vars"].intersection(f.inputs): contracted_factors.append(f) else: uncontracted_factors.append(f) # incorporate the effects of subsampling and handlers.scale through a common scale factor contracted_costs = [ model_terms["scale"] * f for f in funsor.sum_product.partial_sum_product( funsor.ops.logaddexp, funsor.ops.add, model_terms["log_measures"] + contracted_factors, plates=model_terms["plate_vars"], eliminate=model_terms["measure_vars"]) ] costs = contracted_costs + uncontracted_factors # model costs: logp costs += [-f for f in guide_terms["log_factors"] ] # guide costs: -logq # finally, integrate out guide variables in the elbo and all plates plate_vars = guide_terms["plate_vars"] | model_terms["plate_vars"] elbo = to_funsor(0, output=funsor.Real) for cost in costs: # compute the marginal logq in the guide corresponding to this cost term log_prob = funsor.sum_product.sum_product( funsor.ops.logaddexp, funsor.ops.add, guide_terms["log_measures"], plates=plate_vars, eliminate=(plate_vars | guide_terms["measure_vars"]) - frozenset(cost.inputs)) # compute the expected cost term E_q[logp] or E_q[-logq] using the marginal logq for q elbo_term = funsor.Integrate( log_prob, cost, guide_terms["measure_vars"] & frozenset(cost.inputs)) elbo += elbo_term.reduce(funsor.ops.add, plate_vars & frozenset(cost.inputs)) # evaluate the elbo, using memoize to share tensor computation where possible with funsor.memoize.memoize(): return -to_data(funsor.optimizer.apply_optimizer(elbo))
def _sample_posterior(model, first_available_dim, temperature, *args, **kwargs): if temperature == 0: sum_op, prod_op = funsor.ops.max, funsor.ops.add approx = funsor.approximations.argmax_approximate elif temperature == 1: sum_op, prod_op = funsor.ops.logaddexp, funsor.ops.add approx = funsor.montecarlo.MonteCarlo() else: raise ValueError("temperature must be 0 (map) or 1 (sample) for now") with block(), enum(first_available_dim=first_available_dim): # XXX replay against an empty Trace to ensure densities are not double-counted model_tr = trace(replay(model, trace=Trace())).get_trace(*args, **kwargs) terms = terms_from_trace(model_tr) # terms["log_factors"] = [log p(x) for each observed or latent sample site x] # terms["log_measures"] = [log p(z) or other Dice factor # for each latent sample site z] with funsor.interpretations.lazy: log_prob = funsor.sum_product.sum_product( sum_op, prod_op, terms["log_factors"] + terms["log_measures"], eliminate=terms["measure_vars"] | terms["plate_vars"], plates=terms["plate_vars"], ) log_prob = funsor.optimizer.apply_optimizer(log_prob) with approx: approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob) # construct a result trace to replay against the model sample_tr = model_tr.copy() sample_subs = {} for name, node in sample_tr.nodes.items(): if node["type"] != "sample" or site_is_subsample(node): continue if node["is_observed"]: # "observed" values may be collapsed samples that depend on enumerated # values, so we have to slice them down # TODO this should really be handled entirely under the hood by adjoint node["funsor"] = {"value": node["funsor"]["value"](**sample_subs)} else: node["funsor"]["log_measure"] = approx_factors[node["funsor"] ["log_measure"]] node["funsor"]["value"] = _get_support_value( node["funsor"]["log_measure"], name) sample_subs[name] = node["funsor"]["value"] with replay(trace=sample_tr): return model(*args, **kwargs)
def differentiable_loss(self, model, guide, *args, **kwargs): # get batched, enumerated, to_funsor-ed traces from the guide and model with plate( size=self.num_particles ) if self.num_particles > 1 else contextlib.ExitStack(), enum( first_available_dim=(-self.max_plate_nesting - 1) if self.max_plate_nesting else None): guide_tr = trace(guide).get_trace(*args, **kwargs) model_tr = trace(replay(model, trace=guide_tr)).get_trace( *args, **kwargs) # extract from traces all metadata that we will need to compute the elbo guide_terms = terms_from_trace(guide_tr) model_terms = terms_from_trace(model_tr) # build up a lazy expression for the elbo with funsor.terms.lazy: # identify and contract out auxiliary variables in the model with partial_sum_product contracted_factors, uncontracted_factors = [], [] for f in model_terms["log_factors"]: if model_terms["measure_vars"].intersection(f.inputs): contracted_factors.append(f) else: uncontracted_factors.append(f) # incorporate the effects of subsampling and handlers.scale through a common scale factor contracted_costs = [ model_terms["scale"] * f for f in funsor.sum_product.partial_sum_product( funsor.ops.logaddexp, funsor.ops.add, model_terms["log_measures"] + contracted_factors, plates=model_terms["plate_vars"], eliminate=model_terms["measure_vars"], ) ] # accumulate costs from model (logp) and guide (-logq) costs = contracted_costs + uncontracted_factors # model costs: logp costs += [-f for f in guide_terms["log_factors"] ] # guide costs: -logq # compute expected cost # Cf. pyro.infer.util.Dice.compute_expectation() # https://github.com/pyro-ppl/pyro/blob/0.3.0/pyro/infer/util.py#L212 # TODO Replace this with funsor.Expectation plate_vars = guide_terms["plate_vars"] | model_terms["plate_vars"] # compute the marginal logq in the guide corresponding to each cost term targets = dict() for cost in costs: input_vars = frozenset(cost.inputs) if input_vars not in targets: targets[input_vars] = funsor.Tensor( funsor.ops.new_zeros( funsor.tensor.get_default_prototype(), tuple(v.size for v in cost.inputs.values()), ), cost.inputs, cost.dtype, ) with AdjointTape() as tape: logzq = funsor.sum_product.sum_product( funsor.ops.logaddexp, funsor.ops.add, guide_terms["log_measures"] + list(targets.values()), plates=plate_vars, eliminate=(plate_vars | guide_terms["measure_vars"]), ) marginals = tape.adjoint(funsor.ops.logaddexp, funsor.ops.add, logzq, tuple(targets.values())) # finally, integrate out guide variables in the elbo and all plates elbo = to_funsor(0, output=funsor.Real) for cost in costs: target = targets[frozenset(cost.inputs)] logzq_local = marginals[target].reduce( funsor.ops.logaddexp, frozenset(cost.inputs) - plate_vars) log_prob = marginals[target] - logzq_local elbo_term = funsor.Integrate( log_prob, cost, guide_terms["measure_vars"] & frozenset(log_prob.inputs), ) elbo += elbo_term.reduce(funsor.ops.add, plate_vars & frozenset(cost.inputs)) # evaluate the elbo, using memoize to share tensor computation where possible with funsor.interpretations.memoize(): return -to_data(apply_optimizer(elbo))