def differentiable_loss(self, model, guide, *args, **kwargs):
        with plate(size=self.num_particles) if self.num_particles > 1 else contextlib.ExitStack(), \
                enum(first_available_dim=(-self.max_plate_nesting-1) if self.max_plate_nesting else None):
            guide_tr = trace(guide).get_trace(*args, **kwargs)
            model_tr = trace(replay(model, trace=guide_tr)).get_trace(
                *args, **kwargs)

        model_terms = terms_from_trace(model_tr)
        guide_terms = terms_from_trace(guide_tr)

        log_measures = guide_terms["log_measures"] + model_terms["log_measures"]
        log_factors = model_terms["log_factors"] + [
            -f for f in guide_terms["log_factors"]
        ]
        plate_vars = model_terms["plate_vars"] | guide_terms["plate_vars"]
        measure_vars = model_terms["measure_vars"] | guide_terms["measure_vars"]

        with funsor.interpreter.interpretation(funsor.terms.lazy):
            elbo = funsor.sum_product.sum_product(funsor.ops.logaddexp,
                                                  funsor.ops.add,
                                                  log_measures + log_factors,
                                                  eliminate=measure_vars
                                                  | plate_vars,
                                                  plates=plate_vars)

        return -to_data(funsor.optimizer.apply_optimizer(elbo))
Beispiel #2
0
def _sample_posterior(model, first_available_dim, temperature, *args,
                      **kwargs):

    if temperature == 0:
        sum_op, prod_op = funsor.ops.max, funsor.ops.add
        approx = funsor.approximations.argmax_approximate
    elif temperature == 1:
        sum_op, prod_op = funsor.ops.logaddexp, funsor.ops.add
        approx = funsor.montecarlo.MonteCarlo()
    else:
        raise ValueError("temperature must be 0 (map) or 1 (sample) for now")

    with block(), enum(first_available_dim=first_available_dim):
        # XXX replay against an empty Trace to ensure densities are not double-counted
        model_tr = trace(replay(model,
                                trace=Trace())).get_trace(*args, **kwargs)

    terms = terms_from_trace(model_tr)
    # terms["log_factors"] = [log p(x) for each observed or latent sample site x]
    # terms["log_measures"] = [log p(z) or other Dice factor
    #                          for each latent sample site z]

    with funsor.interpretations.lazy:
        log_prob = funsor.sum_product.sum_product(
            sum_op,
            prod_op,
            terms["log_factors"] + terms["log_measures"],
            eliminate=terms["measure_vars"] | terms["plate_vars"],
            plates=terms["plate_vars"],
        )
        log_prob = funsor.optimizer.apply_optimizer(log_prob)

    with approx:
        approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob)

    # construct a result trace to replay against the model
    sample_tr = model_tr.copy()
    sample_subs = {}
    for name, node in sample_tr.nodes.items():
        if node["type"] != "sample" or site_is_subsample(node):
            continue
        if node["is_observed"]:
            # "observed" values may be collapsed samples that depend on enumerated
            # values, so we have to slice them down
            # TODO this should really be handled entirely under the hood by adjoint
            node["funsor"] = {"value": node["funsor"]["value"](**sample_subs)}
        else:
            node["funsor"]["log_measure"] = approx_factors[node["funsor"]
                                                           ["log_measure"]]
            node["funsor"]["value"] = _get_support_value(
                node["funsor"]["log_measure"], name)
            sample_subs[name] = node["funsor"]["value"]

    with replay(trace=sample_tr):
        return model(*args, **kwargs)
Beispiel #3
0
def test_enumeration_multi(model, weeks_data, days_data, vars1, vars2, history,
                           use_replay):
    pyro.clear_param_store()

    with pyro_backend("contrib.funsor"):
        with handlers.enum():
            enum_model = infer.config_enumerate(model, default="parallel")
            # sequential factors
            trace = handlers.trace(enum_model).get_trace(
                weeks_data, days_data, history, False)

            # vectorized trace
            if use_replay:
                guide_trace = handlers.trace(
                    _guide_from_model(model)).get_trace(
                        weeks_data, days_data, history, True)
                vectorized_trace = handlers.trace(
                    handlers.replay(model, trace=guide_trace)).get_trace(
                        weeks_data, days_data, history, True)
            else:
                vectorized_trace = handlers.trace(enum_model).get_trace(
                    weeks_data, days_data, history, True)

        factors = list()
        # sequential weeks factors
        for i in range(len(weeks_data)):
            for v in vars1:
                factors.append(trace.nodes["{}_{}".format(
                    v, i)]["funsor"]["log_prob"])
        # sequential days factors
        for j in range(len(days_data)):
            for v in vars2:
                factors.append(trace.nodes["{}_{}".format(
                    v, j)]["funsor"]["log_prob"])

        vectorized_factors = list()
        # vectorized weeks factors
        for i in range(history):
            for v in vars1:
                vectorized_factors.append(
                    vectorized_trace.nodes["{}_{}".format(
                        v, i)]["funsor"]["log_prob"])
        for i in range(history, len(weeks_data)):
            for v in vars1:
                vectorized_factors.append(
                    vectorized_trace.nodes["{}_{}".format(
                        v, slice(history,
                                 len(weeks_data)))]["funsor"]["log_prob"](**{
                                     "weeks":
                                     i - history
                                 }, **{
                                     "{}_{}".format(
                                         k,
                                         slice(history - j,
                                               len(weeks_data) - j)):
                                     "{}_{}".format(k, i - j)
                                     for j in range(history + 1) for k in vars1
                                 }))
        # vectorized days factors
        for i in range(history):
            for v in vars2:
                vectorized_factors.append(
                    vectorized_trace.nodes["{}_{}".format(
                        v, i)]["funsor"]["log_prob"])
        for i in range(history, len(days_data)):
            for v in vars2:
                vectorized_factors.append(
                    vectorized_trace.nodes["{}_{}".format(
                        v, slice(history,
                                 len(days_data)))]["funsor"]["log_prob"](**{
                                     "days":
                                     i - history
                                 }, **{
                                     "{}_{}".format(
                                         k,
                                         slice(history - j,
                                               len(days_data) - j)):
                                     "{}_{}".format(k, i - j)
                                     for j in range(history + 1) for k in vars2
                                 }))

        # assert correct factors
        for f1, f2 in zip(factors, vectorized_factors):
            assert_close(f2, f1.align(tuple(f2.inputs)))

        # assert correct step

        expected_measure_vars = frozenset()
        actual_weeks_step = vectorized_trace.nodes["weeks"]["value"]
        # expected step: assume that all but the last var is markov
        expected_weeks_step = frozenset()
        for v in vars1[:-1]:
            v_step = tuple("{}_{}".format(v, i) for i in range(history)) \
                     + tuple("{}_{}".format(v, slice(j, len(weeks_data)-history+j)) for j in range(history+1))
            expected_weeks_step |= frozenset({v_step})
            # grab measure_vars, found only at sites that are not replayed
            if not use_replay:
                expected_measure_vars |= frozenset(v_step)

        actual_days_step = vectorized_trace.nodes["days"]["value"]
        # expected step: assume that all but the last var is markov
        expected_days_step = frozenset()
        for v in vars2[:-1]:
            v_step = tuple("{}_{}".format(v, i) for i in range(history)) \
                     + tuple("{}_{}".format(v, slice(j, len(days_data)-history+j)) for j in range(history+1))
            expected_days_step |= frozenset({v_step})
            # grab measure_vars, found only at sites that are not replayed
            if not use_replay:
                expected_measure_vars |= frozenset(v_step)

        assert actual_weeks_step == expected_weeks_step
        assert actual_days_step == expected_days_step

        # check measure_vars
        actual_measure_vars = terms_from_trace(
            vectorized_trace)["measure_vars"]
        assert actual_measure_vars == expected_measure_vars