示例#1
0
    def differentiable_loss(self, model, guide, *args, **kwargs):
        with enum(), plate(
                size=self.num_particles
        ) if self.num_particles > 1 else contextlib.ExitStack():
            guide_tr = trace(config_enumerate(
                default="flat")(guide)).get_trace(*args, **kwargs)
            model_tr = trace(replay(model, trace=guide_tr)).get_trace(
                *args, **kwargs)

        model_terms = terms_from_trace(model_tr)
        guide_terms = terms_from_trace(guide_tr)

        log_measures = guide_terms["log_measures"] + model_terms["log_measures"]
        log_factors = model_terms["log_factors"] + [
            -f for f in guide_terms["log_factors"]
        ]
        plate_vars = model_terms["plate_vars"] | guide_terms["plate_vars"]
        measure_vars = model_terms["measure_vars"] | guide_terms["measure_vars"]

        elbo = funsor.Integrate(
            sum(log_measures, to_funsor(0.0)),
            sum(log_factors, to_funsor(0.0)),
            measure_vars,
        )
        elbo = elbo.reduce(funsor.ops.add, plate_vars)

        return -to_data(elbo)
示例#2
0
    def differentiable_loss(self, model, guide, *args, **kwargs):

        # get batched, enumerated, to_funsor-ed traces from the guide and model
        with plate(size=self.num_particles) if self.num_particles > 1 else contextlib.ExitStack(), \
                enum(first_available_dim=(-self.max_plate_nesting-1) if self.max_plate_nesting else None):
            guide_tr = trace(guide).get_trace(*args, **kwargs)
            model_tr = trace(replay(model, trace=guide_tr)).get_trace(
                *args, **kwargs)

        # extract from traces all metadata that we will need to compute the elbo
        guide_terms = terms_from_trace(guide_tr)
        model_terms = terms_from_trace(model_tr)

        # build up a lazy expression for the elbo
        with funsor.interpreter.interpretation(funsor.terms.lazy):
            # identify and contract out auxiliary variables in the model with partial_sum_product
            contracted_factors, uncontracted_factors = [], []
            for f in model_terms["log_factors"]:
                if model_terms["measure_vars"].intersection(f.inputs):
                    contracted_factors.append(f)
                else:
                    uncontracted_factors.append(f)
            # incorporate the effects of subsampling and handlers.scale through a common scale factor
            contracted_costs = [
                model_terms["scale"] * f
                for f in funsor.sum_product.partial_sum_product(
                    funsor.ops.logaddexp,
                    funsor.ops.add,
                    model_terms["log_measures"] + contracted_factors,
                    plates=model_terms["plate_vars"],
                    eliminate=model_terms["measure_vars"])
            ]

            costs = contracted_costs + uncontracted_factors  # model costs: logp
            costs += [-f for f in guide_terms["log_factors"]
                      ]  # guide costs: -logq

            # finally, integrate out guide variables in the elbo and all plates
            plate_vars = guide_terms["plate_vars"] | model_terms["plate_vars"]
            elbo = to_funsor(0, output=funsor.Real)
            for cost in costs:
                # compute the marginal logq in the guide corresponding to this cost term
                log_prob = funsor.sum_product.sum_product(
                    funsor.ops.logaddexp,
                    funsor.ops.add,
                    guide_terms["log_measures"],
                    plates=plate_vars,
                    eliminate=(plate_vars | guide_terms["measure_vars"]) -
                    frozenset(cost.inputs))
                # compute the expected cost term E_q[logp] or E_q[-logq] using the marginal logq for q
                elbo_term = funsor.Integrate(
                    log_prob, cost,
                    guide_terms["measure_vars"] & frozenset(cost.inputs))
                elbo += elbo_term.reduce(funsor.ops.add,
                                         plate_vars & frozenset(cost.inputs))

        # evaluate the elbo, using memoize to share tensor computation where possible
        with funsor.memoize.memoize():
            return -to_data(funsor.optimizer.apply_optimizer(elbo))
示例#3
0
def terms_from_trace(tr):
    """Helper function to extract elbo components from execution traces."""
    # data structure containing densities, measures, scales, and identification
    # of free variables as either product (plate) variables or sum (measure) variables
    terms = {
        "log_factors": [],
        "log_measures": [],
        "scale": to_funsor(1.0),
        "plate_vars": frozenset(),
        "measure_vars": frozenset(),
        "plate_to_step": dict(),
    }
    for name, node in tr.nodes.items():
        # add markov dimensions to the plate_to_step dictionary
        if node["type"] == "markov_chain":
            terms["plate_to_step"][node["name"]] = node["value"]
            # ensure previous step variables are added to measure_vars
            for step in node["value"]:
                terms["measure_vars"] |= frozenset({
                    var
                    for var in step[1:-1] if tr.nodes[var]["funsor"].get(
                        "log_measure", None) is not None
                })
        if (node["type"] != "sample"
                or type(node["fn"]).__name__ == "_Subsample"
                or node["infer"].get("_do_not_score", False)):
            continue
        # grab plate dimensions from the cond_indep_stack
        terms["plate_vars"] |= frozenset(f.name
                                         for f in node["cond_indep_stack"]
                                         if f.vectorized)
        # grab the log-measure, found only at sites that are not replayed or observed
        if node["funsor"].get("log_measure", None) is not None:
            terms["log_measures"].append(node["funsor"]["log_measure"])
            # sum (measure) variables: the fresh non-plate variables at a site
            terms["measure_vars"] |= (frozenset(node["funsor"]["value"].inputs)
                                      | {name}) - terms["plate_vars"]
        # grab the scale, assuming a common subsampling scale
        if (node.get("replay_active", False) and
                set(node["funsor"]["log_prob"].inputs) & terms["measure_vars"]
                and float(to_data(node["funsor"]["scale"])) != 1.0):
            # model site that depends on enumerated variable: common scale
            terms["scale"] = node["funsor"]["scale"]
        else:  # otherwise: default scale behavior
            node["funsor"]["log_prob"] = (node["funsor"]["log_prob"] *
                                          node["funsor"]["scale"])
        # grab the log-density, found at all sites except those that are not replayed
        if node["is_observed"] or not node.get("replay_skipped", False):
            terms["log_factors"].append(node["funsor"]["log_prob"])
    # add plate dimensions to the plate_to_step dictionary
    terms["plate_to_step"].update({
        plate: terms["plate_to_step"].get(plate, {})
        for plate in terms["plate_vars"]
    })
    return terms
示例#4
0
def terms_from_trace(tr):
    """Helper function to extract elbo components from execution traces."""
    # data structure containing densities, measures, scales, and identification
    # of free variables as either product (plate) variables or sum (measure) variables
    terms = {
        "log_factors": [],
        "log_measures": [],
        "scale": to_funsor(1.),
        "plate_vars": frozenset(),
        "measure_vars": frozenset()
    }
    for name, node in tr.nodes.items():
        if node["type"] != "sample" or type(
                node["fn"]).__name__ == "_Subsample":
            continue
        # grab plate dimensions from the cond_indep_stack
        terms["plate_vars"] |= frozenset(f.name
                                         for f in node["cond_indep_stack"]
                                         if f.vectorized)
        # grab the log-measure, found only at sites that are not replayed or observed
        if node["funsor"].get("log_measure", None) is not None:
            terms["log_measures"].append(node["funsor"]["log_measure"])
            # sum (measure) variables: the fresh non-plate variables at a site
            terms["measure_vars"] |= (frozenset(node["funsor"]["value"].inputs)
                                      | {name}) - terms["plate_vars"]
        # grab the scale, assuming a common subsampling scale
        if node.get("replay_active", False) and set(node["funsor"]["log_prob"].inputs) & terms["measure_vars"] and \
                float(to_data(node["funsor"]["scale"])) != 1.:
            # model site that depends on enumerated variable: common scale
            terms["scale"] = node["funsor"]["scale"]
        else:  # otherwise: default scale behavior
            node["funsor"]["log_prob"] = node["funsor"]["log_prob"] * node[
                "funsor"]["scale"]
        # grab the log-density, found at all sites except those that are not replayed
        if node["is_observed"] or not node.get("replay_skipped", False):
            terms["log_factors"].append(node["funsor"]["log_prob"])
    return terms
示例#5
0
    def differentiable_loss(self, model, guide, *args, **kwargs):

        # get batched, enumerated, to_funsor-ed traces from the guide and model
        with plate(
                size=self.num_particles
        ) if self.num_particles > 1 else contextlib.ExitStack(), enum(
                first_available_dim=(-self.max_plate_nesting -
                                     1) if self.max_plate_nesting else None):
            guide_tr = trace(guide).get_trace(*args, **kwargs)
            model_tr = trace(replay(model, trace=guide_tr)).get_trace(
                *args, **kwargs)

        # extract from traces all metadata that we will need to compute the elbo
        guide_terms = terms_from_trace(guide_tr)
        model_terms = terms_from_trace(model_tr)

        # build up a lazy expression for the elbo
        with funsor.terms.lazy:
            # identify and contract out auxiliary variables in the model with partial_sum_product
            contracted_factors, uncontracted_factors = [], []
            for f in model_terms["log_factors"]:
                if model_terms["measure_vars"].intersection(f.inputs):
                    contracted_factors.append(f)
                else:
                    uncontracted_factors.append(f)
            # incorporate the effects of subsampling and handlers.scale through a common scale factor
            contracted_costs = [
                model_terms["scale"] * f
                for f in funsor.sum_product.partial_sum_product(
                    funsor.ops.logaddexp,
                    funsor.ops.add,
                    model_terms["log_measures"] + contracted_factors,
                    plates=model_terms["plate_vars"],
                    eliminate=model_terms["measure_vars"],
                )
            ]

            # accumulate costs from model (logp) and guide (-logq)
            costs = contracted_costs + uncontracted_factors  # model costs: logp
            costs += [-f for f in guide_terms["log_factors"]
                      ]  # guide costs: -logq

            # compute expected cost
            # Cf. pyro.infer.util.Dice.compute_expectation()
            # https://github.com/pyro-ppl/pyro/blob/0.3.0/pyro/infer/util.py#L212
            # TODO Replace this with funsor.Expectation
            plate_vars = guide_terms["plate_vars"] | model_terms["plate_vars"]
            # compute the marginal logq in the guide corresponding to each cost term
            targets = dict()
            for cost in costs:
                input_vars = frozenset(cost.inputs)
                if input_vars not in targets:
                    targets[input_vars] = funsor.Tensor(
                        funsor.ops.new_zeros(
                            funsor.tensor.get_default_prototype(),
                            tuple(v.size for v in cost.inputs.values()),
                        ),
                        cost.inputs,
                        cost.dtype,
                    )
            with AdjointTape() as tape:
                logzq = funsor.sum_product.sum_product(
                    funsor.ops.logaddexp,
                    funsor.ops.add,
                    guide_terms["log_measures"] + list(targets.values()),
                    plates=plate_vars,
                    eliminate=(plate_vars | guide_terms["measure_vars"]),
                )
            marginals = tape.adjoint(funsor.ops.logaddexp, funsor.ops.add,
                                     logzq, tuple(targets.values()))
            # finally, integrate out guide variables in the elbo and all plates
            elbo = to_funsor(0, output=funsor.Real)
            for cost in costs:
                target = targets[frozenset(cost.inputs)]
                logzq_local = marginals[target].reduce(
                    funsor.ops.logaddexp,
                    frozenset(cost.inputs) - plate_vars)
                log_prob = marginals[target] - logzq_local
                elbo_term = funsor.Integrate(
                    log_prob,
                    cost,
                    guide_terms["measure_vars"] & frozenset(log_prob.inputs),
                )
                elbo += elbo_term.reduce(funsor.ops.add,
                                         plate_vars & frozenset(cost.inputs))

        # evaluate the elbo, using memoize to share tensor computation where possible
        with funsor.interpretations.memoize():
            return -to_data(apply_optimizer(elbo))