def differentiable_loss(self, model, guide, *args, **kwargs): with plate(size=self.num_particles) if self.num_particles > 1 else contextlib.ExitStack(), \ enum(first_available_dim=(-self.max_plate_nesting-1) if self.max_plate_nesting else None): guide_tr = trace(guide).get_trace(*args, **kwargs) model_tr = trace(replay(model, trace=guide_tr)).get_trace( *args, **kwargs) model_terms = terms_from_trace(model_tr) guide_terms = terms_from_trace(guide_tr) log_measures = guide_terms["log_measures"] + model_terms["log_measures"] log_factors = model_terms["log_factors"] + [ -f for f in guide_terms["log_factors"] ] plate_vars = model_terms["plate_vars"] | guide_terms["plate_vars"] measure_vars = model_terms["measure_vars"] | guide_terms["measure_vars"] with funsor.interpreter.interpretation(funsor.terms.lazy): elbo = funsor.sum_product.sum_product(funsor.ops.logaddexp, funsor.ops.add, log_measures + log_factors, eliminate=measure_vars | plate_vars, plates=plate_vars) return -to_data(funsor.optimizer.apply_optimizer(elbo))
def _sample_posterior(model, first_available_dim, temperature, *args, **kwargs): if temperature == 0: sum_op, prod_op = funsor.ops.max, funsor.ops.add approx = funsor.approximations.argmax_approximate elif temperature == 1: sum_op, prod_op = funsor.ops.logaddexp, funsor.ops.add approx = funsor.montecarlo.MonteCarlo() else: raise ValueError("temperature must be 0 (map) or 1 (sample) for now") with block(), enum(first_available_dim=first_available_dim): # XXX replay against an empty Trace to ensure densities are not double-counted model_tr = trace(replay(model, trace=Trace())).get_trace(*args, **kwargs) terms = terms_from_trace(model_tr) # terms["log_factors"] = [log p(x) for each observed or latent sample site x] # terms["log_measures"] = [log p(z) or other Dice factor # for each latent sample site z] with funsor.interpretations.lazy: log_prob = funsor.sum_product.sum_product( sum_op, prod_op, terms["log_factors"] + terms["log_measures"], eliminate=terms["measure_vars"] | terms["plate_vars"], plates=terms["plate_vars"], ) log_prob = funsor.optimizer.apply_optimizer(log_prob) with approx: approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob) # construct a result trace to replay against the model sample_tr = model_tr.copy() sample_subs = {} for name, node in sample_tr.nodes.items(): if node["type"] != "sample" or site_is_subsample(node): continue if node["is_observed"]: # "observed" values may be collapsed samples that depend on enumerated # values, so we have to slice them down # TODO this should really be handled entirely under the hood by adjoint node["funsor"] = {"value": node["funsor"]["value"](**sample_subs)} else: node["funsor"]["log_measure"] = approx_factors[node["funsor"] ["log_measure"]] node["funsor"]["value"] = _get_support_value( node["funsor"]["log_measure"], name) sample_subs[name] = node["funsor"]["value"] with replay(trace=sample_tr): return model(*args, **kwargs)
def test_enumeration_multi(model, weeks_data, days_data, vars1, vars2, history, use_replay): pyro.clear_param_store() with pyro_backend("contrib.funsor"): with handlers.enum(): enum_model = infer.config_enumerate(model, default="parallel") # sequential factors trace = handlers.trace(enum_model).get_trace( weeks_data, days_data, history, False) # vectorized trace if use_replay: guide_trace = handlers.trace( _guide_from_model(model)).get_trace( weeks_data, days_data, history, True) vectorized_trace = handlers.trace( handlers.replay(model, trace=guide_trace)).get_trace( weeks_data, days_data, history, True) else: vectorized_trace = handlers.trace(enum_model).get_trace( weeks_data, days_data, history, True) factors = list() # sequential weeks factors for i in range(len(weeks_data)): for v in vars1: factors.append(trace.nodes["{}_{}".format( v, i)]["funsor"]["log_prob"]) # sequential days factors for j in range(len(days_data)): for v in vars2: factors.append(trace.nodes["{}_{}".format( v, j)]["funsor"]["log_prob"]) vectorized_factors = list() # vectorized weeks factors for i in range(history): for v in vars1: vectorized_factors.append( vectorized_trace.nodes["{}_{}".format( v, i)]["funsor"]["log_prob"]) for i in range(history, len(weeks_data)): for v in vars1: vectorized_factors.append( vectorized_trace.nodes["{}_{}".format( v, slice(history, len(weeks_data)))]["funsor"]["log_prob"](**{ "weeks": i - history }, **{ "{}_{}".format( k, slice(history - j, len(weeks_data) - j)): "{}_{}".format(k, i - j) for j in range(history + 1) for k in vars1 })) # vectorized days factors for i in range(history): for v in vars2: vectorized_factors.append( vectorized_trace.nodes["{}_{}".format( v, i)]["funsor"]["log_prob"]) for i in range(history, len(days_data)): for v in vars2: vectorized_factors.append( vectorized_trace.nodes["{}_{}".format( v, slice(history, len(days_data)))]["funsor"]["log_prob"](**{ "days": i - history }, **{ "{}_{}".format( k, slice(history - j, len(days_data) - j)): "{}_{}".format(k, i - j) for j in range(history + 1) for k in vars2 })) # assert correct factors for f1, f2 in zip(factors, vectorized_factors): assert_close(f2, f1.align(tuple(f2.inputs))) # assert correct step expected_measure_vars = frozenset() actual_weeks_step = vectorized_trace.nodes["weeks"]["value"] # expected step: assume that all but the last var is markov expected_weeks_step = frozenset() for v in vars1[:-1]: v_step = tuple("{}_{}".format(v, i) for i in range(history)) \ + tuple("{}_{}".format(v, slice(j, len(weeks_data)-history+j)) for j in range(history+1)) expected_weeks_step |= frozenset({v_step}) # grab measure_vars, found only at sites that are not replayed if not use_replay: expected_measure_vars |= frozenset(v_step) actual_days_step = vectorized_trace.nodes["days"]["value"] # expected step: assume that all but the last var is markov expected_days_step = frozenset() for v in vars2[:-1]: v_step = tuple("{}_{}".format(v, i) for i in range(history)) \ + tuple("{}_{}".format(v, slice(j, len(days_data)-history+j)) for j in range(history+1)) expected_days_step |= frozenset({v_step}) # grab measure_vars, found only at sites that are not replayed if not use_replay: expected_measure_vars |= frozenset(v_step) assert actual_weeks_step == expected_weeks_step assert actual_days_step == expected_days_step # check measure_vars actual_measure_vars = terms_from_trace( vectorized_trace)["measure_vars"] assert actual_measure_vars == expected_measure_vars